CARLEXsX commited on
Commit
25b989e
·
verified ·
1 Parent(s): 71d6826

Update ltx_manager_helpers.py

Browse files
Files changed (1) hide show
  1. ltx_manager_helpers.py +171 -12
ltx_manager_helpers.py CHANGED
@@ -1,8 +1,125 @@
1
  # ltx_manager_helpers.py
2
- # ... (imports e outras classes) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  class LtxPoolManager:
5
- # ... (__init__, _cleanup_worker) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def generate_video_fragment(
7
  self,
8
  motion_prompt: str, conditioning_items_data: list,
@@ -13,10 +130,44 @@ class LtxPoolManager:
13
  ):
14
  worker_to_use = None
15
  try:
16
- # ... (lógica de seleção e limpeza de worker) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # ... (lógica de preparação de padding e conditioning_items) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
 
 
 
20
  kwargs = {
21
  "prompt": motion_prompt,
22
  "negative_prompt": "blurry, distorted, bad quality, artifacts",
@@ -38,28 +189,36 @@ class LtxPoolManager:
38
  "enhance_prompt": False,
39
  }
40
 
41
- # --- CORREÇÃO AQUI ---
42
  # Verifica se o config do modelo especifica uma lista de timesteps.
43
  # Se sim, usa essa lista. Se não, usa o num_inference_steps da UI.
44
  first_pass_config = worker_to_use.config.get("first_pass", {})
45
  if "timesteps" in first_pass_config:
46
  print("Usando timesteps customizados do arquivo de configuração para o modelo distilled.")
47
  kwargs["timesteps"] = first_pass_config["timesteps"]
48
- # Quando usamos timesteps customizados, o num_inference_steps é inferido
49
- # a partir do tamanho da lista de timesteps.
50
  kwargs["num_inference_steps"] = len(first_pass_config["timesteps"])
 
 
51
  else:
52
- # Comportamento antigo para modelos não-destilados
53
  print(f"Usando num_inference_steps da UI: {num_inference_steps}")
54
  kwargs["num_inference_steps"] = int(num_inference_steps)
55
- # --- FIM DA CORREÇÃO ---
56
-
57
  progress(0.1, desc=f"[Câmera LTX em {worker_to_use.device}] Filmando Cena {current_fragment_index}...")
58
  result_tensor = worker_to_use.generate_video_fragment_internal(**kwargs).images
59
 
60
- # ... (lógica de cropping e salvamento) ...
 
 
 
 
 
 
 
 
61
 
62
  return output_path, video_total_frames
63
 
64
  finally:
65
- # ...
 
 
 
 
1
  # ltx_manager_helpers.py
2
+ # Gerente de Pool de Workers LTX para revezamento assíncrono em múltiplas GPUs.
3
+ # Este arquivo é parte do projeto Euia-AducSdr e está sob a licença AGPL v3.
4
+ # Copyright (C) 4 de Agosto de 2025 Carlos Rodrigues dos Santos
5
+
6
+ import torch
7
+ import gc
8
+ import os
9
+ import yaml
10
+ import numpy as np
11
+ import imageio
12
+ from pathlib import Path
13
+ import huggingface_hub
14
+ import threading
15
+ from PIL import Image
16
+
17
+ # Importa as funções e classes necessárias do inference.py
18
+ from inference import (
19
+ create_ltx_video_pipeline,
20
+ create_latent_upsampler,
21
+ ConditioningItem,
22
+ calculate_padding,
23
+ prepare_conditioning
24
+ )
25
+ from ltx_video.pipelines.pipeline_ltx_video import LTXMultiScalePipeline
26
+
27
+ class LtxWorker:
28
+ """
29
+ Representa uma única instância do pipeline LTX, associada a uma GPU específica.
30
+ O pipeline é carregado na CPU por padrão e movido para a GPU sob demanda.
31
+ """
32
+ def __init__(self, device_id='cuda:0'):
33
+ self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
34
+ print(f"LTX Worker: Inicializando para o dispositivo {self.device} (carregando na CPU)...")
35
+
36
+ config_file_path = "configs/ltxv-13b-0.9.8-distilled.yaml"
37
+ with open(config_file_path, "r") as file:
38
+ self.config = yaml.safe_load(file)
39
+
40
+ LTX_REPO = "Lightricks/LTX-Video"
41
+ models_dir = "downloaded_models_gradio"
42
+
43
+ model_actual_path = huggingface_hub.hf_hub_download(
44
+ repo_id=LTX_REPO,
45
+ filename=self.config["checkpoint_path"],
46
+ local_dir=models_dir,
47
+ local_dir_use_symlinks=False
48
+ )
49
+
50
+ self.pipeline = create_ltx_video_pipeline(
51
+ ckpt_path=model_actual_path,
52
+ precision=self.config["precision"],
53
+ text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
54
+ sampler=self.config["sampler"],
55
+ device='cpu'
56
+ )
57
+
58
+ print(f"LTX Worker para {self.device}: Compilando o transformer (isso pode levar um momento)...")
59
+ self.pipeline.transformer.to(memory_format=torch.channels_last)
60
+ try:
61
+ self.pipeline.transformer = torch.compile(self.pipeline.transformer, mode="reduce-overhead", fullgraph=True)
62
+ print(f"LTX Worker para {self.device}: Transformer compilado com sucesso.")
63
+ except Exception as e:
64
+ print(f"AVISO: A compilação do Transformer falhou em {self.device}: {e}. Continuando sem compilação.")
65
+
66
+ self.latent_upsampler = None
67
+ if self.config.get("pipeline_type") == "multi-scale":
68
+ print(f"LTX Worker para {self.device}: Carregando Latent Upsampler (Multi-Scale)...")
69
+ upscaler_path = huggingface_hub.hf_hub_download(
70
+ repo_id=LTX_REPO,
71
+ filename=self.config["spatial_upscaler_model_path"],
72
+ local_dir=models_dir,
73
+ local_dir_use_symlinks=False
74
+ )
75
+ self.latent_upsampler = create_latent_upsampler(upscaler_path, 'cpu')
76
+
77
+ print(f"LTX Worker para {self.device} pronto na CPU.")
78
+
79
+ def to_gpu(self):
80
+ """Move o pipeline e o upsampler para a GPU designada."""
81
+ if self.device.type == 'cpu': return
82
+ print(f"LTX Worker: Movendo pipeline para {self.device}...")
83
+ self.pipeline.to(self.device)
84
+ if self.latent_upsampler:
85
+ print(f"LTX Worker: Movendo Latent Upsampler para {self.device}...")
86
+ self.latent_upsampler.to(self.device)
87
+ print(f"LTX Worker: Pipeline na GPU {self.device}.")
88
+
89
+ def to_cpu(self):
90
+ """Move o pipeline de volta para a CPU e limpa a memória da GPU."""
91
+ if self.device.type == 'cpu': return
92
+ print(f"LTX Worker: Descarregando pipeline da GPU {self.device}...")
93
+ self.pipeline.to('cpu')
94
+ if self.latent_upsampler:
95
+ self.latent_upsampler.to('cpu')
96
+ gc.collect()
97
+ if torch.cuda.is_available():
98
+ torch.cuda.empty_cache()
99
+ print(f"LTX Worker: GPU {self.device} limpa.")
100
+
101
+ def generate_video_fragment_internal(self, **kwargs):
102
+ """A lógica real da geração de vídeo, que espera estar na GPU."""
103
+ return self.pipeline(**kwargs)
104
 
105
  class LtxPoolManager:
106
+ """
107
+ Gerencia um pool de LtxWorkers, orquestrando um revezamento entre GPUs
108
+ para permitir que a limpeza de uma GPU ocorra em paralelo com a computação em outra.
109
+ """
110
+ def __init__(self, device_ids=['cuda:2', 'cuda:3']):
111
+ print(f"LTX POOL MANAGER: Criando workers para os dispositivos: {device_ids}")
112
+ self.workers = [LtxWorker(device_id) for device_id in device_ids]
113
+ self.current_worker_index = 0
114
+ self.lock = threading.Lock()
115
+ self.last_cleanup_thread = None
116
+
117
+ def _cleanup_worker(self, worker):
118
+ """Função alvo para a thread de limpeza."""
119
+ print(f"CLEANUP THREAD: Iniciando limpeza da GPU {worker.device} em background...")
120
+ worker.to_cpu()
121
+ print(f"CLEANUP THREAD: Limpeza da GPU {worker.device} concluída.")
122
+
123
  def generate_video_fragment(
124
  self,
125
  motion_prompt: str, conditioning_items_data: list,
 
130
  ):
131
  worker_to_use = None
132
  try:
133
+ with self.lock:
134
+ if self.last_cleanup_thread and self.last_cleanup_thread.is_alive():
135
+ print("LTX POOL MANAGER: Aguardando limpeza da GPU anterior...")
136
+ self.last_cleanup_thread.join()
137
+
138
+ worker_to_use = self.workers[self.current_worker_index]
139
+ previous_worker_index = (self.current_worker_index - 1 + len(self.workers)) % len(self.workers)
140
+ worker_to_cleanup = self.workers[previous_worker_index]
141
+
142
+ cleanup_thread = threading.Thread(target=self._cleanup_worker, args=(worker_to_cleanup,))
143
+ cleanup_thread.start()
144
+ self.last_cleanup_thread = cleanup_thread
145
+
146
+ worker_to_use.to_gpu()
147
+
148
+ self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
149
+
150
+ target_device = worker_to_use.device
151
 
152
+ if use_attention_slicing:
153
+ worker_to_use.pipeline.enable_attention_slicing()
154
+
155
+ media_paths = [item[0] for item in conditioning_items_data]
156
+ start_frames = [item[1] for item in conditioning_items_data]
157
+ strengths = [item[2] for item in conditioning_items_data]
158
+
159
+ padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
160
+ padding_vals = calculate_padding(height, width, padded_h, padded_w)
161
+
162
+ conditioning_items = prepare_conditioning(
163
+ conditioning_media_paths=media_paths, conditioning_strengths=strengths,
164
+ conditioning_start_frames=start_frames, height=height, width=width,
165
+ num_frames=video_total_frames, padding=padding_vals, pipeline=worker_to_use.pipeline,
166
+ )
167
 
168
+ for item in conditioning_items:
169
+ item.media_item = item.media_item.to(target_device)
170
+
171
  kwargs = {
172
  "prompt": motion_prompt,
173
  "negative_prompt": "blurry, distorted, bad quality, artifacts",
 
189
  "enhance_prompt": False,
190
  }
191
 
 
192
  # Verifica se o config do modelo especifica uma lista de timesteps.
193
  # Se sim, usa essa lista. Se não, usa o num_inference_steps da UI.
194
  first_pass_config = worker_to_use.config.get("first_pass", {})
195
  if "timesteps" in first_pass_config:
196
  print("Usando timesteps customizados do arquivo de configuração para o modelo distilled.")
197
  kwargs["timesteps"] = first_pass_config["timesteps"]
 
 
198
  kwargs["num_inference_steps"] = len(first_pass_config["timesteps"])
199
+ # Para modelos distilled, a UI de steps é ignorada, mas outros params do config são usados
200
+ kwargs.update({k: v for k, v in first_pass_config.items() if k != "timesteps"})
201
  else:
 
202
  print(f"Usando num_inference_steps da UI: {num_inference_steps}")
203
  kwargs["num_inference_steps"] = int(num_inference_steps)
204
+
 
205
  progress(0.1, desc=f"[Câmera LTX em {worker_to_use.device}] Filmando Cena {current_fragment_index}...")
206
  result_tensor = worker_to_use.generate_video_fragment_internal(**kwargs).images
207
 
208
+ pad_l, pad_r, pad_t, pad_b = map(int, padding_vals)
209
+ slice_h = -pad_b if pad_b > 0 else None
210
+ slice_w = -pad_r if pad_r > 0 else None
211
+ cropped_tensor = result_tensor[:, :, :video_total_frames, pad_t:slice_h, pad_l:slice_w]
212
+ video_np = (cropped_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy() * 255).astype(np.uint8)
213
+
214
+ with imageio.get_writer(output_path, fps=video_fps, codec='libx264', quality=8) as writer:
215
+ for frame in video_np:
216
+ writer.append_data(frame)
217
 
218
  return output_path, video_total_frames
219
 
220
  finally:
221
+ if use_attention_slicing and worker_to_use and worker_to_use.pipeline:
222
+ worker_to_use.pipeline.disable_attention_slicing()
223
+
224
+ ltx_manager_singleton = LtxPoolManager(device_ids=['cuda:2', 'cuda:3'])