mirror of
https://github.com/hoshikawa2/image_lora_training.git
synced 2026-03-11 17:14:57 +00:00
optimized
This commit is contained in:
@@ -1,21 +1,33 @@
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
from peft import PeftModel
|
||||
|
||||
|
||||
from project.generate_image import negative_prompt
|
||||
|
||||
CACHE_DIR = Path("images/cache")
|
||||
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
DEVICE = "cuda"
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
dtype=torch.float32
|
||||
).to("cuda")
|
||||
torch_dtype=torch.float32,
|
||||
use_safetensors=True
|
||||
).to(DEVICE)
|
||||
|
||||
# scheduler mais rápido
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
# otimizações de memória
|
||||
pipe.enable_attention_slicing()
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
# VAE precisa FP32
|
||||
pipe.vae.to(torch.float32)
|
||||
|
||||
# carregar LoRA
|
||||
pipe.unet = PeftModel.from_pretrained(pipe.unet, "sdxl_lora")
|
||||
pipe.unet = pipe.unet.merge_and_unload()
|
||||
pipe.unet.eval()
|
||||
@@ -44,10 +56,12 @@ def generate_slide_image(prompt: str) -> str:
|
||||
|
||||
print("Generating image...")
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
negative_prompt=NEGATIVE,
|
||||
num_inference_steps=30,
|
||||
num_inference_steps=25,
|
||||
guidance_scale=7,
|
||||
height=1024,
|
||||
width=1024
|
||||
@@ -57,13 +71,16 @@ def generate_slide_image(prompt: str) -> str:
|
||||
|
||||
return str(image_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
prompt = """
|
||||
clean enterprise illustration,
|
||||
corporate presentation slide,
|
||||
minimal design,
|
||||
white background,
|
||||
blue gradient accents,
|
||||
expose oracle integration connected to autonomous database and exposing through api gateway
|
||||
oracle integration connected to autonomous database through api gateway
|
||||
"""
|
||||
|
||||
generate_slide_image(prompt)
|
||||
Reference in New Issue
Block a user