mirror of
https://github.com/hoshikawa2/image_lora_training.git
synced 2026-03-11 17:14:57 +00:00
69 lines
1.6 KiB
Python
69 lines
1.6 KiB
Python
import hashlib
|
|
from pathlib import Path
|
|
import torch
|
|
from diffusers import DiffusionPipeline
|
|
from diffusers import StableDiffusionXLPipeline
|
|
from peft import PeftModel
|
|
|
|
|
|
from project.generate_image import negative_prompt
|
|
|
|
CACHE_DIR = Path("images/cache")
|
|
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
"stabilityai/stable-diffusion-xl-base-1.0",
|
|
dtype=torch.float32
|
|
).to("cuda")
|
|
pipe.vae.to(torch.float32)
|
|
pipe.unet = PeftModel.from_pretrained(pipe.unet, "sdxl_lora")
|
|
pipe.unet = pipe.unet.merge_and_unload()
|
|
pipe.unet.eval()
|
|
|
|
NEGATIVE = """
|
|
text,
|
|
watermark,
|
|
people,
|
|
photo,
|
|
photorealistic,
|
|
complex background
|
|
"""
|
|
|
|
def _hash_prompt(prompt: str) -> str:
|
|
return hashlib.sha256(prompt.encode()).hexdigest()[:16]
|
|
|
|
|
|
def generate_slide_image(prompt: str) -> str:
|
|
|
|
h = _hash_prompt(prompt)
|
|
image_path = CACHE_DIR / f"{h}.png"
|
|
|
|
if image_path.exists():
|
|
print("Image cache hit")
|
|
return str(image_path)
|
|
|
|
print("Generating image...")
|
|
|
|
image = pipe(
|
|
prompt,
|
|
negative_prompt=NEGATIVE,
|
|
num_inference_steps=30,
|
|
guidance_scale=7,
|
|
height=1024,
|
|
width=1024
|
|
).images[0]
|
|
|
|
image.save(image_path)
|
|
|
|
return str(image_path)
|
|
|
|
if __name__ == "__main__":
|
|
prompt = """
|
|
clean enterprise illustration,
|
|
corporate presentation slide,
|
|
minimal design,
|
|
white background,
|
|
blue gradient accents,
|
|
expose oracle integration connected to autonomous database and exposing through api gateway
|
|
"""
|
|
generate_slide_image(prompt) |