mirror of
https://github.com/hoshikawa2/image_lora_training.git
synced 2026-03-11 17:14:57 +00:00
first commit
This commit is contained in:
441
README.md
Normal file
441
README.md
Normal file
@@ -0,0 +1,441 @@
|
|||||||
|
# Training Image Generation Models with Diffusion (Stable Diffusion XL + LoRA)
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
Image generation using diffusion models has become one of the most
|
||||||
|
transformative capabilities of modern artificial intelligence. These
|
||||||
|
models are capable of generating high‑quality images from natural
|
||||||
|
language descriptions, enabling applications across multiple industries.
|
||||||
|
|
||||||
|
Real-world use cases include:
|
||||||
|
|
||||||
|
- **Marketing and advertising** -- generating visual assets
|
||||||
|
automatically.
|
||||||
|
- **Software documentation and presentations** -- producing diagrams
|
||||||
|
and illustrations for technical content.
|
||||||
|
- **Game development** -- generating textures, characters, and
|
||||||
|
environments.
|
||||||
|
- **Product design** -- visualizing concepts before prototypes exist.
|
||||||
|
- **Enterprise automation** -- generating architecture diagrams or
|
||||||
|
slide illustrations automatically.
|
||||||
|
|
||||||
|
In enterprise environments, diffusion models can be integrated into
|
||||||
|
automated pipelines. For example, a presentation generator can
|
||||||
|
automatically produce slide images that represent architectural concepts
|
||||||
|
such as:
|
||||||
|
|
||||||
|
- API Gateways
|
||||||
|
- Databases
|
||||||
|
- Integration architectures
|
||||||
|
- Cloud infrastructures
|
||||||
|
|
||||||
|
Instead of manually creating diagrams, an AI pipeline can generate them
|
||||||
|
dynamically from structured prompts.
|
||||||
|
|
||||||
|
This tutorial explains how to train and use a **Stable Diffusion XL
|
||||||
|
model with LoRA fine‑tuning**, and how to deploy the model to generate
|
||||||
|
images programmatically.
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Technologies Involved
|
||||||
|
|
||||||
|
## Diffusion Models
|
||||||
|
|
||||||
|
Diffusion models generate images by **iteratively denoising random
|
||||||
|
noise**. The training process teaches a neural network how to reverse a
|
||||||
|
noise process applied to real images.
|
||||||
|
|
||||||
|
The process works as follows:
|
||||||
|
|
||||||
|
1. Start with a real image
|
||||||
|
2. Gradually add noise until the image becomes pure noise
|
||||||
|
3. Train a model to reverse this process
|
||||||
|
4. During inference, start with noise and iteratively remove it
|
||||||
|
|
||||||
|
This allows the model to synthesize new images from text descriptions.
|
||||||
|
|
||||||
|
Popular diffusion models include:
|
||||||
|
|
||||||
|
- Stable Diffusion
|
||||||
|
- Stable Diffusion XL (SDXL)
|
||||||
|
- DALL‑E
|
||||||
|
- Imagen
|
||||||
|
|
||||||
|
In this tutorial we use **Stable Diffusion XL**, which provides:
|
||||||
|
|
||||||
|
- Higher resolution
|
||||||
|
- Better text understanding
|
||||||
|
- Dual text encoders
|
||||||
|
- Micro‑conditioning
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
## Stable Diffusion XL (SDXL)
|
||||||
|
|
||||||
|
SDXL is an advanced diffusion architecture that improves generation
|
||||||
|
quality through:
|
||||||
|
|
||||||
|
- **Two text encoders**
|
||||||
|
- **Improved conditioning**
|
||||||
|
- **Higher resolution generation**
|
||||||
|
- **Better prompt interpretation**
|
||||||
|
|
||||||
|
Unlike earlier diffusion models, SDXL requires:
|
||||||
|
|
||||||
|
- two tokenizers
|
||||||
|
- two text encoders
|
||||||
|
- pooled embeddings
|
||||||
|
- time conditioning parameters
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
## LoRA (Low Rank Adaptation)
|
||||||
|
|
||||||
|
Training diffusion models from scratch is extremely expensive.
|
||||||
|
|
||||||
|
Instead, **LoRA** allows fine‑tuning large models efficiently by
|
||||||
|
training small low‑rank matrices that modify the attention layers of the
|
||||||
|
network.
|
||||||
|
|
||||||
|
Advantages:
|
||||||
|
|
||||||
|
- Very small training footprint
|
||||||
|
- Works with limited VRAM
|
||||||
|
- Easy to merge into the base model
|
||||||
|
- Fast training
|
||||||
|
|
||||||
|
In this project, LoRA is applied to the **UNet attention layers**.
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
## HuggingFace Diffusers
|
||||||
|
|
||||||
|
The **Diffusers library** provides a high‑level API for working with
|
||||||
|
diffusion models.
|
||||||
|
|
||||||
|
It includes:
|
||||||
|
|
||||||
|
- pipelines
|
||||||
|
- schedulers
|
||||||
|
- training utilities
|
||||||
|
- optimization helpers
|
||||||
|
|
||||||
|
Main components used:
|
||||||
|
|
||||||
|
- `StableDiffusionXLPipeline`
|
||||||
|
- `DDPMScheduler`
|
||||||
|
- `DPMSolverMultistepScheduler`
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
## PyTorch
|
||||||
|
|
||||||
|
PyTorch is used for:
|
||||||
|
|
||||||
|
- training loops
|
||||||
|
- GPU acceleration
|
||||||
|
- tensor operations
|
||||||
|
- neural network execution
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Code Walkthrough
|
||||||
|
|
||||||
|
## Dataset Structure
|
||||||
|
|
||||||
|
The training script expects a dataset structured as:
|
||||||
|
|
||||||
|
dataset/
|
||||||
|
images/
|
||||||
|
image1.png
|
||||||
|
image2.png
|
||||||
|
captions/
|
||||||
|
image1.txt
|
||||||
|
image2.txt
|
||||||
|
|
||||||
|
Each caption describes the image.
|
||||||
|
|
||||||
|
Example caption:
|
||||||
|
|
||||||
|
enterprise cloud architecture diagram with API gateway and database
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Dataset Loader
|
||||||
|
|
||||||
|
The dataset loader reads images and their captions.
|
||||||
|
|
||||||
|
Key operations:
|
||||||
|
|
||||||
|
- resizing images
|
||||||
|
- converting to tensors
|
||||||
|
- normalization
|
||||||
|
|
||||||
|
Important section:
|
||||||
|
|
||||||
|
``` python
|
||||||
|
transforms.Compose([
|
||||||
|
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
|
||||||
|
])
|
||||||
|
```
|
||||||
|
|
||||||
|
Normalization is important because diffusion models expect images in a
|
||||||
|
**\[-1,1\] range**.
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Prompt Encoding
|
||||||
|
|
||||||
|
SDXL uses **two text encoders**.
|
||||||
|
|
||||||
|
The function:
|
||||||
|
|
||||||
|
encode_prompt_sdxl()
|
||||||
|
|
||||||
|
performs:
|
||||||
|
|
||||||
|
1. tokenization of captions
|
||||||
|
2. embedding generation
|
||||||
|
3. concatenation of embeddings
|
||||||
|
|
||||||
|
Important concept:
|
||||||
|
|
||||||
|
prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1)
|
||||||
|
|
||||||
|
This merges both encoders into a single conditioning representation.
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Latent Encoding
|
||||||
|
|
||||||
|
Images are encoded into latent space using the **VAE**.
|
||||||
|
|
||||||
|
latents = vae.encode(images).latent_dist.sample()
|
||||||
|
|
||||||
|
The VAE compresses images before training the diffusion process.
|
||||||
|
|
||||||
|
This drastically reduces memory usage.
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Noise Training
|
||||||
|
|
||||||
|
Diffusion training consists of predicting noise added to images.
|
||||||
|
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
|
The model learns to predict this noise.
|
||||||
|
|
||||||
|
Loss function:
|
||||||
|
|
||||||
|
loss = F.mse_loss(noise_pred.float(), noise.float())
|
||||||
|
|
||||||
|
This is the standard diffusion training loss.
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# LoRA Configuration
|
||||||
|
|
||||||
|
LoRA modifies attention layers of the UNet.
|
||||||
|
|
||||||
|
LoraConfig(
|
||||||
|
r=8,
|
||||||
|
lora_alpha=16,
|
||||||
|
target_modules=["to_q","to_k","to_v","to_out.0"]
|
||||||
|
)
|
||||||
|
|
||||||
|
Key parameters:
|
||||||
|
|
||||||
|
Parameter Description
|
||||||
|
---------------- ---------------------------
|
||||||
|
r rank of adaptation
|
||||||
|
alpha scaling factor
|
||||||
|
target_modules attention layers to adapt
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Training Loop
|
||||||
|
|
||||||
|
Main training steps:
|
||||||
|
|
||||||
|
1. Encode image into latent space
|
||||||
|
2. Add noise
|
||||||
|
3. Encode text prompt
|
||||||
|
4. Predict noise with UNet
|
||||||
|
5. Compute loss
|
||||||
|
6. Backpropagate
|
||||||
|
|
||||||
|
The loop runs for multiple epochs:
|
||||||
|
|
||||||
|
for epoch in range(EPOCHS):
|
||||||
|
for step,(images,captions) in enumerate(dataloader):
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Saving the LoRA Model
|
||||||
|
|
||||||
|
After training:
|
||||||
|
|
||||||
|
unet.save_pretrained("sdxl_lora")
|
||||||
|
|
||||||
|
The LoRA weights can later be merged into the base model.
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Image Generation Pipeline
|
||||||
|
|
||||||
|
The generation script loads:
|
||||||
|
|
||||||
|
- SDXL base model
|
||||||
|
- trained LoRA
|
||||||
|
- optimized scheduler
|
||||||
|
|
||||||
|
Key configuration:
|
||||||
|
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
"stabilityai/stable-diffusion-xl-base-1.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
Scheduler optimization:
|
||||||
|
|
||||||
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||||
|
|
||||||
|
This significantly accelerates generation.
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Memory Optimization
|
||||||
|
|
||||||
|
To support Mac M‑series GPUs or limited VRAM:
|
||||||
|
|
||||||
|
pipe.enable_attention_slicing()
|
||||||
|
pipe.enable_vae_slicing()
|
||||||
|
|
||||||
|
These techniques reduce peak memory usage.
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Prompt Caching
|
||||||
|
|
||||||
|
Images are cached using a hash of the prompt.
|
||||||
|
|
||||||
|
hashlib.sha256(prompt.encode()).hexdigest()
|
||||||
|
|
||||||
|
This prevents regenerating identical images repeatedly.
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Image Generation
|
||||||
|
|
||||||
|
Image generation call:
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt,
|
||||||
|
negative_prompt=NEGATIVE,
|
||||||
|
num_inference_steps=25,
|
||||||
|
guidance_scale=7,
|
||||||
|
height=1024,
|
||||||
|
width=1024
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
Important parameters:
|
||||||
|
|
||||||
|
Parameter Meaning
|
||||||
|
--------------------- ----------------------
|
||||||
|
num_inference_steps diffusion iterations
|
||||||
|
guidance_scale prompt strength
|
||||||
|
height / width image resolution
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Deployment
|
||||||
|
|
||||||
|
## Install Dependencies
|
||||||
|
|
||||||
|
pip install torch diffusers transformers peft accelerate pillow torchvision
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Training
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
python train_diffusion.py
|
||||||
|
|
||||||
|
Training will produce:
|
||||||
|
|
||||||
|
sdxl_lora/
|
||||||
|
|
||||||
|
containing LoRA weights.
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Running the Generator
|
||||||
|
|
||||||
|
python generate_image3.py
|
||||||
|
|
||||||
|
Example prompt:
|
||||||
|
|
||||||
|
clean enterprise illustration,
|
||||||
|
corporate presentation slide,
|
||||||
|
minimal design,
|
||||||
|
white background,
|
||||||
|
oracle integration architecture
|
||||||
|
|
||||||
|
Generated images are saved in:
|
||||||
|
|
||||||
|
images/cache/
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
|
||||||
|
You can test generation with different prompts:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
enterprise microservices architecture diagram
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
cloud integration architecture with api gateway
|
||||||
|
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Conclusion
|
||||||
|
|
||||||
|
Diffusion models are revolutionizing image generation by allowing AI
|
||||||
|
systems to synthesize visual content from text descriptions.
|
||||||
|
|
||||||
|
Using Stable Diffusion XL combined with LoRA fine‑tuning enables:
|
||||||
|
|
||||||
|
- efficient training
|
||||||
|
- domain specialization
|
||||||
|
- enterprise use cases
|
||||||
|
- automated content generation
|
||||||
|
|
||||||
|
In practical systems, diffusion models can be integrated into larger
|
||||||
|
pipelines such as:
|
||||||
|
|
||||||
|
- automated presentation builders
|
||||||
|
- documentation systems
|
||||||
|
- AI‑generated diagrams
|
||||||
|
- marketing automation platforms
|
||||||
|
|
||||||
|
With efficient techniques like LoRA, high‑quality image generation is
|
||||||
|
now accessible even on consumer hardware such as:
|
||||||
|
|
||||||
|
- RTX GPUs
|
||||||
|
- Apple Silicon
|
||||||
|
- workstation GPUs
|
||||||
|
|
||||||
|
As diffusion architectures continue evolving, they will increasingly
|
||||||
|
become a core component of AI‑driven content generation systems.
|
||||||
|
|
||||||
|
# Acknowledgments
|
||||||
|
|
||||||
|
- **Author** - Cristiano Hoshikawa (Oracle LAD A-Team Solution Engineer)
|
||||||
69
generate_image3.py
Normal file
69
generate_image3.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import hashlib
|
||||||
|
from pathlib import Path
|
||||||
|
import torch
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
from diffusers import StableDiffusionXLPipeline
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
|
||||||
|
from project.generate_image import negative_prompt
|
||||||
|
|
||||||
|
CACHE_DIR = Path("images/cache")
|
||||||
|
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||||
|
dtype=torch.float32
|
||||||
|
).to("cuda")
|
||||||
|
pipe.vae.to(torch.float32)
|
||||||
|
pipe.unet = PeftModel.from_pretrained(pipe.unet, "sdxl_lora")
|
||||||
|
pipe.unet = pipe.unet.merge_and_unload()
|
||||||
|
pipe.unet.eval()
|
||||||
|
|
||||||
|
NEGATIVE = """
|
||||||
|
text,
|
||||||
|
watermark,
|
||||||
|
people,
|
||||||
|
photo,
|
||||||
|
photorealistic,
|
||||||
|
complex background
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _hash_prompt(prompt: str) -> str:
|
||||||
|
return hashlib.sha256(prompt.encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_slide_image(prompt: str) -> str:
|
||||||
|
|
||||||
|
h = _hash_prompt(prompt)
|
||||||
|
image_path = CACHE_DIR / f"{h}.png"
|
||||||
|
|
||||||
|
if image_path.exists():
|
||||||
|
print("Image cache hit")
|
||||||
|
return str(image_path)
|
||||||
|
|
||||||
|
print("Generating image...")
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt,
|
||||||
|
negative_prompt=NEGATIVE,
|
||||||
|
num_inference_steps=30,
|
||||||
|
guidance_scale=7,
|
||||||
|
height=1024,
|
||||||
|
width=1024
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
image.save(image_path)
|
||||||
|
|
||||||
|
return str(image_path)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
prompt = """
|
||||||
|
clean enterprise illustration,
|
||||||
|
corporate presentation slide,
|
||||||
|
minimal design,
|
||||||
|
white background,
|
||||||
|
blue gradient accents,
|
||||||
|
expose oracle integration connected to autonomous database and exposing through api gateway
|
||||||
|
"""
|
||||||
|
generate_slide_image(prompt)
|
||||||
218
train_diffusion.py
Normal file
218
train_diffusion.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import Image
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from diffusers import StableDiffusionXLPipeline, DDPMScheduler
|
||||||
|
from peft import LoraConfig, get_peft_model
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from peft import get_peft_model_state_dict
|
||||||
|
from diffusers.utils import convert_state_dict_to_diffusers
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# CONFIG
|
||||||
|
# =========================
|
||||||
|
IMAGES_DIR = Path("dataset/images")
|
||||||
|
CAPTIONS_DIR = Path("dataset/captions")
|
||||||
|
OUTPUT_DIR = Path("sdxl_lora")
|
||||||
|
|
||||||
|
IMAGE_SIZE = 1024
|
||||||
|
BATCH_SIZE = 2
|
||||||
|
EPOCHS = 10
|
||||||
|
LR = 1e-4
|
||||||
|
DEVICE = "cuda"
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# DATASET
|
||||||
|
# =========================
|
||||||
|
class ImageCaptionDataset(Dataset):
|
||||||
|
def __init__(self):
|
||||||
|
self.images = sorted([p for p in IMAGES_DIR.glob("*") if p.suffix.lower() in [".png", ".jpg", ".jpeg", ".webp"]])
|
||||||
|
self.transform = transforms.Compose([
|
||||||
|
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
# RGB: normalize 3 canais
|
||||||
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||||
|
])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.images)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
img_path = self.images[idx]
|
||||||
|
caption_path = CAPTIONS_DIR / (img_path.stem + ".txt")
|
||||||
|
if not caption_path.exists():
|
||||||
|
raise FileNotFoundError(f"Caption não encontrada: {caption_path}")
|
||||||
|
|
||||||
|
image = Image.open(img_path).convert("RGB")
|
||||||
|
image = self.transform(image)
|
||||||
|
caption = caption_path.read_text(encoding="utf-8").strip()
|
||||||
|
return image, caption
|
||||||
|
|
||||||
|
dataset = ImageCaptionDataset()
|
||||||
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# LOAD SDXL
|
||||||
|
# =========================
|
||||||
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
use_safetensors=True,
|
||||||
|
).to(DEVICE)
|
||||||
|
# VAE precisa FP32
|
||||||
|
pipe.vae.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
# UNet FP16
|
||||||
|
pipe.unet.to(dtype=torch.float16)
|
||||||
|
|
||||||
|
# Encoders FP16
|
||||||
|
pipe.text_encoder.to(dtype=torch.float16)
|
||||||
|
pipe.text_encoder_2.to(dtype=torch.float16)
|
||||||
|
|
||||||
|
# pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
tokenizer_1 = pipe.tokenizer
|
||||||
|
tokenizer_2 = pipe.tokenizer_2
|
||||||
|
text_encoder_1 = pipe.text_encoder
|
||||||
|
text_encoder_2 = pipe.text_encoder_2
|
||||||
|
vae = pipe.vae
|
||||||
|
unet = pipe.unet
|
||||||
|
|
||||||
|
scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
|
||||||
|
|
||||||
|
# Congelar VAE e text encoders (treina só o LoRA no UNet)
|
||||||
|
vae.eval()
|
||||||
|
vae.requires_grad_(False)
|
||||||
|
text_encoder_1.eval()
|
||||||
|
text_encoder_1.requires_grad_(False)
|
||||||
|
text_encoder_2.eval()
|
||||||
|
text_encoder_2.requires_grad_(False)
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# LoRA CONFIG (UNet)
|
||||||
|
# =========================
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=8,
|
||||||
|
lora_alpha=16,
|
||||||
|
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||||
|
lora_dropout=0.1,
|
||||||
|
bias="none",
|
||||||
|
)
|
||||||
|
|
||||||
|
unet = get_peft_model(unet, lora_config)
|
||||||
|
unet.train()
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(unet.parameters(), lr=LR)
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# Helpers SDXL conditioning
|
||||||
|
# =========================
|
||||||
|
def encode_prompt_sdxl(captions):
|
||||||
|
# encoder 1
|
||||||
|
inputs_1 = tokenizer_1(
|
||||||
|
captions,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=tokenizer_1.model_max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
# encoder 2
|
||||||
|
inputs_2 = tokenizer_2(
|
||||||
|
captions,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=tokenizer_2.model_max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids_1 = inputs_1.input_ids.to(DEVICE)
|
||||||
|
input_ids_2 = inputs_2.input_ids.to(DEVICE)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out_1 = text_encoder_1(input_ids_1, output_hidden_states=True)
|
||||||
|
out_2 = text_encoder_2(input_ids_2, output_hidden_states=True)
|
||||||
|
|
||||||
|
# SDXL usa penúltima hidden layer como token embeddings
|
||||||
|
prompt_embeds_1 = out_1.hidden_states[-2]
|
||||||
|
prompt_embeds_2 = out_2.hidden_states[-2]
|
||||||
|
|
||||||
|
# concatena embeddings dos dois encoders no eixo de features
|
||||||
|
prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1)
|
||||||
|
|
||||||
|
# pooled embed: normalmente vem do text_encoder_2
|
||||||
|
pooled_prompt_embeds = out_2.hidden_states[-1][:, 0]
|
||||||
|
|
||||||
|
return prompt_embeds, pooled_prompt_embeds
|
||||||
|
|
||||||
|
def make_time_ids(batch_size):
|
||||||
|
# Micro-conditioning do SDXL:
|
||||||
|
# [orig_h, orig_w, crop_y, crop_x, target_h, target_w]
|
||||||
|
return torch.tensor(
|
||||||
|
[[IMAGE_SIZE, IMAGE_SIZE, 0, 0, IMAGE_SIZE, IMAGE_SIZE]],
|
||||||
|
device=DEVICE,
|
||||||
|
dtype=torch.float16,
|
||||||
|
).repeat(batch_size, 1)
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# TRAIN
|
||||||
|
# =========================
|
||||||
|
for epoch in range(EPOCHS):
|
||||||
|
for step, (images, captions) in enumerate(dataloader):
|
||||||
|
# images = images.to(DEVICE, dtype=torch.float16)
|
||||||
|
images = images.to(DEVICE, dtype=torch.float32)
|
||||||
|
vae.requires_grad_(False)
|
||||||
|
vae.eval()
|
||||||
|
vae.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
# VAE encode -> latents
|
||||||
|
with torch.no_grad():
|
||||||
|
latents = vae.encode(images).latent_dist.sample()
|
||||||
|
latents = latents * vae.config.scaling_factor
|
||||||
|
latents = latents.to(torch.float16)
|
||||||
|
|
||||||
|
# noise + timestep
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
timesteps = torch.randint(
|
||||||
|
0,
|
||||||
|
scheduler.config.num_train_timesteps,
|
||||||
|
(latents.shape[0],),
|
||||||
|
device=DEVICE,
|
||||||
|
).long()
|
||||||
|
|
||||||
|
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
|
# SDXL conditioning
|
||||||
|
prompt_embeds, pooled_prompt_embeds = encode_prompt_sdxl(captions)
|
||||||
|
time_ids = make_time_ids(latents.shape[0])
|
||||||
|
|
||||||
|
# UNet predict
|
||||||
|
noise_pred = unet(
|
||||||
|
noisy_latents,
|
||||||
|
timesteps,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
added_cond_kwargs={
|
||||||
|
"text_embeds": pooled_prompt_embeds,
|
||||||
|
"time_ids": time_ids,
|
||||||
|
},
|
||||||
|
).sample
|
||||||
|
|
||||||
|
loss = F.mse_loss(noise_pred.float(), noise.float())
|
||||||
|
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
if step % 10 == 0:
|
||||||
|
print(f"Epoch {epoch} Step {step} Loss {loss.item():.4f}")
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# SAVE LoRA CORRETO
|
||||||
|
# =========================
|
||||||
|
|
||||||
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
unet.save_pretrained("sdxl_lora")
|
||||||
|
|
||||||
|
print("LoRA saved correctly.")
|
||||||
Reference in New Issue
Block a user