Files
image_lora_training/train_diffusion.py
2026-03-07 07:13:48 -03:00

218 lines
6.5 KiB
Python

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
from diffusers import StableDiffusionXLPipeline, DDPMScheduler
from peft import LoraConfig, get_peft_model
import torch.nn.functional as F
from peft import get_peft_model_state_dict
from diffusers.utils import convert_state_dict_to_diffusers
# =========================
# CONFIG
# =========================
IMAGES_DIR = Path("dataset/images")
CAPTIONS_DIR = Path("dataset/captions")
OUTPUT_DIR = Path("sdxl_lora")
IMAGE_SIZE = 1024
BATCH_SIZE = 2
EPOCHS = 10
LR = 1e-4
DEVICE = "cuda"
# =========================
# DATASET
# =========================
class ImageCaptionDataset(Dataset):
def __init__(self):
self.images = sorted([p for p in IMAGES_DIR.glob("*") if p.suffix.lower() in [".png", ".jpg", ".jpeg", ".webp"]])
self.transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
# RGB: normalize 3 canais
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
caption_path = CAPTIONS_DIR / (img_path.stem + ".txt")
if not caption_path.exists():
raise FileNotFoundError(f"Caption não encontrada: {caption_path}")
image = Image.open(img_path).convert("RGB")
image = self.transform(image)
caption = caption_path.read_text(encoding="utf-8").strip()
return image, caption
dataset = ImageCaptionDataset()
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
# =========================
# LOAD SDXL
# =========================
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float32,
use_safetensors=True,
).to(DEVICE)
# VAE precisa FP32
pipe.vae.to(dtype=torch.float32)
# UNet FP16
pipe.unet.to(dtype=torch.float16)
# Encoders FP16
pipe.text_encoder.to(dtype=torch.float16)
pipe.text_encoder_2.to(dtype=torch.float16)
# pipe.enable_xformers_memory_efficient_attention()
tokenizer_1 = pipe.tokenizer
tokenizer_2 = pipe.tokenizer_2
text_encoder_1 = pipe.text_encoder
text_encoder_2 = pipe.text_encoder_2
vae = pipe.vae
unet = pipe.unet
scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
# Congelar VAE e text encoders (treina só o LoRA no UNet)
vae.eval()
vae.requires_grad_(False)
text_encoder_1.eval()
text_encoder_1.requires_grad_(False)
text_encoder_2.eval()
text_encoder_2.requires_grad_(False)
# =========================
# LoRA CONFIG (UNet)
# =========================
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
lora_dropout=0.1,
bias="none",
)
unet = get_peft_model(unet, lora_config)
unet.train()
optimizer = torch.optim.AdamW(unet.parameters(), lr=LR)
# =========================
# Helpers SDXL conditioning
# =========================
def encode_prompt_sdxl(captions):
# encoder 1
inputs_1 = tokenizer_1(
captions,
padding="max_length",
truncation=True,
max_length=tokenizer_1.model_max_length,
return_tensors="pt",
)
# encoder 2
inputs_2 = tokenizer_2(
captions,
padding="max_length",
truncation=True,
max_length=tokenizer_2.model_max_length,
return_tensors="pt",
)
input_ids_1 = inputs_1.input_ids.to(DEVICE)
input_ids_2 = inputs_2.input_ids.to(DEVICE)
with torch.no_grad():
out_1 = text_encoder_1(input_ids_1, output_hidden_states=True)
out_2 = text_encoder_2(input_ids_2, output_hidden_states=True)
# SDXL usa penúltima hidden layer como token embeddings
prompt_embeds_1 = out_1.hidden_states[-2]
prompt_embeds_2 = out_2.hidden_states[-2]
# concatena embeddings dos dois encoders no eixo de features
prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1)
# pooled embed: normalmente vem do text_encoder_2
pooled_prompt_embeds = out_2.hidden_states[-1][:, 0]
return prompt_embeds, pooled_prompt_embeds
def make_time_ids(batch_size):
# Micro-conditioning do SDXL:
# [orig_h, orig_w, crop_y, crop_x, target_h, target_w]
return torch.tensor(
[[IMAGE_SIZE, IMAGE_SIZE, 0, 0, IMAGE_SIZE, IMAGE_SIZE]],
device=DEVICE,
dtype=torch.float16,
).repeat(batch_size, 1)
# =========================
# TRAIN
# =========================
for epoch in range(EPOCHS):
for step, (images, captions) in enumerate(dataloader):
# images = images.to(DEVICE, dtype=torch.float16)
images = images.to(DEVICE, dtype=torch.float32)
vae.requires_grad_(False)
vae.eval()
vae.to(dtype=torch.float32)
# VAE encode -> latents
with torch.no_grad():
latents = vae.encode(images).latent_dist.sample()
latents = latents * vae.config.scaling_factor
latents = latents.to(torch.float16)
# noise + timestep
noise = torch.randn_like(latents)
timesteps = torch.randint(
0,
scheduler.config.num_train_timesteps,
(latents.shape[0],),
device=DEVICE,
).long()
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
# SDXL conditioning
prompt_embeds, pooled_prompt_embeds = encode_prompt_sdxl(captions)
time_ids = make_time_ids(latents.shape[0])
# UNet predict
noise_pred = unet(
noisy_latents,
timesteps,
encoder_hidden_states=prompt_embeds,
added_cond_kwargs={
"text_embeds": pooled_prompt_embeds,
"time_ids": time_ids,
},
).sample
loss = F.mse_loss(noise_pred.float(), noise.float())
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
if step % 10 == 0:
print(f"Epoch {epoch} Step {step} Loss {loss.item():.4f}")
# =========================
# SAVE LoRA CORRETO
# =========================
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
unet.save_pretrained("sdxl_lora")
print("LoRA saved correctly.")