MagicQuillV2 / app.py
LiuZichen's picture
update
f460ce6
raw
history blame
16.2 kB
import sys
import os
import gradio as gr
import spaces
import tempfile
import numpy as np
import io
import base64
from gradio_client import Client, handle_file
from huggingface_hub import snapshot_download
from gradio_magicquillv2 import MagicQuillV2
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import requests
from PIL import Image, ImageOps
import random
import time
import torch
import json
# Try importing as a package (recommended)
from edit_space import KontextEditModel
from util import (
load_and_preprocess_image,
read_base64_image as read_base64_image_utils,
create_alpha_mask,
tensor_to_base64,
get_mask_bbox
)
# Initialize models
print("Downloading models...")
hf_token = os.environ.get("hf_token")
snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models", token=hf_token)
print("Initializing models...")
kontext_model = KontextEditModel()
# Initialize SAM Client
# Replace with your actual SAM Space ID
sam_client = Client("LiuZichen/MagicQuillHelper")
print("Models initialized.")
css = """
.ms {
width: 60%;
margin: auto
}
"""
url = "http://localhost:7860"
@spaces.GPU
def generate(merged_image, total_mask, original_image, add_color_image, add_edge_mask, remove_edge_mask, fill_mask, add_prop_image, positive_prompt, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg):
print("prompt is:", positive_prompt)
print("other parameters:", negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg)
if kontext_model is None:
raise RuntimeError("KontextEditModel not initialized")
# Preprocess inputs
# utils.read_base64_image returns BytesIO, which create_alpha_mask accepts (via Image.open)
# load_and_preprocess_image accepts path, so we might need to check if it accepts file-like object.
# utils.load_and_preprocess_image uses Image.open(image_path), so BytesIO works.
merged_image_tensor = load_and_preprocess_image(read_base64_image_utils(merged_image))
total_mask_tensor = create_alpha_mask(read_base64_image_utils(total_mask))
original_image_tensor = load_and_preprocess_image(read_base64_image_utils(original_image))
if add_color_image:
add_color_image_tensor = load_and_preprocess_image(read_base64_image_utils(add_color_image))
else:
add_color_image_tensor = original_image_tensor
add_mask = create_alpha_mask(read_base64_image_utils(add_edge_mask)) if add_edge_mask else torch.zeros_like(total_mask_tensor)
remove_mask = create_alpha_mask(read_base64_image_utils(remove_edge_mask)) if remove_edge_mask else torch.zeros_like(total_mask_tensor)
add_prop_mask = create_alpha_mask(read_base64_image_utils(add_prop_image)) if add_prop_image else torch.zeros_like(total_mask_tensor)
fill_mask_tensor = create_alpha_mask(read_base64_image_utils(fill_mask)) if fill_mask else torch.zeros_like(total_mask_tensor)
# Determine flag and modify prompt
flag = "kontext"
if torch.sum(add_prop_mask) > 0:
flag = "foreground"
positive_prompt = "Fill in the white region naturally and adapt the foreground into the background. Fix the perspective of the foreground object if necessary. " + positive_prompt
elif torch.sum(fill_mask_tensor).item() > 0:
flag = "local"
elif (torch.sum(remove_mask).item() > 0 and torch.sum(add_mask).item() == 0):
positive_prompt = "remove the instance"
flag = "removal"
elif (torch.sum(add_mask).item() > 0 or torch.sum(remove_mask).item() > 0 or (not torch.equal(original_image_tensor, add_color_image_tensor))):
flag = "precise_edit"
print("positive prompt: ", positive_prompt)
print("current flag: ", flag)
final_image, condition, mask = kontext_model.process(
original_image_tensor,
add_color_image_tensor,
merged_image_tensor,
positive_prompt,
total_mask_tensor,
add_mask,
remove_mask,
add_prop_mask,
fill_mask_tensor,
fine_edge,
fix_perspective,
edge_strength,
color_strength,
local_strength,
grow_size,
seed,
steps,
cfg,
flag,
)
# tensor_to_base64 returns pure base64 string
res_base64 = tensor_to_base64(final_image)
return res_base64
def generate_image_handler(x, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg):
merged_image = x['from_frontend']['img']
total_mask = x['from_frontend']['total_mask']
original_image = x['from_frontend']['original_image']
add_color_image = x['from_frontend']['add_color_image']
add_edge_mask = x['from_frontend']['add_edge_mask']
remove_edge_mask = x['from_frontend']['remove_edge_mask']
fill_mask = x['from_frontend']['fill_mask']
add_prop_image = x['from_frontend']['add_prop_image']
positive_prompt = x['from_backend']['prompt']
try:
res_base64 = generate(
merged_image,
total_mask,
original_image,
add_color_image,
add_edge_mask,
remove_edge_mask,
fill_mask,
add_prop_image,
positive_prompt,
negative_prompt,
fine_edge,
fix_perspective,
grow_size,
edge_strength,
color_strength,
local_strength,
seed,
steps,
cfg
)
x["from_backend"]["generated_image"] = res_base64
except Exception as e:
print(f"Error in generation: {e}")
x["from_backend"]["generated_image"] = None
return x
with gr.Blocks(title="MagicQuill V2") as demo:
with gr.Row():
ms = MagicQuillV2()
with gr.Row():
with gr.Column():
btn = gr.Button("Run", variant="primary")
with gr.Column():
with gr.Accordion("parameters", open=False):
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="",
interactive=True
)
fine_edge = gr.Radio(
label="Fine Edge",
choices=['enable', 'disable'],
value='disable',
interactive=True
)
fix_perspective = gr.Radio(
label="Fix Perspective",
choices=['enable', 'disable'],
value='disable',
interactive=True
)
grow_size = gr.Slider(
label="Grow Size",
minimum=10,
maximum=100,
value=50,
step=1,
interactive=True
)
edge_strength = gr.Slider(
label="Edge Strength",
minimum=0.0,
maximum=5.0,
value=0.6,
step=0.01,
interactive=True
)
color_strength = gr.Slider(
label="Color Strength",
minimum=0.0,
maximum=5.0,
value=1.5,
step=0.01,
interactive=True
)
local_strength = gr.Slider(
label="Local Strength",
minimum=0.0,
maximum=5.0,
value=1.0,
step=0.01,
interactive=True
)
seed = gr.Number(
label="Seed",
value=-1,
precision=0,
interactive=True
)
steps = gr.Slider(
label="Steps",
minimum=0,
maximum=50,
value=20,
interactive=True
)
cfg = gr.Slider(
label="CFG",
minimum=0.0,
maximum=20.0,
value=3.5,
step=0.1,
interactive=True
)
btn.click(generate_image_handler, inputs=[ms, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg], outputs=ms)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def get_root_url(
request: Request, route_path: str, root_path: str | None
):
print(root_path)
return root_path
import gradio.route_utils
gr.route_utils.get_root_url = get_root_url
gr.mount_gradio_app(app, demo, path="/demo", root_path="/demo")
@app.post("/magic_quill/generate_image")
async def generate_image(request: Request):
data = await request.json()
res = generate(
data["merged_image"],
data["total_mask"],
data["original_image"],
data["add_color_image"],
data["add_edge_mask"],
data["remove_edge_mask"],
data["fill_mask"],
data["add_prop_image"],
data["positive_prompt"],
data["negative_prompt"],
data["fine_edge"],
data["fix_perspective"],
data["grow_size"],
data["edge_strength"],
data["color_strength"],
data["local_strength"],
data["seed"],
data["steps"],
data["cfg"]
)
return {'res': res}
@app.post("/magic_quill/process_background_img")
async def process_background_img(request: Request):
img = await request.json()
from util import process_background
# process_background returns tensor [1, H, W, 3] in uint8 or float
resized_img_tensor = process_background(img)
# tensor_to_base64 from util expects tensor
resized_img_base64 = "data:image/webp;base64," + tensor_to_base64(
resized_img_tensor,
quality=80,
method=6
)
return resized_img_base64
@app.post("/magic_quill/segmentation")
async def segmentation(request: Request):
json_data = await request.json()
image_base64 = json_data.get("image", None)
coordinates_positive = json_data.get("coordinates_positive", None)
coordinates_negative = json_data.get("coordinates_negative", None)
bboxes = json_data.get("bboxes", None)
if sam_client is None:
return {"error": "sam client not initialized"}
# Process coordinates and bboxes
pos_coordinates = None
if coordinates_positive and len(coordinates_positive) > 0:
pos_coordinates = []
for coord in coordinates_positive:
coord['x'] = int(round(coord['x']))
coord['y'] = int(round(coord['y']))
pos_coordinates.append({'x': coord['x'], 'y': coord['y']})
pos_coordinates = json.dumps(pos_coordinates)
neg_coordinates = None
if coordinates_negative and len(coordinates_negative) > 0:
neg_coordinates = []
for coord in coordinates_negative:
coord['x'] = int(round(coord['x']))
coord['y'] = int(round(coord['y']))
neg_coordinates.append({'x': coord['x'], 'y': coord['y']})
neg_coordinates = json.dumps(neg_coordinates)
bboxes_xyxy = None
if bboxes and len(bboxes) > 0:
valid_bboxes = []
for bbox in bboxes:
if (bbox.get("startX") is None or
bbox.get("startY") is None or
bbox.get("endX") is None or
bbox.get("endY") is None):
continue
else:
x_min = max(min(int(bbox["startX"]), int(bbox["endX"])), 0)
y_min = max(min(int(bbox["startY"]), int(bbox["endY"])), 0)
# Note: image_tensor not available here easily without loading image,
# but usually we don't need to clip strictly if SAM handles it or we clip to large values
# For now, we skip strict clipping against image dims or assume 10000
x_max = int(bbox["startX"]) if int(bbox["startX"]) > int(bbox["endX"]) else int(bbox["endX"])
y_max = int(bbox["startY"]) if int(bbox["startY"]) > int(bbox["endY"]) else int(bbox["endY"])
valid_bboxes.append((x_min, y_min, x_max, y_max))
bboxes_xyxy = []
for bbox in valid_bboxes:
x_min, y_min, x_max, y_max = bbox
bboxes_xyxy.append((x_min, y_min, x_max, y_max))
# Convert to JSON string if that's what the client expects, or keep as list
# Assuming JSON string for consistency with coords
if bboxes_xyxy:
bboxes_xyxy = json.dumps(bboxes_xyxy)
print(f"Segmentation request: pos={pos_coordinates}, neg={neg_coordinates}, bboxes={bboxes_xyxy}")
try:
# Save base64 image to temp file
image_bytes = read_base64_image_utils(image_base64)
# Image.open to verify and save as WebP (smaller size)
pil_image = Image.open(image_bytes)
with tempfile.NamedTemporaryFile(suffix=".webp", delete=False) as temp_in:
pil_image.save(temp_in.name, format="WEBP", quality=80)
temp_in_path = temp_in.name
# Execute segmentation via Client
# We assume the remote space returns a filepath to the segmented image (with alpha)
# NOW it returns mask_np image
result_path = sam_client.predict(
handle_file(temp_in_path),
pos_coordinates,
neg_coordinates,
bboxes_xyxy,
api_name="/segment"
)
# Clean up input temp
os.unlink(temp_in_path)
# Process result
# result_path should be a generic object, usually a tuple (image_path, mask_path) or just image_path
# Depending on how the remote space is implemented.
if isinstance(result_path, (list, tuple)):
result_path = result_path[0] # Take the first return value if multiple
if not result_path or not os.path.exists(result_path):
raise RuntimeError("Client returned invalid result path")
# result_path is the Mask Image (White=Selected, Black=Background)
mask_pil = Image.open(result_path)
if mask_pil.mode != 'L':
mask_pil = mask_pil.convert('L')
pil_image = pil_image.convert("RGB")
if pil_image.size != mask_pil.size:
mask_pil = mask_pil.resize(pil_image.size, Image.NEAREST)
r, g, b = pil_image.split()
res_pil = Image.merge("RGBA", (r, g, b, mask_pil))
# Extract bbox from mask (alpha)
mask_tensor = torch.from_numpy(np.array(mask_pil) / 255.0).float().unsqueeze(0)
mask_bbox = get_mask_bbox(mask_tensor)
if mask_bbox:
x_min, y_min, x_max, y_max = mask_bbox
seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max}
else:
seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0}
print(seg_bbox)
# Convert result to base64
# We need to convert the PIL image to base64 string
buffered = io.BytesIO()
res_pil.save(buffered, format="PNG")
image_base64_res = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {
"error": False,
"segmentation_image": "data:image/png;base64," + image_base64_res,
"segmentation_bbox": seg_bbox
}
except Exception as e:
print(f"Error in segmentation: {e}")
return {"error": str(e)}
app = gr.mount_gradio_app(app, demo, "/")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
# demo.launch()