Files
image_lora_training/generate_image3.py
2026-03-07 07:13:48 -03:00

69 lines
1.6 KiB
Python

import hashlib
from pathlib import Path
import torch
from diffusers import DiffusionPipeline
from diffusers import StableDiffusionXLPipeline
from peft import PeftModel
from project.generate_image import negative_prompt
CACHE_DIR = Path("images/cache")
CACHE_DIR.mkdir(parents=True, exist_ok=True)
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
dtype=torch.float32
).to("cuda")
pipe.vae.to(torch.float32)
pipe.unet = PeftModel.from_pretrained(pipe.unet, "sdxl_lora")
pipe.unet = pipe.unet.merge_and_unload()
pipe.unet.eval()
NEGATIVE = """
text,
watermark,
people,
photo,
photorealistic,
complex background
"""
def _hash_prompt(prompt: str) -> str:
return hashlib.sha256(prompt.encode()).hexdigest()[:16]
def generate_slide_image(prompt: str) -> str:
h = _hash_prompt(prompt)
image_path = CACHE_DIR / f"{h}.png"
if image_path.exists():
print("Image cache hit")
return str(image_path)
print("Generating image...")
image = pipe(
prompt,
negative_prompt=NEGATIVE,
num_inference_steps=30,
guidance_scale=7,
height=1024,
width=1024
).images[0]
image.save(image_path)
return str(image_path)
if __name__ == "__main__":
prompt = """
clean enterprise illustration,
corporate presentation slide,
minimal design,
white background,
blue gradient accents,
expose oracle integration connected to autonomous database and exposing through api gateway
"""
generate_slide_image(prompt)