mirror of
https://github.com/hoshikawa2/image_lora_training.git
synced 2026-03-11 17:14:57 +00:00
86 lines
1.9 KiB
Python
86 lines
1.9 KiB
Python
import hashlib
|
|
from pathlib import Path
|
|
import torch
|
|
|
|
from diffusers import StableDiffusionXLPipeline
|
|
from diffusers import DPMSolverMultistepScheduler
|
|
from peft import PeftModel
|
|
|
|
CACHE_DIR = Path("images/cache")
|
|
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
DEVICE = "cuda"
|
|
|
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
"stabilityai/stable-diffusion-xl-base-1.0",
|
|
torch_dtype=torch.float32,
|
|
use_safetensors=True
|
|
).to(DEVICE)
|
|
|
|
# scheduler mais rápido
|
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
|
|
|
# otimizações de memória
|
|
pipe.enable_attention_slicing()
|
|
pipe.enable_vae_slicing()
|
|
|
|
# VAE precisa FP32
|
|
pipe.vae.to(torch.float32)
|
|
|
|
# carregar LoRA
|
|
pipe.unet = PeftModel.from_pretrained(pipe.unet, "sdxl_lora")
|
|
pipe.unet = pipe.unet.merge_and_unload()
|
|
pipe.unet.eval()
|
|
|
|
NEGATIVE = """
|
|
text,
|
|
watermark,
|
|
people,
|
|
photo,
|
|
photorealistic,
|
|
complex background
|
|
"""
|
|
|
|
def _hash_prompt(prompt: str) -> str:
|
|
return hashlib.sha256(prompt.encode()).hexdigest()[:16]
|
|
|
|
|
|
def generate_slide_image(prompt: str) -> str:
|
|
|
|
h = _hash_prompt(prompt)
|
|
image_path = CACHE_DIR / f"{h}.png"
|
|
|
|
if image_path.exists():
|
|
print("Image cache hit")
|
|
return str(image_path)
|
|
|
|
print("Generating image...")
|
|
|
|
with torch.no_grad():
|
|
|
|
image = pipe(
|
|
prompt,
|
|
negative_prompt=NEGATIVE,
|
|
num_inference_steps=25,
|
|
guidance_scale=7,
|
|
height=1024,
|
|
width=1024
|
|
).images[0]
|
|
|
|
image.save(image_path)
|
|
|
|
return str(image_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
prompt = """
|
|
clean enterprise illustration,
|
|
corporate presentation slide,
|
|
minimal design,
|
|
white background,
|
|
blue gradient accents,
|
|
oracle integration connected to autonomous database through api gateway
|
|
"""
|
|
|
|
generate_slide_image(prompt) |