Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -12,6 +12,14 @@ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline | |
| 12 | 
             
            from PIL import Image, ImageDraw
         | 
| 13 | 
             
            import numpy as np
         | 
| 14 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 15 | 
             
            config_file = hf_hub_download(
         | 
| 16 | 
             
                "xinsir/controlnet-union-sdxl-1.0",
         | 
| 17 | 
             
                filename="config_promax.json",
         | 
| @@ -37,20 +45,19 @@ result = ControlNetModel_Union._load_pretrained_model( | |
| 37 |  | 
| 38 | 
             
            # Use the first element from the result
         | 
| 39 | 
             
            model = result[0]
         | 
| 40 | 
            -
            model = model.to(device= | 
| 41 | 
            -
             | 
| 42 |  | 
| 43 | 
             
            vae = AutoencoderKL.from_pretrained(
         | 
| 44 | 
            -
                "madebyollin/sdxl-vae-fp16-fix", torch_dtype= | 
| 45 | 
            -
            ).to( | 
| 46 |  | 
| 47 | 
             
            pipe = StableDiffusionXLFillPipeline.from_pretrained(
         | 
| 48 | 
             
                "SG161222/RealVisXL_V5.0_Lightning",
         | 
| 49 | 
            -
                torch_dtype= | 
| 50 | 
             
                vae=vae,
         | 
| 51 | 
             
                controlnet=model,
         | 
| 52 | 
            -
                variant="fp16",
         | 
| 53 | 
            -
            ).to( | 
| 54 |  | 
| 55 | 
             
            pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
         | 
| 56 |  | 
| @@ -152,7 +159,6 @@ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_opti | |
| 152 | 
             
                elif alignment == "Bottom":
         | 
| 153 | 
             
                    bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
         | 
| 154 |  | 
| 155 | 
            -
             | 
| 156 | 
             
                # Draw the mask
         | 
| 157 | 
             
                mask_draw.rectangle([
         | 
| 158 | 
             
                    (left_overlap, top_overlap),
         | 
| @@ -181,39 +187,45 @@ def preview_image_and_mask(image, width, height, overlap_percentage, resize_opti | |
| 181 |  | 
| 182 | 
             
            @spaces.GPU(duration=24)
         | 
| 183 | 
             
            def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
         | 
| 184 | 
            -
                 | 
| 185 | 
            -
             | 
| 186 | 
            -
             | 
| 187 | 
            -
                     | 
| 188 | 
            -
             | 
| 189 | 
            -
             | 
| 190 | 
            -
             | 
| 191 | 
            -
             | 
| 192 | 
            -
             | 
| 193 | 
            -
             | 
| 194 | 
            -
             | 
| 195 | 
            -
             | 
| 196 | 
            -
                    (
         | 
| 197 | 
            -
                         | 
| 198 | 
            -
             | 
| 199 | 
            -
             | 
| 200 | 
            -
             | 
| 201 | 
            -
             | 
| 202 | 
            -
             | 
| 203 | 
            -
             | 
| 204 | 
            -
                         | 
| 205 | 
            -
             | 
| 206 | 
            -
             | 
| 207 | 
            -
             | 
| 208 | 
            -
             | 
| 209 | 
            -
             | 
| 210 | 
            -
             | 
| 211 | 
            -
                         | 
| 212 | 
            -
             | 
| 213 | 
            -
             | 
| 214 | 
            -
             | 
| 215 | 
            -
             | 
| 216 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 217 |  | 
| 218 | 
             
            def clear_result():
         | 
| 219 | 
             
                """Clears the result ImageSlider."""
         | 
| @@ -253,9 +265,21 @@ def update_history(new_image, history): | |
| 253 | 
             
                """Updates the history gallery with the new image."""
         | 
| 254 | 
             
                if history is None:
         | 
| 255 | 
             
                    history = []
         | 
| 256 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 257 | 
             
                return history
         | 
| 258 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 259 | 
             
            css = """
         | 
| 260 | 
             
            .gradio-container {
         | 
| 261 | 
             
                width: 1200px !important;
         | 
| @@ -358,8 +382,6 @@ with gr.Blocks(theme="soft", css=css) as demo: | |
| 358 | 
             
                                inputs=[input_image, width_slider, height_slider, alignment_dropdown],
         | 
| 359 | 
             
                            )
         | 
| 360 |  | 
| 361 | 
            -
                            
         | 
| 362 | 
            -
             | 
| 363 | 
             
                        with gr.Column():
         | 
| 364 | 
             
                            result = ImageSlider(
         | 
| 365 | 
             
                                interactive=False,
         | 
| @@ -370,11 +392,11 @@ with gr.Blocks(theme="soft", css=css) as demo: | |
| 370 | 
             
                            history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
         | 
| 371 | 
             
                            preview_image = gr.Image(label="Preview")
         | 
| 372 |  | 
| 373 | 
            -
                    
         | 
| 374 | 
            -
             | 
| 375 | 
             
                def use_output_as_input(output_image):
         | 
| 376 | 
             
                    """Sets the generated output as the new input image."""
         | 
| 377 | 
            -
                     | 
|  | |
|  | |
| 378 |  | 
| 379 | 
             
                use_as_input_button.click(
         | 
| 380 | 
             
                    fn=use_output_as_input,
         | 
| @@ -421,7 +443,7 @@ with gr.Blocks(theme="soft", css=css) as demo: | |
| 421 | 
             
                            overlap_left, overlap_right, overlap_top, overlap_bottom],
         | 
| 422 | 
             
                    outputs=result,
         | 
| 423 | 
             
                ).then(  # Update the history gallery
         | 
| 424 | 
            -
                    fn= | 
| 425 | 
             
                    inputs=[result, history_gallery],
         | 
| 426 | 
             
                    outputs=history_gallery,
         | 
| 427 | 
             
                ).then(  # Show the "Use as Input Image" button
         | 
| @@ -441,7 +463,7 @@ with gr.Blocks(theme="soft", css=css) as demo: | |
| 441 | 
             
                            overlap_left, overlap_right, overlap_top, overlap_bottom],
         | 
| 442 | 
             
                    outputs=result,
         | 
| 443 | 
             
                ).then(  # Update the history gallery
         | 
| 444 | 
            -
                    fn= | 
| 445 | 
             
                    inputs=[result, history_gallery],
         | 
| 446 | 
             
                    outputs=history_gallery,
         | 
| 447 | 
             
                ).then(  # Show the "Use as Input Image" button
         | 
|  | |
| 12 | 
             
            from PIL import Image, ImageDraw
         | 
| 13 | 
             
            import numpy as np
         | 
| 14 |  | 
| 15 | 
            +
            # Initialize CUDA if available
         | 
| 16 | 
            +
            if torch.cuda.is_available():
         | 
| 17 | 
            +
                device = "cuda"
         | 
| 18 | 
            +
                dtype = torch.float16
         | 
| 19 | 
            +
            else:
         | 
| 20 | 
            +
                device = "cpu"
         | 
| 21 | 
            +
                dtype = torch.float32
         | 
| 22 | 
            +
             | 
| 23 | 
             
            config_file = hf_hub_download(
         | 
| 24 | 
             
                "xinsir/controlnet-union-sdxl-1.0",
         | 
| 25 | 
             
                filename="config_promax.json",
         | 
|  | |
| 45 |  | 
| 46 | 
             
            # Use the first element from the result
         | 
| 47 | 
             
            model = result[0]
         | 
| 48 | 
            +
            model = model.to(device=device, dtype=dtype)
         | 
|  | |
| 49 |  | 
| 50 | 
             
            vae = AutoencoderKL.from_pretrained(
         | 
| 51 | 
            +
                "madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype
         | 
| 52 | 
            +
            ).to(device)
         | 
| 53 |  | 
| 54 | 
             
            pipe = StableDiffusionXLFillPipeline.from_pretrained(
         | 
| 55 | 
             
                "SG161222/RealVisXL_V5.0_Lightning",
         | 
| 56 | 
            +
                torch_dtype=dtype,
         | 
| 57 | 
             
                vae=vae,
         | 
| 58 | 
             
                controlnet=model,
         | 
| 59 | 
            +
                variant="fp16" if dtype == torch.float16 else None,
         | 
| 60 | 
            +
            ).to(device)
         | 
| 61 |  | 
| 62 | 
             
            pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
         | 
| 63 |  | 
|  | |
| 159 | 
             
                elif alignment == "Bottom":
         | 
| 160 | 
             
                    bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
         | 
| 161 |  | 
|  | |
| 162 | 
             
                # Draw the mask
         | 
| 163 | 
             
                mask_draw.rectangle([
         | 
| 164 | 
             
                    (left_overlap, top_overlap),
         | 
|  | |
| 187 |  | 
| 188 | 
             
            @spaces.GPU(duration=24)
         | 
| 189 | 
             
            def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
         | 
| 190 | 
            +
                try:
         | 
| 191 | 
            +
                    background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
         | 
| 192 | 
            +
                    
         | 
| 193 | 
            +
                    if not can_expand(background.width, background.height, width, height, alignment):
         | 
| 194 | 
            +
                        alignment = "Middle"
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    cnet_image = background.copy()
         | 
| 197 | 
            +
                    cnet_image.paste(0, (0, 0), mask)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    final_prompt = f"{prompt_input} , high quality, 4k" if prompt_input else "high quality, 4k"
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    # Use with torch.autocast to ensure consistent dtype
         | 
| 202 | 
            +
                    with torch.autocast(device_type=device, dtype=dtype):
         | 
| 203 | 
            +
                        (
         | 
| 204 | 
            +
                            prompt_embeds,
         | 
| 205 | 
            +
                            negative_prompt_embeds,
         | 
| 206 | 
            +
                            pooled_prompt_embeds,
         | 
| 207 | 
            +
                            negative_pooled_prompt_embeds,
         | 
| 208 | 
            +
                        ) = pipe.encode_prompt(final_prompt, device, True)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                        for image in pipe(
         | 
| 211 | 
            +
                            prompt_embeds=prompt_embeds,
         | 
| 212 | 
            +
                            negative_prompt_embeds=negative_prompt_embeds,
         | 
| 213 | 
            +
                            pooled_prompt_embeds=pooled_prompt_embeds,
         | 
| 214 | 
            +
                            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
         | 
| 215 | 
            +
                            image=cnet_image,
         | 
| 216 | 
            +
                            num_inference_steps=num_inference_steps
         | 
| 217 | 
            +
                        ):
         | 
| 218 | 
            +
                            yield cnet_image, image
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    image = image.convert("RGBA")
         | 
| 221 | 
            +
                    cnet_image.paste(image, (0, 0), mask)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    yield background, cnet_image
         | 
| 224 | 
            +
                    
         | 
| 225 | 
            +
                except Exception as e:
         | 
| 226 | 
            +
                    print(f"Error in infer function: {e}")
         | 
| 227 | 
            +
                    # Return a placeholder or error image
         | 
| 228 | 
            +
                    yield None, None
         | 
| 229 |  | 
| 230 | 
             
            def clear_result():
         | 
| 231 | 
             
                """Clears the result ImageSlider."""
         | 
|  | |
| 265 | 
             
                """Updates the history gallery with the new image."""
         | 
| 266 | 
             
                if history is None:
         | 
| 267 | 
             
                    history = []
         | 
| 268 | 
            +
                
         | 
| 269 | 
            +
                # Check if new_image is valid and has the expected structure
         | 
| 270 | 
            +
                if new_image is not None and isinstance(new_image, (tuple, list)) and len(new_image) > 1:
         | 
| 271 | 
            +
                    if new_image[1] is not None:  # Check if the second element exists
         | 
| 272 | 
            +
                        history.insert(0, new_image[1])
         | 
| 273 | 
            +
                
         | 
| 274 | 
             
                return history
         | 
| 275 |  | 
| 276 | 
            +
            # Safe wrapper for update_history to handle None values
         | 
| 277 | 
            +
            def safe_update_history(result, history):
         | 
| 278 | 
            +
                """Safely updates the history gallery with the new image."""
         | 
| 279 | 
            +
                if result is None:
         | 
| 280 | 
            +
                    return history
         | 
| 281 | 
            +
                return update_history(result, history)
         | 
| 282 | 
            +
             | 
| 283 | 
             
            css = """
         | 
| 284 | 
             
            .gradio-container {
         | 
| 285 | 
             
                width: 1200px !important;
         | 
|  | |
| 382 | 
             
                                inputs=[input_image, width_slider, height_slider, alignment_dropdown],
         | 
| 383 | 
             
                            )
         | 
| 384 |  | 
|  | |
|  | |
| 385 | 
             
                        with gr.Column():
         | 
| 386 | 
             
                            result = ImageSlider(
         | 
| 387 | 
             
                                interactive=False,
         | 
|  | |
| 392 | 
             
                            history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
         | 
| 393 | 
             
                            preview_image = gr.Image(label="Preview")
         | 
| 394 |  | 
|  | |
|  | |
| 395 | 
             
                def use_output_as_input(output_image):
         | 
| 396 | 
             
                    """Sets the generated output as the new input image."""
         | 
| 397 | 
            +
                    if output_image is not None and isinstance(output_image, (tuple, list)) and len(output_image) > 1:
         | 
| 398 | 
            +
                        return gr.update(value=output_image[1])
         | 
| 399 | 
            +
                    return gr.update()
         | 
| 400 |  | 
| 401 | 
             
                use_as_input_button.click(
         | 
| 402 | 
             
                    fn=use_output_as_input,
         | 
|  | |
| 443 | 
             
                            overlap_left, overlap_right, overlap_top, overlap_bottom],
         | 
| 444 | 
             
                    outputs=result,
         | 
| 445 | 
             
                ).then(  # Update the history gallery
         | 
| 446 | 
            +
                    fn=safe_update_history,
         | 
| 447 | 
             
                    inputs=[result, history_gallery],
         | 
| 448 | 
             
                    outputs=history_gallery,
         | 
| 449 | 
             
                ).then(  # Show the "Use as Input Image" button
         | 
|  | |
| 463 | 
             
                            overlap_left, overlap_right, overlap_top, overlap_bottom],
         | 
| 464 | 
             
                    outputs=result,
         | 
| 465 | 
             
                ).then(  # Update the history gallery
         | 
| 466 | 
            +
                    fn=safe_update_history,
         | 
| 467 | 
             
                    inputs=[result, history_gallery],
         | 
| 468 | 
             
                    outputs=history_gallery,
         | 
| 469 | 
             
                ).then(  # Show the "Use as Input Image" button
         | 
 
			

