import gradio as gr import numpy as np import torch import random import os import spaces from PIL import Image, ImageOps, ImageFilter from diffusers import FluxPipeline, DiffusionPipeline import requests from io import BytesIO # Constants MAX_SEED = np.iinfo(np.int32).max HF_TOKEN = os.getenv("HF_TOKEN") # Model configuration KONTEXT_MODEL = "black-forest-labs/FLUX.1-Kontext-dev" FALLBACK_MODEL = "black-forest-labs/FLUX.1-dev" LORA_MODEL = "thedeoxen/refcontrol-flux-kontext-reference-pose-lora" TRIGGER_WORD = "refcontrolpose" # Initialize pipeline print("Loading models...") def load_pipeline(): """Load the appropriate pipeline based on availability""" global pipe, MODEL_STATUS try: # First, try to import necessary libraries try: from diffusers import FluxKontextPipeline import peft print("PEFT library found") use_kontext = True except ImportError: print("FluxKontextPipeline or PEFT not available, using fallback") use_kontext = False if use_kontext and HF_TOKEN: # Try to load Kontext model pipe = FluxKontextPipeline.from_pretrained( KONTEXT_MODEL, torch_dtype=torch.bfloat16, token=HF_TOKEN ) # Try to load LoRA if PEFT is available try: pipe.load_lora_weights( LORA_MODEL, adapter_name="refcontrol", token=HF_TOKEN ) MODEL_STATUS = "✅ Flux Kontext + RefControl LoRA loaded" except Exception as e: print(f"Could not load LoRA: {e}") MODEL_STATUS = "⚠️ Flux Kontext loaded (without LoRA - PEFT required)" pipe = pipe.to("cuda") else: # Fallback to standard FLUX pipe = FluxPipeline.from_pretrained( FALLBACK_MODEL, torch_dtype=torch.bfloat16, token=HF_TOKEN if HF_TOKEN else True ) pipe = pipe.to("cuda") MODEL_STATUS = "⚠️ Using FLUX.1-dev (fallback mode)" except Exception as e: print(f"Error loading models: {e}") MODEL_STATUS = f"❌ Error: {str(e)}" pipe = None return pipe, MODEL_STATUS # Load the pipeline pipe, MODEL_STATUS = load_pipeline() print(MODEL_STATUS) def prepare_images_for_kontext(reference_image, pose_image, target_size=768): """ Prepare reference and pose images for Kontext processing. Following the RefControl format: reference (left) | pose (right) """ if reference_image is None or pose_image is None: return None # Convert to RGB reference_image = reference_image.convert("RGB") pose_image = pose_image.convert("RGB") # Calculate dimensions maintaining aspect ratio ref_ratio = reference_image.width / reference_image.height pose_ratio = pose_image.width / pose_image.height # Set heights to target size height = target_size ref_width = int(height * ref_ratio) pose_width = int(height * pose_ratio) # Ensure dimensions are divisible by 8 (FLUX requirement) ref_width = (ref_width // 8) * 8 pose_width = (pose_width // 8) * 8 height = (height // 8) * 8 # Resize images reference_resized = reference_image.resize((ref_width, height), Image.LANCZOS) pose_resized = pose_image.resize((pose_width, height), Image.LANCZOS) # Concatenate horizontally: reference | pose total_width = ref_width + pose_width concatenated = Image.new('RGB', (total_width, height)) concatenated.paste(reference_resized, (0, 0)) concatenated.paste(pose_resized, (ref_width, 0)) return concatenated def extract_pose_edges(image): """ Extract edge/pose information from an image. """ if image is None: return None # Convert to grayscale gray = image.convert("L") # Apply edge detection edges = gray.filter(ImageFilter.FIND_EDGES) # Enhance contrast edges = ImageOps.autocontrast(edges) # Invert to get black lines on white edges = ImageOps.invert(edges) # Smooth the result edges = edges.filter(ImageFilter.SMOOTH_MORE) # Convert back to RGB return edges.convert("RGB") @spaces.GPU(duration=60) def generate_pose_transfer( reference_image, pose_image, prompt="", negative_prompt="", seed=42, randomize_seed=False, guidance_scale=3.5, num_inference_steps=28, lora_scale=1.0, enhance_pose=False, progress=gr.Progress(track_tqdm=True) ): """ Main generation function using RefControl approach. """ if pipe is None: return None, 0, "Model not loaded. Please check HF_TOKEN and restart the Space" if reference_image is None or pose_image is None: raise gr.Error("Please upload both reference and pose images") # Randomize seed if requested if randomize_seed: seed = random.randint(0, MAX_SEED) # Enhance pose if requested if enhance_pose: pose_image = extract_pose_edges(pose_image) # Prepare concatenated input concatenated_input = prepare_images_for_kontext(reference_image, pose_image) if concatenated_input is None: raise gr.Error("Failed to process images") # Construct prompt with trigger word if prompt: full_prompt = f"{TRIGGER_WORD}, {prompt}" else: full_prompt = f"{TRIGGER_WORD}, transfer the pose from the right image to the subject in the left image while maintaining their identity, clothing, and style" # Add instruction for the model full_prompt += ". The left image shows the reference subject, the right image shows the target pose." # Set generator for reproducibility generator = torch.Generator("cuda").manual_seed(seed) try: # Check if we have LoRA capabilities has_lora = hasattr(pipe, 'set_adapters') and "RefControl" in MODEL_STATUS with torch.autocast("cuda"): if has_lora: # Try to set LoRA adapter strength try: pipe.set_adapters(["refcontrol"], adapter_weights=[lora_scale]) except Exception as e: print(f"Could not set LoRA adapter: {e}") # Generate image based on pipeline type if "Kontext" in MODEL_STATUS: # Use Kontext pipeline result = pipe( image=concatenated_input, prompt=full_prompt, negative_prompt=negative_prompt if negative_prompt else None, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, width=concatenated_input.width, height=concatenated_input.height, ).images[0] else: # Use standard FLUX pipeline (image-to-image) result = pipe( prompt=full_prompt, image=concatenated_input, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, strength=0.85, # For img2img mode ).images[0] return result, seed, concatenated_input except Exception as e: raise gr.Error(f"Generation failed: {str(e)}") # CSS styling css = """ #col-container { margin: 0 auto; max-width: 1280px; } .header { text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 12px; margin-bottom: 20px; } .header h1 { color: white; margin: 0; font-size: 2em; } .status-box { padding: 10px; border-radius: 8px; margin: 10px 0; font-weight: bold; text-align: center; } .input-image { border: 2px solid #e0e0e0; border-radius: 8px; overflow: hidden; } .result-image { border: 3px solid #4CAF50; border-radius: 8px; overflow: hidden; } .info-box { background: #f0f0f0; padding: 10px; border-radius: 8px; margin: 10px 0; } """ # Create Gradio interface with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_id="col-container"): # Header gr.HTML("""
Transfer poses while preserving identity
peft to your requirements.txt file.