multimodalart HF Staff commited on
Commit
96923a5
·
verified ·
1 Parent(s): 1e6bffc

Update code block

Browse files
Files changed (1) hide show
  1. README.md +14 -36
README.md CHANGED
@@ -134,55 +134,32 @@ pip install git+github.com/huggingface/diffusers.git
134
  ```py
135
  import torch
136
  from collections import deque
137
- from diffusers import ModularPipelineBlocks, FlowMatchEulerDiscreteScheduler
138
  from diffusers.utils import export_to_video
 
139
  from diffusers.modular_pipelines import PipelineState, WanModularPipeline
140
 
141
- class WanRTStreamingPipeline(WanModularPipeline):
142
- @property
143
- def default_sample_height(self):
144
- return 60
145
-
146
- @property
147
- def default_sample_width(self):
148
- return 104
149
-
150
- @property
151
- def frame_seq_length(self):
152
- return 1560
153
-
154
- @property
155
- def seq_length(self):
156
- return 32760
157
-
158
- @property
159
- def kv_cache_num_frames(self):
160
- return 3
161
-
162
- @property
163
- def frame_cache_len(self):
164
- return 1 + (self.kv_cache_num_frames - 1) * 4
165
-
166
-
167
- block_path = "krea/krea-realtime-video"
168
- blocks = ModularPipelineBlocks.from_pretrained(block_path, trust_remote_code=True)
169
- pipe = WanRTStreamingPipeline(blocks, block_path)
170
 
171
  pipe.load_components(
172
  trust_remote_code=True,
173
  device_map="cuda",
174
- torch_dtype={"default": torch.bfloat16, "vae": torch.float32},
175
  )
176
- pipe.scheduler = FlowMatchEulerDiscreteScheduler(shift=5.0)
177
-
178
- prompt = ["A cat sitting on a boat"]
179
 
180
  num_frames_per_block = 3
181
  num_blocks = 9
182
 
183
  frames = []
184
  state = PipelineState()
185
- state.set("frame_cache_context", deque(maxlen=pipe.frame_cache_len))
 
 
 
 
 
 
186
  for block_idx in range(num_blocks):
187
  state = pipe(
188
  state,
@@ -191,8 +168,9 @@ for block_idx in range(num_blocks):
191
  num_blocks=num_blocks,
192
  num_frames_per_block=num_frames_per_block,
193
  block_idx=block_idx,
 
194
  )
195
  frames.extend(state.values["videos"][0])
196
 
197
- export_to_video(frames, "krt.mp4")
198
  ```
 
134
  ```py
135
  import torch
136
  from collections import deque
 
137
  from diffusers.utils import export_to_video
138
+ from diffusers import ModularPipelineBlocks
139
  from diffusers.modular_pipelines import PipelineState, WanModularPipeline
140
 
141
+ repo_id = "krea/krea-realtime-video"
142
+ blocks = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)
143
+ pipe = WanModularPipeline(blocks, repo_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  pipe.load_components(
146
  trust_remote_code=True,
147
  device_map="cuda",
148
+ torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
149
  )
 
 
 
150
 
151
  num_frames_per_block = 3
152
  num_blocks = 9
153
 
154
  frames = []
155
  state = PipelineState()
156
+ state.set("frame_cache_context", deque(maxlen=pipe.config.frame_cache_len))
157
+
158
+ prompt = ["a cat sitting on a boat"]
159
+
160
+ for block in pipe.transformer.blocks:
161
+ block.self_attn.fuse_projections()
162
+
163
  for block_idx in range(num_blocks):
164
  state = pipe(
165
  state,
 
168
  num_blocks=num_blocks,
169
  num_frames_per_block=num_frames_per_block,
170
  block_idx=block_idx,
171
+ generator=torch.Generator("cuda").manual_seed(42),
172
  )
173
  frames.extend(state.values["videos"][0])
174
 
175
+ export_to_video(frames, "output.mp4", fps=16)
176
  ```