Spaces:
Running
on
Zero
Running
on
Zero
| import sys | |
| import os | |
| import gradio as gr | |
| import spaces | |
| import tempfile | |
| import numpy as np | |
| import io | |
| import base64 | |
| from gradio_client import Client, handle_file | |
| from huggingface_hub import snapshot_download | |
| from gradio_magicquillv2 import MagicQuillV2 | |
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| import requests | |
| from PIL import Image, ImageOps | |
| import random | |
| import time | |
| import torch | |
| import json | |
| # Try importing as a package (recommended) | |
| from edit_space import KontextEditModel | |
| from util import ( | |
| load_and_preprocess_image, | |
| read_base64_image as read_base64_image_utils, | |
| create_alpha_mask, | |
| tensor_to_base64, | |
| get_mask_bbox | |
| ) | |
| # Initialize models | |
| print("Downloading models...") | |
| hf_token = os.environ.get("hf_token") | |
| snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models", token=hf_token) | |
| print("Initializing models...") | |
| kontext_model = KontextEditModel() | |
| # Initialize SAM Client | |
| # Replace with your actual SAM Space ID | |
| sam_client = Client("LiuZichen/MagicQuillHelper") | |
| print("Models initialized.") | |
| css = """ | |
| .ms { | |
| width: 60%; | |
| margin: auto | |
| } | |
| """ | |
| url = "http://localhost:7860" | |
| def generate(merged_image, total_mask, original_image, add_color_image, add_edge_mask, remove_edge_mask, fill_mask, add_prop_image, positive_prompt, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg): | |
| print("prompt is:", positive_prompt) | |
| print("other parameters:", negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg) | |
| if kontext_model is None: | |
| raise RuntimeError("KontextEditModel not initialized") | |
| # Preprocess inputs | |
| # utils.read_base64_image returns BytesIO, which create_alpha_mask accepts (via Image.open) | |
| # load_and_preprocess_image accepts path, so we might need to check if it accepts file-like object. | |
| # utils.load_and_preprocess_image uses Image.open(image_path), so BytesIO works. | |
| merged_image_tensor = load_and_preprocess_image(read_base64_image_utils(merged_image)) | |
| total_mask_tensor = create_alpha_mask(read_base64_image_utils(total_mask)) | |
| original_image_tensor = load_and_preprocess_image(read_base64_image_utils(original_image)) | |
| if add_color_image: | |
| add_color_image_tensor = load_and_preprocess_image(read_base64_image_utils(add_color_image)) | |
| else: | |
| add_color_image_tensor = original_image_tensor | |
| add_mask = create_alpha_mask(read_base64_image_utils(add_edge_mask)) if add_edge_mask else torch.zeros_like(total_mask_tensor) | |
| remove_mask = create_alpha_mask(read_base64_image_utils(remove_edge_mask)) if remove_edge_mask else torch.zeros_like(total_mask_tensor) | |
| add_prop_mask = create_alpha_mask(read_base64_image_utils(add_prop_image)) if add_prop_image else torch.zeros_like(total_mask_tensor) | |
| fill_mask_tensor = create_alpha_mask(read_base64_image_utils(fill_mask)) if fill_mask else torch.zeros_like(total_mask_tensor) | |
| # Determine flag and modify prompt | |
| flag = "kontext" | |
| if torch.sum(add_prop_mask) > 0: | |
| flag = "foreground" | |
| positive_prompt = "Fill in the white region naturally and adapt the foreground into the background. Fix the perspective of the foreground object if necessary. " + positive_prompt | |
| elif torch.sum(fill_mask_tensor).item() > 0: | |
| flag = "local" | |
| elif (torch.sum(remove_mask).item() > 0 and torch.sum(add_mask).item() == 0): | |
| positive_prompt = "remove the instance" | |
| flag = "removal" | |
| elif (torch.sum(add_mask).item() > 0 or torch.sum(remove_mask).item() > 0 or (not torch.equal(original_image_tensor, add_color_image_tensor))): | |
| flag = "precise_edit" | |
| print("positive prompt: ", positive_prompt) | |
| print("current flag: ", flag) | |
| final_image, condition, mask = kontext_model.process( | |
| original_image_tensor, | |
| add_color_image_tensor, | |
| merged_image_tensor, | |
| positive_prompt, | |
| total_mask_tensor, | |
| add_mask, | |
| remove_mask, | |
| add_prop_mask, | |
| fill_mask_tensor, | |
| fine_edge, | |
| fix_perspective, | |
| edge_strength, | |
| color_strength, | |
| local_strength, | |
| grow_size, | |
| seed, | |
| steps, | |
| cfg, | |
| flag, | |
| ) | |
| # tensor_to_base64 returns pure base64 string | |
| res_base64 = tensor_to_base64(final_image) | |
| return res_base64 | |
| def generate_image_handler(x, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg): | |
| merged_image = x['from_frontend']['img'] | |
| total_mask = x['from_frontend']['total_mask'] | |
| original_image = x['from_frontend']['original_image'] | |
| add_color_image = x['from_frontend']['add_color_image'] | |
| add_edge_mask = x['from_frontend']['add_edge_mask'] | |
| remove_edge_mask = x['from_frontend']['remove_edge_mask'] | |
| fill_mask = x['from_frontend']['fill_mask'] | |
| add_prop_image = x['from_frontend']['add_prop_image'] | |
| positive_prompt = x['from_backend']['prompt'] | |
| try: | |
| res_base64 = generate( | |
| merged_image, | |
| total_mask, | |
| original_image, | |
| add_color_image, | |
| add_edge_mask, | |
| remove_edge_mask, | |
| fill_mask, | |
| add_prop_image, | |
| positive_prompt, | |
| negative_prompt, | |
| fine_edge, | |
| fix_perspective, | |
| grow_size, | |
| edge_strength, | |
| color_strength, | |
| local_strength, | |
| seed, | |
| steps, | |
| cfg | |
| ) | |
| x["from_backend"]["generated_image"] = res_base64 | |
| except Exception as e: | |
| print(f"Error in generation: {e}") | |
| x["from_backend"]["generated_image"] = None | |
| return x | |
| with gr.Blocks(title="MagicQuill V2") as demo: | |
| with gr.Row(): | |
| ms = MagicQuillV2() | |
| with gr.Row(): | |
| with gr.Column(): | |
| btn = gr.Button("Run", variant="primary") | |
| with gr.Column(): | |
| with gr.Accordion("parameters", open=False): | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| value="", | |
| interactive=True | |
| ) | |
| fine_edge = gr.Radio( | |
| label="Fine Edge", | |
| choices=['enable', 'disable'], | |
| value='disable', | |
| interactive=True | |
| ) | |
| fix_perspective = gr.Radio( | |
| label="Fix Perspective", | |
| choices=['enable', 'disable'], | |
| value='disable', | |
| interactive=True | |
| ) | |
| grow_size = gr.Slider( | |
| label="Grow Size", | |
| minimum=10, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| interactive=True | |
| ) | |
| edge_strength = gr.Slider( | |
| label="Edge Strength", | |
| minimum=0.0, | |
| maximum=5.0, | |
| value=0.6, | |
| step=0.01, | |
| interactive=True | |
| ) | |
| color_strength = gr.Slider( | |
| label="Color Strength", | |
| minimum=0.0, | |
| maximum=5.0, | |
| value=1.5, | |
| step=0.01, | |
| interactive=True | |
| ) | |
| local_strength = gr.Slider( | |
| label="Local Strength", | |
| minimum=0.0, | |
| maximum=5.0, | |
| value=1.0, | |
| step=0.01, | |
| interactive=True | |
| ) | |
| seed = gr.Number( | |
| label="Seed", | |
| value=-1, | |
| precision=0, | |
| interactive=True | |
| ) | |
| steps = gr.Slider( | |
| label="Steps", | |
| minimum=0, | |
| maximum=50, | |
| value=20, | |
| interactive=True | |
| ) | |
| cfg = gr.Slider( | |
| label="CFG", | |
| minimum=0.0, | |
| maximum=20.0, | |
| value=3.5, | |
| step=0.1, | |
| interactive=True | |
| ) | |
| btn.click(generate_image_handler, inputs=[ms, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg], outputs=ms) | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=['*'], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def get_root_url( | |
| request: Request, route_path: str, root_path: str | None | |
| ): | |
| print(root_path) | |
| return root_path | |
| import gradio.route_utils | |
| gr.route_utils.get_root_url = get_root_url | |
| gr.mount_gradio_app(app, demo, path="/demo", root_path="/demo") | |
| async def generate_image(request: Request): | |
| data = await request.json() | |
| res = generate( | |
| data["merged_image"], | |
| data["total_mask"], | |
| data["original_image"], | |
| data["add_color_image"], | |
| data["add_edge_mask"], | |
| data["remove_edge_mask"], | |
| data["fill_mask"], | |
| data["add_prop_image"], | |
| data["positive_prompt"], | |
| data["negative_prompt"], | |
| data["fine_edge"], | |
| data["fix_perspective"], | |
| data["grow_size"], | |
| data["edge_strength"], | |
| data["color_strength"], | |
| data["local_strength"], | |
| data["seed"], | |
| data["steps"], | |
| data["cfg"] | |
| ) | |
| return {'res': res} | |
| async def process_background_img(request: Request): | |
| img = await request.json() | |
| from util import process_background | |
| # process_background returns tensor [1, H, W, 3] in uint8 or float | |
| resized_img_tensor = process_background(img) | |
| # tensor_to_base64 from util expects tensor | |
| resized_img_base64 = "data:image/webp;base64," + tensor_to_base64( | |
| resized_img_tensor, | |
| quality=80, | |
| method=6 | |
| ) | |
| return resized_img_base64 | |
| async def segmentation(request: Request): | |
| json_data = await request.json() | |
| image_base64 = json_data.get("image", None) | |
| coordinates_positive = json_data.get("coordinates_positive", None) | |
| coordinates_negative = json_data.get("coordinates_negative", None) | |
| bboxes = json_data.get("bboxes", None) | |
| if sam_client is None: | |
| return {"error": "sam client not initialized"} | |
| # Process coordinates and bboxes | |
| pos_coordinates = None | |
| if coordinates_positive and len(coordinates_positive) > 0: | |
| pos_coordinates = [] | |
| for coord in coordinates_positive: | |
| coord['x'] = int(round(coord['x'])) | |
| coord['y'] = int(round(coord['y'])) | |
| pos_coordinates.append({'x': coord['x'], 'y': coord['y']}) | |
| pos_coordinates = json.dumps(pos_coordinates) | |
| neg_coordinates = None | |
| if coordinates_negative and len(coordinates_negative) > 0: | |
| neg_coordinates = [] | |
| for coord in coordinates_negative: | |
| coord['x'] = int(round(coord['x'])) | |
| coord['y'] = int(round(coord['y'])) | |
| neg_coordinates.append({'x': coord['x'], 'y': coord['y']}) | |
| neg_coordinates = json.dumps(neg_coordinates) | |
| bboxes_xyxy = None | |
| if bboxes and len(bboxes) > 0: | |
| valid_bboxes = [] | |
| for bbox in bboxes: | |
| if (bbox.get("startX") is None or | |
| bbox.get("startY") is None or | |
| bbox.get("endX") is None or | |
| bbox.get("endY") is None): | |
| continue | |
| else: | |
| x_min = max(min(int(bbox["startX"]), int(bbox["endX"])), 0) | |
| y_min = max(min(int(bbox["startY"]), int(bbox["endY"])), 0) | |
| # Note: image_tensor not available here easily without loading image, | |
| # but usually we don't need to clip strictly if SAM handles it or we clip to large values | |
| # For now, we skip strict clipping against image dims or assume 10000 | |
| x_max = int(bbox["startX"]) if int(bbox["startX"]) > int(bbox["endX"]) else int(bbox["endX"]) | |
| y_max = int(bbox["startY"]) if int(bbox["startY"]) > int(bbox["endY"]) else int(bbox["endY"]) | |
| valid_bboxes.append((x_min, y_min, x_max, y_max)) | |
| bboxes_xyxy = [] | |
| for bbox in valid_bboxes: | |
| x_min, y_min, x_max, y_max = bbox | |
| bboxes_xyxy.append((x_min, y_min, x_max, y_max)) | |
| # Convert to JSON string if that's what the client expects, or keep as list | |
| # Assuming JSON string for consistency with coords | |
| if bboxes_xyxy: | |
| bboxes_xyxy = json.dumps(bboxes_xyxy) | |
| print(f"Segmentation request: pos={pos_coordinates}, neg={neg_coordinates}, bboxes={bboxes_xyxy}") | |
| try: | |
| # Save base64 image to temp file | |
| image_bytes = read_base64_image_utils(image_base64) | |
| # Image.open to verify and save as WebP (smaller size) | |
| pil_image = Image.open(image_bytes) | |
| with tempfile.NamedTemporaryFile(suffix=".webp", delete=False) as temp_in: | |
| pil_image.save(temp_in.name, format="WEBP", quality=80) | |
| temp_in_path = temp_in.name | |
| # Execute segmentation via Client | |
| # We assume the remote space returns a filepath to the segmented image (with alpha) | |
| # NOW it returns mask_np image | |
| result_path = sam_client.predict( | |
| handle_file(temp_in_path), | |
| pos_coordinates, | |
| neg_coordinates, | |
| bboxes_xyxy, | |
| api_name="/segment" | |
| ) | |
| # Clean up input temp | |
| os.unlink(temp_in_path) | |
| # Process result | |
| # result_path should be a generic object, usually a tuple (image_path, mask_path) or just image_path | |
| # Depending on how the remote space is implemented. | |
| if isinstance(result_path, (list, tuple)): | |
| result_path = result_path[0] # Take the first return value if multiple | |
| if not result_path or not os.path.exists(result_path): | |
| raise RuntimeError("Client returned invalid result path") | |
| # result_path is the Mask Image (White=Selected, Black=Background) | |
| mask_pil = Image.open(result_path) | |
| if mask_pil.mode != 'L': | |
| mask_pil = mask_pil.convert('L') | |
| pil_image = pil_image.convert("RGB") | |
| if pil_image.size != mask_pil.size: | |
| mask_pil = mask_pil.resize(pil_image.size, Image.NEAREST) | |
| r, g, b = pil_image.split() | |
| res_pil = Image.merge("RGBA", (r, g, b, mask_pil)) | |
| # Extract bbox from mask (alpha) | |
| mask_tensor = torch.from_numpy(np.array(mask_pil) / 255.0).float().unsqueeze(0) | |
| mask_bbox = get_mask_bbox(mask_tensor) | |
| if mask_bbox: | |
| x_min, y_min, x_max, y_max = mask_bbox | |
| seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max} | |
| else: | |
| seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0} | |
| print(seg_bbox) | |
| # Convert result to base64 | |
| # We need to convert the PIL image to base64 string | |
| buffered = io.BytesIO() | |
| res_pil.save(buffered, format="PNG") | |
| image_base64_res = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return { | |
| "error": False, | |
| "segmentation_image": "data:image/png;base64," + image_base64_res, | |
| "segmentation_bbox": seg_bbox | |
| } | |
| except Exception as e: | |
| print(f"Error in segmentation: {e}") | |
| return {"error": str(e)} | |
| app = gr.mount_gradio_app(app, demo, "/") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |
| # demo.launch() | |