import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image from pathlib import Path from diffusers import StableDiffusionXLPipeline, DDPMScheduler from peft import LoraConfig, get_peft_model import torch.nn.functional as F from peft import get_peft_model_state_dict from diffusers.utils import convert_state_dict_to_diffusers # ========================= # CONFIG # ========================= IMAGES_DIR = Path("dataset/images") CAPTIONS_DIR = Path("dataset/captions") OUTPUT_DIR = Path("sdxl_lora") IMAGE_SIZE = 1024 BATCH_SIZE = 2 EPOCHS = 10 LR = 1e-4 DEVICE = "cuda" # ========================= # DATASET # ========================= class ImageCaptionDataset(Dataset): def __init__(self): self.images = sorted([p for p in IMAGES_DIR.glob("*") if p.suffix.lower() in [".png", ".jpg", ".jpeg", ".webp"]]) self.transform = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), # RGB: normalize 3 canais transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) def __len__(self): return len(self.images) def __getitem__(self, idx): img_path = self.images[idx] caption_path = CAPTIONS_DIR / (img_path.stem + ".txt") if not caption_path.exists(): raise FileNotFoundError(f"Caption não encontrada: {caption_path}") image = Image.open(img_path).convert("RGB") image = self.transform(image) caption = caption_path.read_text(encoding="utf-8").strip() return image, caption dataset = ImageCaptionDataset() dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True) # ========================= # LOAD SDXL # ========================= pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float32, use_safetensors=True, ).to(DEVICE) # VAE precisa FP32 pipe.vae.to(dtype=torch.float32) # UNet FP16 pipe.unet.to(dtype=torch.float16) # Encoders FP16 pipe.text_encoder.to(dtype=torch.float16) pipe.text_encoder_2.to(dtype=torch.float16) # pipe.enable_xformers_memory_efficient_attention() tokenizer_1 = pipe.tokenizer tokenizer_2 = pipe.tokenizer_2 text_encoder_1 = pipe.text_encoder text_encoder_2 = pipe.text_encoder_2 vae = pipe.vae unet = pipe.unet scheduler = DDPMScheduler.from_config(pipe.scheduler.config) # Congelar VAE e text encoders (treina só o LoRA no UNet) vae.eval() vae.requires_grad_(False) text_encoder_1.eval() text_encoder_1.requires_grad_(False) text_encoder_2.eval() text_encoder_2.requires_grad_(False) # ========================= # LoRA CONFIG (UNet) # ========================= lora_config = LoraConfig( r=8, lora_alpha=16, target_modules=["to_q", "to_k", "to_v", "to_out.0"], lora_dropout=0.1, bias="none", ) unet = get_peft_model(unet, lora_config) unet.train() optimizer = torch.optim.AdamW(unet.parameters(), lr=LR) # ========================= # Helpers SDXL conditioning # ========================= def encode_prompt_sdxl(captions): # encoder 1 inputs_1 = tokenizer_1( captions, padding="max_length", truncation=True, max_length=tokenizer_1.model_max_length, return_tensors="pt", ) # encoder 2 inputs_2 = tokenizer_2( captions, padding="max_length", truncation=True, max_length=tokenizer_2.model_max_length, return_tensors="pt", ) input_ids_1 = inputs_1.input_ids.to(DEVICE) input_ids_2 = inputs_2.input_ids.to(DEVICE) with torch.no_grad(): out_1 = text_encoder_1(input_ids_1, output_hidden_states=True) out_2 = text_encoder_2(input_ids_2, output_hidden_states=True) # SDXL usa penúltima hidden layer como token embeddings prompt_embeds_1 = out_1.hidden_states[-2] prompt_embeds_2 = out_2.hidden_states[-2] # concatena embeddings dos dois encoders no eixo de features prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1) # pooled embed: normalmente vem do text_encoder_2 pooled_prompt_embeds = out_2.hidden_states[-1][:, 0] return prompt_embeds, pooled_prompt_embeds def make_time_ids(batch_size): # Micro-conditioning do SDXL: # [orig_h, orig_w, crop_y, crop_x, target_h, target_w] return torch.tensor( [[IMAGE_SIZE, IMAGE_SIZE, 0, 0, IMAGE_SIZE, IMAGE_SIZE]], device=DEVICE, dtype=torch.float16, ).repeat(batch_size, 1) # ========================= # TRAIN # ========================= for epoch in range(EPOCHS): for step, (images, captions) in enumerate(dataloader): # images = images.to(DEVICE, dtype=torch.float16) images = images.to(DEVICE, dtype=torch.float32) vae.requires_grad_(False) vae.eval() vae.to(dtype=torch.float32) # VAE encode -> latents with torch.no_grad(): latents = vae.encode(images).latent_dist.sample() latents = latents * vae.config.scaling_factor latents = latents.to(torch.float16) # noise + timestep noise = torch.randn_like(latents) timesteps = torch.randint( 0, scheduler.config.num_train_timesteps, (latents.shape[0],), device=DEVICE, ).long() noisy_latents = scheduler.add_noise(latents, noise, timesteps) # SDXL conditioning prompt_embeds, pooled_prompt_embeds = encode_prompt_sdxl(captions) time_ids = make_time_ids(latents.shape[0]) # UNet predict noise_pred = unet( noisy_latents, timesteps, encoder_hidden_states=prompt_embeds, added_cond_kwargs={ "text_embeds": pooled_prompt_embeds, "time_ids": time_ids, }, ).sample loss = F.mse_loss(noise_pred.float(), noise.float()) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() if step % 10 == 0: print(f"Epoch {epoch} Step {step} Loss {loss.item():.4f}") # ========================= # SAVE LoRA # ========================= OUTPUT_DIR.mkdir(parents=True, exist_ok=True) unet.save_pretrained("sdxl_lora") print("LoRA saved correctly.")