import sys import os import gradio as gr import spaces import tempfile import numpy as np import io import base64 from gradio_client import Client, handle_file from huggingface_hub import snapshot_download from gradio_magicquillv2 import MagicQuillV2 from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware import uvicorn import requests from PIL import Image, ImageOps import random import time import torch import json # Try importing as a package (recommended) from edit_space import KontextEditModel from util import ( load_and_preprocess_image, read_base64_image as read_base64_image_utils, create_alpha_mask, tensor_to_base64, get_mask_bbox ) # Initialize models print("Downloading models...") hf_token = os.environ.get("hf_token") snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models", token=hf_token) print("Initializing models...") kontext_model = KontextEditModel() # Initialize SAM Client # Replace with your actual SAM Space ID sam_client = Client("LiuZichen/MagicQuillHelper") print("Models initialized.") css = """ .ms { width: 60%; margin: auto } """ url = "http://localhost:7860" @spaces.GPU def generate(merged_image, total_mask, original_image, add_color_image, add_edge_mask, remove_edge_mask, fill_mask, add_prop_image, positive_prompt, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg): print("prompt is:", positive_prompt) print("other parameters:", negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg) if kontext_model is None: raise RuntimeError("KontextEditModel not initialized") # Preprocess inputs # utils.read_base64_image returns BytesIO, which create_alpha_mask accepts (via Image.open) # load_and_preprocess_image accepts path, so we might need to check if it accepts file-like object. # utils.load_and_preprocess_image uses Image.open(image_path), so BytesIO works. merged_image_tensor = load_and_preprocess_image(read_base64_image_utils(merged_image)) total_mask_tensor = create_alpha_mask(read_base64_image_utils(total_mask)) original_image_tensor = load_and_preprocess_image(read_base64_image_utils(original_image)) if add_color_image: add_color_image_tensor = load_and_preprocess_image(read_base64_image_utils(add_color_image)) else: add_color_image_tensor = original_image_tensor add_mask = create_alpha_mask(read_base64_image_utils(add_edge_mask)) if add_edge_mask else torch.zeros_like(total_mask_tensor) remove_mask = create_alpha_mask(read_base64_image_utils(remove_edge_mask)) if remove_edge_mask else torch.zeros_like(total_mask_tensor) add_prop_mask = create_alpha_mask(read_base64_image_utils(add_prop_image)) if add_prop_image else torch.zeros_like(total_mask_tensor) fill_mask_tensor = create_alpha_mask(read_base64_image_utils(fill_mask)) if fill_mask else torch.zeros_like(total_mask_tensor) # Determine flag and modify prompt flag = "kontext" if torch.sum(add_prop_mask) > 0: flag = "foreground" positive_prompt = "Fill in the white region naturally and adapt the foreground into the background. Fix the perspective of the foreground object if necessary. " + positive_prompt elif torch.sum(fill_mask_tensor).item() > 0: flag = "local" elif (torch.sum(remove_mask).item() > 0 and torch.sum(add_mask).item() == 0): positive_prompt = "remove the instance" flag = "removal" elif (torch.sum(add_mask).item() > 0 or torch.sum(remove_mask).item() > 0 or (not torch.equal(original_image_tensor, add_color_image_tensor))): flag = "precise_edit" print("positive prompt: ", positive_prompt) print("current flag: ", flag) final_image, condition, mask = kontext_model.process( original_image_tensor, add_color_image_tensor, merged_image_tensor, positive_prompt, total_mask_tensor, add_mask, remove_mask, add_prop_mask, fill_mask_tensor, fine_edge, fix_perspective, edge_strength, color_strength, local_strength, grow_size, seed, steps, cfg, flag, ) # tensor_to_base64 returns pure base64 string res_base64 = tensor_to_base64(final_image) return res_base64 def generate_image_handler(x, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg): merged_image = x['from_frontend']['img'] total_mask = x['from_frontend']['total_mask'] original_image = x['from_frontend']['original_image'] add_color_image = x['from_frontend']['add_color_image'] add_edge_mask = x['from_frontend']['add_edge_mask'] remove_edge_mask = x['from_frontend']['remove_edge_mask'] fill_mask = x['from_frontend']['fill_mask'] add_prop_image = x['from_frontend']['add_prop_image'] positive_prompt = x['from_backend']['prompt'] try: res_base64 = generate( merged_image, total_mask, original_image, add_color_image, add_edge_mask, remove_edge_mask, fill_mask, add_prop_image, positive_prompt, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg ) x["from_backend"]["generated_image"] = res_base64 except Exception as e: print(f"Error in generation: {e}") x["from_backend"]["generated_image"] = None return x with gr.Blocks(title="MagicQuill V2") as demo: with gr.Row(): ms = MagicQuillV2() with gr.Row(): with gr.Column(): btn = gr.Button("Run", variant="primary") with gr.Column(): with gr.Accordion("parameters", open=False): negative_prompt = gr.Textbox( label="Negative Prompt", value="", interactive=True ) fine_edge = gr.Radio( label="Fine Edge", choices=['enable', 'disable'], value='disable', interactive=True ) fix_perspective = gr.Radio( label="Fix Perspective", choices=['enable', 'disable'], value='disable', interactive=True ) grow_size = gr.Slider( label="Grow Size", minimum=10, maximum=100, value=50, step=1, interactive=True ) edge_strength = gr.Slider( label="Edge Strength", minimum=0.0, maximum=5.0, value=0.6, step=0.01, interactive=True ) color_strength = gr.Slider( label="Color Strength", minimum=0.0, maximum=5.0, value=1.5, step=0.01, interactive=True ) local_strength = gr.Slider( label="Local Strength", minimum=0.0, maximum=5.0, value=1.0, step=0.01, interactive=True ) seed = gr.Number( label="Seed", value=-1, precision=0, interactive=True ) steps = gr.Slider( label="Steps", minimum=0, maximum=50, value=20, interactive=True ) cfg = gr.Slider( label="CFG", minimum=0.0, maximum=20.0, value=3.5, step=0.1, interactive=True ) btn.click(generate_image_handler, inputs=[ms, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg], outputs=ms) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def get_root_url( request: Request, route_path: str, root_path: str | None ): print(root_path) return root_path import gradio.route_utils gr.route_utils.get_root_url = get_root_url gr.mount_gradio_app(app, demo, path="/demo", root_path="/demo") @app.post("/magic_quill/generate_image") async def generate_image(request: Request): data = await request.json() res = generate( data["merged_image"], data["total_mask"], data["original_image"], data["add_color_image"], data["add_edge_mask"], data["remove_edge_mask"], data["fill_mask"], data["add_prop_image"], data["positive_prompt"], data["negative_prompt"], data["fine_edge"], data["fix_perspective"], data["grow_size"], data["edge_strength"], data["color_strength"], data["local_strength"], data["seed"], data["steps"], data["cfg"] ) return {'res': res} @app.post("/magic_quill/process_background_img") async def process_background_img(request: Request): img = await request.json() from util import process_background # process_background returns tensor [1, H, W, 3] in uint8 or float resized_img_tensor = process_background(img) # tensor_to_base64 from util expects tensor resized_img_base64 = "data:image/webp;base64," + tensor_to_base64( resized_img_tensor, quality=80, method=6 ) return resized_img_base64 @app.post("/magic_quill/segmentation") async def segmentation(request: Request): json_data = await request.json() image_base64 = json_data.get("image", None) coordinates_positive = json_data.get("coordinates_positive", None) coordinates_negative = json_data.get("coordinates_negative", None) bboxes = json_data.get("bboxes", None) if sam_client is None: return {"error": "sam client not initialized"} # Process coordinates and bboxes pos_coordinates = None if coordinates_positive and len(coordinates_positive) > 0: pos_coordinates = [] for coord in coordinates_positive: coord['x'] = int(round(coord['x'])) coord['y'] = int(round(coord['y'])) pos_coordinates.append({'x': coord['x'], 'y': coord['y']}) pos_coordinates = json.dumps(pos_coordinates) neg_coordinates = None if coordinates_negative and len(coordinates_negative) > 0: neg_coordinates = [] for coord in coordinates_negative: coord['x'] = int(round(coord['x'])) coord['y'] = int(round(coord['y'])) neg_coordinates.append({'x': coord['x'], 'y': coord['y']}) neg_coordinates = json.dumps(neg_coordinates) bboxes_xyxy = None if bboxes and len(bboxes) > 0: valid_bboxes = [] for bbox in bboxes: if (bbox.get("startX") is None or bbox.get("startY") is None or bbox.get("endX") is None or bbox.get("endY") is None): continue else: x_min = max(min(int(bbox["startX"]), int(bbox["endX"])), 0) y_min = max(min(int(bbox["startY"]), int(bbox["endY"])), 0) # Note: image_tensor not available here easily without loading image, # but usually we don't need to clip strictly if SAM handles it or we clip to large values # For now, we skip strict clipping against image dims or assume 10000 x_max = int(bbox["startX"]) if int(bbox["startX"]) > int(bbox["endX"]) else int(bbox["endX"]) y_max = int(bbox["startY"]) if int(bbox["startY"]) > int(bbox["endY"]) else int(bbox["endY"]) valid_bboxes.append((x_min, y_min, x_max, y_max)) bboxes_xyxy = [] for bbox in valid_bboxes: x_min, y_min, x_max, y_max = bbox bboxes_xyxy.append((x_min, y_min, x_max, y_max)) # Convert to JSON string if that's what the client expects, or keep as list # Assuming JSON string for consistency with coords if bboxes_xyxy: bboxes_xyxy = json.dumps(bboxes_xyxy) print(f"Segmentation request: pos={pos_coordinates}, neg={neg_coordinates}, bboxes={bboxes_xyxy}") try: # Save base64 image to temp file image_bytes = read_base64_image_utils(image_base64) # Image.open to verify and save as WebP (smaller size) pil_image = Image.open(image_bytes) with tempfile.NamedTemporaryFile(suffix=".webp", delete=False) as temp_in: pil_image.save(temp_in.name, format="WEBP", quality=80) temp_in_path = temp_in.name # Execute segmentation via Client # We assume the remote space returns a filepath to the segmented image (with alpha) # NOW it returns mask_np image result_path = sam_client.predict( handle_file(temp_in_path), pos_coordinates, neg_coordinates, bboxes_xyxy, api_name="/segment" ) # Clean up input temp os.unlink(temp_in_path) # Process result # result_path should be a generic object, usually a tuple (image_path, mask_path) or just image_path # Depending on how the remote space is implemented. if isinstance(result_path, (list, tuple)): result_path = result_path[0] # Take the first return value if multiple if not result_path or not os.path.exists(result_path): raise RuntimeError("Client returned invalid result path") # result_path is the Mask Image (White=Selected, Black=Background) mask_pil = Image.open(result_path) if mask_pil.mode != 'L': mask_pil = mask_pil.convert('L') pil_image = pil_image.convert("RGB") if pil_image.size != mask_pil.size: mask_pil = mask_pil.resize(pil_image.size, Image.NEAREST) r, g, b = pil_image.split() res_pil = Image.merge("RGBA", (r, g, b, mask_pil)) # Extract bbox from mask (alpha) mask_tensor = torch.from_numpy(np.array(mask_pil) / 255.0).float().unsqueeze(0) mask_bbox = get_mask_bbox(mask_tensor) if mask_bbox: x_min, y_min, x_max, y_max = mask_bbox seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max} else: seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0} print(seg_bbox) # Convert result to base64 # We need to convert the PIL image to base64 string buffered = io.BytesIO() res_pil.save(buffered, format="PNG") image_base64_res = base64.b64encode(buffered.getvalue()).decode("utf-8") return { "error": False, "segmentation_image": "data:image/png;base64," + image_base64_res, "segmentation_bbox": seg_bbox } except Exception as e: print(f"Error in segmentation: {e}") return {"error": str(e)} app = gr.mount_gradio_app(app, demo, "/") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860) # demo.launch()