import argparse import os import sys import time from typing import Optional import requests def log(msg: str) -> None: print(msg) def upload_image(base_url: str, image_path: str, headers: dict) -> Optional[str]: log("Uploading image: %s" % image_path) import mimetypes mime, _ = mimetypes.guess_type(image_path) if not mime or not mime.startswith("image/"): mime = "image/jpeg" filename = os.path.basename(image_path) or "image.jpg" with open(image_path, "rb") as f: files = {"file": (filename, f, mime)} r = requests.post("%s/upload" % base_url, files=files, headers=headers, timeout=120) if r.status_code != 200: log("Upload failed: %s %s" % (r.status_code, r.text)) return None image_id = r.json().get("image_id") log("Uploaded. image_id=%s" % image_id) return image_id def edit_image(base_url: str, image_id: str, prompt: str, headers: dict) -> Optional[str]: log("Editing image with prompt: %s" % prompt) r = requests.post( "%s/edit" % base_url, data={"image_id": image_id, "prompt": prompt}, headers=headers, timeout=300 ) if r.status_code != 200: log("Edit failed: %s %s" % (r.status_code, r.text)) return None task_id = r.json().get("task_id") log("Edit submitted. task_id=%s status=%s" % (task_id, r.json().get("status"))) return task_id def get_result(base_url: str, task_id: str, headers: dict) -> Optional[dict]: r = requests.get("%s/result/%s" % (base_url, task_id), headers=headers, timeout=120) if r.status_code != 200: log("Result failed: %s %s" % (r.status_code, r.text)) return None return r.json() def download_image(base_url: str, result_image_id: str, out_path: str, headers: dict) -> bool: r = requests.get("%s/result/image/%s" % (base_url, result_image_id), headers=headers, timeout=300) if r.status_code != 200: log("Download failed: %s %s" % (r.status_code, r.text)) return False with open(out_path, "wb") as f: f.write(r.content) log("Saved: %s" % out_path) return True def health(base_url: str, headers: dict) -> bool: try: r = requests.get("%s/health" % base_url, headers=headers, timeout=30) if r.status_code == 200: j = r.json() log("Health: %s (model_loaded=%s)" % (j.get("status"), j.get("model_loaded"))) return True log("Health failed: %s %s" % (r.status_code, r.text)) return False except requests.ConnectionError: log("Cannot connect to %s" % base_url) return False def main(): parser = argparse.ArgumentParser(description="Test Nano Banana Image Edit API") parser.add_argument("--base", dest="base_url", required=True, help="Base URL of the API, e.g. https://hf.space/... or http://127.0.0.1:7860") parser.add_argument("--image", dest="image_path", required=True, help="Path to input image") parser.add_argument("--prompt", dest="prompt", default="enhance the image", help="Edit prompt") parser.add_argument("--out", dest="out_path", default="edited.png", help="Output file for edited image") parser.add_argument("--token", dest="token", default="", help="Hugging Face access token for private Spaces") args = parser.parse_args() base_url = args.base_url.rstrip("/") if not os.path.exists(args.image_path): log("Image not found: %s" % args.image_path) sys.exit(1) headers = {} if args.token: headers["Authorization"] = "Bearer %s" % args.token if not health(base_url, headers): sys.exit(1) image_id = upload_image(base_url, args.image_path, headers) if not image_id: sys.exit(1) task_id = edit_image(base_url, image_id, args.prompt, headers) if not task_id: sys.exit(1) # If the backend ever becomes async, poll here time.sleep(1.0) result = get_result(base_url, task_id, headers) if not result or result.get("status") != "completed": log("Result not completed: %s" % (result,)) sys.exit(1) result_image_id = result.get("result_image_id") if not result_image_id: log("No result_image_id in response") sys.exit(1) ok = download_image(base_url, result_image_id, args.out_path, headers) if not ok: sys.exit(1) log("Done.") if __name__ == "__main__": main()