Update code block
Browse files
    	
        README.md
    CHANGED
    
    | @@ -134,55 +134,32 @@ pip install git+github.com/huggingface/diffusers.git | |
| 134 | 
             
            ```py
         | 
| 135 | 
             
            import torch
         | 
| 136 | 
             
            from collections import deque
         | 
| 137 | 
            -
            from diffusers import ModularPipelineBlocks, FlowMatchEulerDiscreteScheduler
         | 
| 138 | 
             
            from diffusers.utils import export_to_video
         | 
|  | |
| 139 | 
             
            from diffusers.modular_pipelines import PipelineState, WanModularPipeline
         | 
| 140 |  | 
| 141 | 
            -
             | 
| 142 | 
            -
             | 
| 143 | 
            -
             | 
| 144 | 
            -
                    return 60
         | 
| 145 | 
            -
             | 
| 146 | 
            -
                @property
         | 
| 147 | 
            -
                def default_sample_width(self):
         | 
| 148 | 
            -
                    return 104
         | 
| 149 | 
            -
             | 
| 150 | 
            -
                @property
         | 
| 151 | 
            -
                def frame_seq_length(self):
         | 
| 152 | 
            -
                    return 1560
         | 
| 153 | 
            -
             | 
| 154 | 
            -
                @property
         | 
| 155 | 
            -
                def seq_length(self):
         | 
| 156 | 
            -
                    return 32760
         | 
| 157 | 
            -
             | 
| 158 | 
            -
                @property
         | 
| 159 | 
            -
                def kv_cache_num_frames(self):
         | 
| 160 | 
            -
                    return 3
         | 
| 161 | 
            -
             | 
| 162 | 
            -
                @property
         | 
| 163 | 
            -
                def frame_cache_len(self):
         | 
| 164 | 
            -
                    return 1 + (self.kv_cache_num_frames - 1) * 4
         | 
| 165 | 
            -
             | 
| 166 | 
            -
             | 
| 167 | 
            -
            block_path = "krea/krea-realtime-video"
         | 
| 168 | 
            -
            blocks = ModularPipelineBlocks.from_pretrained(block_path, trust_remote_code=True)
         | 
| 169 | 
            -
            pipe = WanRTStreamingPipeline(blocks, block_path)
         | 
| 170 |  | 
| 171 | 
             
            pipe.load_components(
         | 
| 172 | 
             
                trust_remote_code=True,
         | 
| 173 | 
             
                device_map="cuda",
         | 
| 174 | 
            -
                torch_dtype={"default": torch.bfloat16, "vae": torch. | 
| 175 | 
             
            )
         | 
| 176 | 
            -
            pipe.scheduler = FlowMatchEulerDiscreteScheduler(shift=5.0)
         | 
| 177 | 
            -
             | 
| 178 | 
            -
            prompt = ["A cat sitting on a boat"]
         | 
| 179 |  | 
| 180 | 
             
            num_frames_per_block = 3
         | 
| 181 | 
             
            num_blocks = 9
         | 
| 182 |  | 
| 183 | 
             
            frames = []
         | 
| 184 | 
             
            state = PipelineState()
         | 
| 185 | 
            -
            state.set("frame_cache_context", deque(maxlen=pipe.frame_cache_len))
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 186 | 
             
            for block_idx in range(num_blocks):
         | 
| 187 | 
             
                state = pipe(
         | 
| 188 | 
             
                    state,
         | 
| @@ -191,8 +168,9 @@ for block_idx in range(num_blocks): | |
| 191 | 
             
                    num_blocks=num_blocks,
         | 
| 192 | 
             
                    num_frames_per_block=num_frames_per_block,
         | 
| 193 | 
             
                    block_idx=block_idx,
         | 
|  | |
| 194 | 
             
                )
         | 
| 195 | 
             
                frames.extend(state.values["videos"][0])
         | 
| 196 |  | 
| 197 | 
            -
            export_to_video(frames, " | 
| 198 | 
             
            ``` 
         | 
|  | |
| 134 | 
             
            ```py
         | 
| 135 | 
             
            import torch
         | 
| 136 | 
             
            from collections import deque
         | 
|  | |
| 137 | 
             
            from diffusers.utils import export_to_video
         | 
| 138 | 
            +
            from diffusers import ModularPipelineBlocks
         | 
| 139 | 
             
            from diffusers.modular_pipelines import PipelineState, WanModularPipeline
         | 
| 140 |  | 
| 141 | 
            +
            repo_id = "krea/krea-realtime-video"
         | 
| 142 | 
            +
            blocks = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)
         | 
| 143 | 
            +
            pipe = WanModularPipeline(blocks, repo_id)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 144 |  | 
| 145 | 
             
            pipe.load_components(
         | 
| 146 | 
             
                trust_remote_code=True,
         | 
| 147 | 
             
                device_map="cuda",
         | 
| 148 | 
            +
                torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
         | 
| 149 | 
             
            )
         | 
|  | |
|  | |
|  | |
| 150 |  | 
| 151 | 
             
            num_frames_per_block = 3
         | 
| 152 | 
             
            num_blocks = 9
         | 
| 153 |  | 
| 154 | 
             
            frames = []
         | 
| 155 | 
             
            state = PipelineState()
         | 
| 156 | 
            +
            state.set("frame_cache_context", deque(maxlen=pipe.config.frame_cache_len))
         | 
| 157 | 
            +
             | 
| 158 | 
            +
            prompt = ["a cat sitting on a boat"]
         | 
| 159 | 
            +
             | 
| 160 | 
            +
            for block in pipe.transformer.blocks:
         | 
| 161 | 
            +
                block.self_attn.fuse_projections()
         | 
| 162 | 
            +
             | 
| 163 | 
             
            for block_idx in range(num_blocks):
         | 
| 164 | 
             
                state = pipe(
         | 
| 165 | 
             
                    state,
         | 
|  | |
| 168 | 
             
                    num_blocks=num_blocks,
         | 
| 169 | 
             
                    num_frames_per_block=num_frames_per_block,
         | 
| 170 | 
             
                    block_idx=block_idx,
         | 
| 171 | 
            +
                    generator=torch.Generator("cuda").manual_seed(42),
         | 
| 172 | 
             
                )
         | 
| 173 | 
             
                frames.extend(state.values["videos"][0])
         | 
| 174 |  | 
| 175 | 
            +
            export_to_video(frames, "output.mp4", fps=16)
         | 
| 176 | 
             
            ``` 
         | 

