CARLEXsX commited on
Commit
71d6826
·
verified ·
1 Parent(s): 1c44f8b

Update ltx_manager_helpers.py

Browse files
Files changed (1) hide show
  1. ltx_manager_helpers.py +23 -171
ltx_manager_helpers.py CHANGED
@@ -1,126 +1,8 @@
1
  # ltx_manager_helpers.py
2
- # Gerente de Pool de Workers LTX para revezamento assíncrono em múltiplas GPUs.
3
- # Este arquivo é parte do projeto Euia-AducSdr e está sob a licença AGPL v3.
4
- # Copyright (C) 4 de Agosto de 2025 Carlos Rodrigues dos Santos
5
-
6
- import torch
7
- import gc
8
- import os
9
- import yaml
10
- import numpy as np
11
- import imageio
12
- from pathlib import Path
13
- import huggingface_hub
14
- import threading
15
- from PIL import Image
16
-
17
- # Importa as funções e classes necessárias do inference.py
18
- from inference import (
19
- create_ltx_video_pipeline,
20
- create_latent_upsampler,
21
- ConditioningItem,
22
- calculate_padding,
23
- prepare_conditioning
24
- )
25
- from ltx_video.pipelines.pipeline_ltx_video import LTXMultiScalePipeline
26
-
27
- class LtxWorker:
28
- """
29
- Representa uma única instância do pipeline LTX, associada a uma GPU específica.
30
- O pipeline é carregado na CPU por padrão e movido para a GPU sob demanda.
31
- """
32
- def __init__(self, device_id='cuda:0'):
33
- self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
34
- print(f"LTX Worker: Inicializando para o dispositivo {self.device} (carregando na CPU)...")
35
-
36
- config_file_path = "configs/ltxv-13b-0.9.8-distilled.yaml"
37
- with open(config_file_path, "r") as file:
38
- self.config = yaml.safe_load(file)
39
-
40
- LTX_REPO = "Lightricks/LTX-Video"
41
- models_dir = "downloaded_models_gradio"
42
-
43
- model_actual_path = huggingface_hub.hf_hub_download(
44
- repo_id=LTX_REPO,
45
- filename=self.config["checkpoint_path"],
46
- local_dir=models_dir,
47
- local_dir_use_symlinks=False
48
- )
49
-
50
- self.pipeline = create_ltx_video_pipeline(
51
- ckpt_path=model_actual_path,
52
- precision=self.config["precision"],
53
- text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
54
- sampler=self.config["sampler"],
55
- device='cpu'
56
- )
57
-
58
- print(f"LTX Worker para {self.device}: Compilando o transformer (isso pode levar um momento)...")
59
- self.pipeline.transformer.to(memory_format=torch.channels_last)
60
- try:
61
- # Usando um modo de compilação menos agressivo e mais compatível.
62
- self.pipeline.transformer = torch.compile(self.pipeline.transformer, mode="reduce-overhead", fullgraph=True)
63
- print(f"LTX Worker para {self.device}: Transformer compilado com sucesso.")
64
- except Exception as e:
65
- print(f"AVISO: A compilação do Transformer falhou em {self.device}: {e}. Continuando sem compilação.")
66
-
67
- self.latent_upsampler = None
68
- if self.config.get("pipeline_type") == "multi-scale":
69
- print(f"LTX Worker para {self.device}: Carregando Latent Upsampler (Multi-Scale)...")
70
- upscaler_path = huggingface_hub.hf_hub_download(
71
- repo_id=LTX_REPO,
72
- filename=self.config["spatial_upscaler_model_path"],
73
- local_dir=models_dir,
74
- local_dir_use_symlinks=False
75
- )
76
- self.latent_upsampler = create_latent_upsampler(upscaler_path, 'cpu')
77
-
78
- print(f"LTX Worker para {self.device} pronto na CPU.")
79
-
80
- def to_gpu(self):
81
- """Move o pipeline e o upsampler para a GPU designada."""
82
- if self.device.type == 'cpu': return
83
- print(f"LTX Worker: Movendo pipeline para {self.device}...")
84
- self.pipeline.to(self.device)
85
- if self.latent_upsampler:
86
- print(f"LTX Worker: Movendo Latent Upsampler para {self.device}...")
87
- self.latent_upsampler.to(self.device)
88
- print(f"LTX Worker: Pipeline na GPU {self.device}.")
89
-
90
- def to_cpu(self):
91
- """Move o pipeline de volta para a CPU e limpa a memória da GPU."""
92
- if self.device.type == 'cpu': return
93
- print(f"LTX Worker: Descarregando pipeline da GPU {self.device}...")
94
- self.pipeline.to('cpu')
95
- if self.latent_upsampler:
96
- self.latent_upsampler.to('cpu')
97
- gc.collect()
98
- if torch.cuda.is_available():
99
- torch.cuda.empty_cache()
100
- print(f"LTX Worker: GPU {self.device} limpa.")
101
-
102
- def generate_video_fragment_internal(self, **kwargs):
103
- """A lógica real da geração de vídeo, que espera estar na GPU."""
104
- return self.pipeline(**kwargs)
105
 
106
  class LtxPoolManager:
107
- """
108
- Gerencia um pool de LtxWorkers, orquestrando um revezamento entre GPUs
109
- para permitir que a limpeza de uma GPU ocorra em paralelo com a computação em outra.
110
- """
111
- def __init__(self, device_ids=['cuda:2', 'cuda:3']):
112
- print(f"LTX POOL MANAGER: Criando workers para os dispositivos: {device_ids}")
113
- self.workers = [LtxWorker(device_id) for device_id in device_ids]
114
- self.current_worker_index = 0
115
- self.lock = threading.Lock()
116
- self.last_cleanup_thread = None
117
-
118
- def _cleanup_worker(self, worker):
119
- """Função alvo para a thread de limpeza."""
120
- print(f"CLEANUP THREAD: Iniciando limpeza da GPU {worker.device} em background...")
121
- worker.to_cpu()
122
- print(f"CLEANUP THREAD: Limpeza da GPU {worker.device} concluída.")
123
-
124
  def generate_video_fragment(
125
  self,
126
  motion_prompt: str, conditioning_items_data: list,
@@ -131,44 +13,10 @@ class LtxPoolManager:
131
  ):
132
  worker_to_use = None
133
  try:
134
- with self.lock:
135
- if self.last_cleanup_thread and self.last_cleanup_thread.is_alive():
136
- print("LTX POOL MANAGER: Aguardando limpeza da GPU anterior...")
137
- self.last_cleanup_thread.join()
138
-
139
- worker_to_use = self.workers[self.current_worker_index]
140
- previous_worker_index = (self.current_worker_index - 1 + len(self.workers)) % len(self.workers)
141
- worker_to_cleanup = self.workers[previous_worker_index]
142
-
143
- cleanup_thread = threading.Thread(target=self._cleanup_worker, args=(worker_to_cleanup,))
144
- cleanup_thread.start()
145
- self.last_cleanup_thread = cleanup_thread
146
-
147
- worker_to_use.to_gpu()
148
-
149
- self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
150
-
151
- target_device = worker_to_use.device
152
 
153
- if use_attention_slicing:
154
- worker_to_use.pipeline.enable_attention_slicing()
155
-
156
- media_paths = [item[0] for item in conditioning_items_data]
157
- start_frames = [item[1] for item in conditioning_items_data]
158
- strengths = [item[2] for item in conditioning_items_data]
159
-
160
- padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
161
- padding_vals = calculate_padding(height, width, padded_h, padded_w)
162
-
163
- conditioning_items = prepare_conditioning(
164
- conditioning_media_paths=media_paths, conditioning_strengths=strengths,
165
- conditioning_start_frames=start_frames, height=height, width=width,
166
- num_frames=video_total_frames, padding=padding_vals, pipeline=worker_to_use.pipeline,
167
- )
168
 
169
- for item in conditioning_items:
170
- item.media_item = item.media_item.to(target_device)
171
-
172
  kwargs = {
173
  "prompt": motion_prompt,
174
  "negative_prompt": "blurry, distorted, bad quality, artifacts",
@@ -188,26 +36,30 @@ class LtxPoolManager:
188
  "vae_per_channel_normalize": True,
189
  "mixed_precision": (worker_to_use.config.get("precision") == "mixed_precision"),
190
  "enhance_prompt": False,
191
- "num_inference_steps": int(num_inference_steps),
192
  }
193
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  progress(0.1, desc=f"[Câmera LTX em {worker_to_use.device}] Filmando Cena {current_fragment_index}...")
195
  result_tensor = worker_to_use.generate_video_fragment_internal(**kwargs).images
196
 
197
- pad_l, pad_r, pad_t, pad_b = map(int, padding_vals)
198
- slice_h = -pad_b if pad_b > 0 else None
199
- slice_w = -pad_r if pad_r > 0 else None
200
- cropped_tensor = result_tensor[:, :, :video_total_frames, pad_t:slice_h, pad_l:slice_w]
201
- video_np = (cropped_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy() * 255).astype(np.uint8)
202
-
203
- with imageio.get_writer(output_path, fps=video_fps, codec='libx264', quality=8) as writer:
204
- for frame in video_np:
205
- writer.append_data(frame)
206
 
207
  return output_path, video_total_frames
208
 
209
  finally:
210
- if use_attention_slicing and worker_to_use and worker_to_use.pipeline:
211
- worker_to_use.pipeline.disable_attention_slicing()
212
-
213
- ltx_manager_singleton = LtxPoolManager(device_ids=['cuda:2', 'cuda:3'])
 
1
  # ltx_manager_helpers.py
2
+ # ... (imports e outras classes) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  class LtxPoolManager:
5
+ # ... (__init__, _cleanup_worker) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def generate_video_fragment(
7
  self,
8
  motion_prompt: str, conditioning_items_data: list,
 
13
  ):
14
  worker_to_use = None
15
  try:
16
+ # ... (lógica de seleção e limpeza de worker) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # ... (lógica de preparação de padding e conditioning_items) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
 
 
 
20
  kwargs = {
21
  "prompt": motion_prompt,
22
  "negative_prompt": "blurry, distorted, bad quality, artifacts",
 
36
  "vae_per_channel_normalize": True,
37
  "mixed_precision": (worker_to_use.config.get("precision") == "mixed_precision"),
38
  "enhance_prompt": False,
 
39
  }
40
+
41
+ # --- CORREÇÃO AQUI ---
42
+ # Verifica se o config do modelo especifica uma lista de timesteps.
43
+ # Se sim, usa essa lista. Se não, usa o num_inference_steps da UI.
44
+ first_pass_config = worker_to_use.config.get("first_pass", {})
45
+ if "timesteps" in first_pass_config:
46
+ print("Usando timesteps customizados do arquivo de configuração para o modelo distilled.")
47
+ kwargs["timesteps"] = first_pass_config["timesteps"]
48
+ # Quando usamos timesteps customizados, o num_inference_steps é inferido
49
+ # a partir do tamanho da lista de timesteps.
50
+ kwargs["num_inference_steps"] = len(first_pass_config["timesteps"])
51
+ else:
52
+ # Comportamento antigo para modelos não-destilados
53
+ print(f"Usando num_inference_steps da UI: {num_inference_steps}")
54
+ kwargs["num_inference_steps"] = int(num_inference_steps)
55
+ # --- FIM DA CORREÇÃO ---
56
+
57
  progress(0.1, desc=f"[Câmera LTX em {worker_to_use.device}] Filmando Cena {current_fragment_index}...")
58
  result_tensor = worker_to_use.generate_video_fragment_internal(**kwargs).images
59
 
60
+ # ... (lógica de cropping e salvamento) ...
 
 
 
 
 
 
 
 
61
 
62
  return output_path, video_total_frames
63
 
64
  finally:
65
+ # ...