| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						from diffusers import DiffusionPipeline   | 
					
					
						
						| 
							 | 
						from PIL import Image | 
					
					
						
						| 
							 | 
						from io import BytesIO | 
					
					
						
						| 
							 | 
						import base64 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						from custom_pipeline import WanTransformer3DModel, AutoencoderKLWan | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						model_id = "grnr9730/Wan2.1-I2V-14B-720P-Diffusers" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						class WanImageToVideoPipeline(DiffusionPipeline): | 
					
					
						
						| 
							 | 
						    def __init__(self, transformer, vae, scheduler): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.transformer = transformer | 
					
					
						
						| 
							 | 
						        self.vae = vae | 
					
					
						
						| 
							 | 
						        self.scheduler = scheduler | 
					
					
						
						| 
							 | 
						        self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __call__(self, prompt, **kwargs): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        latents = self.vae.encode(torch.randn(1, 3, 224, 224)).latent_dist.sample() | 
					
					
						
						| 
							 | 
						        for _ in self.scheduler.timesteps: | 
					
					
						
						| 
							 | 
						            latents = self.transformer(latents) | 
					
					
						
						| 
							 | 
						        video_frames = self.vae.decode(latents).sample | 
					
					
						
						| 
							 | 
						        return type('Result', (), {'frames': [Image.fromarray((frame * 255).byte().cpu().numpy()) for frame in video_frames]}) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						pipe = WanImageToVideoPipeline.from_pretrained( | 
					
					
						
						| 
							 | 
						    model_id, | 
					
					
						
						| 
							 | 
						    transformer=WanTransformer3DModel.from_pretrained(model_id), | 
					
					
						
						| 
							 | 
						    vae=AutoencoderKLWan.from_pretrained(model_id), | 
					
					
						
						| 
							 | 
						    scheduler=FlowMatchEulerDiscreteScheduler.from_pretrained(model_id), | 
					
					
						
						| 
							 | 
						    torch_dtype=torch.bfloat16 | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						pipe.enable_model_cpu_offload() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def infer(data): | 
					
					
						
						| 
							 | 
						    prompt = data.get("prompt", "A futuristic cityscape") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    video = pipe(prompt).frames   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    first_frame = video[0] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    buffered = BytesIO() | 
					
					
						
						| 
							 | 
						    first_frame.save(buffered, format="PNG") | 
					
					
						
						| 
							 | 
						    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return {"image": img_str} |