import gradio as gr import numpy as np import torch import random import os import spaces from PIL import Image, ImageOps, ImageFilter from diffusers import FluxPipeline, DiffusionPipeline import requests from io import BytesIO # Constants MAX_SEED = np.iinfo(np.int32).max HF_TOKEN = os.getenv("HF_TOKEN") # Model configuration KONTEXT_MODEL = "black-forest-labs/FLUX.1-Kontext-dev" FALLBACK_MODEL = "black-forest-labs/FLUX.1-dev" LORA_MODEL = "thedeoxen/refcontrol-flux-kontext-reference-pose-lora" TRIGGER_WORD = "refcontrolpose" # Initialize pipeline print("Loading models...") def load_pipeline(): """Load the appropriate pipeline based on availability""" global pipe, MODEL_STATUS try: # First, try to import necessary libraries try: from diffusers import FluxKontextPipeline import peft print("PEFT library found") use_kontext = True except ImportError: print("FluxKontextPipeline or PEFT not available, using fallback") use_kontext = False if use_kontext and HF_TOKEN: # Try to load Kontext model pipe = FluxKontextPipeline.from_pretrained( KONTEXT_MODEL, torch_dtype=torch.bfloat16, token=HF_TOKEN ) # Try to load LoRA if PEFT is available try: pipe.load_lora_weights( LORA_MODEL, adapter_name="refcontrol", token=HF_TOKEN ) MODEL_STATUS = "✅ Flux Kontext + RefControl LoRA loaded" except Exception as e: print(f"Could not load LoRA: {e}") MODEL_STATUS = "⚠️ Flux Kontext loaded (without LoRA - PEFT required)" pipe = pipe.to("cuda") else: # Fallback to standard FLUX pipe = FluxPipeline.from_pretrained( FALLBACK_MODEL, torch_dtype=torch.bfloat16, token=HF_TOKEN if HF_TOKEN else True ) pipe = pipe.to("cuda") MODEL_STATUS = "⚠️ Using FLUX.1-dev (fallback mode)" except Exception as e: print(f"Error loading models: {e}") MODEL_STATUS = f"❌ Error: {str(e)}" pipe = None return pipe, MODEL_STATUS # Load the pipeline pipe, MODEL_STATUS = load_pipeline() print(MODEL_STATUS) def prepare_images_for_kontext(reference_image, pose_image, target_size=768): """ Prepare reference and pose images for Kontext processing. Following the RefControl format: reference (left) | pose (right) """ if reference_image is None or pose_image is None: return None # Convert to RGB reference_image = reference_image.convert("RGB") pose_image = pose_image.convert("RGB") # Calculate dimensions maintaining aspect ratio ref_ratio = reference_image.width / reference_image.height pose_ratio = pose_image.width / pose_image.height # Set heights to target size height = target_size ref_width = int(height * ref_ratio) pose_width = int(height * pose_ratio) # Ensure dimensions are divisible by 8 (FLUX requirement) ref_width = (ref_width // 8) * 8 pose_width = (pose_width // 8) * 8 height = (height // 8) * 8 # Resize images reference_resized = reference_image.resize((ref_width, height), Image.LANCZOS) pose_resized = pose_image.resize((pose_width, height), Image.LANCZOS) # Concatenate horizontally: reference | pose total_width = ref_width + pose_width concatenated = Image.new('RGB', (total_width, height)) concatenated.paste(reference_resized, (0, 0)) concatenated.paste(pose_resized, (ref_width, 0)) return concatenated def extract_pose_edges(image): """ Extract edge/pose information from an image. """ if image is None: return None # Convert to grayscale gray = image.convert("L") # Apply edge detection edges = gray.filter(ImageFilter.FIND_EDGES) # Enhance contrast edges = ImageOps.autocontrast(edges) # Invert to get black lines on white edges = ImageOps.invert(edges) # Smooth the result edges = edges.filter(ImageFilter.SMOOTH_MORE) # Convert back to RGB return edges.convert("RGB") @spaces.GPU(duration=60) def generate_pose_transfer( reference_image, pose_image, prompt="", negative_prompt="", seed=42, randomize_seed=False, guidance_scale=3.5, num_inference_steps=28, lora_scale=1.0, enhance_pose=False, progress=gr.Progress(track_tqdm=True) ): """ Main generation function using RefControl approach. """ if pipe is None: return None, 0, "Model not loaded. Please check HF_TOKEN and restart the Space" if reference_image is None or pose_image is None: raise gr.Error("Please upload both reference and pose images") # Randomize seed if requested if randomize_seed: seed = random.randint(0, MAX_SEED) # Enhance pose if requested if enhance_pose: pose_image = extract_pose_edges(pose_image) # Prepare concatenated input concatenated_input = prepare_images_for_kontext(reference_image, pose_image) if concatenated_input is None: raise gr.Error("Failed to process images") # Construct prompt with trigger word if prompt: full_prompt = f"{TRIGGER_WORD}, {prompt}" else: full_prompt = f"{TRIGGER_WORD}, transfer the pose from the right image to the subject in the left image while maintaining their identity, clothing, and style" # Add instruction for the model full_prompt += ". The left image shows the reference subject, the right image shows the target pose." # Set generator for reproducibility generator = torch.Generator("cuda").manual_seed(seed) try: # Check if we have LoRA capabilities has_lora = hasattr(pipe, 'set_adapters') and "RefControl" in MODEL_STATUS with torch.autocast("cuda"): if has_lora: # Try to set LoRA adapter strength try: pipe.set_adapters(["refcontrol"], adapter_weights=[lora_scale]) except Exception as e: print(f"Could not set LoRA adapter: {e}") # Generate image based on pipeline type if "Kontext" in MODEL_STATUS: # Use Kontext pipeline result = pipe( image=concatenated_input, prompt=full_prompt, negative_prompt=negative_prompt if negative_prompt else None, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, width=concatenated_input.width, height=concatenated_input.height, ).images[0] else: # Use standard FLUX pipeline (image-to-image) result = pipe( prompt=full_prompt, image=concatenated_input, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, strength=0.85, # For img2img mode ).images[0] return result, seed, concatenated_input except Exception as e: raise gr.Error(f"Generation failed: {str(e)}") # CSS styling css = """ #col-container { margin: 0 auto; max-width: 1280px; } .header { text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 12px; margin-bottom: 20px; } .header h1 { color: white; margin: 0; font-size: 2em; } .status-box { padding: 10px; border-radius: 8px; margin: 10px 0; font-weight: bold; text-align: center; } .input-image { border: 2px solid #e0e0e0; border-radius: 8px; overflow: hidden; } .result-image { border: 3px solid #4CAF50; border-radius: 8px; overflow: hidden; } .info-box { background: #f0f0f0; padding: 10px; border-radius: 8px; margin: 10px 0; } """ # Create Gradio interface with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_id="col-container"): # Header gr.HTML("""

🎭 FLUX Pose Transfer System

Transfer poses while preserving identity

""") # Model status status_color = "#d4edda" if "✅" in MODEL_STATUS else "#fff3cd" if "⚠️" in MODEL_STATUS else "#f8d7da" gr.HTML(f"""
{MODEL_STATUS}
""") # Authentication check if not HF_TOKEN: gr.Markdown(""" ### 🔐 Authentication Required To use this Space with full features: 1. Go to **Settings** → **Variables and secrets** 2. Add `HF_TOKEN` with your Hugging Face token 3. Restart the Space Or click below to sign in: """) gr.LoginButton("Sign in with Hugging Face", size="lg") # Info box for PEFT requirement if "PEFT required" in MODEL_STATUS: gr.HTML("""
Note: For full LoRA support, PEFT library is required. Add peft to your requirements.txt file.
""") # Main interface with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📥 Input Images") # Reference image reference_image = gr.Image( label="Reference Image (Subject to transform)", type="pil", elem_classes=["input-image"], height=300 ) # Pose image pose_image = gr.Image( label="Pose Control (Line art or skeleton)", type="pil", elem_classes=["input-image"], height=300 ) # Pose extraction tool with gr.Accordion("🔧 Extract Pose from Image", open=False): extract_source = gr.Image( label="Source image for pose extraction", type="pil", height=200 ) extract_btn = gr.Button("Extract Pose", size="sm") # Prompts prompt = gr.Textbox( label=f"Prompt ('{TRIGGER_WORD}' added automatically)", placeholder="e.g., wearing elegant dress, professional photography", lines=2 ) negative_prompt = gr.Textbox( label="Negative Prompt (optional)", placeholder="e.g., blurry, low quality, distorted", lines=1, value="blurry, low quality, distorted, deformed" ) # Generate button generate_btn = gr.Button( "🎨 Generate Pose Transfer", variant="primary", size="lg" ) # Advanced settings with gr.Accordion("⚙️ Advanced Settings", open=False): with gr.Row(): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42 ) randomize_seed = gr.Checkbox( label="Randomize", value=True ) guidance_scale = gr.Slider( label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.5, value=3.5, info="How strictly to follow the pose" ) num_inference_steps = gr.Slider( label="Inference Steps", minimum=20, maximum=50, step=1, value=28 ) if "LoRA" in MODEL_STATUS: lora_scale = gr.Slider( label="LoRA Strength", minimum=0.0, maximum=2.0, step=0.1, value=1.0, info="RefControl LoRA influence" ) else: lora_scale = gr.Slider( label="LoRA Strength (not available)", minimum=0.0, maximum=2.0, step=0.1, value=1.0, interactive=False ) enhance_pose = gr.Checkbox( label="Auto-enhance pose edges", value=False ) with gr.Column(scale=1): gr.Markdown("### 🖼️ Result") # Result image result_image = gr.Image( label="Generated Image", elem_classes=["result-image"], interactive=False, height=500 ) # Seed display seed_used = gr.Number( label="Seed Used", interactive=False ) # Debug view with gr.Accordion("🔍 Debug View", open=False): concat_preview = gr.Image( label="Input Concatenation (Reference | Pose)", height=200 ) # Action buttons with gr.Row(): reuse_ref_btn = gr.Button("♻️ Use as Reference", size="sm") reuse_pose_btn = gr.Button("📐 Extract Pose", size="sm") clear_btn = gr.Button("🗑️ Clear All", size="sm") # Examples gr.Markdown("### 💡 Example Prompts") gr.Examples( examples=[ ["professional portrait, studio lighting"], ["wearing red dress, outdoor garden"], ["business attire, office setting"], ["casual streetwear, urban background"], ["athletic wear, gym environment"], ], inputs=[prompt] ) # Instructions with gr.Accordion("📖 Instructions", open=False): gr.Markdown(f""" ## How to Use: 1. **Upload Reference Image**: The person whose appearance you want to keep 2. **Upload Pose Image**: Line art or skeleton pose to follow 3. **Add Prompt** (optional): Describe additional details 4. **Click Generate**: Create your pose-transferred image ## Model Information: - **Current Model**: {MODEL_STATUS} - **Trigger Word**: `{TRIGGER_WORD}` (added automatically) ## Tips: - Use clear, high-contrast pose images - Black lines on white background work best for poses - Adjust guidance scale for pose adherence strength - Higher steps = better quality but slower ## Requirements: - **HF_TOKEN**: Required for model access - **peft**: Required for LoRA support (add to requirements.txt) """) # Event handlers generate_btn.click( fn=generate_pose_transfer, inputs=[ reference_image, pose_image, prompt, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, lora_scale, enhance_pose ], outputs=[result_image, seed_used, concat_preview] ) extract_btn.click( fn=extract_pose_edges, inputs=[extract_source], outputs=[pose_image] ) reuse_ref_btn.click( fn=lambda x: x, inputs=[result_image], outputs=[reference_image] ) reuse_pose_btn.click( fn=extract_pose_edges, inputs=[result_image], outputs=[pose_image] ) clear_btn.click( fn=lambda: [None, None, "", "blurry, low quality, distorted, deformed", 42, None, None], outputs=[ reference_image, pose_image, prompt, negative_prompt, seed_used, result_image, concat_preview ] ) # Launch the app if __name__ == "__main__": demo.queue() demo.launch()