Files
image_lora_training/generate_image3.py
2026-03-08 18:14:58 -03:00

86 lines
1.9 KiB
Python

import hashlib
from pathlib import Path
import torch
from diffusers import StableDiffusionXLPipeline
from diffusers import DPMSolverMultistepScheduler
from peft import PeftModel
CACHE_DIR = Path("images/cache")
CACHE_DIR.mkdir(parents=True, exist_ok=True)
DEVICE = "cuda"
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float32,
use_safetensors=True
).to(DEVICE)
# scheduler mais rápido
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
# otimizações de memória
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
# VAE precisa FP32
pipe.vae.to(torch.float32)
# carregar LoRA
pipe.unet = PeftModel.from_pretrained(pipe.unet, "sdxl_lora")
pipe.unet = pipe.unet.merge_and_unload()
pipe.unet.eval()
NEGATIVE = """
text,
watermark,
people,
photo,
photorealistic,
complex background
"""
def _hash_prompt(prompt: str) -> str:
return hashlib.sha256(prompt.encode()).hexdigest()[:16]
def generate_slide_image(prompt: str) -> str:
h = _hash_prompt(prompt)
image_path = CACHE_DIR / f"{h}.png"
if image_path.exists():
print("Image cache hit")
return str(image_path)
print("Generating image...")
with torch.no_grad():
image = pipe(
prompt,
negative_prompt=NEGATIVE,
num_inference_steps=25,
guidance_scale=7,
height=1024,
width=1024
).images[0]
image.save(image_path)
return str(image_path)
if __name__ == "__main__":
prompt = """
clean enterprise illustration,
corporate presentation slide,
minimal design,
white background,
blue gradient accents,
oracle integration connected to autonomous database through api gateway
"""
generate_slide_image(prompt)