commit 350e2f9a300bf9039f7a057984cdc9539aa13842 Author: hoshikawa2 Date: Sat Mar 7 07:13:48 2026 -0300 first commit diff --git a/README.md b/README.md new file mode 100644 index 0000000..35bdc4e --- /dev/null +++ b/README.md @@ -0,0 +1,441 @@ +# Training Image Generation Models with Diffusion (Stable Diffusion XL + LoRA) + +## Introduction + +Image generation using diffusion models has become one of the most +transformative capabilities of modern artificial intelligence. These +models are capable of generating high‑quality images from natural +language descriptions, enabling applications across multiple industries. + +Real-world use cases include: + +- **Marketing and advertising** -- generating visual assets + automatically. +- **Software documentation and presentations** -- producing diagrams + and illustrations for technical content. +- **Game development** -- generating textures, characters, and + environments. +- **Product design** -- visualizing concepts before prototypes exist. +- **Enterprise automation** -- generating architecture diagrams or + slide illustrations automatically. + +In enterprise environments, diffusion models can be integrated into +automated pipelines. For example, a presentation generator can +automatically produce slide images that represent architectural concepts +such as: + +- API Gateways +- Databases +- Integration architectures +- Cloud infrastructures + +Instead of manually creating diagrams, an AI pipeline can generate them +dynamically from structured prompts. + +This tutorial explains how to train and use a **Stable Diffusion XL +model with LoRA fine‑tuning**, and how to deploy the model to generate +images programmatically. + +------------------------------------------------------------------------ + +# Technologies Involved + +## Diffusion Models + +Diffusion models generate images by **iteratively denoising random +noise**. The training process teaches a neural network how to reverse a +noise process applied to real images. + +The process works as follows: + +1. Start with a real image +2. Gradually add noise until the image becomes pure noise +3. Train a model to reverse this process +4. During inference, start with noise and iteratively remove it + +This allows the model to synthesize new images from text descriptions. + +Popular diffusion models include: + +- Stable Diffusion +- Stable Diffusion XL (SDXL) +- DALL‑E +- Imagen + +In this tutorial we use **Stable Diffusion XL**, which provides: + +- Higher resolution +- Better text understanding +- Dual text encoders +- Micro‑conditioning + +------------------------------------------------------------------------ + +## Stable Diffusion XL (SDXL) + +SDXL is an advanced diffusion architecture that improves generation +quality through: + +- **Two text encoders** +- **Improved conditioning** +- **Higher resolution generation** +- **Better prompt interpretation** + +Unlike earlier diffusion models, SDXL requires: + +- two tokenizers +- two text encoders +- pooled embeddings +- time conditioning parameters + +------------------------------------------------------------------------ + +## LoRA (Low Rank Adaptation) + +Training diffusion models from scratch is extremely expensive. + +Instead, **LoRA** allows fine‑tuning large models efficiently by +training small low‑rank matrices that modify the attention layers of the +network. + +Advantages: + +- Very small training footprint +- Works with limited VRAM +- Easy to merge into the base model +- Fast training + +In this project, LoRA is applied to the **UNet attention layers**. + +------------------------------------------------------------------------ + +## HuggingFace Diffusers + +The **Diffusers library** provides a high‑level API for working with +diffusion models. + +It includes: + +- pipelines +- schedulers +- training utilities +- optimization helpers + +Main components used: + +- `StableDiffusionXLPipeline` +- `DDPMScheduler` +- `DPMSolverMultistepScheduler` + +------------------------------------------------------------------------ + +## PyTorch + +PyTorch is used for: + +- training loops +- GPU acceleration +- tensor operations +- neural network execution + +------------------------------------------------------------------------ + +# Code Walkthrough + +## Dataset Structure + +The training script expects a dataset structured as: + + dataset/ + images/ + image1.png + image2.png + captions/ + image1.txt + image2.txt + +Each caption describes the image. + +Example caption: + + enterprise cloud architecture diagram with API gateway and database + +------------------------------------------------------------------------ + +# Dataset Loader + +The dataset loader reads images and their captions. + +Key operations: + +- resizing images +- converting to tensors +- normalization + +Important section: + +``` python +transforms.Compose([ + transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), + transforms.ToTensor(), + transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]) +]) +``` + +Normalization is important because diffusion models expect images in a +**\[-1,1\] range**. + +------------------------------------------------------------------------ + +# Prompt Encoding + +SDXL uses **two text encoders**. + +The function: + + encode_prompt_sdxl() + +performs: + +1. tokenization of captions +2. embedding generation +3. concatenation of embeddings + +Important concept: + + prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1) + +This merges both encoders into a single conditioning representation. + +------------------------------------------------------------------------ + +# Latent Encoding + +Images are encoded into latent space using the **VAE**. + + latents = vae.encode(images).latent_dist.sample() + +The VAE compresses images before training the diffusion process. + +This drastically reduces memory usage. + +------------------------------------------------------------------------ + +# Noise Training + +Diffusion training consists of predicting noise added to images. + + noise = torch.randn_like(latents) + noisy_latents = scheduler.add_noise(latents, noise, timesteps) + +The model learns to predict this noise. + +Loss function: + + loss = F.mse_loss(noise_pred.float(), noise.float()) + +This is the standard diffusion training loss. + +------------------------------------------------------------------------ + +# LoRA Configuration + +LoRA modifies attention layers of the UNet. + + LoraConfig( + r=8, + lora_alpha=16, + target_modules=["to_q","to_k","to_v","to_out.0"] + ) + +Key parameters: + + Parameter Description + ---------------- --------------------------- + r rank of adaptation + alpha scaling factor + target_modules attention layers to adapt + +------------------------------------------------------------------------ + +# Training Loop + +Main training steps: + +1. Encode image into latent space +2. Add noise +3. Encode text prompt +4. Predict noise with UNet +5. Compute loss +6. Backpropagate + +The loop runs for multiple epochs: + + for epoch in range(EPOCHS): + for step,(images,captions) in enumerate(dataloader): + +------------------------------------------------------------------------ + +# Saving the LoRA Model + +After training: + + unet.save_pretrained("sdxl_lora") + +The LoRA weights can later be merged into the base model. + +------------------------------------------------------------------------ + +# Image Generation Pipeline + +The generation script loads: + +- SDXL base model +- trained LoRA +- optimized scheduler + +Key configuration: + + pipe = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0" + ) + +Scheduler optimization: + + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + +This significantly accelerates generation. + +------------------------------------------------------------------------ + +# Memory Optimization + +To support Mac M‑series GPUs or limited VRAM: + + pipe.enable_attention_slicing() + pipe.enable_vae_slicing() + +These techniques reduce peak memory usage. + +------------------------------------------------------------------------ + +# Prompt Caching + +Images are cached using a hash of the prompt. + + hashlib.sha256(prompt.encode()).hexdigest() + +This prevents regenerating identical images repeatedly. + +------------------------------------------------------------------------ + +# Image Generation + +Image generation call: + + image = pipe( + prompt, + negative_prompt=NEGATIVE, + num_inference_steps=25, + guidance_scale=7, + height=1024, + width=1024 + ).images[0] + +Important parameters: + + Parameter Meaning + --------------------- ---------------------- + num_inference_steps diffusion iterations + guidance_scale prompt strength + height / width image resolution + +------------------------------------------------------------------------ + +# Deployment + +## Install Dependencies + + pip install torch diffusers transformers peft accelerate pillow torchvision + +------------------------------------------------------------------------ + +# Training + +Run: + + python train_diffusion.py + +Training will produce: + + sdxl_lora/ + +containing LoRA weights. + +------------------------------------------------------------------------ + +# Running the Generator + + python generate_image3.py + +Example prompt: + + clean enterprise illustration, + corporate presentation slide, + minimal design, + white background, + oracle integration architecture + +Generated images are saved in: + + images/cache/ + +------------------------------------------------------------------------ + +# Testing + +You can test generation with different prompts: + +Example: + + enterprise microservices architecture diagram + +or + + cloud integration architecture with api gateway + +------------------------------------------------------------------------ + +# Conclusion + +Diffusion models are revolutionizing image generation by allowing AI +systems to synthesize visual content from text descriptions. + +Using Stable Diffusion XL combined with LoRA fine‑tuning enables: + +- efficient training +- domain specialization +- enterprise use cases +- automated content generation + +In practical systems, diffusion models can be integrated into larger +pipelines such as: + +- automated presentation builders +- documentation systems +- AI‑generated diagrams +- marketing automation platforms + +With efficient techniques like LoRA, high‑quality image generation is +now accessible even on consumer hardware such as: + +- RTX GPUs +- Apple Silicon +- workstation GPUs + +As diffusion architectures continue evolving, they will increasingly +become a core component of AI‑driven content generation systems. + +# Acknowledgments + +- **Author** - Cristiano Hoshikawa (Oracle LAD A-Team Solution Engineer) \ No newline at end of file diff --git a/generate_image3.py b/generate_image3.py new file mode 100644 index 0000000..5bacbd1 --- /dev/null +++ b/generate_image3.py @@ -0,0 +1,69 @@ +import hashlib +from pathlib import Path +import torch +from diffusers import DiffusionPipeline +from diffusers import StableDiffusionXLPipeline +from peft import PeftModel + + +from project.generate_image import negative_prompt + +CACHE_DIR = Path("images/cache") +CACHE_DIR.mkdir(parents=True, exist_ok=True) + +pipe = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + dtype=torch.float32 +).to("cuda") +pipe.vae.to(torch.float32) +pipe.unet = PeftModel.from_pretrained(pipe.unet, "sdxl_lora") +pipe.unet = pipe.unet.merge_and_unload() +pipe.unet.eval() + +NEGATIVE = """ +text, +watermark, +people, +photo, +photorealistic, +complex background +""" + +def _hash_prompt(prompt: str) -> str: + return hashlib.sha256(prompt.encode()).hexdigest()[:16] + + +def generate_slide_image(prompt: str) -> str: + + h = _hash_prompt(prompt) + image_path = CACHE_DIR / f"{h}.png" + + if image_path.exists(): + print("Image cache hit") + return str(image_path) + + print("Generating image...") + + image = pipe( + prompt, + negative_prompt=NEGATIVE, + num_inference_steps=30, + guidance_scale=7, + height=1024, + width=1024 + ).images[0] + + image.save(image_path) + + return str(image_path) + +if __name__ == "__main__": + prompt = """ + clean enterprise illustration, + corporate presentation slide, + minimal design, + white background, + blue gradient accents, + expose oracle integration connected to autonomous database and exposing through api gateway + """ + generate_slide_image(prompt) \ No newline at end of file diff --git a/train_diffusion.py b/train_diffusion.py new file mode 100644 index 0000000..5408db1 --- /dev/null +++ b/train_diffusion.py @@ -0,0 +1,218 @@ +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from PIL import Image +from pathlib import Path + +from diffusers import StableDiffusionXLPipeline, DDPMScheduler +from peft import LoraConfig, get_peft_model +import torch.nn.functional as F +from peft import get_peft_model_state_dict +from diffusers.utils import convert_state_dict_to_diffusers + +# ========================= +# CONFIG +# ========================= +IMAGES_DIR = Path("dataset/images") +CAPTIONS_DIR = Path("dataset/captions") +OUTPUT_DIR = Path("sdxl_lora") + +IMAGE_SIZE = 1024 +BATCH_SIZE = 2 +EPOCHS = 10 +LR = 1e-4 +DEVICE = "cuda" + +# ========================= +# DATASET +# ========================= +class ImageCaptionDataset(Dataset): + def __init__(self): + self.images = sorted([p for p in IMAGES_DIR.glob("*") if p.suffix.lower() in [".png", ".jpg", ".jpeg", ".webp"]]) + self.transform = transforms.Compose([ + transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), + transforms.ToTensor(), + # RGB: normalize 3 canais + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ]) + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + img_path = self.images[idx] + caption_path = CAPTIONS_DIR / (img_path.stem + ".txt") + if not caption_path.exists(): + raise FileNotFoundError(f"Caption não encontrada: {caption_path}") + + image = Image.open(img_path).convert("RGB") + image = self.transform(image) + caption = caption_path.read_text(encoding="utf-8").strip() + return image, caption + +dataset = ImageCaptionDataset() +dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True) + +# ========================= +# LOAD SDXL +# ========================= +pipe = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float32, + use_safetensors=True, +).to(DEVICE) +# VAE precisa FP32 +pipe.vae.to(dtype=torch.float32) + +# UNet FP16 +pipe.unet.to(dtype=torch.float16) + +# Encoders FP16 +pipe.text_encoder.to(dtype=torch.float16) +pipe.text_encoder_2.to(dtype=torch.float16) + +# pipe.enable_xformers_memory_efficient_attention() + +tokenizer_1 = pipe.tokenizer +tokenizer_2 = pipe.tokenizer_2 +text_encoder_1 = pipe.text_encoder +text_encoder_2 = pipe.text_encoder_2 +vae = pipe.vae +unet = pipe.unet + +scheduler = DDPMScheduler.from_config(pipe.scheduler.config) + +# Congelar VAE e text encoders (treina só o LoRA no UNet) +vae.eval() +vae.requires_grad_(False) +text_encoder_1.eval() +text_encoder_1.requires_grad_(False) +text_encoder_2.eval() +text_encoder_2.requires_grad_(False) + +# ========================= +# LoRA CONFIG (UNet) +# ========================= +lora_config = LoraConfig( + r=8, + lora_alpha=16, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + lora_dropout=0.1, + bias="none", +) + +unet = get_peft_model(unet, lora_config) +unet.train() + +optimizer = torch.optim.AdamW(unet.parameters(), lr=LR) + +# ========================= +# Helpers SDXL conditioning +# ========================= +def encode_prompt_sdxl(captions): + # encoder 1 + inputs_1 = tokenizer_1( + captions, + padding="max_length", + truncation=True, + max_length=tokenizer_1.model_max_length, + return_tensors="pt", + ) + # encoder 2 + inputs_2 = tokenizer_2( + captions, + padding="max_length", + truncation=True, + max_length=tokenizer_2.model_max_length, + return_tensors="pt", + ) + + input_ids_1 = inputs_1.input_ids.to(DEVICE) + input_ids_2 = inputs_2.input_ids.to(DEVICE) + + with torch.no_grad(): + out_1 = text_encoder_1(input_ids_1, output_hidden_states=True) + out_2 = text_encoder_2(input_ids_2, output_hidden_states=True) + + # SDXL usa penúltima hidden layer como token embeddings + prompt_embeds_1 = out_1.hidden_states[-2] + prompt_embeds_2 = out_2.hidden_states[-2] + + # concatena embeddings dos dois encoders no eixo de features + prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1) + + # pooled embed: normalmente vem do text_encoder_2 + pooled_prompt_embeds = out_2.hidden_states[-1][:, 0] + + return prompt_embeds, pooled_prompt_embeds + +def make_time_ids(batch_size): + # Micro-conditioning do SDXL: + # [orig_h, orig_w, crop_y, crop_x, target_h, target_w] + return torch.tensor( + [[IMAGE_SIZE, IMAGE_SIZE, 0, 0, IMAGE_SIZE, IMAGE_SIZE]], + device=DEVICE, + dtype=torch.float16, + ).repeat(batch_size, 1) + +# ========================= +# TRAIN +# ========================= +for epoch in range(EPOCHS): + for step, (images, captions) in enumerate(dataloader): + # images = images.to(DEVICE, dtype=torch.float16) + images = images.to(DEVICE, dtype=torch.float32) + vae.requires_grad_(False) + vae.eval() + vae.to(dtype=torch.float32) + + # VAE encode -> latents + with torch.no_grad(): + latents = vae.encode(images).latent_dist.sample() + latents = latents * vae.config.scaling_factor + latents = latents.to(torch.float16) + + # noise + timestep + noise = torch.randn_like(latents) + timesteps = torch.randint( + 0, + scheduler.config.num_train_timesteps, + (latents.shape[0],), + device=DEVICE, + ).long() + + noisy_latents = scheduler.add_noise(latents, noise, timesteps) + + # SDXL conditioning + prompt_embeds, pooled_prompt_embeds = encode_prompt_sdxl(captions) + time_ids = make_time_ids(latents.shape[0]) + + # UNet predict + noise_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs={ + "text_embeds": pooled_prompt_embeds, + "time_ids": time_ids, + }, + ).sample + + loss = F.mse_loss(noise_pred.float(), noise.float()) + + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + + if step % 10 == 0: + print(f"Epoch {epoch} Step {step} Loss {loss.item():.4f}") + +# ========================= +# SAVE LoRA CORRETO +# ========================= + +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +unet.save_pretrained("sdxl_lora") + +print("LoRA saved correctly.") \ No newline at end of file