import os import re import time import json import copy import random import requests import torch import cv2 import numpy as np import gradio as gr import spaces from PIL import Image from urllib.parse import quote # Disable Torch JIT compilation for compatibility torch.jit.script = lambda f: f # Model & Utilities import timm import diffusers from diffusers.utils import load_image from diffusers.models import ControlNetModel from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, UNet2DConditionModel from safetensors.torch import load_file from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download from insightface.app import FaceAnalysis from controlnet_aux import ZoeDetector from compel import Compel, ReturnedEmbeddingsType from gradio_imageslider import ImageSlider # Custom imports try: from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler except ImportError as e: print(f"Import Error: {e}. Check if modules exist or paths are correct.") exit() # Device setup device = "cuda" if torch.cuda.is_available() else "cpu" # Load LoRA configuration with open("sdxl_loras.json", "r") as file: sdxl_loras_raw = json.load(file) with open("defaults_data.json", "r") as file: lora_defaults = json.load(file) # Download required models CHECKPOINT_DIR = "/data/checkpoints" hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir=CHECKPOINT_DIR) hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=CHECKPOINT_DIR) hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir=CHECKPOINT_DIR) hf_hub_download(repo_id="latent-consistency/lcm-lora-sdxl", filename="pytorch_lora_weights.safetensors", local_dir=CHECKPOINT_DIR) # Download Antelopev2 Face Recognition model antelope_download = snapshot_download(repo_id="DIAMONIK7777/antelopev2", local_dir="/data/models/antelopev2") print("Antelopev2 Download Path:", antelope_download) # Initialize FaceAnalysis app = FaceAnalysis(name="antelopev2", root="/data", providers=["CPUExecutionProvider"]) app.prepare(ctx_id=0, det_size=(640, 640)) # Load identity & depth models face_adapter = os.path.join(CHECKPOINT_DIR, "ip-adapter.bin") controlnet_path = os.path.join(CHECKPOINT_DIR, "ControlNetModel") identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) zoedepthnet = ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0", torch_dtype=torch.float16) vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) # Load main pipeline pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained( "frankjoshua/albedobaseXL_v21", vae=vae, controlnet=[identitynet, zoedepthnet], torch_dtype=torch.float16 ) pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True) pipe.load_ip_adapter_instantid(face_adapter) pipe.set_ip_adapter_scale(0.8) # Initialize Compel for text conditioning compel = Compel( tokenizer=[pipe.tokenizer, pipe.tokenizer_2], text_encoder=[pipe.text_encoder, pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True] ) # Load ZoeDetector for depth estimation zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators") zoe.to(device) pipe.to(device) # LoRA Management last_lora = "" last_fused = False # --- Utility Functions --- def update_selection(selected_state, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative): index = selected_state.index lora_repo = sdxl_loras[index]["repo"] updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})" for lora_list in lora_defaults: if lora_list["model"] == lora_repo: face_strength = lora_list.get("face_strength", 0.85) image_strength = lora_list.get("image_strength", 0.15) weight = lora_list.get("weight", 0.9) depth_control_scale = lora_list.get("depth_control_scale", 0.8) negative = lora_list.get("negative", "") return ( updated_text, gr.update(placeholder="Type a prompt"), face_strength, image_strength, weight, depth_control_scale, negative, selected_state ) def center_crop_image(img): square_size = min(img.size) left = (img.width - square_size) // 2 top = (img.height - square_size) // 2 return img.crop((left, top, left + square_size, top + square_size)) def process_face(image): face_info = app.get(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)) face_info = sorted(face_info, key=lambda x: (x['bbox'][2]-x['bbox'][0]) * (x['bbox'][3]-x['bbox'][1]))[-1] face_emb = face_info['embedding'] face_kps = draw_kps(image, face_info['kps']) return face_emb, face_kps def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, lora_scale): global last_fused, last_lora if last_lora != repo_name and last_fused: pipe.unfuse_lora() pipe.unload_lora_weights() pipe.load_lora_weights(repo_name) pipe.fuse_lora(lora_scale) last_lora, last_fused = repo_name, True conditioning, pooled = compel(prompt) negative_conditioning, negative_pooled = compel(negative) if negative else (None, None) images = [face_kps, zoe(face_image).resize(face_kps.size)] return pipe( prompt_embeds=conditioning, pooled_prompt_embeds=pooled, negative_prompt_embeds=negative_conditioning, negative_pooled_prompt_embeds=negative_pooled, width=1024, height=1024, image_embeds=face_emb, image=face_image, strength=1-image_strength, control_image=images, num_inference_steps=20, guidance_scale=guidance_scale, controlnet_conditioning_scale=[face_strength, depth_control_scale] ).images[0] # --- UI Setup --- with gr.Blocks() as demo: photo = gr.Image(label="Upload a picture", interactive=True, type="pil", height=300) gallery = gr.Gallery(label="Pick a style", allow_preview=False, columns=4, height=550) prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt...") button = gr.Button("Run") result = ImageSlider(interactive=False, label="Generated Image") button.click(fn=generate_image, inputs=[prompt, gr.State(), gr.State()], outputs=result) demo.queue() demo.launch(share=True)