오늘은 Stable Diffusion 코드를 활용하여 FastAPI를 구현했습니다.
이전에 Grounding-DINO를 이용하여 FastAPI 구현 및 실행하는 글을 작성하였습니다.
2024.12.06 - [Machine Learning] - Grounding-DINO FastAPI 구현
이번에는 Stable Diffusion에 대해서 어떻게 코드를 작성해서 사용하는 지에 대한 글을 작성하겠습니다.
Stable Diffusion은 텍스트 프롬프트를 기반으로 이미지를 생성하는 딥러닝 모델로, Latent Diffusion Model(LDM)을 활용하여 노이즈를 점진적으로 제거하는 방식으로 동작합니다. 이 모델은 GPU 메모리 사용을 최적화하여 비교적 적은 자원으로도 실행할 수 있으며, 오픈소스로 제공되어 다양한 커스텀 모델이 개발되고 있습니다. DreamBooth나 ControlNet 같은 확장 기능을 통해 특정 스타일이나 조건을 반영한 이미지 생성이 가능하고 주로 AI 아트, 디자인, 데이터 증강 등의 분야에서 활용되며, 프롬프트 엔지니어링을 통해 원하는 결과를 보다 정밀하게 조정할 수 있습니다.
한번 만들어놓고 잘 활용한다면 다른 task에 적용할 수도 있고 재미로 만들어둬도 좋을 것 같아 정리합니다.
Stable Diffusion에 대한 논문 정리는 추후 작성할 예정입니다.
제가 참고한 다른 자료는 아래와 같습니다.
https://huggingface.co/blog/OzzyGT/outpainting-differential-diffusion
1. 필요한 패키지 설치하기
어려운 부분이 아니니 추가 설명은 하지 않겠습니다. Conda나 venv와 같은 가상환경에서 하시는 것은 자유입니다.
requirements.txt 파일은 첨부해두겠습니다.
! pip install -r requirements.txt
2. Utils.py 생성
API를 만드는 데 필요한 코드를 생성합니다. utils라는 폴더를 생성 후 utils.py 파일 안에 밑의 코드를 복사 붙여넣기 하시면 됩니다. 함수는 크게 5개가 있고 코드 가독성 및 분리성을 위해서 편하신대로 수정하셔도 좋습니다.
import cv2
import numpy as np
import torch
import random
def merge_images(original, new_image, offset, direction):
if direction in ["left", "right"]:
merged_image = np.zeros((original.shape[0], original.shape[1] + offset, 3), dtype=np.uint8)
elif direction in ["top", "bottom"]:
merged_image = np.zeros((original.shape[0] + offset, original.shape[1], 3), dtype=np.uint8)
if direction == "left":
merged_image[:, offset:] = original
merged_image[:, : new_image.shape[1]] = new_image
elif direction == "right":
merged_image[:, : original.shape[1]] = original
merged_image[:, original.shape[1] + offset - new_image.shape[1] : original.shape[1] + offset] = new_image
elif direction == "top":
merged_image[offset:, :] = original
merged_image[: new_image.shape[0], :] = new_image
elif direction == "bottom":
merged_image[: original.shape[0], :] = original
merged_image[original.shape[0] + offset - new_image.shape[0] : original.shape[0] + offset, :] = new_image
return merged_image
def slice_image(image):
height, width, _ = image.shape
slice_size = min(width // 2, height // 3)
slices = []
for h in range(3):
for w in range(2):
left = w * slice_size
upper = h * slice_size
right = left + slice_size
lower = upper + slice_size
if w == 1 and right > width:
left -= right - width
right = width
if h == 2 and lower > height:
upper -= lower - height
lower = height
slice = image[upper:lower, left:right]
slices.append(slice)
return slices
def process_image(
image,
fill_color=(0, 0, 0),
mask_offset=50,
blur_radius=500,
expand_pixels=256,
direction="left",
inpaint_mask_color=50,
max_size=1024,
):
height, width = image.shape[:2]
new_height = height + (expand_pixels if direction in ["top", "bottom"] else 0)
new_width = width + (expand_pixels if direction in ["left", "right"] else 0)
if new_height > max_size:
# If so, crop the image from the opposite side
if direction == "top":
image = image[:max_size, :]
elif direction == "bottom":
image = image[new_height - max_size :, :]
new_height = max_size
if new_width > max_size:
# If so, crop the image from the opposite side
if direction == "left":
image = image[:, :max_size]
elif direction == "right":
image = image[:, new_width - max_size :]
new_width = max_size
height, width = image.shape[:2]
new_image = np.full((new_height, new_width, 3), fill_color, dtype=np.uint8)
mask = np.full_like(new_image, 255, dtype=np.uint8)
inpaint_mask = np.full_like(new_image, 0, dtype=np.uint8)
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
inpaint_mask = cv2.cvtColor(inpaint_mask, cv2.COLOR_BGR2GRAY)
if direction == "left":
new_image[:, expand_pixels:] = image[:, : max_size - expand_pixels]
mask[:, : expand_pixels + mask_offset] = inpaint_mask_color
inpaint_mask[:, :expand_pixels] = 255
elif direction == "right":
new_image[:, :width] = image
mask[:, width - mask_offset :] = inpaint_mask_color
inpaint_mask[:, width:] = 255
elif direction == "top":
new_image[expand_pixels:, :] = image[: max_size - expand_pixels, :]
mask[: expand_pixels + mask_offset, :] = inpaint_mask_color
inpaint_mask[:expand_pixels, :] = 255
elif direction == "bottom":
new_image[:height, :] = image
mask[height - mask_offset :, :] = inpaint_mask_color
inpaint_mask[height:, :] = 255
# mask blur
if blur_radius % 2 == 0:
blur_radius += 1
mask = cv2.GaussianBlur(mask, (blur_radius, blur_radius), 0)
# telea inpaint
_, mask_np = cv2.threshold(inpaint_mask, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
inpaint = cv2.inpaint(new_image, mask_np, 3, cv2.INPAINT_TELEA)
# convert image to tensor
inpaint = cv2.cvtColor(inpaint, cv2.COLOR_BGR2RGB)
inpaint = torch.from_numpy(inpaint).permute(2, 0, 1).float()
inpaint = inpaint / 127.5 - 1
inpaint = inpaint.unsqueeze(0).to("cuda")
# convert mask to tensor
mask = torch.from_numpy(mask)
mask = mask.unsqueeze(0).float() / 255.0
mask = mask.to("cuda")
return inpaint, mask
def image_resize(image, new_size=1024):
height, width = image.shape[:2]
aspect_ratio = width / height
new_width = new_size
new_height = new_size
if aspect_ratio != 1:
if width > height:
new_height = int(new_size / aspect_ratio)
else:
new_width = int(new_size * aspect_ratio)
image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4)
return image
def generate_image(pipeline, prompt, negative_prompt, image, mask, ip_adapter_image, seed: int = None):
if seed is None:
seed = random.randint(0, 2**32 - 1)
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
width=1024,
height=1024,
guidance_scale=4.0,
num_inference_steps=25,
original_image=image,
image=image,
strength=1.0,
map=mask,
generator=generator,
ip_adapter_image=[ip_adapter_image],
output_type="np",
).images[0]
image = (image * 255).astype(np.uint8)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
3. app.py 생성
마지막 코드 추가입니다. API는 크게 txt2img와 outpaint에 대해서 정리하였습니다. 코드가 어렵지 않으니 이해하는 것도 어렵지 않으실 껍니다.
여기서 저는 모델을 SG161222/RealVisXL_V5.0 모델을 사용했는데 더 좋은 다른 모델이 있다면 다른 모델을 사용하셔도 좋고 CivitAI라는 페이지에서 모델을 다운로드 받아서 사용하셔도 좋습니다.
import os
import io
import torch
import random
import cv2
import numpy as np
import sys
from fastapi import FastAPI, Form, File, UploadFile
from fastapi.responses import StreamingResponse
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionXLPipeline
from utils.utils import generate_image, image_resize, process_image, slice_image, merge_images
app = FastAPI()
TXT2IMG_MODEL = "SG161222/RealVisXL_V5.0"
OUTPAINT_MODEL = "SG161222/RealVisXL_V5.0"
ADAPTER_MODEL = "h94/IP-Adapter"
device = "cuda" if torch.cuda.is_available() else "cpu"
# txt2img 모델
txt2img_pipe = DiffusionPipeline.from_pretrained(TXT2IMG_MODEL).to("cuda")
# outpaint 모델
outpaint_pipe = StableDiffusionXLPipeline.from_pretrained(
OUTPAINT_MODEL,
custom_pipeline="pipeline_stable_diffusion_xl_differential_img2img",
).to("cuda")
outpaint_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
outpaint_pipe.scheduler.config, use_karras_sigmas=True
)
outpaint_pipe.load_ip_adapter(
ADAPTER_MODEL,
subfolder="sdxl_models",
weight_name=[
"ip-adapter-plus_sdxl_vit-h.safetensors",
],
image_encoder_folder="models/image_encoder",
)
outpaint_pipe.set_ip_adapter_scale(0.1)
outpaint_pipe = outpaint_pipe.to(torch.float16)
@app.post("/stable-diffusion/txt2img")
async def txt2img(
text_prompt: str = Form(..., description="A realistic portrait of a woman with a smile, 8k"),
negative_prompt: str = Form(
"Signature, Poor body structure, Low-quality drawing, \
Incorrect size, Outside the edges, Unclear, Dull background, Logo, \
Cropped, Trimmed, Body parts separated, Uneven size, Twisted, Copy, \
Duplicated elements, Additional arms, fingers, hands, legs, Additional body parts, Flaw, \
Imperfection, Joined fingers, Unpleasant size, Identifying sign, Incorrect structure, \
Wrong proportion, Tacky, Poor quality, Poor clarity, Spot, Absent arms, fingers, hands, \
legs, Error, Damaged, Beyond the image, Badly drawn face, feet, hands, Text on paper, Repulsive, \
Unpleasant size, Shortened, Narrow eyes, Visual plan, Arrangement, Cut off, Unpleasant, \
Blurry, Unattractive, Awkward position, Imaginary framework, Watermark"
),
guidance_scale: float = Form(8.5),
num_inference_steps: int = Form(50),
hires_fix: bool = Form(True),
hires_steps: int = Form(10),
upscaler: str = Form("4x-UltraSharp"),
denoising_strength: float = Form(0.3),
upscale_by: float = Form(1.5),
):
seed = random.randint(0, 2**32 - 1)
generator = torch.manual_seed(seed)
output_image = txt2img_pipe(
prompt=text_prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
sampler="dpm++ 2m karras",
hires_fix=hires_fix,
hires_steps=hires_steps,
upscaler=upscaler,
denoising_strength=denoising_strength,
upscale_by=upscale_by,
generator=generator,
).images[0]
image_cv = np.array(output_image)
image_cv = cv2.cvtColor(image_cv, cv2.COLOR_RGB2BGR)
_, img_encoded = cv2.imencode(".png", image_cv)
img_bytes = io.BytesIO(img_encoded.tobytes())
return StreamingResponse(img_bytes, media_type="image/png")
@app.post("/stable-diffusion/outpaint")
async def outpaint(
image_file: UploadFile = File(...),
direction: str = Form(..., description="Direction to expand (left, right, top, bottom)"),
expand_pixels: int = Form(32, description="Number of pixels to expand"),
inpaint_mask_color: int = Form(50, description="Inpaint mask color"),
times_to_expand: int = Form(4, description="Number of times to expand"),
prompt: str = Form("8k image", description="Prompt for image generation"),
negative_prompt: str = Form(
"Signature, Poor body structure, Low-quality drawing, \
Incorrect size, Outside the edges, Unclear, Dull background, Logo, \
Cropped, Trimmed, Body parts separated, Uneven size, Twisted, Copy, \
Duplicated elements, Additional arms, fingers, hands, legs, Additional body parts, Flaw, \
Imperfection, Joined fingers, Unpleasant size, Identifying sign, Incorrect structure, \
Wrong proportion, Tacky, Poor quality, Poor clarity, Spot, Absent arms, fingers, hands, \
legs, Error, Damaged, Beyond the image, Badly drawn face, feet, hands, Text on paper, Repulsive, \
Unpleasant size, Shortened, Narrow eyes, Visual plan, Arrangement, Cut off, Unpleasant, \
Blurry, Unattractive, Awkward position, Imaginary framework, Watermark"
),
):
image_data = await image_file.read()
image_array = np.frombuffer(image_data, np.uint8)
original_img = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
image = image_resize(original_img)
expand_pixels_to_square = 1024 - image.shape[1]
image, mask = process_image(
image, expand_pixels=expand_pixels_to_square, direction=direction, inpaint_mask_color=inpaint_mask_color
)
ip_adapter_image = [part for part in slice_image(original_img)]
generated = generate_image(outpaint_pipe, prompt, negative_prompt, image, mask, ip_adapter_image)
final_image = generated
for _ in range(times_to_expand):
image, mask = process_image(
final_image, direction=direction, expand_pixels=expand_pixels, inpaint_mask_color=inpaint_mask_color
)
ip_adapter_image = [part for part in slice_image(generated)]
generated = generate_image(outpaint_pipe, prompt, negative_prompt, image, mask, ip_adapter_image)
final_image = merge_images(final_image, generated, expand_pixels, direction)
_, img_encoded = cv2.imencode(".png", final_image)
img_bytes = io.BytesIO(img_encoded.tobytes())
return StreamingResponse(img_bytes, media_type="image/png")
4. 실행 및 결과 확인
위 방식대로 하셨으면 폴더 구조는 아래처럼 됩니다.
├── api.py
├── requirements.txt
└── utils
└── utils.py
실행하는 방법은 아래 커맨드를 입력해주시면 됩니다. 그러면 웹에서 localhost:4444에 접근을 할 수 있게 됩니다. 근데 이 화면은 아무 UI가 없으니 확인하려면 Swagger로 확인해보시는 게 좋습니다.
! uvicorn app:app --host 0.0.0.0 --port 4444 # port는 자유
그래서 localhost:4444/docs로 접근해보죠. 그러면 아래와 같은 화면이 뜨게 됩니다.
어떻게 사용하는 지 알려드리겠습니다.
일단 두 API 중 위부터 보겠습니다. 화살표를 누르고 "Try it out" 버튼을 누르시고 나머지는 건드릴 필요 없이 text_prompt에 원하시는 프롬프트를 입력하시면 됩니다. 단, 꼭 영어로 하셔야 합니다!
업로드 하시고 Execute를 누르면 됩니다. 나머지 parameter들은 수정하셔도 되고 안하셔도 됩니다.
실행하면 아래와 같이 이미지를 얻을 수 있습니다. 간단하죠?
아래 API도 비슷합니다. 원하는 이미지를 입력하시고 "left", "right", "top", "bottom" 넷 중 하나를 direction에 입력하시고 똑같이 실행해보겠습니다.
속도는 느리지만 간단하게 이미지를 출력하는 것을 보실 수 있습니다.
궁금하신 사항 있으시면 언제든 댓글로 의견 주시면 감사하겠습니다.
'Machine Learning' 카테고리의 다른 글
PDFTranslate FastAPI 구현 (1) | 2025.02.04 |
---|---|
Grounding-DINO FastAPI 구현 (1) | 2024.12.06 |
GLIP : Grounded Language-Image Pre-training (1) | 2024.07.26 |
CNN 기반 모델들 (0) | 2024.06.25 |
딥러닝 기초 지식 (2) | 2024.06.11 |