mirror of
https://github.com/hoshikawa2/image_lora_training.git
synced 2026-03-11 17:14:57 +00:00
first commit
This commit is contained in:
441
README.md
Normal file
441
README.md
Normal file
@@ -0,0 +1,441 @@
|
||||
# Training Image Generation Models with Diffusion (Stable Diffusion XL + LoRA)
|
||||
|
||||
## Introduction
|
||||
|
||||
Image generation using diffusion models has become one of the most
|
||||
transformative capabilities of modern artificial intelligence. These
|
||||
models are capable of generating high‑quality images from natural
|
||||
language descriptions, enabling applications across multiple industries.
|
||||
|
||||
Real-world use cases include:
|
||||
|
||||
- **Marketing and advertising** -- generating visual assets
|
||||
automatically.
|
||||
- **Software documentation and presentations** -- producing diagrams
|
||||
and illustrations for technical content.
|
||||
- **Game development** -- generating textures, characters, and
|
||||
environments.
|
||||
- **Product design** -- visualizing concepts before prototypes exist.
|
||||
- **Enterprise automation** -- generating architecture diagrams or
|
||||
slide illustrations automatically.
|
||||
|
||||
In enterprise environments, diffusion models can be integrated into
|
||||
automated pipelines. For example, a presentation generator can
|
||||
automatically produce slide images that represent architectural concepts
|
||||
such as:
|
||||
|
||||
- API Gateways
|
||||
- Databases
|
||||
- Integration architectures
|
||||
- Cloud infrastructures
|
||||
|
||||
Instead of manually creating diagrams, an AI pipeline can generate them
|
||||
dynamically from structured prompts.
|
||||
|
||||
This tutorial explains how to train and use a **Stable Diffusion XL
|
||||
model with LoRA fine‑tuning**, and how to deploy the model to generate
|
||||
images programmatically.
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# Technologies Involved
|
||||
|
||||
## Diffusion Models
|
||||
|
||||
Diffusion models generate images by **iteratively denoising random
|
||||
noise**. The training process teaches a neural network how to reverse a
|
||||
noise process applied to real images.
|
||||
|
||||
The process works as follows:
|
||||
|
||||
1. Start with a real image
|
||||
2. Gradually add noise until the image becomes pure noise
|
||||
3. Train a model to reverse this process
|
||||
4. During inference, start with noise and iteratively remove it
|
||||
|
||||
This allows the model to synthesize new images from text descriptions.
|
||||
|
||||
Popular diffusion models include:
|
||||
|
||||
- Stable Diffusion
|
||||
- Stable Diffusion XL (SDXL)
|
||||
- DALL‑E
|
||||
- Imagen
|
||||
|
||||
In this tutorial we use **Stable Diffusion XL**, which provides:
|
||||
|
||||
- Higher resolution
|
||||
- Better text understanding
|
||||
- Dual text encoders
|
||||
- Micro‑conditioning
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
## Stable Diffusion XL (SDXL)
|
||||
|
||||
SDXL is an advanced diffusion architecture that improves generation
|
||||
quality through:
|
||||
|
||||
- **Two text encoders**
|
||||
- **Improved conditioning**
|
||||
- **Higher resolution generation**
|
||||
- **Better prompt interpretation**
|
||||
|
||||
Unlike earlier diffusion models, SDXL requires:
|
||||
|
||||
- two tokenizers
|
||||
- two text encoders
|
||||
- pooled embeddings
|
||||
- time conditioning parameters
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
## LoRA (Low Rank Adaptation)
|
||||
|
||||
Training diffusion models from scratch is extremely expensive.
|
||||
|
||||
Instead, **LoRA** allows fine‑tuning large models efficiently by
|
||||
training small low‑rank matrices that modify the attention layers of the
|
||||
network.
|
||||
|
||||
Advantages:
|
||||
|
||||
- Very small training footprint
|
||||
- Works with limited VRAM
|
||||
- Easy to merge into the base model
|
||||
- Fast training
|
||||
|
||||
In this project, LoRA is applied to the **UNet attention layers**.
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
## HuggingFace Diffusers
|
||||
|
||||
The **Diffusers library** provides a high‑level API for working with
|
||||
diffusion models.
|
||||
|
||||
It includes:
|
||||
|
||||
- pipelines
|
||||
- schedulers
|
||||
- training utilities
|
||||
- optimization helpers
|
||||
|
||||
Main components used:
|
||||
|
||||
- `StableDiffusionXLPipeline`
|
||||
- `DDPMScheduler`
|
||||
- `DPMSolverMultistepScheduler`
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
## PyTorch
|
||||
|
||||
PyTorch is used for:
|
||||
|
||||
- training loops
|
||||
- GPU acceleration
|
||||
- tensor operations
|
||||
- neural network execution
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# Code Walkthrough
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
The training script expects a dataset structured as:
|
||||
|
||||
dataset/
|
||||
images/
|
||||
image1.png
|
||||
image2.png
|
||||
captions/
|
||||
image1.txt
|
||||
image2.txt
|
||||
|
||||
Each caption describes the image.
|
||||
|
||||
Example caption:
|
||||
|
||||
enterprise cloud architecture diagram with API gateway and database
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# Dataset Loader
|
||||
|
||||
The dataset loader reads images and their captions.
|
||||
|
||||
Key operations:
|
||||
|
||||
- resizing images
|
||||
- converting to tensors
|
||||
- normalization
|
||||
|
||||
Important section:
|
||||
|
||||
``` python
|
||||
transforms.Compose([
|
||||
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
|
||||
])
|
||||
```
|
||||
|
||||
Normalization is important because diffusion models expect images in a
|
||||
**\[-1,1\] range**.
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# Prompt Encoding
|
||||
|
||||
SDXL uses **two text encoders**.
|
||||
|
||||
The function:
|
||||
|
||||
encode_prompt_sdxl()
|
||||
|
||||
performs:
|
||||
|
||||
1. tokenization of captions
|
||||
2. embedding generation
|
||||
3. concatenation of embeddings
|
||||
|
||||
Important concept:
|
||||
|
||||
prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1)
|
||||
|
||||
This merges both encoders into a single conditioning representation.
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# Latent Encoding
|
||||
|
||||
Images are encoded into latent space using the **VAE**.
|
||||
|
||||
latents = vae.encode(images).latent_dist.sample()
|
||||
|
||||
The VAE compresses images before training the diffusion process.
|
||||
|
||||
This drastically reduces memory usage.
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# Noise Training
|
||||
|
||||
Diffusion training consists of predicting noise added to images.
|
||||
|
||||
noise = torch.randn_like(latents)
|
||||
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
The model learns to predict this noise.
|
||||
|
||||
Loss function:
|
||||
|
||||
loss = F.mse_loss(noise_pred.float(), noise.float())
|
||||
|
||||
This is the standard diffusion training loss.
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# LoRA Configuration
|
||||
|
||||
LoRA modifies attention layers of the UNet.
|
||||
|
||||
LoraConfig(
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["to_q","to_k","to_v","to_out.0"]
|
||||
)
|
||||
|
||||
Key parameters:
|
||||
|
||||
Parameter Description
|
||||
---------------- ---------------------------
|
||||
r rank of adaptation
|
||||
alpha scaling factor
|
||||
target_modules attention layers to adapt
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# Training Loop
|
||||
|
||||
Main training steps:
|
||||
|
||||
1. Encode image into latent space
|
||||
2. Add noise
|
||||
3. Encode text prompt
|
||||
4. Predict noise with UNet
|
||||
5. Compute loss
|
||||
6. Backpropagate
|
||||
|
||||
The loop runs for multiple epochs:
|
||||
|
||||
for epoch in range(EPOCHS):
|
||||
for step,(images,captions) in enumerate(dataloader):
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# Saving the LoRA Model
|
||||
|
||||
After training:
|
||||
|
||||
unet.save_pretrained("sdxl_lora")
|
||||
|
||||
The LoRA weights can later be merged into the base model.
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# Image Generation Pipeline
|
||||
|
||||
The generation script loads:
|
||||
|
||||
- SDXL base model
|
||||
- trained LoRA
|
||||
- optimized scheduler
|
||||
|
||||
Key configuration:
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0"
|
||||
)
|
||||
|
||||
Scheduler optimization:
|
||||
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
This significantly accelerates generation.
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# Memory Optimization
|
||||
|
||||
To support Mac M‑series GPUs or limited VRAM:
|
||||
|
||||
pipe.enable_attention_slicing()
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
These techniques reduce peak memory usage.
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# Prompt Caching
|
||||
|
||||
Images are cached using a hash of the prompt.
|
||||
|
||||
hashlib.sha256(prompt.encode()).hexdigest()
|
||||
|
||||
This prevents regenerating identical images repeatedly.
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# Image Generation
|
||||
|
||||
Image generation call:
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
negative_prompt=NEGATIVE,
|
||||
num_inference_steps=25,
|
||||
guidance_scale=7,
|
||||
height=1024,
|
||||
width=1024
|
||||
).images[0]
|
||||
|
||||
Important parameters:
|
||||
|
||||
Parameter Meaning
|
||||
--------------------- ----------------------
|
||||
num_inference_steps diffusion iterations
|
||||
guidance_scale prompt strength
|
||||
height / width image resolution
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# Deployment
|
||||
|
||||
## Install Dependencies
|
||||
|
||||
pip install torch diffusers transformers peft accelerate pillow torchvision
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# Training
|
||||
|
||||
Run:
|
||||
|
||||
python train_diffusion.py
|
||||
|
||||
Training will produce:
|
||||
|
||||
sdxl_lora/
|
||||
|
||||
containing LoRA weights.
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# Running the Generator
|
||||
|
||||
python generate_image3.py
|
||||
|
||||
Example prompt:
|
||||
|
||||
clean enterprise illustration,
|
||||
corporate presentation slide,
|
||||
minimal design,
|
||||
white background,
|
||||
oracle integration architecture
|
||||
|
||||
Generated images are saved in:
|
||||
|
||||
images/cache/
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# Testing
|
||||
|
||||
You can test generation with different prompts:
|
||||
|
||||
Example:
|
||||
|
||||
enterprise microservices architecture diagram
|
||||
|
||||
or
|
||||
|
||||
cloud integration architecture with api gateway
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# Conclusion
|
||||
|
||||
Diffusion models are revolutionizing image generation by allowing AI
|
||||
systems to synthesize visual content from text descriptions.
|
||||
|
||||
Using Stable Diffusion XL combined with LoRA fine‑tuning enables:
|
||||
|
||||
- efficient training
|
||||
- domain specialization
|
||||
- enterprise use cases
|
||||
- automated content generation
|
||||
|
||||
In practical systems, diffusion models can be integrated into larger
|
||||
pipelines such as:
|
||||
|
||||
- automated presentation builders
|
||||
- documentation systems
|
||||
- AI‑generated diagrams
|
||||
- marketing automation platforms
|
||||
|
||||
With efficient techniques like LoRA, high‑quality image generation is
|
||||
now accessible even on consumer hardware such as:
|
||||
|
||||
- RTX GPUs
|
||||
- Apple Silicon
|
||||
- workstation GPUs
|
||||
|
||||
As diffusion architectures continue evolving, they will increasingly
|
||||
become a core component of AI‑driven content generation systems.
|
||||
|
||||
# Acknowledgments
|
||||
|
||||
- **Author** - Cristiano Hoshikawa (Oracle LAD A-Team Solution Engineer)
|
||||
69
generate_image3.py
Normal file
69
generate_image3.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from peft import PeftModel
|
||||
|
||||
|
||||
from project.generate_image import negative_prompt
|
||||
|
||||
CACHE_DIR = Path("images/cache")
|
||||
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
dtype=torch.float32
|
||||
).to("cuda")
|
||||
pipe.vae.to(torch.float32)
|
||||
pipe.unet = PeftModel.from_pretrained(pipe.unet, "sdxl_lora")
|
||||
pipe.unet = pipe.unet.merge_and_unload()
|
||||
pipe.unet.eval()
|
||||
|
||||
NEGATIVE = """
|
||||
text,
|
||||
watermark,
|
||||
people,
|
||||
photo,
|
||||
photorealistic,
|
||||
complex background
|
||||
"""
|
||||
|
||||
def _hash_prompt(prompt: str) -> str:
|
||||
return hashlib.sha256(prompt.encode()).hexdigest()[:16]
|
||||
|
||||
|
||||
def generate_slide_image(prompt: str) -> str:
|
||||
|
||||
h = _hash_prompt(prompt)
|
||||
image_path = CACHE_DIR / f"{h}.png"
|
||||
|
||||
if image_path.exists():
|
||||
print("Image cache hit")
|
||||
return str(image_path)
|
||||
|
||||
print("Generating image...")
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
negative_prompt=NEGATIVE,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=7,
|
||||
height=1024,
|
||||
width=1024
|
||||
).images[0]
|
||||
|
||||
image.save(image_path)
|
||||
|
||||
return str(image_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
prompt = """
|
||||
clean enterprise illustration,
|
||||
corporate presentation slide,
|
||||
minimal design,
|
||||
white background,
|
||||
blue gradient accents,
|
||||
expose oracle integration connected to autonomous database and exposing through api gateway
|
||||
"""
|
||||
generate_slide_image(prompt)
|
||||
218
train_diffusion.py
Normal file
218
train_diffusion.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
|
||||
from diffusers import StableDiffusionXLPipeline, DDPMScheduler
|
||||
from peft import LoraConfig, get_peft_model
|
||||
import torch.nn.functional as F
|
||||
from peft import get_peft_model_state_dict
|
||||
from diffusers.utils import convert_state_dict_to_diffusers
|
||||
|
||||
# =========================
|
||||
# CONFIG
|
||||
# =========================
|
||||
IMAGES_DIR = Path("dataset/images")
|
||||
CAPTIONS_DIR = Path("dataset/captions")
|
||||
OUTPUT_DIR = Path("sdxl_lora")
|
||||
|
||||
IMAGE_SIZE = 1024
|
||||
BATCH_SIZE = 2
|
||||
EPOCHS = 10
|
||||
LR = 1e-4
|
||||
DEVICE = "cuda"
|
||||
|
||||
# =========================
|
||||
# DATASET
|
||||
# =========================
|
||||
class ImageCaptionDataset(Dataset):
|
||||
def __init__(self):
|
||||
self.images = sorted([p for p in IMAGES_DIR.glob("*") if p.suffix.lower() in [".png", ".jpg", ".jpeg", ".webp"]])
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
||||
transforms.ToTensor(),
|
||||
# RGB: normalize 3 canais
|
||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = self.images[idx]
|
||||
caption_path = CAPTIONS_DIR / (img_path.stem + ".txt")
|
||||
if not caption_path.exists():
|
||||
raise FileNotFoundError(f"Caption não encontrada: {caption_path}")
|
||||
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
image = self.transform(image)
|
||||
caption = caption_path.read_text(encoding="utf-8").strip()
|
||||
return image, caption
|
||||
|
||||
dataset = ImageCaptionDataset()
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
|
||||
|
||||
# =========================
|
||||
# LOAD SDXL
|
||||
# =========================
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float32,
|
||||
use_safetensors=True,
|
||||
).to(DEVICE)
|
||||
# VAE precisa FP32
|
||||
pipe.vae.to(dtype=torch.float32)
|
||||
|
||||
# UNet FP16
|
||||
pipe.unet.to(dtype=torch.float16)
|
||||
|
||||
# Encoders FP16
|
||||
pipe.text_encoder.to(dtype=torch.float16)
|
||||
pipe.text_encoder_2.to(dtype=torch.float16)
|
||||
|
||||
# pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
tokenizer_1 = pipe.tokenizer
|
||||
tokenizer_2 = pipe.tokenizer_2
|
||||
text_encoder_1 = pipe.text_encoder
|
||||
text_encoder_2 = pipe.text_encoder_2
|
||||
vae = pipe.vae
|
||||
unet = pipe.unet
|
||||
|
||||
scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
# Congelar VAE e text encoders (treina só o LoRA no UNet)
|
||||
vae.eval()
|
||||
vae.requires_grad_(False)
|
||||
text_encoder_1.eval()
|
||||
text_encoder_1.requires_grad_(False)
|
||||
text_encoder_2.eval()
|
||||
text_encoder_2.requires_grad_(False)
|
||||
|
||||
# =========================
|
||||
# LoRA CONFIG (UNet)
|
||||
# =========================
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
lora_dropout=0.1,
|
||||
bias="none",
|
||||
)
|
||||
|
||||
unet = get_peft_model(unet, lora_config)
|
||||
unet.train()
|
||||
|
||||
optimizer = torch.optim.AdamW(unet.parameters(), lr=LR)
|
||||
|
||||
# =========================
|
||||
# Helpers SDXL conditioning
|
||||
# =========================
|
||||
def encode_prompt_sdxl(captions):
|
||||
# encoder 1
|
||||
inputs_1 = tokenizer_1(
|
||||
captions,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=tokenizer_1.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
# encoder 2
|
||||
inputs_2 = tokenizer_2(
|
||||
captions,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=tokenizer_2.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids_1 = inputs_1.input_ids.to(DEVICE)
|
||||
input_ids_2 = inputs_2.input_ids.to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
out_1 = text_encoder_1(input_ids_1, output_hidden_states=True)
|
||||
out_2 = text_encoder_2(input_ids_2, output_hidden_states=True)
|
||||
|
||||
# SDXL usa penúltima hidden layer como token embeddings
|
||||
prompt_embeds_1 = out_1.hidden_states[-2]
|
||||
prompt_embeds_2 = out_2.hidden_states[-2]
|
||||
|
||||
# concatena embeddings dos dois encoders no eixo de features
|
||||
prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1)
|
||||
|
||||
# pooled embed: normalmente vem do text_encoder_2
|
||||
pooled_prompt_embeds = out_2.hidden_states[-1][:, 0]
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
def make_time_ids(batch_size):
|
||||
# Micro-conditioning do SDXL:
|
||||
# [orig_h, orig_w, crop_y, crop_x, target_h, target_w]
|
||||
return torch.tensor(
|
||||
[[IMAGE_SIZE, IMAGE_SIZE, 0, 0, IMAGE_SIZE, IMAGE_SIZE]],
|
||||
device=DEVICE,
|
||||
dtype=torch.float16,
|
||||
).repeat(batch_size, 1)
|
||||
|
||||
# =========================
|
||||
# TRAIN
|
||||
# =========================
|
||||
for epoch in range(EPOCHS):
|
||||
for step, (images, captions) in enumerate(dataloader):
|
||||
# images = images.to(DEVICE, dtype=torch.float16)
|
||||
images = images.to(DEVICE, dtype=torch.float32)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
# VAE encode -> latents
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(images).latent_dist.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
latents = latents.to(torch.float16)
|
||||
|
||||
# noise + timestep
|
||||
noise = torch.randn_like(latents)
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
scheduler.config.num_train_timesteps,
|
||||
(latents.shape[0],),
|
||||
device=DEVICE,
|
||||
).long()
|
||||
|
||||
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# SDXL conditioning
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt_sdxl(captions)
|
||||
time_ids = make_time_ids(latents.shape[0])
|
||||
|
||||
# UNet predict
|
||||
noise_pred = unet(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
added_cond_kwargs={
|
||||
"text_embeds": pooled_prompt_embeds,
|
||||
"time_ids": time_ids,
|
||||
},
|
||||
).sample
|
||||
|
||||
loss = F.mse_loss(noise_pred.float(), noise.float())
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if step % 10 == 0:
|
||||
print(f"Epoch {epoch} Step {step} Loss {loss.item():.4f}")
|
||||
|
||||
# =========================
|
||||
# SAVE LoRA CORRETO
|
||||
# =========================
|
||||
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
unet.save_pretrained("sdxl_lora")
|
||||
|
||||
print("LoRA saved correctly.")
|
||||
Reference in New Issue
Block a user