Spaces:
Runtime error
Runtime error
| import os | |
| IS_SPACE = True | |
| if IS_SPACE: | |
| import spaces | |
| import sys | |
| import warnings | |
| import subprocess | |
| from pathlib import Path | |
| from typing import Optional, Tuple, Dict | |
| import torch | |
| def space_context(duration: int): | |
| if IS_SPACE: | |
| return spaces.GPU(duration=duration) | |
| return lambda x: x | |
| def test_env(): | |
| assert torch.cuda.is_available() | |
| try: | |
| import flash_attn | |
| except ImportError: | |
| print("Flash-attn not found, installing...") | |
| os.system("pip install flash-attn==2.7.3 --no-build-isolation") | |
| else: | |
| print("Flash-attn found, skipping installation...") | |
| test_env() | |
| warnings.filterwarnings("ignore") | |
| # Add the current directory to Python path | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| try: | |
| import gradio as gr | |
| from PIL import Image | |
| from hyimage.diffusion.pipelines.hunyuanimage_pipeline import HunyuanImagePipeline | |
| from huggingface_hub import snapshot_download | |
| import modelscope | |
| except ImportError as e: | |
| print(f"Missing required dependencies: {e}") | |
| print("Please install with: pip install -r requirements_gradio.txt") | |
| print("For checkpoint downloads, also install: pip install -U 'huggingface_hub[cli]' modelscope") | |
| sys.exit(1) | |
| class CheckpointDownloader: | |
| """Handles downloading of all required checkpoints for HunyuanImage.""" | |
| def __init__(self, base_dir: str = "/data/ckpts"): | |
| self.base_dir = Path(base_dir) | |
| self.base_dir.mkdir(exist_ok=True) | |
| # Define all required checkpoints | |
| self.checkpoints = { | |
| "main_model": { | |
| "repo_id": "tencent/HunyuanImage-2.1", | |
| "local_dir": self.base_dir, | |
| }, | |
| "mllm_encoder": { | |
| "repo_id": "Qwen/Qwen2.5-VL-7B-Instruct", | |
| "local_dir": self.base_dir / "text_encoder" / "llm", | |
| }, | |
| "byt5_encoder": { | |
| "repo_id": "google/byt5-small", | |
| "local_dir": self.base_dir / "text_encoder" / "byt5-small", | |
| }, | |
| "glyph_encoder": { | |
| "repo_id": "AI-ModelScope/Glyph-SDXL-v2", | |
| "local_dir": self.base_dir / "text_encoder" / "Glyph-SDXL-v2", | |
| "use_modelscope": True | |
| } | |
| } | |
| def download_checkpoint(self, checkpoint_name: str, progress_callback=None) -> Tuple[bool, str]: | |
| """Download a specific checkpoint.""" | |
| if checkpoint_name not in self.checkpoints: | |
| return False, f"Unknown checkpoint: {checkpoint_name}" | |
| config = self.checkpoints[checkpoint_name] | |
| local_dir = config["local_dir"] | |
| local_dir.mkdir(parents=True, exist_ok=True) | |
| try: | |
| if config.get("use_modelscope", False): | |
| # Use modelscope for Chinese models | |
| return self._download_with_modelscope(config, progress_callback) | |
| else: | |
| # Use huggingface_hub for other models | |
| return self._download_with_hf(config, progress_callback) | |
| except Exception as e: | |
| return False, f"Download failed: {str(e)}" | |
| def _download_with_hf(self, config: Dict, progress_callback=None) -> Tuple[bool, str]: | |
| """Download using huggingface_hub.""" | |
| repo_id = config["repo_id"] | |
| local_dir = config["local_dir"] | |
| if progress_callback: | |
| progress_callback(f"Downloading {repo_id}...") | |
| try: | |
| snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=str(local_dir), | |
| local_dir_use_symlinks=False, | |
| resume_download=True | |
| ) | |
| return True, f"Successfully downloaded {repo_id}" | |
| except Exception as e: | |
| return False, f"HF download failed: {str(e)}" | |
| def _download_with_modelscope(self, config: Dict, progress_callback=None) -> Tuple[bool, str]: | |
| """Download using modelscope.""" | |
| repo_id = config["repo_id"] | |
| local_dir = config["local_dir"] | |
| if progress_callback: | |
| progress_callback(f"Downloading {repo_id} via ModelScope...") | |
| print(f"Downloading {repo_id} via ModelScope...") | |
| try: | |
| # Use subprocess to call modelscope CLI | |
| cmd = [ | |
| "modelscope", "download", | |
| "--model", repo_id, | |
| "--local_dir", str(local_dir) | |
| ] | |
| subprocess.run(cmd, capture_output=True, text=True, check=True) | |
| return True, f"Successfully downloaded {repo_id} via ModelScope" | |
| except subprocess.CalledProcessError as e: | |
| return False, f"ModelScope download failed: {e.stderr}" | |
| except FileNotFoundError: | |
| return False, "ModelScope CLI not found. Install with: pip install modelscope" | |
| def download_all_checkpoints(self, progress_callback=None) -> Tuple[bool, str, Dict[str, any]]: | |
| """Download all checkpoints.""" | |
| results = {} | |
| for name, _ in self.checkpoints.items(): | |
| if progress_callback: | |
| progress_callback(f"Starting download of {name}...") | |
| success, message = self.download_checkpoint(name, progress_callback) | |
| results[name] = {"success": success, "message": message} | |
| if not success: | |
| return False, f"Failed to download {name}: {message}", results | |
| return True, "All checkpoints downloaded successfully", results | |
| def load_pipeline(use_distilled: bool = False, device: str = "cuda"): | |
| """Load the HunyuanImage pipeline (only load once, refiner and reprompt are accessed from it).""" | |
| try: | |
| assert not use_distilled # use_distilled is a placeholder for the future | |
| print(f"Loading HunyuanImage pipeline (distilled={use_distilled})...") | |
| model_name = "hunyuanimage-v2.1-distilled" if use_distilled else "hunyuanimage-v2.1" | |
| pipeline = HunyuanImagePipeline.from_pretrained( | |
| model_name=model_name, | |
| device=device, | |
| enable_dit_offloading=True, | |
| enable_reprompt_model_offloading=True, | |
| enable_refiner_offloading=True | |
| ) | |
| print("✓ Pipeline loaded successfully") | |
| return pipeline | |
| except Exception as e: | |
| error_msg = f"Error loading pipeline: {str(e)}" | |
| print(f"✗ {error_msg}") | |
| raise | |
| if IS_SPACE: | |
| downloader = CheckpointDownloader() | |
| downloader.download_all_checkpoints() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipeline = load_pipeline(use_distilled=False, device=device) | |
| class HunyuanImageApp: | |
| def __init__(self, auto_load: bool = True, use_distilled: bool = False, device: str = "cuda"): | |
| """Initialize the HunyuanImage Gradio app.""" | |
| global pipeline | |
| self.pipeline = pipeline | |
| self.current_use_distilled = None | |
| def print_peak_memory(self): | |
| import torch | |
| stats = torch.cuda.memory_stats() | |
| peak_bytes_requirement = stats["allocated_bytes.all.peak"] | |
| print(f"Before refiner Peak memory requirement: {peak_bytes_requirement / 1024 ** 3:.2f} GB") | |
| def generate_image(self, | |
| prompt: str, | |
| negative_prompt: str, | |
| width: int, | |
| height: int, | |
| num_inference_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| use_reprompt: bool, | |
| use_refiner: bool, | |
| # use_distilled: bool | |
| ) -> Tuple[Optional[Image.Image], str]: | |
| """Generate an image using the HunyuanImage pipeline.""" | |
| try: | |
| if self.pipeline is None: | |
| return None, "Pipeline not loaded. Please try again." | |
| if hasattr(self.pipeline, '_refiner_pipeline'): | |
| self.pipeline.refiner_pipeline.to('cpu') | |
| self.pipeline.to('cuda') | |
| # Generate image | |
| image = self.pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| use_reprompt=use_reprompt, | |
| use_refiner=use_refiner | |
| ) | |
| self.print_peak_memory() | |
| return image, "Image generated successfully!" | |
| except Exception as e: | |
| error_msg = f"Error generating image: {str(e)}" | |
| print(f"✗ {error_msg}") | |
| return None, error_msg | |
| def enhance_prompt(self, prompt: str, # use_distilled: bool | |
| ) -> Tuple[str, str]: | |
| """Enhance a prompt using the reprompt model.""" | |
| try: | |
| # Load pipeline if needed | |
| if self.pipeline is None: | |
| return prompt, "Pipeline not loaded. Please try again." | |
| self.pipeline.to('cpu') | |
| if hasattr(self.pipeline, '_refiner_pipeline'): | |
| self.pipeline.refiner_pipeline.to('cpu') | |
| # Use reprompt model from the main pipeline | |
| enhanced_prompt = self.pipeline.reprompt_model.predict(prompt) | |
| self.print_peak_memory() | |
| return enhanced_prompt, "Prompt enhanced successfully!" | |
| except Exception as e: | |
| error_msg = f"Error enhancing prompt: {str(e)}" | |
| print(f"✗ {error_msg}") | |
| return prompt, error_msg | |
| def refine_image(self, | |
| image: Image.Image, | |
| prompt: str, | |
| negative_prompt: str, | |
| width: int, | |
| height: int, | |
| num_inference_steps: int, | |
| guidance_scale: float, | |
| seed: int) -> Tuple[Optional[Image.Image], str]: | |
| """Refine an image using the refiner pipeline.""" | |
| try: | |
| if image is None: | |
| return None, "Please upload an image to refine." | |
| # Resize image to target dimensions if needed | |
| if image.size != (width, height): | |
| image = image.resize((width, height), Image.Resampling.LANCZOS) | |
| self.pipeline.to('cpu') | |
| self.pipeline.refiner_pipeline.to('cuda') | |
| # Use refiner from the main pipeline | |
| refined_image = self.pipeline.refiner_pipeline( | |
| image=image, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed | |
| ) | |
| self.print_peak_memory() | |
| return refined_image, "Image refined successfully!" | |
| except Exception as e: | |
| error_msg = f"Error refining image: {str(e)}" | |
| print(f"✗ {error_msg}") | |
| return None, error_msg | |
| def download_single_checkpoint(self, checkpoint_name: str) -> Tuple[bool, str]: | |
| """Download a single checkpoint.""" | |
| try: | |
| success, message = self.downloader.download_checkpoint(checkpoint_name) | |
| return success, message | |
| except Exception as e: | |
| return False, f"Download error: {str(e)}" | |
| def download_all_checkpoints(self) -> Tuple[bool, str, Dict[str, any]]: | |
| """Download all missing checkpoints.""" | |
| try: | |
| success, message, results = self.downloader.download_all_checkpoints() | |
| return success, message, results | |
| except Exception as e: | |
| return False, f"Download error: {str(e)}", {} | |
| def create_interface(auto_load: bool = True, use_distilled: bool = False, device: str = "cuda"): | |
| """Create the Gradio interface.""" | |
| app = HunyuanImageApp(auto_load=auto_load, use_distilled=use_distilled, device=device) | |
| # Custom CSS for better styling | |
| css = """ | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| margin: auto !important; | |
| } | |
| .tab-nav { | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
| border-radius: 10px; | |
| padding: 10px; | |
| margin-bottom: 20px; | |
| } | |
| .model-info { | |
| background: #f8f9fa; | |
| border: 1px solid #dee2e6; | |
| border-radius: 8px; | |
| padding: 15px; | |
| margin-bottom: 20px; | |
| } | |
| """ | |
| with gr.Blocks(css=css, title="HunyuanImage Pipeline", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎨 HunyuanImage 2.1 Pipeline | |
| **HunyuanImage-2.1: An Efficient Diffusion Model for High-Resolution (2K) Text-to-Image Generation** | |
| This app provides three main functionalities: | |
| 1. **Text-to-Image Generation**: Generate high-quality images from text prompts | |
| 2. **Prompt Enhancement**: Improve your prompts using MLLM reprompting | |
| 3. **Image Refinement**: Enhance existing images with the refiner model (Refiner is not supported yet; coming soon.) | |
| """, | |
| elem_classes="model-info" | |
| ) | |
| with gr.Tabs(): | |
| # Tab 1: Text-to-Image Generation | |
| with gr.Tab("🖼️ Text-to-Image Generation"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Generation Settings") | |
| gr.Markdown("**Model**: HunyuanImage v2.1 (Non-distilled)") | |
| # use_distilled = gr.Checkbox( | |
| # label="Use Distilled Model", | |
| # value=False, | |
| # info="Faster generation with slightly lower quality" | |
| # ) | |
| use_distilled = False | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="", | |
| lines=3, | |
| value="A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, wearing a red knitted scarf and a red beret with the word “Tencent” on it, holding a paintbrush with a focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style." | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| placeholder="", | |
| lines=2, | |
| value="" | |
| ) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| minimum=512, maximum=2048, step=64, value=2048, | |
| label="Width", info="Image width in pixels" | |
| ) | |
| height = gr.Slider( | |
| minimum=512, maximum=2048, step=64, value=2048, | |
| label="Height", info="Image height in pixels" | |
| ) | |
| with gr.Row(): | |
| num_inference_steps = gr.Slider( | |
| minimum=10, maximum=100, step=5, value=50, | |
| label="Inference Steps", info="More steps = better quality, slower generation" | |
| ) | |
| guidance_scale = gr.Slider( | |
| minimum=1.0, maximum=10.0, step=0.1, value=3.5, | |
| label="Guidance Scale", info="How closely to follow the prompt" | |
| ) | |
| with gr.Row(): | |
| seed = gr.Number( | |
| label="Seed", value=649151, precision=0, | |
| info="Random seed for reproducibility" | |
| ) | |
| use_reprompt = gr.Checkbox( | |
| label="Use Reprompt", value=False, | |
| info="Enhance prompt automatically" | |
| ) | |
| use_refiner = gr.Checkbox( | |
| label="Use Refiner", value=False, | |
| info="Apply refiner after generation (Refiner is not supported yet; coming soon.)", | |
| interactive=False | |
| ) | |
| generate_btn = gr.Button("🎨 Generate Image", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Generated Image") | |
| generated_image = gr.Image( | |
| label="Generated Image", | |
| type="pil", | |
| height=600 | |
| ) | |
| generation_status = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| value="Ready to generate" | |
| ) | |
| # Tab 2: Prompt Enhancement | |
| with gr.Tab("✨ Prompt Enhancement"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Prompt Enhancement Settings") | |
| gr.Markdown("**Model**: HunyuanImage v2.1 Reprompt Model") | |
| # enhance_use_distilled = gr.Checkbox( | |
| # label="Use Distilled Model", | |
| # value=False, | |
| # info="For loading the reprompt model" | |
| # ) | |
| enhance_use_distilled = False | |
| original_prompt = gr.Textbox( | |
| label="Original Prompt", | |
| placeholder="A cat sitting on a table", | |
| lines=4, | |
| value="A cat sitting on a table" | |
| ) | |
| enhance_btn = gr.Button("✨ Enhance Prompt", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Enhanced Prompt") | |
| enhanced_prompt = gr.Textbox( | |
| label="Enhanced Prompt", | |
| lines=6, | |
| interactive=False | |
| ) | |
| enhancement_status = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| value="Ready to enhance" | |
| ) | |
| # # Tab 3: Image Refinement | |
| # with gr.Tab("🔧 Image Refinement"): | |
| # with gr.Row(): | |
| # with gr.Column(scale=1): | |
| # gr.Markdown("### Refinement Settings") | |
| # gr.Markdown("**Model**: HunyuanImage v2.1 Refiner") | |
| # input_image = gr.Image( | |
| # label="Input Image", | |
| # type="pil", | |
| # height=300 | |
| # ) | |
| # refine_prompt = gr.Textbox( | |
| # label="Refinement Prompt", | |
| # placeholder="Make the image more detailed and high quality", | |
| # lines=2, | |
| # value="Make the image more detailed and high quality" | |
| # ) | |
| # refine_negative_prompt = gr.Textbox( | |
| # label="Negative Prompt", | |
| # placeholder="", | |
| # lines=2, | |
| # value="" | |
| # ) | |
| # with gr.Row(): | |
| # refine_width = gr.Slider( | |
| # minimum=512, maximum=2048, step=64, value=2048, | |
| # label="Width", info="Output width" | |
| # ) | |
| # refine_height = gr.Slider( | |
| # minimum=512, maximum=2048, step=64, value=2048, | |
| # label="Height", info="Output height" | |
| # ) | |
| # with gr.Row(): | |
| # refine_steps = gr.Slider( | |
| # minimum=1, maximum=20, step=1, value=4, | |
| # label="Refinement Steps", info="More steps = more refinement" | |
| # ) | |
| # refine_guidance = gr.Slider( | |
| # minimum=1.0, maximum=10.0, step=0.1, value=3.5, | |
| # label="Guidance Scale", info="How strongly to follow the prompt" | |
| # ) | |
| # refine_seed = gr.Number( | |
| # label="Seed", value=649151, precision=0, | |
| # info="Random seed for reproducibility" | |
| # ) | |
| # refine_btn = gr.Button("🔧 Refine Image", variant="primary", size="lg") | |
| # with gr.Column(scale=1): | |
| # gr.Markdown("### Refined Image") | |
| # refined_image = gr.Image( | |
| # label="Refined Image", | |
| # type="pil", | |
| # height=600 | |
| # ) | |
| # refinement_status = gr.Textbox( | |
| # label="Status", | |
| # interactive=False, | |
| # value="Ready to refine" | |
| # ) | |
| # Event handlers | |
| generate_btn.click( | |
| fn=app.generate_image, | |
| inputs=[ | |
| prompt, negative_prompt, width, height, num_inference_steps, | |
| guidance_scale, seed, use_reprompt, use_refiner # , use_distilled | |
| ], | |
| outputs=[generated_image, generation_status] | |
| ) | |
| enhance_btn.click( | |
| fn=app.enhance_prompt, | |
| inputs=[original_prompt], | |
| outputs=[enhanced_prompt, enhancement_status] | |
| ) | |
| #refine_btn.click( | |
| # fn=app.refine_image, | |
| # inputs=[ | |
| # input_image, refine_prompt, refine_negative_prompt, | |
| # refine_width, refine_height, refine_steps, refine_guidance, refine_seed | |
| # ], | |
| # outputs=[refined_image, refinement_status] | |
| #) | |
| # Additional info | |
| gr.Markdown( | |
| """ | |
| ### 📝 Usage Tips | |
| **Text-to-Image Generation:** | |
| - Use descriptive prompts with specific details | |
| - Adjust guidance scale: higher values follow prompts more closely | |
| - More inference steps generally produce better quality | |
| - Enable reprompt for automatic prompt enhancement | |
| - Enable refiner for additional quality improvement | |
| **Prompt Enhancement:** | |
| - Enter your basic prompt idea | |
| - The AI will enhance it with better structure and details | |
| - Enhanced prompts often produce better results | |
| **Image Refinement:** | |
| - Upload any image you want to improve | |
| - Describe what improvements you want in the refinement prompt | |
| - The refiner will enhance details and quality | |
| - Works best with images generated by HunyuanImage | |
| """, | |
| elem_classes="model-info" | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| import argparse | |
| # Parse command line arguments | |
| parser = argparse.ArgumentParser(description="Launch HunyuanImage Gradio App") | |
| parser.add_argument("--no-auto-load", action="store_true", help="Disable auto-loading pipeline on startup") | |
| parser.add_argument("--use-distilled", action="store_true", help="Use distilled model") | |
| parser.add_argument("--device", type=str, default="cuda", help="Device to use (cuda/cpu)") | |
| parser.add_argument("--port", type=int, default=8081, help="Port to run the app on") | |
| parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") | |
| args = parser.parse_args() | |
| # Create and launch the interface | |
| auto_load = not args.no_auto_load | |
| demo = create_interface(auto_load=auto_load, use_distilled=args.use_distilled, device=args.device) | |
| print("🚀 Starting HunyuanImage Gradio App...") | |
| print(f"📱 The app will be available at: http://{args.host}:{args.port}") | |
| print(f"🔧 Auto-load pipeline: {'Yes' if auto_load else 'No'}") | |
| print(f"🎯 Model type: {'Distilled' if args.use_distilled else 'Non-distilled'}") | |
| print(f"💻 Device: {args.device}") | |
| print("⚠️ Make sure you have the required model checkpoints downloaded!") | |
| demo.launch( | |
| server_name=args.host, | |
| # server_port=args.port, | |
| share=False, | |
| show_error=True, | |
| quiet=False, | |
| # max_threads=1, # Default: sequential processing (recommended for GPU apps) | |
| # max_threads=4, # Enable parallel processing (requires more GPU memory) | |
| ) | |