mirror of
https://github.com/hoshikawa2/image_lora_training.git
synced 2026-03-11 17:14:57 +00:00
218 lines
6.5 KiB
Python
218 lines
6.5 KiB
Python
import torch
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torchvision import transforms
|
|
from PIL import Image
|
|
from pathlib import Path
|
|
|
|
from diffusers import StableDiffusionXLPipeline, DDPMScheduler
|
|
from peft import LoraConfig, get_peft_model
|
|
import torch.nn.functional as F
|
|
from peft import get_peft_model_state_dict
|
|
from diffusers.utils import convert_state_dict_to_diffusers
|
|
|
|
# =========================
|
|
# CONFIG
|
|
# =========================
|
|
IMAGES_DIR = Path("dataset/images")
|
|
CAPTIONS_DIR = Path("dataset/captions")
|
|
OUTPUT_DIR = Path("sdxl_lora")
|
|
|
|
IMAGE_SIZE = 1024
|
|
BATCH_SIZE = 2
|
|
EPOCHS = 10
|
|
LR = 1e-4
|
|
DEVICE = "cuda"
|
|
|
|
# =========================
|
|
# DATASET
|
|
# =========================
|
|
class ImageCaptionDataset(Dataset):
|
|
def __init__(self):
|
|
self.images = sorted([p for p in IMAGES_DIR.glob("*") if p.suffix.lower() in [".png", ".jpg", ".jpeg", ".webp"]])
|
|
self.transform = transforms.Compose([
|
|
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
|
transforms.ToTensor(),
|
|
# RGB: normalize 3 canais
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
|
])
|
|
|
|
def __len__(self):
|
|
return len(self.images)
|
|
|
|
def __getitem__(self, idx):
|
|
img_path = self.images[idx]
|
|
caption_path = CAPTIONS_DIR / (img_path.stem + ".txt")
|
|
if not caption_path.exists():
|
|
raise FileNotFoundError(f"Caption não encontrada: {caption_path}")
|
|
|
|
image = Image.open(img_path).convert("RGB")
|
|
image = self.transform(image)
|
|
caption = caption_path.read_text(encoding="utf-8").strip()
|
|
return image, caption
|
|
|
|
dataset = ImageCaptionDataset()
|
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
|
|
|
|
# =========================
|
|
# LOAD SDXL
|
|
# =========================
|
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
"stabilityai/stable-diffusion-xl-base-1.0",
|
|
torch_dtype=torch.float32,
|
|
use_safetensors=True,
|
|
).to(DEVICE)
|
|
# VAE precisa FP32
|
|
pipe.vae.to(dtype=torch.float32)
|
|
|
|
# UNet FP16
|
|
pipe.unet.to(dtype=torch.float16)
|
|
|
|
# Encoders FP16
|
|
pipe.text_encoder.to(dtype=torch.float16)
|
|
pipe.text_encoder_2.to(dtype=torch.float16)
|
|
|
|
# pipe.enable_xformers_memory_efficient_attention()
|
|
|
|
tokenizer_1 = pipe.tokenizer
|
|
tokenizer_2 = pipe.tokenizer_2
|
|
text_encoder_1 = pipe.text_encoder
|
|
text_encoder_2 = pipe.text_encoder_2
|
|
vae = pipe.vae
|
|
unet = pipe.unet
|
|
|
|
scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
|
|
|
|
# Congelar VAE e text encoders (treina só o LoRA no UNet)
|
|
vae.eval()
|
|
vae.requires_grad_(False)
|
|
text_encoder_1.eval()
|
|
text_encoder_1.requires_grad_(False)
|
|
text_encoder_2.eval()
|
|
text_encoder_2.requires_grad_(False)
|
|
|
|
# =========================
|
|
# LoRA CONFIG (UNet)
|
|
# =========================
|
|
lora_config = LoraConfig(
|
|
r=8,
|
|
lora_alpha=16,
|
|
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
|
lora_dropout=0.1,
|
|
bias="none",
|
|
)
|
|
|
|
unet = get_peft_model(unet, lora_config)
|
|
unet.train()
|
|
|
|
optimizer = torch.optim.AdamW(unet.parameters(), lr=LR)
|
|
|
|
# =========================
|
|
# Helpers SDXL conditioning
|
|
# =========================
|
|
def encode_prompt_sdxl(captions):
|
|
# encoder 1
|
|
inputs_1 = tokenizer_1(
|
|
captions,
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=tokenizer_1.model_max_length,
|
|
return_tensors="pt",
|
|
)
|
|
# encoder 2
|
|
inputs_2 = tokenizer_2(
|
|
captions,
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=tokenizer_2.model_max_length,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
input_ids_1 = inputs_1.input_ids.to(DEVICE)
|
|
input_ids_2 = inputs_2.input_ids.to(DEVICE)
|
|
|
|
with torch.no_grad():
|
|
out_1 = text_encoder_1(input_ids_1, output_hidden_states=True)
|
|
out_2 = text_encoder_2(input_ids_2, output_hidden_states=True)
|
|
|
|
# SDXL usa penúltima hidden layer como token embeddings
|
|
prompt_embeds_1 = out_1.hidden_states[-2]
|
|
prompt_embeds_2 = out_2.hidden_states[-2]
|
|
|
|
# concatena embeddings dos dois encoders no eixo de features
|
|
prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1)
|
|
|
|
# pooled embed: normalmente vem do text_encoder_2
|
|
pooled_prompt_embeds = out_2.hidden_states[-1][:, 0]
|
|
|
|
return prompt_embeds, pooled_prompt_embeds
|
|
|
|
def make_time_ids(batch_size):
|
|
# Micro-conditioning do SDXL:
|
|
# [orig_h, orig_w, crop_y, crop_x, target_h, target_w]
|
|
return torch.tensor(
|
|
[[IMAGE_SIZE, IMAGE_SIZE, 0, 0, IMAGE_SIZE, IMAGE_SIZE]],
|
|
device=DEVICE,
|
|
dtype=torch.float16,
|
|
).repeat(batch_size, 1)
|
|
|
|
# =========================
|
|
# TRAIN
|
|
# =========================
|
|
for epoch in range(EPOCHS):
|
|
for step, (images, captions) in enumerate(dataloader):
|
|
# images = images.to(DEVICE, dtype=torch.float16)
|
|
images = images.to(DEVICE, dtype=torch.float32)
|
|
vae.requires_grad_(False)
|
|
vae.eval()
|
|
vae.to(dtype=torch.float32)
|
|
|
|
# VAE encode -> latents
|
|
with torch.no_grad():
|
|
latents = vae.encode(images).latent_dist.sample()
|
|
latents = latents * vae.config.scaling_factor
|
|
latents = latents.to(torch.float16)
|
|
|
|
# noise + timestep
|
|
noise = torch.randn_like(latents)
|
|
timesteps = torch.randint(
|
|
0,
|
|
scheduler.config.num_train_timesteps,
|
|
(latents.shape[0],),
|
|
device=DEVICE,
|
|
).long()
|
|
|
|
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
|
|
|
# SDXL conditioning
|
|
prompt_embeds, pooled_prompt_embeds = encode_prompt_sdxl(captions)
|
|
time_ids = make_time_ids(latents.shape[0])
|
|
|
|
# UNet predict
|
|
noise_pred = unet(
|
|
noisy_latents,
|
|
timesteps,
|
|
encoder_hidden_states=prompt_embeds,
|
|
added_cond_kwargs={
|
|
"text_embeds": pooled_prompt_embeds,
|
|
"time_ids": time_ids,
|
|
},
|
|
).sample
|
|
|
|
loss = F.mse_loss(noise_pred.float(), noise.float())
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
if step % 10 == 0:
|
|
print(f"Epoch {epoch} Step {step} Loss {loss.item():.4f}")
|
|
|
|
# =========================
|
|
# SAVE LoRA
|
|
# =========================
|
|
|
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
unet.save_pretrained("sdxl_lora")
|
|
|
|
print("LoRA saved correctly.") |