mirror of
https://github.com/hoshikawa2/image_lora_training.git
synced 2026-03-11 17:14:57 +00:00
optimized
This commit is contained in:
@@ -1,21 +1,33 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
from diffusers import DiffusionPipeline
|
|
||||||
from diffusers import StableDiffusionXLPipeline
|
from diffusers import StableDiffusionXLPipeline
|
||||||
|
from diffusers import DPMSolverMultistepScheduler
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
|
|
||||||
|
|
||||||
from project.generate_image import negative_prompt
|
|
||||||
|
|
||||||
CACHE_DIR = Path("images/cache")
|
CACHE_DIR = Path("images/cache")
|
||||||
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
DEVICE = "cuda"
|
||||||
|
|
||||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||||
dtype=torch.float32
|
torch_dtype=torch.float32,
|
||||||
).to("cuda")
|
use_safetensors=True
|
||||||
|
).to(DEVICE)
|
||||||
|
|
||||||
|
# scheduler mais rápido
|
||||||
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||||
|
|
||||||
|
# otimizações de memória
|
||||||
|
pipe.enable_attention_slicing()
|
||||||
|
pipe.enable_vae_slicing()
|
||||||
|
|
||||||
|
# VAE precisa FP32
|
||||||
pipe.vae.to(torch.float32)
|
pipe.vae.to(torch.float32)
|
||||||
|
|
||||||
|
# carregar LoRA
|
||||||
pipe.unet = PeftModel.from_pretrained(pipe.unet, "sdxl_lora")
|
pipe.unet = PeftModel.from_pretrained(pipe.unet, "sdxl_lora")
|
||||||
pipe.unet = pipe.unet.merge_and_unload()
|
pipe.unet = pipe.unet.merge_and_unload()
|
||||||
pipe.unet.eval()
|
pipe.unet.eval()
|
||||||
@@ -44,26 +56,31 @@ def generate_slide_image(prompt: str) -> str:
|
|||||||
|
|
||||||
print("Generating image...")
|
print("Generating image...")
|
||||||
|
|
||||||
image = pipe(
|
with torch.no_grad():
|
||||||
prompt,
|
|
||||||
negative_prompt=NEGATIVE,
|
image = pipe(
|
||||||
num_inference_steps=30,
|
prompt,
|
||||||
guidance_scale=7,
|
negative_prompt=NEGATIVE,
|
||||||
height=1024,
|
num_inference_steps=25,
|
||||||
width=1024
|
guidance_scale=7,
|
||||||
).images[0]
|
height=1024,
|
||||||
|
width=1024
|
||||||
|
).images[0]
|
||||||
|
|
||||||
image.save(image_path)
|
image.save(image_path)
|
||||||
|
|
||||||
return str(image_path)
|
return str(image_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
prompt = """
|
prompt = """
|
||||||
clean enterprise illustration,
|
clean enterprise illustration,
|
||||||
corporate presentation slide,
|
corporate presentation slide,
|
||||||
minimal design,
|
minimal design,
|
||||||
white background,
|
white background,
|
||||||
blue gradient accents,
|
blue gradient accents,
|
||||||
expose oracle integration connected to autonomous database and exposing through api gateway
|
oracle integration connected to autonomous database through api gateway
|
||||||
"""
|
"""
|
||||||
|
|
||||||
generate_slide_image(prompt)
|
generate_slide_image(prompt)
|
||||||
Reference in New Issue
Block a user