multimodalart HF Staff commited on
Commit
d73341d
·
verified ·
1 Parent(s): 7345819

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -19
app.py CHANGED
@@ -45,7 +45,9 @@ from huggingface_hub import hf_hub_download
45
 
46
  from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
47
  from ltx_core.quantization import QuantizationPolicy
 
48
  from ltx_pipelines.distilled import DistilledPipeline
 
49
  from ltx_pipelines.utils.args import ImageConditioningInput
50
  from ltx_pipelines.utils.media_io import encode_video
51
 
@@ -105,18 +107,6 @@ print("Pipeline ready!")
105
  print("=" * 80)
106
 
107
 
108
- class PrecomputedTextEncoder(torch.nn.Module):
109
- """Fake text encoder that returns pre-computed embeddings."""
110
-
111
- def __init__(self, video_context, audio_context):
112
- super().__init__()
113
- self.video_context = video_context
114
- self.audio_context = audio_context
115
-
116
- def forward(self, text, padding_side="left"):
117
- return self.video_context, self.audio_context, None
118
-
119
-
120
  @spaces.GPU(duration=120, size='xlarge')
121
  def generate_video(
122
  input_image,
@@ -178,7 +168,9 @@ def generate_video(
178
 
179
  embeddings = torch.load(embedding_path)
180
  video_context = embeddings["video_context"].to("cuda")
181
- audio_context = embeddings["audio_context"].to("cuda")
 
 
182
  print("Embeddings loaded successfully")
183
  except Exception as e:
184
  raise RuntimeError(
@@ -186,10 +178,15 @@ def generate_video(
186
  f"Please ensure {TEXT_ENCODER_SPACE} is running properly."
187
  )
188
 
189
- # Patch the model_ledger to return a fake text encoder with pre-computed embeddings
190
- fake_encoder = PrecomputedTextEncoder(video_context, audio_context)
191
- original_text_encoder_fn = pipeline.model_ledger.text_encoder
192
- pipeline.model_ledger.text_encoder = lambda: fake_encoder
 
 
 
 
 
193
 
194
  try:
195
  tiling_config = TilingConfig.default()
@@ -218,8 +215,8 @@ def generate_video(
218
 
219
  return str(output_path), current_seed
220
  finally:
221
- # Restore original text encoder method
222
- pipeline.model_ledger.text_encoder = original_text_encoder_fn
223
 
224
  except Exception as e:
225
  import traceback
 
45
 
46
  from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
47
  from ltx_core.quantization import QuantizationPolicy
48
+ from ltx_core.text_encoders.gemma.embeddings_processor import EmbeddingsProcessorOutput
49
  from ltx_pipelines.distilled import DistilledPipeline
50
+ from ltx_pipelines.utils import helpers as pipeline_helpers
51
  from ltx_pipelines.utils.args import ImageConditioningInput
52
  from ltx_pipelines.utils.media_io import encode_video
53
 
 
107
  print("=" * 80)
108
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  @spaces.GPU(duration=120, size='xlarge')
111
  def generate_video(
112
  input_image,
 
168
 
169
  embeddings = torch.load(embedding_path)
170
  video_context = embeddings["video_context"].to("cuda")
171
+ audio_context = embeddings["audio_context"]
172
+ if audio_context is not None:
173
+ audio_context = audio_context.to("cuda")
174
  print("Embeddings loaded successfully")
175
  except Exception as e:
176
  raise RuntimeError(
 
178
  f"Please ensure {TEXT_ENCODER_SPACE} is running properly."
179
  )
180
 
181
+ # Monkey-patch encode_prompts to return pre-computed embeddings
182
+ # instead of loading the text encoder + embeddings processor
183
+ precomputed = EmbeddingsProcessorOutput(
184
+ video_encoding=video_context,
185
+ audio_encoding=audio_context,
186
+ attention_mask=torch.ones(1, device="cuda"), # dummy mask
187
+ )
188
+ original_encode_prompts = pipeline_helpers.encode_prompts
189
+ pipeline_helpers.encode_prompts = lambda *args, **kwargs: [precomputed]
190
 
191
  try:
192
  tiling_config = TilingConfig.default()
 
215
 
216
  return str(output_path), current_seed
217
  finally:
218
+ # Restore original encode_prompts
219
+ pipeline_helpers.encode_prompts = original_encode_prompts
220
 
221
  except Exception as e:
222
  import traceback