Update app.py
Browse files
app.py
CHANGED
|
@@ -118,30 +118,32 @@ def prepare_images_for_kontext(reference_image, pose_image, target_size=512):
|
|
| 118 |
|
| 119 |
return concatenated
|
| 120 |
|
| 121 |
-
def
|
| 122 |
"""
|
| 123 |
-
|
| 124 |
"""
|
| 125 |
-
if
|
| 126 |
return None
|
| 127 |
|
| 128 |
-
# Convert to grayscale
|
| 129 |
-
gray =
|
| 130 |
|
| 131 |
-
# Apply edge detection
|
| 132 |
edges = gray.filter(ImageFilter.FIND_EDGES)
|
|
|
|
| 133 |
|
| 134 |
-
#
|
| 135 |
-
edges = ImageOps.autocontrast(edges)
|
| 136 |
|
| 137 |
-
#
|
| 138 |
-
|
|
|
|
| 139 |
|
| 140 |
-
#
|
| 141 |
-
edges = edges.
|
|
|
|
| 142 |
|
| 143 |
-
|
| 144 |
-
return edges.convert("RGB")
|
| 145 |
|
| 146 |
@spaces.GPU(duration=60)
|
| 147 |
def generate_pose_transfer(
|
|
@@ -151,9 +153,10 @@ def generate_pose_transfer(
|
|
| 151 |
negative_prompt="",
|
| 152 |
seed=42,
|
| 153 |
randomize_seed=False,
|
| 154 |
-
guidance_scale=
|
| 155 |
num_inference_steps=28,
|
| 156 |
lora_scale=1.0,
|
|
|
|
| 157 |
enhance_pose=False,
|
| 158 |
progress=gr.Progress(track_tqdm=True)
|
| 159 |
):
|
|
@@ -205,14 +208,17 @@ def generate_pose_transfer(
|
|
| 205 |
if (width, height) != concatenated_input.size:
|
| 206 |
concatenated_input = concatenated_input.resize((width, height), Image.LANCZOS)
|
| 207 |
|
| 208 |
-
# Construct prompt with trigger word
|
|
|
|
|
|
|
|
|
|
| 209 |
if prompt:
|
| 210 |
-
full_prompt = f"{
|
| 211 |
else:
|
| 212 |
-
full_prompt =
|
| 213 |
|
| 214 |
-
# Add
|
| 215 |
-
full_prompt += ".
|
| 216 |
|
| 217 |
# Set generator for reproducibility
|
| 218 |
generator = torch.Generator("cuda").manual_seed(seed)
|
|
@@ -221,30 +227,33 @@ def generate_pose_transfer(
|
|
| 221 |
# Check if we have LoRA capabilities
|
| 222 |
has_lora = hasattr(pipe, 'set_adapters') and "LoRA" in MODEL_STATUS
|
| 223 |
|
| 224 |
-
# Set LoRA
|
| 225 |
if has_lora:
|
| 226 |
try:
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
| 229 |
except Exception as e:
|
| 230 |
print(f"LoRA adapter not set: {e}")
|
| 231 |
|
| 232 |
print(f"Generating with size: {width}x{height}")
|
| 233 |
-
print(f"Prompt: {full_prompt[:
|
| 234 |
|
| 235 |
-
# Generate image
|
| 236 |
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 237 |
if "Kontext" in MODEL_STATUS:
|
| 238 |
-
# Use Kontext pipeline
|
| 239 |
result = pipe(
|
| 240 |
image=concatenated_input,
|
| 241 |
prompt=full_prompt,
|
| 242 |
-
negative_prompt=negative_prompt if negative_prompt else "",
|
| 243 |
-
guidance_scale=guidance_scale,
|
| 244 |
num_inference_steps=num_inference_steps,
|
| 245 |
generator=generator,
|
| 246 |
width=width,
|
| 247 |
height=height,
|
|
|
|
| 248 |
).images[0]
|
| 249 |
else:
|
| 250 |
# Use standard FLUX pipeline
|
|
@@ -424,11 +433,11 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
| 424 |
|
| 425 |
guidance_scale = gr.Slider(
|
| 426 |
label="Guidance Scale",
|
| 427 |
-
minimum=
|
| 428 |
-
maximum=
|
| 429 |
step=0.5,
|
| 430 |
-
value=
|
| 431 |
-
info="
|
| 432 |
)
|
| 433 |
|
| 434 |
num_inference_steps = gr.Slider(
|
|
@@ -436,17 +445,17 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
| 436 |
minimum=20,
|
| 437 |
maximum=50,
|
| 438 |
step=1,
|
| 439 |
-
value=
|
| 440 |
)
|
| 441 |
|
| 442 |
if "LoRA" in MODEL_STATUS:
|
| 443 |
lora_scale = gr.Slider(
|
| 444 |
label="LoRA Strength",
|
| 445 |
-
minimum=0.
|
| 446 |
maximum=2.0,
|
| 447 |
step=0.1,
|
| 448 |
-
value=1.
|
| 449 |
-
info="RefControl LoRA influence"
|
| 450 |
)
|
| 451 |
else:
|
| 452 |
lora_scale = gr.Slider(
|
|
@@ -458,6 +467,15 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
| 458 |
interactive=False
|
| 459 |
)
|
| 460 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
enhance_pose = gr.Checkbox(
|
| 462 |
label="Auto-enhance pose edges",
|
| 463 |
value=False
|
|
|
|
| 118 |
|
| 119 |
return concatenated
|
| 120 |
|
| 121 |
+
def process_pose_for_control(pose_image):
|
| 122 |
"""
|
| 123 |
+
Process pose image to ensure maximum contrast and clarity for control
|
| 124 |
"""
|
| 125 |
+
if pose_image is None:
|
| 126 |
return None
|
| 127 |
|
| 128 |
+
# Convert to grayscale first
|
| 129 |
+
gray = pose_image.convert("L")
|
| 130 |
|
| 131 |
+
# Apply strong edge detection
|
| 132 |
edges = gray.filter(ImageFilter.FIND_EDGES)
|
| 133 |
+
edges = edges.filter(ImageFilter.EDGE_ENHANCE_MORE)
|
| 134 |
|
| 135 |
+
# Maximize contrast
|
| 136 |
+
edges = ImageOps.autocontrast(edges, cutoff=2)
|
| 137 |
|
| 138 |
+
# Convert to pure black and white
|
| 139 |
+
threshold = 128
|
| 140 |
+
edges = edges.point(lambda x: 255 if x > threshold else 0, mode='1')
|
| 141 |
|
| 142 |
+
# Convert back to RGB with inverted colors (black lines on white)
|
| 143 |
+
edges = edges.convert("RGB")
|
| 144 |
+
edges = ImageOps.invert(edges)
|
| 145 |
|
| 146 |
+
return edges
|
|
|
|
| 147 |
|
| 148 |
@spaces.GPU(duration=60)
|
| 149 |
def generate_pose_transfer(
|
|
|
|
| 153 |
negative_prompt="",
|
| 154 |
seed=42,
|
| 155 |
randomize_seed=False,
|
| 156 |
+
guidance_scale=7.5, # Increased for better pose adherence
|
| 157 |
num_inference_steps=28,
|
| 158 |
lora_scale=1.0,
|
| 159 |
+
controlnet_scale=1.0, # Added control strength
|
| 160 |
enhance_pose=False,
|
| 161 |
progress=gr.Progress(track_tqdm=True)
|
| 162 |
):
|
|
|
|
| 208 |
if (width, height) != concatenated_input.size:
|
| 209 |
concatenated_input = concatenated_input.resize((width, height), Image.LANCZOS)
|
| 210 |
|
| 211 |
+
# Construct prompt with trigger word - CRITICAL FOR POSE CONTROL
|
| 212 |
+
# The prompt must explicitly describe the pose transfer task
|
| 213 |
+
base_instruction = f"{TRIGGER_WORD}, A photo composed of two images side by side. Left: reference person. Right: target pose skeleton. Task: Generate the person from the left image in the exact pose shown in the right image"
|
| 214 |
+
|
| 215 |
if prompt:
|
| 216 |
+
full_prompt = f"{base_instruction}. Additional details: {prompt}"
|
| 217 |
else:
|
| 218 |
+
full_prompt = base_instruction
|
| 219 |
|
| 220 |
+
# Add strong pose control instructions
|
| 221 |
+
full_prompt += ". IMPORTANT: Strictly follow the pose/skeleton from the right image while preserving the identity, clothing, and appearance from the left image. The output should show ONLY the transformed person, not the side-by-side layout."
|
| 222 |
|
| 223 |
# Set generator for reproducibility
|
| 224 |
generator = torch.Generator("cuda").manual_seed(seed)
|
|
|
|
| 227 |
# Check if we have LoRA capabilities
|
| 228 |
has_lora = hasattr(pipe, 'set_adapters') and "LoRA" in MODEL_STATUS
|
| 229 |
|
| 230 |
+
# Set LoRA with higher strength for better pose control
|
| 231 |
if has_lora:
|
| 232 |
try:
|
| 233 |
+
# Increase LoRA strength for pose control
|
| 234 |
+
actual_lora_scale = lora_scale * 1.0 # Boost LoRA influence
|
| 235 |
+
pipe.set_adapters(["refcontrol"], adapter_weights=[actual_lora_scale])
|
| 236 |
+
print(f"LoRA adapter set with boosted strength: {actual_lora_scale}")
|
| 237 |
except Exception as e:
|
| 238 |
print(f"LoRA adapter not set: {e}")
|
| 239 |
|
| 240 |
print(f"Generating with size: {width}x{height}")
|
| 241 |
+
print(f"Prompt: {full_prompt[:200]}...")
|
| 242 |
|
| 243 |
+
# Generate image with stronger pose control
|
| 244 |
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 245 |
if "Kontext" in MODEL_STATUS:
|
| 246 |
+
# Use Kontext pipeline with enhanced settings
|
| 247 |
result = pipe(
|
| 248 |
image=concatenated_input,
|
| 249 |
prompt=full_prompt,
|
| 250 |
+
negative_prompt=negative_prompt if negative_prompt else "blurry, distorted, deformed, wrong pose, incorrect posture",
|
| 251 |
+
guidance_scale=guidance_scale, # Higher for better control
|
| 252 |
num_inference_steps=num_inference_steps,
|
| 253 |
generator=generator,
|
| 254 |
width=width,
|
| 255 |
height=height,
|
| 256 |
+
controlnet_conditioning_scale=controlnet_scale, # Control strength
|
| 257 |
).images[0]
|
| 258 |
else:
|
| 259 |
# Use standard FLUX pipeline
|
|
|
|
| 433 |
|
| 434 |
guidance_scale = gr.Slider(
|
| 435 |
label="Guidance Scale",
|
| 436 |
+
minimum=5.0,
|
| 437 |
+
maximum=15.0,
|
| 438 |
step=0.5,
|
| 439 |
+
value=7.5,
|
| 440 |
+
info="Higher = stricter pose following (7-10 recommended)"
|
| 441 |
)
|
| 442 |
|
| 443 |
num_inference_steps = gr.Slider(
|
|
|
|
| 445 |
minimum=20,
|
| 446 |
maximum=50,
|
| 447 |
step=1,
|
| 448 |
+
value=30
|
| 449 |
)
|
| 450 |
|
| 451 |
if "LoRA" in MODEL_STATUS:
|
| 452 |
lora_scale = gr.Slider(
|
| 453 |
label="LoRA Strength",
|
| 454 |
+
minimum=0.5,
|
| 455 |
maximum=2.0,
|
| 456 |
step=0.1,
|
| 457 |
+
value=1.2,
|
| 458 |
+
info="RefControl LoRA influence (1.0-1.5 recommended)"
|
| 459 |
)
|
| 460 |
else:
|
| 461 |
lora_scale = gr.Slider(
|
|
|
|
| 467 |
interactive=False
|
| 468 |
)
|
| 469 |
|
| 470 |
+
controlnet_scale = gr.Slider(
|
| 471 |
+
label="Pose Control Strength",
|
| 472 |
+
minimum=0.5,
|
| 473 |
+
maximum=2.0,
|
| 474 |
+
step=0.1,
|
| 475 |
+
value=1.0,
|
| 476 |
+
info="How strongly to enforce the pose"
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
enhance_pose = gr.Checkbox(
|
| 480 |
label="Auto-enhance pose edges",
|
| 481 |
value=False
|