diff --git a/README.md b/README.md index b4d828b8497c7bc7137a1fa17a90faafc661d56b..d012f8b356bb33e4f34cf1fa8d5813d7a7931814 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,13 @@ --- -title: DreamO -emoji: 🐨 -colorFrom: purple -colorTo: yellow +title: 🇧🇷 LTX Video MegaUltraGigaFast 🇧🇷 +emoji: 🎥 +colorFrom: yellow +colorTo: pink sdk: gradio -sdk_version: 5.29.0 +sdk_version: 5.39.0 app_file: app.py pinned: false -license: apache-2.0 -short_description: A Unified Framework for Image Customization +short_description: ultra-fast video model, LTX 0.9.8 13B distilled --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file diff --git a/app-old.py b/app-old.py new file mode 100644 index 0000000000000000000000000000000000000000..0597d5294aaa93e8ff045276bf074c65067a221a --- /dev/null +++ b/app-old.py @@ -0,0 +1,215 @@ +# --- app.py (O Painel de Controle do Maestro - Produção em Lote com Diário de Bordo) --- +# By Carlex & Gemini + +# --- Ato 1: A Convocação da Orquestra (Importações) --- +import gradio as gr +import torch +import spaces +import os +import yaml +from PIL import Image +import shutil +import gc +import traceback +import subprocess +import math +import google.generativeai as genai +import numpy as np +import imageio +import tempfile +from pathlib import Path +from huggingface_hub import hf_hub_download +import json +from facexlib.utils.face_restoration_helper import FaceRestoreHelper +import huggingface_hub +import spaces +import argparse + +import spaces +import argparse + + +import cv2 + +from facexlib.utils.face_restoration_helper import FaceRestoreHelper +import huggingface_hub + + + +from dreamo.dreamo_pipeline import DreamOPipeline +from dreamo.utils import img2tensor, resize_numpy_image_area, tensor2img, resize_numpy_image_long +from tools import BEN2 + + +# --- Músicos Originais (Sua implementação) --- +from inference import create_ltx_video_pipeline, load_image_to_tensor_with_resize_and_crop, seed_everething, calculate_padding +from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem + +# --- Ato 2: A Preparação do Palco (Configurações) --- +config_file_path = "configs/ltxv-13b-0.9.8-distilled.yaml" +with open(config_file_path, "r") as file: + PIPELINE_CONFIG_YAML = yaml.safe_load(file) + +# --- Constantes Globais --- +LTX_REPO = "Lightricks/LTX-Video" +models_dir = "downloaded_models_gradio_cpu_init" +Path(models_dir).mkdir(parents=True, exist_ok=True) +WORKSPACE_DIR = "aduc_workspace" +GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") + +# --- Carregamento de Modelos LTX na CPU --- +print("Baixando e criando pipelines LTX na CPU...") +distilled_model_actual_path = hf_hub_download(repo_id=LTX_REPO, filename=PIPELINE_CONFIG_YAML["checkpoint_path"], local_dir=models_dir, local_dir_use_symlinks=False) +pipeline_instance = create_ltx_video_pipeline(ckpt_path=distilled_model_actual_path, precision=PIPELINE_CONFIG_YAML["precision"], text_encoder_model_name_or_path=PIPELINE_CONFIG_YAML["text_encoder_model_name_or_path"], sampler=PIPELINE_CONFIG_YAML["sampler"], device="cpu") +print("Modelos LTX prontos.") + + +# --- Ato 3: As Partituras dos Músicos (Funções) --- + +def get_storyboard_from_director_v2(num_fragments: int, prompt: str, initial_image_path: str, progress=gr.Progress()): + progress(0.5, desc="[Diretor Gemini] Criando o storyboard completo...") + if not initial_image_path: raise gr.Error("Por favor, forneça uma imagem de referência inicial.") + if not GEMINI_API_KEY: raise gr.Error("Chave da API Gemini (GEMINI_API_KEY) não configurada!") + genai.configure(api_key=GEMINI_API_KEY) + try: + with open("prompts/director_storyboard_v2.txt", "r", encoding="utf-8") as f: template = f.read() + except FileNotFoundError: raise gr.Error("'prompts/director_storyboard_v2.txt' não encontrado!") + director_prompt = template.format(user_prompt=prompt, num_fragments=int(num_fragments)) + model = genai.GenerativeModel('gemini-2.0-flash') + img = Image.open(initial_image_path) + response = model.generate_content([director_prompt, img]) + try: + cleaned_response = response.text.strip().replace("```json", "").replace("```", "") + storyboard_data = json.loads(cleaned_response) + storyboard_list = storyboard_data.get("storyboard", []) + if not storyboard_list: raise gr.Error("A IA não retornou um storyboard válido.") + return storyboard_list + except (json.JSONDecodeError, KeyError, TypeError) as e: + raise gr.Error(f"O Diretor retornou uma resposta inesperada. Erro: {e}\nResposta Bruta: {response.text}") + +def run_ltx_animation(current_fragment_index, motion_prompt, input_frame_path, height, width, fps, seed, cfg, progress=gr.Progress()): + progress(0, desc=f"[Animador LTX] Aquecendo para a Cena {current_fragment_index}...") + target_device = "cuda"; output_path = os.path.join(WORKSPACE_DIR, f"fragment_{current_fragment_index}.mp4") + try: + pipeline_instance.to(target_device) + duration_fragment, target_frames_ideal = 3.0, 3.0 * fps + n_val = round((float(round(target_frames_ideal)) - 1.0) / 8.0); actual_num_frames = max(9, min(int(n_val * 8 + 1), 257)) + num_frames_padded = ((actual_num_frames - 2) // 8 + 1) * 8 + 1 + padded_h, padded_w = ((int(height) - 1) // 32 + 1) * 32, ((int(width) - 1) // 32 + 1) * 32 + padding_vals = calculate_padding(int(height), int(width), padded_h, padded_w) + timesteps = PIPELINE_CONFIG_YAML.get("first_pass", {}).get("timesteps") + kwargs = {"prompt": motion_prompt, "negative_prompt": "blurry, distorted", "height": padded_h, "width": padded_w, "num_frames": num_frames_padded, "frame_rate": int(fps), "generator": torch.Generator(device=target_device).manual_seed(int(seed) + current_fragment_index), "output_type": "pt", "guidance_scale": float(cfg), "timesteps": timesteps, "vae_per_channel_normalize": True, "decode_timestep": PIPELINE_CONFIG_YAML["decode_timestep"], "decode_noise_scale": PIPELINE_CONFIG_YAML["decode_noise_scale"], "stochastic_sampling": PIPELINE_CONFIG_YAML["stochastic_sampling"], "image_cond_noise_scale": 0.15, "is_video": True, "mixed_precision": (PIPELINE_CONFIG_YAML["precision"] == "mixed_precision"), "offload_to_cpu": False, "enhance_prompt": False} + media_tensor = load_image_to_tensor_with_resize_and_crop(input_frame_path, int(height), int(width)); media_tensor = torch.nn.functional.pad(media_tensor, padding_vals); kwargs["conditioning_items"] = [ConditioningItem(media_tensor.to(target_device), 0, 1.0)] + result_tensor = pipeline_instance(**kwargs).images + pad_l, pad_r, pad_t, pad_b = padding_vals; slice_h, slice_w = (-pad_b if pad_b > 0 else None), (-pad_r if pad_r > 0 else None) + cropped_tensor = result_tensor[:, :, :actual_num_frames, pad_t:slice_h, pad_l:slice_w]; video_np = (cropped_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy() * 255).astype(np.uint8) + with imageio.get_writer(output_path, fps=int(fps), codec='libx264', quality=8) as writer: + for i, frame in enumerate(video_np): progress(i / len(video_np), desc=f"Renderizando frame {i+1}/{len(video_np)}..."); writer.append_data(frame) + return output_path + finally: + pipeline_instance.to("cpu"); gc.collect(); torch.cuda.empty_cache() + +def concatenate_masterpiece(fragment_paths: list, progress=gr.Progress()): + progress(0.5, desc="Montando a obra-prima final..."); list_file_path, final_output_path = os.path.join(WORKSPACE_DIR, "concat_list.txt"), os.path.join(WORKSPACE_DIR, "obra_prima_final.mp4") + with open(list_file_path, "w") as f: + for path in fragment_paths: f.write(f"file '{os.path.abspath(path)}'\n") + command = f"ffmpeg -y -f concat -safe 0 -i {list_file_path} -c copy {final_output_path}" + try: + subprocess.run(command, shell=True, check=True, capture_output=True, text=True); return final_output_path + except subprocess.CalledProcessError as e: + raise gr.Error(f"FFmpeg falhou ao unir os vídeos: {e.stderr}") + +def run_full_production(storyboard, ref_img_path, height, width, fps, seed, cfg): + if not storyboard: raise gr.Error("Nenhum roteiro para produzir.") + if not ref_img_path: raise gr.Error("Nenhuma imagem de referência definida.") + video_fragments, log_history = [], "" + for i, motion_prompt in enumerate(storyboard): + log_message = f"Iniciando produção da Cena {i+1}/{len(storyboard)}..." + log_history += log_message + "\n" + yield {production_log_output: gr.update(value=log_history)} + fragment_path = run_ltx_animation(i + 1, motion_prompt, ref_img_path, height, width, fps, seed, cfg, gr.Progress()) + video_fragments.append(fragment_path) + log_message = f"Cena {i+1} concluída e salva em {os.path.basename(fragment_path)}." + log_history += log_message + "\n" + yield {production_log_output: gr.update(value=log_history), fragment_gallery_output: gr.update(value=video_fragments), fragment_list_state: video_fragments, final_fragments_display: gr.update(value=video_fragments)} + log_history += "\nProdução de todas as cenas concluída!" + yield {production_log_output: gr.update(value=log_history)} + +# --- Ato 4: A Apresentação (UI do Gradio) --- +with gr.Blocks(theme=gr.themes.Soft()) as demo: + gr.Markdown("# LTX Video - Storyboard em Vídeo (ADUC-SDR)\n*By Carlex & Gemini*") + + storyboard_state = gr.State([]) + reference_image_state = gr.State("") + fragment_list_state = gr.State([]) + + if os.path.exists(WORKSPACE_DIR): shutil.rmtree(WORKSPACE_DIR) + os.makedirs(WORKSPACE_DIR) + + with gr.Tabs(): + with gr.TabItem("ETAPA 1: O DIRETOR (Roteiro Visual)"): + with gr.Row(): + with gr.Column(): + num_fragments_input = gr.Slider(2, 10, 4, step=1, label="Número de Cenas (Fragmentos)") + prompt_input = gr.Textbox(label="Ideia Geral (Prompt)") + image_input = gr.Image(type="filepath", label="Imagem de Referência") + director_button = gr.Button("▶️ Gerar Roteiro Visual (Gemini)", variant="primary") + with gr.Column(): + storyboard_output = gr.JSON(label="Roteiro Visual Gerado (Storyboard)") + + with gr.TabItem("ETAPA 2: A PRODUÇÃO (Gerar Cenas em Vídeo)"): + with gr.Row(): + with gr.Column(): + storyboard_to_render = gr.JSON(label="Roteiro a ser Produzido") + animator_button = gr.Button("▶️ Produzir TODAS as Cenas (LTX)", variant="primary") + production_log_output = gr.Textbox(label="Diário de Bordo da Produção", lines=5, interactive=False, placeholder="Aguardando início da produção...") + with gr.Column(): + fragment_gallery_output = gr.Gallery(label="Cenas Produzidas (Fragmentos de Vídeo)", object_fit="contain", height="auto") + with gr.Row(): + height_slider = gr.Slider(256, 1024, 512, step=32, label="Altura") + width_slider = gr.Slider(256, 1024, 512, step=32, label="Largura") + with gr.Row(): + fps_slider = gr.Slider(8, 24, 15, step=1, label="FPS") + seed_number = gr.Number(42, label="Seed") + cfg_slider = gr.Slider(1.0, 10.0, 2.5, step=0.1, label="CFG") + + with gr.TabItem("ETAPA 3: PÓS-PRODUÇÃO"): + with gr.Row(): + with gr.Column(): + final_fragments_display = gr.JSON(label="Vídeos a Concatenar") + editor_button = gr.Button("▶️ Concatenar Tudo (FFmpeg)", variant="primary") + with gr.Column(): + final_video_output = gr.Video(label="A Obra-Prima Final") + + # --- Ato 5: A Regência (Lógica de Conexão dos Botões) --- + + def director_success(img_path, storyboard_json): + if not img_path: raise gr.Error("A imagem de referência é necessária.") + storyboard_list = storyboard_json if isinstance(storyboard_json, list) else storyboard_json.get("storyboard", []) + if not storyboard_list: raise gr.Error("O storyboard está vazio.") + return storyboard_list, img_path, gr.update(value=storyboard_json) + + director_button.click( + fn=get_storyboard_from_director_v2, + inputs=[num_fragments_input, prompt_input, image_input], + outputs=[storyboard_output] + ).success( + fn=director_success, + inputs=[image_input, storyboard_output], + outputs=[storyboard_state, reference_image_state, storyboard_to_render] + ) + + animator_button.click( + fn=run_full_production, + inputs=[storyboard_state, reference_image_state, height_slider, width_slider, fps_slider, seed_number, cfg_slider], + outputs=[production_log_output, fragment_gallery_output, fragment_list_state, final_fragments_display] + ) + + editor_button.click( + fn=concatenate_masterpiece, + inputs=[fragment_list_state], + outputs=[final_video_output] + ) + +if __name__ == "__main__": + demo.queue().launch(server_name="0.0.0.0", share=True) \ No newline at end of file diff --git a/app-v1.py b/app-v1.py new file mode 100644 index 0000000000000000000000000000000000000000..6ebc39f9ea63a70e014155df8db0bd148f36b4fb --- /dev/null +++ b/app-v1.py @@ -0,0 +1,301 @@ +# --- app.py (O Painel de Controle do Maestro - Depuração Focada) --- +# By Carlex & Gemini & DreamO + +# ... (importações e inicializações inalteradas) ... +import gradio as gr +import torch +import os +import yaml +from PIL import Image +import shutil +import gc +import subprocess +import math +import google.generativeai as genai +import numpy as np +import imageio +from pathlib import Path +import huggingface_hub +import json + +from inference import create_ltx_video_pipeline, load_image_to_tensor_with_resize_and_crop, seed_everething, calculate_padding +from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem +from dreamo_helpers import dreamo_generator_singleton + +# ... (configurações e constantes inalteradas) ... +config_file_path = "configs/ltxv-13b-0.9.8-distilled.yaml" +with open(config_file_path, "r") as file: + PIPELINE_CONFIG_YAML = yaml.safe_load(file) + +LTX_REPO = "Lightricks/LTX-Video" +models_dir = "downloaded_models_gradio_cpu_init" +Path(models_dir).mkdir(parents=True, exist_ok=True) +WORKSPACE_DIR = "aduc_workspace" +GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") + +VIDEO_WIDTH = 720 +VIDEO_HEIGHT = 720 +VIDEO_FPS = 24 +VIDEO_DURATION_SECONDS = 4 +VIDEO_TOTAL_FRAMES = VIDEO_DURATION_SECONDS * VIDEO_FPS + +print("Baixando e criando pipelines LTX na CPU...") +distilled_model_actual_path = huggingface_hub.hf_hub_download(repo_id=LTX_REPO, filename=PIPELINE_CONFIG_YAML["checkpoint_path"], local_dir=models_dir, local_dir_use_symlinks=False) +pipeline_instance = create_ltx_video_pipeline(ckpt_path=distilled_model_actual_path, precision=PIPELINE_CONFIG_YAML["precision"], text_encoder_model_name_or_path=PIPELINE_CONFIG_YAML["text_encoder_model_name_or_path"], sampler=PIPELINE_CONFIG_YAML["sampler"], device='cpu') +print("Modelos LTX prontos (na CPU).") + +# --- Ato 3: As Partituras dos Músicos (Funções) --- + +# ... (get_storyboard_from_director e run_keyframe_generation inalterados) ... +def get_storyboard_from_director(num_fragments: int, prompt: str, initial_image_path: str, progress=gr.Progress()): + progress(0.5, desc="[Diretor Gemini] Criando o storyboard...") + if not initial_image_path: raise gr.Error("Por favor, forneça uma imagem de referência inicial.") + if not GEMINI_API_KEY: raise gr.Error("Chave da API Gemini não configurada!") + genai.configure(api_key=GEMINI_API_KEY) + try: + script_dir = os.path.dirname(os.path.abspath(__file__)) + prompt_file_path = os.path.join(script_dir, "prompts", "director_storyboard_v2.txt") + with open(prompt_file_path, "r", encoding="utf-8") as f: template = f.read() + except FileNotFoundError: raise gr.Error(f"Arquivo de prompt não encontrado em '{prompt_file_path}'!") + director_prompt = template.format(user_prompt=prompt, num_fragments=int(num_fragments)) + model = genai.GenerativeModel('gemini-2.5-flash') + img = Image.open(initial_image_path) + response = model.generate_content([director_prompt, img]) + try: + cleaned_response = response.text.strip().replace("```json", "").replace("```", "") + if not cleaned_response: raise ValueError("A resposta do Gemini estava vazia após a limpeza.") + storyboard_data = json.loads(cleaned_response) + return storyboard_data.get("storyboard", []) + except (json.JSONDecodeError, ValueError) as e: + raise gr.Error(f"O Diretor retornou uma resposta inválida. Erro: {e}. Resposta Bruta: '{response.text}'") + +def run_keyframe_generation(storyboard, ref_img_path_1, ref_img_path_2, ref_task_1, ref_task_2): + if not storyboard: raise gr.Error("Nenhum roteiro para gerar imagens-chave.") + if not ref_img_path_1: raise gr.Error("A Referência 1 é obrigatória.") + + with Image.open(ref_img_path_1) as img: + width, height = img.size + width = (width // 32) * 32 + height = (height // 32) * 32 + + keyframe_paths, log_history = [], "" + try: + dreamo_generator_singleton.to_gpu() + for i, prompt in enumerate(storyboard): + log_message = f"Pintando Cena {i+1}/{len(storyboard)} com DreamO ({width}x{height})..." + log_history += log_message + "\n" + yield {keyframe_log_output: gr.update(value=log_history)} + output_path = os.path.join(WORKSPACE_DIR, f"keyframe_image_{i+1}.png") + image = dreamo_generator_singleton.generate_image_with_gpu_management( + ref_image1_np=np.array(Image.open(ref_img_path_1).convert("RGB")) if ref_img_path_1 else None, + ref_image2_np=np.array(Image.open(ref_img_path_2).convert("RGB")) if ref_img_path_2 else None, + ref_task1=ref_task_1, ref_task2=ref_task_2, + prompt=prompt, width=width, height=height + ) + image.save(output_path) + keyframe_paths.append(output_path) + log_message = f"Cena {i+1} pintada." + log_history += log_message + "\n" + yield {keyframe_log_output: gr.update(value=log_history), keyframe_gallery_output: gr.update(value=keyframe_paths), keyframe_images_state: keyframe_paths} + finally: + dreamo_generator_singleton.to_cpu() + + log_history += "\nPintura de todas as cenas concluída!" + yield {keyframe_log_output: gr.update(value=log_history)} + +def run_ltx_animation(current_fragment_index, motion_prompt, conditioning_items_data, width, height, seed, cfg, progress=gr.Progress()): + # ... (código inalterado) + progress(0, desc=f"[Animador LTX] Gerando Cena {current_fragment_index}...") + output_path = os.path.join(WORKSPACE_DIR, f"fragment_{current_fragment_index}.mp4") + target_device = 'cuda' if torch.cuda.is_available() else 'cpu' + try: + pipeline_instance.to(target_device) + conditioning_items = [] + for (path, start_frame, strength) in conditioning_items_data: + tensor = load_image_to_tensor_with_resize_and_crop(path, height, width) + conditioning_items.append(ConditioningItem(tensor.to(target_device), start_frame, strength)) + n_val = round((float(VIDEO_TOTAL_FRAMES) - 1.0) / 8.0) + actual_num_frames = int(n_val * 8 + 1) + padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32 + padding_vals = calculate_padding(height, width, padded_h, padded_w) + for cond_item in conditioning_items: cond_item.media_item = torch.nn.functional.pad(cond_item.media_item, padding_vals) + timesteps = PIPELINE_CONFIG_YAML.get("first_pass", {}).get("timesteps") + kwargs = {"prompt": motion_prompt, "negative_prompt": "blurry, distorted, bad quality, artifacts", "height": padded_h, "width": padded_w, "num_frames": actual_num_frames, "frame_rate": VIDEO_FPS, "generator": torch.Generator(device=target_device).manual_seed(int(seed) + current_fragment_index), "output_type": "pt", "guidance_scale": float(cfg), "timesteps": timesteps, "conditioning_items": conditioning_items, "vae_per_channel_normalize": True, "decode_timestep": PIPELINE_CONFIG_YAML["decode_timestep"], "decode_noise_scale": PIPELINE_CONFIG_YAML["decode_noise_scale"], "stochastic_sampling": PIPELINE_CONFIG_YAML["stochastic_sampling"], "image_cond_noise_scale": 0.15, "is_video": True, "mixed_precision": (PIPELINE_CONFIG_YAML["precision"] == "mixed_precision"), "offload_to_cpu": False, "enhance_prompt": False} + result_tensor = pipeline_instance(**kwargs).images + pad_l, pad_r, pad_t, pad_b = padding_vals; slice_h, slice_w = (-pad_b if pad_b > 0 else None), (-pad_r if pad_r > 0 else None) + cropped_tensor = result_tensor[:, :, :VIDEO_TOTAL_FRAMES, pad_t:slice_h, pad_l:slice_w]; + video_np = (cropped_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy() * 255).astype(np.uint8) + with imageio.get_writer(output_path, fps=VIDEO_FPS, codec='libx264', quality=8) as writer: + for i, frame in enumerate(video_np): progress(i / len(video_np), desc=f"Renderizando frame {i+1}/{len(video_np)}..."); writer.append_data(frame) + return output_path + finally: + pipeline_instance.to('cpu'); gc.collect() + if torch.cuda.is_available(): torch.cuda.empty_cache() + +# <<<< FUNÇÃO DE PRODUÇÃO SIMPLIFICADA PARA DEPURAÇÃO >>>> +def run_full_video_production(storyboard, keyframe_image_paths, seed, cfg): + if not storyboard or not keyframe_image_paths: raise gr.Error("Roteiro e/ou imagens-chave estão faltando.") + if len(storyboard) != len(keyframe_image_paths): raise gr.Error("A contagem de prompts do roteiro e imagens-chave não coincide.") + + with Image.open(keyframe_image_paths[0]) as img: + width, height = img.size + + video_fragments, log_history = [], "" + num_keyframes = len(keyframe_image_paths) + + n_val = round((float(VIDEO_TOTAL_FRAMES) - 1.0) / 8.0) + actual_num_frames = int(n_val * 8 + 1) + end_frame_index = actual_num_frames - 1 + + for i in range(num_keyframes - 1): + # ... (lógica de interpolação inalterada) + motion_prompt = storyboard[i] + start_image_path = keyframe_image_paths[i] + end_image_path = keyframe_image_paths[i+1] + log_message = f"Preparando Cena de Interpolação {i+1}/{num_keyframes}..." + log_history += log_message + "\n" + yield {video_production_log_output: gr.update(value=log_history), fragment_list_state: video_fragments} + conditioning_items_data = [(start_image_path, 0, 1.0), (end_image_path, end_frame_index, 1.0)] + log_message = f" -> De: {os.path.basename(start_image_path)} | Para: {os.path.basename(end_image_path)}" + log_history += log_message + "\n" + yield {video_production_log_output: gr.update(value=log_history), fragment_list_state: video_fragments} + fragment_path = run_ltx_animation(i + 1, motion_prompt, conditioning_items_data, width, height, seed, cfg) + video_fragments.append(fragment_path) + log_message = f"Cena {i+1} concluída." + log_history += log_message + "\n" + yield {video_production_log_output: gr.update(value=log_history), fragment_list_state: video_fragments} + + if num_keyframes > 0: + # ... (lógica da cena final inalterada) + last_scene_index = num_keyframes - 1 + last_motion_prompt = storyboard[last_scene_index] + last_image_path = keyframe_image_paths[last_scene_index] + log_message = f"Preparando Cena Final (Animação Livre) {num_keyframes}/{num_keyframes}..." + log_history += log_message + "\n" + yield {video_production_log_output: gr.update(value=log_history), fragment_list_state: video_fragments} + conditioning_items_data = [(last_image_path, 0, 1.0)] + log_message = f" -> Ponto de Partida: {os.path.basename(last_image_path)}" + log_history += log_message + "\n" + yield {video_production_log_output: gr.update(value=log_history), fragment_list_state: video_fragments} + fragment_path = run_ltx_animation(last_scene_index + 1, last_motion_prompt, conditioning_items_data, width, height, seed, cfg) + video_fragments.append(fragment_path) + log_message = f"Cena Final concluída." + log_history += log_message + "\n" + yield {video_production_log_output: gr.update(value=log_history), fragment_list_state: video_fragments} + + log_history += "\nProdução de todas as cenas de vídeo concluída!" + yield {video_production_log_output: gr.update(value=log_history), fragment_list_state: video_fragments} + +def concatenate_masterpiece(fragment_paths: list, progress=gr.Progress()): + # ... (código inalterado) + progress(0.5, desc="Montando a obra-prima final..."); list_file_path, final_output_path = os.path.join(WORKSPACE_DIR, "concat_list.txt"), os.path.join(WORKSPACE_DIR, "obra_prima_final.mp4") + with open(list_file_path, "w") as f: + for path in fragment_paths: f.write(f"file '{os.path.abspath(path)}'\n") + command = f"ffmpeg -y -f concat -safe 0 -i {list_file_path} -c copy {final_output_path}" + try: subprocess.run(command, shell=True, check=True, capture_output=True, text=True); return final_output_path + except subprocess.CalledProcessError as e: raise gr.Error(f"FFmpeg falhou ao unir os vídeos: {e.stderr}") + +with gr.Blocks(theme=gr.themes.Soft()) as demo: + # ... (UI inalterada) + gr.Markdown("# LTX Video - Storyboard em Vídeo (ADUC-SDR)\n*By Carlex & Gemini & DreamO*") + storyboard_state = gr.State([]) + keyframe_images_state = gr.State([]) + fragment_list_state = gr.State([]) + if os.path.exists(WORKSPACE_DIR): shutil.rmtree(WORKSPACE_DIR) + os.makedirs(WORKSPACE_DIR) + + with gr.Tabs(): + with gr.TabItem("ETAPA 1: O DIRETOR (Roteiro Visual)"): + # ... (UI inalterada) + with gr.Row(): + with gr.Column(): + num_fragments_input = gr.Slider(2, 10, 4, step=1, label="Número de Cenas") + prompt_input = gr.Textbox(label="Ideia Geral (Prompt)") + image_input = gr.Image(type="filepath", label="Imagem de Referência Principal") + director_button = gr.Button("▶️ 1. Gerar Roteiro Visual", variant="primary") + with gr.Column(): + storyboard_to_show = gr.JSON(label="Roteiro Gerado (para visualização)") + with gr.TabItem("ETAPA 2: O PINTOR (Imagens-Chave)"): + # ... (UI inalterada) + with gr.Row(): + with gr.Column(scale=2): + gr.Markdown("### Controles do Pintor (DreamO)") + with gr.Row(): + ref_image_1_input = gr.Image(label="Referência 1 (Principal)", type="filepath") + ref_image_2_input = gr.Image(label="Referência 2 (Opcional, para composição)", type="filepath") + with gr.Row(): + ref_task_1_input = gr.Dropdown(choices=["ip", "id", "style"], value="ip", label="Tarefa para Referência 1") + ref_task_2_input = gr.Dropdown(choices=["ip", "id", "style"], value="ip", label="Tarefa para Referência 2") + photographer_button = gr.Button("▶️ 2. Pintar Imagens-Chave", variant="primary") + keyframe_log_output = gr.Textbox(label="Diário de Bordo do Pintor", lines=5, interactive=False) + with gr.Column(scale=1): + keyframe_gallery_output = gr.Gallery(label="Imagens-Chave Pintadas", object_fit="contain", height="auto", type="filepath") + with gr.TabItem("ETAPA 3: A PRODUÇÃO (Gerar Cenas em Vídeo)"): + # ... (UI inalterada) + gr.Markdown(f"Gere o vídeo interpolando entre as imagens-chave. A resolução será a mesma da sua imagem de referência. Cada clipe terá **{VIDEO_DURATION_SECONDS} segundos a {VIDEO_FPS} FPS**.") + with gr.Row(): + with gr.Column(): + keyframes_to_render = gr.Gallery(label="Imagens-Chave para Animar", object_fit="contain", height="auto", interactive=False) + animator_button = gr.Button("▶️ 3. Produzir Cenas em Vídeo", variant="primary", interactive=False) + video_production_log_output = gr.Textbox(label="Diário de Bordo da Produção", lines=10, interactive=False) + with gr.Column(): + # <<<< REMOVIDO PARA DEPURAÇÃO >>>> + # fragment_gallery_output = gr.Gallery(label="Cenas Produzidas (Vídeos)", object_fit="contain", height="auto") + gr.Markdown("A galeria de vídeos foi desativada para depuração. Verifique o resultado na Etapa 4.") + with gr.Row(): + seed_number = gr.Number(42, label="Seed") + cfg_slider = gr.Slider(1.0, 10.0, 2.5, step=0.1, label="CFG") + with gr.TabItem("ETAPA 4: PÓS-PRODUÇÃO"): + # ... (UI inalterada) + with gr.Row(): + with gr.Column(): + editor_button = gr.Button("▶️ 4. Concatenar Vídeo Final", variant="primary") + final_fragments_display = gr.JSON(label="Fragmentos a Concatenar") + with gr.Column(): + final_video_output = gr.Video(label="A Obra-Prima Final") + + # --- Ato 5: A Regência (Lógica de Conexão dos Botões) --- + def director_success(storyboard_list, img_path): + # ... (lógica inalterada) + if not storyboard_list: raise gr.Error("O storyboard está vazio ou em formato inválido.") + return {storyboard_state: storyboard_list, storyboard_to_show: gr.update(value=storyboard_list), ref_image_1_input: gr.update(value=img_path)} + + director_button.click( + fn=get_storyboard_from_director, + inputs=[num_fragments_input, prompt_input, image_input], + outputs=[storyboard_state] + ).then( + fn=director_success, + inputs=[storyboard_state, image_input], + outputs=[storyboard_state, storyboard_to_show, ref_image_1_input] + ) + + photographer_button.click( + fn=run_keyframe_generation, + inputs=[storyboard_state, ref_image_1_input, ref_image_2_input, ref_task_1_input, ref_task_2_input], + outputs=[keyframe_log_output, keyframe_gallery_output, keyframe_images_state] + ).then( + lambda paths: {keyframes_to_render: gr.update(value=paths), animator_button: gr.update(interactive=True)}, + inputs=[keyframe_images_state], + outputs=[keyframes_to_render, animator_button] + ) + + # <<<< CHAMADA DE CLICK SIMPLIFICADA PARA DEPURAÇÃO >>>> + animator_button.click( + fn=run_full_video_production, + inputs=[storyboard_state, keyframe_images_state, seed_number, cfg_slider], + outputs=[video_production_log_output, fragment_list_state] + ).then( + lambda paths: gr.update(value=paths), + inputs=[fragment_list_state], + outputs=[final_fragments_display] + ) + + editor_button.click( + fn=concatenate_masterpiece, + inputs=[fragment_list_state], + outputs=[final_video_output] + ) + +if __name__ == "__main__": + demo.queue().launch(server_name="0.0.0.0", share=True) \ No newline at end of file diff --git a/app1.py b/app1.py new file mode 100644 index 0000000000000000000000000000000000000000..e6c19603d941d4ec0c15e12ef03c0479feb5e63a --- /dev/null +++ b/app1.py @@ -0,0 +1,437 @@ +# --- app.py (O Painel de Controle do Maestro - Versão Final Completa) --- +# By Carlex & Gemini & DreamO + +# --- Ato 1: A Convocação da Orquestra (Importações) --- +import gradio as gr +import torch +import os +import yaml +from PIL import Image +import shutil +import gc +import subprocess +import google.generativeai as genai +import numpy as np +import imageio +from pathlib import Path +import huggingface_hub +import json + +from inference import create_ltx_video_pipeline, load_image_to_tensor_with_resize_and_crop, calculate_padding, ConditioningItem +from dreamo_helpers import dreamo_generator_singleton +import ltx_video.pipelines.crf_compressor as crf_compressor + +# --- Ato 2: A Preparação do Palco (Configurações) --- +config_file_path = "configs/ltxv-13b-0.9.8-distilled.yaml" +with open(config_file_path, "r") as file: PIPELINE_CONFIG_YAML = yaml.safe_load(file) + +LTX_REPO = "Lightricks/LTX-Video" +models_dir = "downloaded_models_gradio_cpu_init" +Path(models_dir).mkdir(parents=True, exist_ok=True) +WORKSPACE_DIR = "aduc_workspace" +GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") + +VIDEO_FPS = 30 +VIDEO_DURATION_SECONDS = 3 +VIDEO_TOTAL_FRAMES = VIDEO_DURATION_SECONDS * VIDEO_FPS +CONVERGENCE_FRAMES = 8 +MAX_REFS = 5 # Definimos um máximo de 5 referências para a UI + +print("Baixando e criando pipelines LTX na CPU...") +distilled_model_actual_path = huggingface_hub.hf_hub_download(repo_id=LTX_REPO, filename=PIPELINE_CONFIG_YAML["checkpoint_path"], local_dir=models_dir, local_dir_use_symlinks=False) +pipeline_instance = create_ltx_video_pipeline(ckpt_path=distilled_model_actual_path, precision=PIPELINE_CONFIG_YAML["precision"], text_encoder_model_name_or_path=PIPELINE_CONFIG_YAML["text_encoder_model_name_or_path"], sampler=PIPELINE_CONFIG_YAML["sampler"], device='cpu') +print("Modelos LTX prontos (na CPU).") + + +# --- Ato 3: As Partituras dos Músicos (Funções) --- + +def get_static_scenes_storyboard(num_fragments: int, prompt: str, initial_image_path: str, progress=gr.Progress()): + progress(0.5, desc="[Fotógrafo Gemini] Descrevendo as cenas estáticas...") + if not initial_image_path: raise gr.Error("Por favor, forneça uma imagem de referência inicial.") + if not GEMINI_API_KEY: raise gr.Error("Chave da API Gemini não configurada!") + genai.configure(api_key=GEMINI_API_KEY) + + prompt_file = "prompts/photographer_prompt.txt" + try: + script_dir = os.path.dirname(os.path.abspath(__file__)) + prompt_file_path = os.path.join(script_dir, prompt_file) + with open(prompt_file_path, "r", encoding="utf-8") as f: template = f.read() + except FileNotFoundError: raise gr.Error(f"Arquivo de prompt '{prompt_file}' não encontrado!") + + director_prompt = template.format(user_prompt=prompt, num_fragments=int(num_fragments)) + model = genai.GenerativeModel('gemini-2.0-flash') + img = Image.open(initial_image_path) + response = model.generate_content([director_prompt, img]) + + try: + cleaned_response = response.text.strip().replace("```json", "").replace("```", "") + if not cleaned_response: raise ValueError("A resposta do Gemini estava vazia.") + storyboard_data = json.loads(cleaned_response) + return storyboard_data.get("scene_storyboard", []) + except (json.JSONDecodeError, ValueError) as e: + raise gr.Error(f"O Fotógrafo retornou uma resposta inválida. Erro: {e}. Resposta Bruta: '{response.text}'") + +def get_motion_storyboard(user_prompt: str, keyframe_image_paths: list, progress=gr.Progress()): + progress(0.5, desc="[Diretor Gemini] Criando o roteiro de movimento...") + if not keyframe_image_paths: raise gr.Error("Nenhuma imagem-chave fornecida para o diretor de cena.") + if not GEMINI_API_KEY: raise gr.Error("Chave da API Gemini não configurada!") + genai.configure(api_key=GEMINI_API_KEY) + + prompt_file = "prompts/director_motion_prompt.txt" + try: + script_dir = os.path.dirname(os.path.abspath(__file__)) + prompt_file_path = os.path.join(script_dir, prompt_file) + with open(prompt_file_path, "r", encoding="utf-8") as f: template = f.read() + except FileNotFoundError: raise gr.Error(f"Arquivo de prompt '{prompt_file}' não encontrado!") + + director_prompt = template.format(user_prompt=user_prompt, num_fragments=len(keyframe_image_paths)) + + model_contents = [director_prompt] + for img_path in keyframe_image_paths: + img = Image.open(img_path) + model_contents.append(img) + + model = genai.GenerativeModel('gemini-2.0-flash') + response = model.generate_content(model_contents) + + try: + cleaned_response = response.text.strip().replace("```json", "").replace("```", "") + if not cleaned_response: raise ValueError("A resposta do Gemini estava vazia.") + storyboard_data = json.loads(cleaned_response) + return storyboard_data.get("motion_storyboard", []) + except (json.JSONDecodeError, ValueError) as e: + raise gr.Error(f"O Diretor de Cena retornou uma resposta inválida. Erro: {e}. Resposta Bruta: '{response.text}'") + +def run_sequential_keyframe_generation(storyboard, initial_ref_image_path, *reference_args): + if not storyboard: raise gr.Error("Nenhum roteiro para gerar imagens-chave.") + if not initial_ref_image_path: raise gr.Error("A imagem de referência inicial é obrigatória.") + + ref_paths = reference_args[:MAX_REFS] + ref_tasks = reference_args[MAX_REFS:] + + with Image.open(initial_ref_image_path) as img: + width, height = img.size + width, height = (width // 32) * 32, (height // 32) * 32 + + keyframe_paths, log_history = [], "" + current_ref_image_path = initial_ref_image_path + + try: + dreamo_generator_singleton.to_gpu() + for i, prompt in enumerate(storyboard): + log_history += f"Pintando Cena Sequencial {i+1}/{len(storyboard)}...\n" + yield {keyframe_log_output: gr.update(value=log_history), keyframe_gallery_output: gr.update(value=keyframe_paths)} + + reference_items_for_dreamo = [] + + reference_items_for_dreamo.append({ + 'image_np': np.array(Image.open(current_ref_image_path).convert("RGB")), + 'task': ref_tasks[0] + }) + + for j in range(1, MAX_REFS): + if ref_paths[j]: + reference_items_for_dreamo.append({ + 'image_np': np.array(Image.open(ref_paths[j]).convert("RGB")), + 'task': ref_tasks[j] + }) + + output_path = os.path.join(WORKSPACE_DIR, f"keyframe_image_{i+1}.png") + image = dreamo_generator_singleton.generate_image_with_gpu_management( + reference_items=reference_items_for_dreamo, + prompt=prompt, + width=width, + height=height + ) + image.save(output_path) + keyframe_paths.append(output_path) + current_ref_image_path = output_path + + log_history += f"Cena {i+1} pintada. A próxima cena usará '{os.path.basename(output_path)}' como referência.\n" + yield { + keyframe_log_output: gr.update(value=log_history), + keyframe_gallery_output: gr.update(value=keyframe_paths), + keyframe_images_state: keyframe_paths, + ref_image_inputs[0]: gr.update(value=current_ref_image_path) + } + finally: + dreamo_generator_singleton.to_cpu() + log_history += "\nPintura sequencial de todas as cenas concluída!" + yield {keyframe_log_output: gr.update(value=log_history)} + +def extract_final_frames_video(input_video_path: str, output_video_path: str, num_frames: int): + if not os.path.exists(input_video_path): raise gr.Error(f"Erro Interno: Vídeo de entrada para extração não encontrado: {input_video_path}") + try: + command_probe = f"ffprobe -v error -select_streams v:0 -count_frames -show_entries stream=nb_read_frames -of default=noprint_wrappers=1:nokey=1 \"{input_video_path}\"" + result_probe = subprocess.run(command_probe, shell=True, check=True, capture_output=True, text=True) + total_frames = int(result_probe.stdout.strip()) + start_frame_index = total_frames - num_frames + if start_frame_index < 0: + print(f"Aviso: O vídeo tem menos de {num_frames} frames. Usando o vídeo inteiro como convergência.") + shutil.copyfile(input_video_path, output_video_path) + return output_video_path + command_extract = f"ffmpeg -y -i \"{input_video_path}\" -vf \"select='gte(n,{start_frame_index})'\" -c:v libx264 -preset ultrafast -an \"{output_video_path}\"" + subprocess.run(command_extract, shell=True, check=True, capture_output=True, text=True) + return output_video_path + except (subprocess.CalledProcessError, ValueError) as e: + error_message = f"FFmpeg/FFprobe falhou ao extrair os frames finais: {e}" + if hasattr(e, 'stderr'): error_message += f"\nDetalhes: {e.stderr}" + raise gr.Error(error_message) + +def load_conditioning_tensor(media_path: str, height: int, width: int) -> torch.Tensor: + if media_path.lower().endswith(('.png', '.jpg', '.jpeg')): + return load_image_to_tensor_with_resize_and_crop(media_path, height, width) + elif media_path.lower().endswith('.mp4'): + try: + with imageio.get_reader(media_path) as reader: + first_frame = reader.get_data(0) + image = Image.fromarray(first_frame).convert("RGB").resize((width, height)) + image = np.array(image) + frame_tensor = torch.from_numpy(image).float() + frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0 + frame_tensor = frame_tensor.permute(2, 0, 1) + frame_tensor = (frame_tensor / 127.5) - 1.0 + return frame_tensor.unsqueeze(0).unsqueeze(2) + except Exception as e: + raise gr.Error(f"Falha ao ler o primeiro frame do vídeo de convergência '{media_path}': {e}") + else: + raise gr.Error(f"Formato de arquivo de condicionamento não suportado: {media_path}") + +def run_ltx_animation(current_fragment_index, motion_prompt, conditioning_items_data, width, height, seed, cfg, progress=gr.Progress()): + progress(0, desc=f"[Animador LTX] Gerando Cena {current_fragment_index}...") + output_path = os.path.join(WORKSPACE_DIR, f"fragment_{current_fragment_index}.mp4") + target_device = 'cuda' if torch.cuda.is_available() else 'cpu' + try: + pipeline_instance.to(target_device) + conditioning_items = [] + for (path, start_frame, strength) in conditioning_items_data: + tensor = load_conditioning_tensor(path, height, width) + conditioning_items.append(ConditioningItem(tensor.to(target_device), start_frame, strength)) + + n_val = round((float(VIDEO_TOTAL_FRAMES) - 1.0) / 8.0) + actual_num_frames = int(n_val * 8 + 1) + padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32 + padding_vals = calculate_padding(height, width, padded_h, padded_w) + for cond_item in conditioning_items: cond_item.media_item = torch.nn.functional.pad(cond_item.media_item, padding_vals) + timesteps = PIPELINE_CONFIG_YAML.get("first_pass", {}).get("timesteps") + kwargs = {"prompt": motion_prompt, "negative_prompt": "blurry, distorted, bad quality, artifacts", "height": padded_h, "width": padded_w, "num_frames": actual_num_frames, "frame_rate": VIDEO_FPS, "generator": torch.Generator(device=target_device).manual_seed(int(seed) + current_fragment_index), "output_type": "pt", "guidance_scale": float(cfg), "timesteps": timesteps, "conditioning_items": conditioning_items, "vae_per_channel_normalize": True, "decode_timestep": PIPELINE_CONFIG_YAML["decode_timestep"], "decode_noise_scale": PIPELINE_CONFIG_YAML["decode_noise_scale"], "stochastic_sampling": PIPELINE_CONFIG_YAML["stochastic_sampling"], "image_cond_noise_scale": 0.15, "is_video": True, "mixed_precision": (PIPELINE_CONFIG_YAML["precision"] == "mixed_precision"), "offload_to_cpu": False, "enhance_prompt": False} + result_tensor = pipeline_instance(**kwargs).images + pad_l, pad_r, pad_t, pad_b = padding_vals + slice_h, slice_w = (-pad_b if pad_b > 0 else None), (-pad_r if pad_r > 0 else None) + cropped_tensor = result_tensor[:, :, :VIDEO_TOTAL_FRAMES, pad_t:slice_h, pad_l:slice_w] + video_np = (cropped_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy() * 255).astype(np.uint8) + with imageio.get_writer(output_path, fps=VIDEO_FPS, codec='libx264', quality=8) as writer: + for i, frame in enumerate(video_np): progress(i / len(video_np), desc=f"Renderizando frame {i+1}/{len(video_np)}..."); writer.append_data(frame) + return output_path + finally: + pipeline_instance.to('cpu'); gc.collect(); torch.cuda.empty_cache() + +def run_full_video_production(prompt_geral, keyframe_image_paths, seed, cfg): + if not keyframe_image_paths: raise gr.Error("Imagens-chave estão faltando.") + + log_history = "Iniciando Etapa 3: Geração do Roteiro de Movimento...\n" + yield {video_production_log_output: gr.update(value=log_history)} + motion_storyboard = get_motion_storyboard(prompt_geral, keyframe_image_paths) + if not motion_storyboard or len(motion_storyboard) != len(keyframe_image_paths): + raise gr.Error("Falha ao gerar o roteiro de movimento ou o número de prompts não corresponde ao número de imagens.") + log_history += "Roteiro de movimento gerado com sucesso.\n\nIniciando Etapa 4: Produção dos Vídeos com Convergência Física...\n" + yield {video_production_log_output: gr.update(value=log_history)} + + with Image.open(keyframe_image_paths[0]) as img: width, height = img.size + + video_fragments = [] + num_keyframes = len(keyframe_image_paths) + n_val = round((float(VIDEO_TOTAL_FRAMES) - 1.0) / 8.0) + actual_num_frames = int(n_val * 8 + 1) + end_frame_index = actual_num_frames - 1 + + previous_media_path = keyframe_image_paths[0] + + for i in range(num_keyframes): + current_motion_prompt = motion_storyboard[i] + + log_message = f"\n--- Preparando Fragmento {i+1}/{num_keyframes} ---\n" + log_message += f"Motor de partida (convergência): {os.path.basename(previous_media_path)}\n" + log_history += log_message + yield {video_production_log_output: gr.update(value=log_history)} + + start_media_path = previous_media_path + + if i < num_keyframes - 1: + end_image_path = keyframe_image_paths[i+1] + conditioning_items_data = [(start_media_path, 0, 1.0), (end_image_path, end_frame_index, 1.0)] + log_message = f"Ponto final (alvo): {os.path.basename(end_image_path)}\n" + else: + conditioning_items_data = [(start_media_path, 0, 1.0)] + log_message = "Animação final livre (sem ponto final definido).\n" + + log_history += log_message + yield {video_production_log_output: gr.update(value=log_history)} + + full_fragment_path = run_ltx_animation(i + 1, current_motion_prompt, conditioning_items_data, width, height, seed, cfg) + video_fragments.append(full_fragment_path) + + log_message = f"Fragmento {i+1} concluído: {os.path.basename(full_fragment_path)}\n" + log_history += log_message + yield { + video_production_log_output: gr.update(value=log_history), + fragment_gallery_output: gr.update(value=video_fragments), + fragment_list_state: video_fragments, + final_fragments_display: gr.update(value=video_fragments) + } + + if i < num_keyframes - 1: + convergence_video_path = os.path.join(WORKSPACE_DIR, f"convergence_clip_{i+1}.mp4") + log_message = f"Extraindo {CONVERGENCE_FRAMES} frames de convergência para a próxima etapa...\n" + log_history += log_message + yield {video_production_log_output: gr.update(value=log_history)} + extract_final_frames_video(full_fragment_path, convergence_video_path, CONVERGENCE_FRAMES) + previous_media_path = convergence_video_path + + log_history += "\nProdução de todas as cenas de vídeo concluída!" + yield {video_production_log_output: gr.update(value=log_history)} + +def concatenate_masterpiece(fragment_paths: list, progress=gr.Progress()): + progress(0.5, desc="Montando a obra-prima final...") + list_file_path = os.path.join(WORKSPACE_DIR, "concat_list.txt") + final_output_path = os.path.join(WORKSPACE_DIR, "obra_prima_final.mp4") + with open(list_file_path, "w") as f: + for path in fragment_paths: f.write(f"file '{os.path.abspath(path)}'\n") + command = f"ffmpeg -y -f concat -safe 0 -i {list_file_path} -c copy {final_output_path}" + try: + subprocess.run(command, shell=True, check=True, capture_output=True, text=True) + return final_output_path + except subprocess.CalledProcessError as e: + raise gr.Error(f"FFmpeg falhou ao unir os vídeos: {e.stderr}") + +with gr.Blocks(theme=gr.themes.Soft()) as demo: + gr.Markdown("# LTX Video - Storyboard em Vídeo (ADUC-SDR)\n*By Carlex & Gemini & DreamO*") + + scene_storyboard_state = gr.State([]) + keyframe_images_state = gr.State([]) + fragment_list_state = gr.State([]) + prompt_geral_state = gr.State("") + + if os.path.exists(WORKSPACE_DIR): shutil.rmtree(WORKSPACE_DIR) + os.makedirs(WORKSPACE_DIR) + + with gr.Tabs(): + with gr.TabItem("ETAPA 1: O FOTÓGRAFO (Roteiro de Cenas)"): + with gr.Row(): + with gr.Column(): + num_fragments_input = gr.Slider(2, 10, 4, step=1, label="Número de Cenas") + prompt_input = gr.Textbox(label="Ideia Geral (Prompt)") + image_input = gr.Image(type="filepath", label="Imagem de Referência Principal") + director_button = gr.Button("▶️ 1. Gerar Roteiro de Cenas", variant="primary") + with gr.Column(): + storyboard_to_show = gr.JSON(label="Roteiro de Cenas Gerado") + + with gr.TabItem("ETAPA 2: O PINTOR (Imagens-Chave)"): + with gr.Row(): + with gr.Column(scale=2): + gr.Markdown("### Controles do Pintor (DreamO)\nUse os botões `+` e `-` para adicionar ou remover slots de referência opcionais (até 5 no total).") + + visible_references_state = gr.State(1) + ref_image_inputs = [] + ref_task_inputs = [] + + with gr.Blocks() as ref_blocks: + for i in range(MAX_REFS): + is_visible = i < 1 + label_prefix = f"Referência {i+1}" + if i == 0: + label_prefix += " (Sequencial)" + default_task = "style" + is_interactive = False + else: + label_prefix += " (Opcional, Fixa)" + default_task = "ip" + is_interactive = True + + with gr.Row(visible=is_visible) as ref_row: + img = gr.Image(label=label_prefix, type="filepath", interactive=is_interactive) + task = gr.Dropdown(choices=["ip", "id", "style"], value=default_task, label=f"Tarefa para Ref {i+1}") + ref_image_inputs.append(img) + ref_task_inputs.append(task) + + with gr.Row(): + add_ref_button = gr.Button("➕ Adicionar Referência") + remove_ref_button = gr.Button("➖ Remover Referência") + + photographer_button = gr.Button("▶️ 2. Pintar Imagens-Chave em Sequência", variant="primary") + keyframe_log_output = gr.Textbox(label="Diário de Bordo do Pintor", lines=5, interactive=False) + + with gr.Column(scale=1): + keyframe_gallery_output = gr.Gallery(label="Imagens-Chave Pintadas", object_fit="contain", height="auto", type="filepath") + + with gr.TabItem("ETAPA 3: PRODUÇÃO (Gerar Vídeos)"): + gr.Markdown("Nesta etapa, o sistema irá primeiro gerar o roteiro de movimento e depois animar os clipes, **usando o final de um clipe para dar partida no próximo**.") + with gr.Row(): + with gr.Column(): + keyframes_to_render = gr.Gallery(label="Imagens-Chave para Animar", object_fit="contain", height="auto", interactive=False) + animator_button = gr.Button("▶️ 3. Produzir Cenas em Vídeo", variant="primary", interactive=False) + video_production_log_output = gr.Textbox(label="Diário de Bordo da Produção", lines=10, interactive=False) + with gr.Column(): + fragment_gallery_output = gr.Gallery(label="Cenas Produzidas (Vídeos)", object_fit="contain", height="auto") + with gr.Row(): + seed_number = gr.Number(42, label="Seed") + cfg_slider = gr.Slider(1.0, 10.0, 2.5, step=0.1, label="CFG") + + with gr.TabItem("ETAPA 4: PÓS-PRODUÇÃO"): + with gr.Row(): + with gr.Column(): + editor_button = gr.Button("▶️ 4. Concatenar Vídeo Final", variant="primary") + final_fragments_display = gr.JSON(label="Fragmentos a Concatenar") + with gr.Column(): + final_video_output = gr.Video(label="A Obra-Prima Final") + + # --- Ato 5: A Regência (Lógica de Conexão dos Botões) --- + + def on_director_success(storyboard_list, img_path, prompt_geral): + if not storyboard_list: raise gr.Error("O storyboard está vazio ou em formato inválido.") + return storyboard_list, img_path, prompt_geral, gr.update(value=storyboard_list), gr.update(value=img_path) + + director_button.click( + fn=get_static_scenes_storyboard, + inputs=[num_fragments_input, prompt_input, image_input], + outputs=[scene_storyboard_state] + ).then( + fn=on_director_success, + inputs=[scene_storyboard_state, image_input, prompt_input], + outputs=[scene_storyboard_state, ref_image_inputs[0], prompt_geral_state, storyboard_to_show, ref_image_inputs[0]] + ) + + def update_reference_visibility(current_count, action): + if action == "add": new_count = min(MAX_REFS, current_count + 1) + else: new_count = max(1, current_count - 1) + updates = [gr.update(visible=(i < new_count)) for i in range(MAX_REFS)] + return [new_count] + updates + + all_ref_rows = [comp.parent for comp in ref_image_inputs] + add_ref_button.click(fn=update_reference_visibility, inputs=[visible_references_state, gr.State("add")], outputs=[visible_references_state] + all_ref_rows) + remove_ref_button.click(fn=update_reference_visibility, inputs=[visible_references_state, gr.State("remove")], outputs=[visible_references_state] + all_ref_rows) + + photographer_button.click( + fn=run_sequential_keyframe_generation, + inputs=[scene_storyboard_state, ref_image_inputs[0]] + ref_image_inputs + ref_task_inputs, + outputs=[keyframe_log_output, keyframe_gallery_output, keyframe_images_state, ref_image_inputs[0]] + ).then( + lambda paths: {keyframes_to_render: gr.update(value=paths), animator_button: gr.update(interactive=True)}, + inputs=[keyframe_images_state], + outputs=[keyframes_to_render, animator_button] + ) + + animator_button.click( + fn=run_full_video_production, + inputs=[prompt_geral_state, keyframe_images_state, seed_number, cfg_slider], + outputs=[video_production_log_output, fragment_gallery_output, fragment_list_state, final_fragments_display] + ) + + editor_button.click( + fn=concatenate_masterpiece, + inputs=[fragment_list_state], + outputs=[final_video_output] + ) + +if __name__ == "__main__": + demo.queue().launch(server_name="0.0.0.0", share=True) \ No newline at end of file diff --git a/appv3.py b/appv3.py new file mode 100644 index 0000000000000000000000000000000000000000..e35f6f477e53fe6f8bc4ba07b339fe100ba073aa --- /dev/null +++ b/appv3.py @@ -0,0 +1,295 @@ +# --- app.py (ADUC-SDR v4.0 - Correção de Compilação e Estado) --- +# By Carlex & Gemini & DreamO + +# --- Ato 1: A Convocação da Orquestra (Importações) --- +import gradio as gr +import torch +import os +import yaml +from PIL import Image +import shutil +import gc +import subprocess +import math +import google.generativeai as genai +import numpy as np +import imageio +from pathlib import Path +import huggingface_hub +import json + +from inference import create_ltx_video_pipeline, load_image_to_tensor_with_resize_and_crop, ConditioningItem, calculate_padding +from dreamo_helpers import dreamo_generator_singleton + +# --- Ato 2: A Preparação do Palco (Configurações) --- +config_file_path = "configs/ltxv-13b-0.9.8-distilled.yaml" +with open(config_file_path, "r") as file: + PIPELINE_CONFIG_YAML = yaml.safe_load(file) + +LTX_REPO = "Lightricks/LTX-Video" +models_dir = "downloaded_models_gradio_cpu_init" +Path(models_dir).mkdir(parents=True, exist_ok=True) +WORKSPACE_DIR = "aduc_workspace" +GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") + +VIDEO_FPS = 30 +VIDEO_DURATION_SECONDS = 3 +VIDEO_TOTAL_FRAMES = VIDEO_DURATION_SECONDS * VIDEO_FPS + +print("Baixando e criando pipelines LTX na CPU...") +distilled_model_actual_path = huggingface_hub.hf_hub_download(repo_id=LTX_REPO, filename=PIPELINE_CONFIG_YAML["checkpoint_path"], local_dir=models_dir, local_dir_use_symlinks=False) +pipeline_instance_original = create_ltx_video_pipeline(ckpt_path=distilled_model_actual_path, precision=PIPELINE_CONFIG_YAML["precision"], text_encoder_model_name_or_path=PIPELINE_CONFIG_YAML["text_encoder_model_name_or_path"], sampler=PIPELINE_CONFIG_YAML["sampler"], device='cpu') +print("Modelos LTX prontos (na CPU).") + +# <<< CORREÇÃO: A variável global `pipeline_instance` será o objeto que usamos. >>> +pipeline_instance = pipeline_instance_original + +if torch.cuda.is_available(): + print("Compilando o modelo LTX para otimização de desempenho (torch.compile)...") + try: + # Reatribui a variável global com a versão compilada + pipeline_instance = torch.compile(pipeline_instance_original, mode="reduce-overhead", fullgraph=True) + print("Modelo compilado com sucesso.") + except Exception as e: + print(f"Falha ao compilar o modelo, usando a versão não compilada. Erro: {e}") + pipeline_instance = pipeline_instance_original + + +# --- Ato 3: As Partituras dos Músicos (Funções) --- + +def get_next_scene_prompt(user_prompt: str, prompt_history_str: str, previous_image_path: str): + genai.configure(api_key=GEMINI_API_KEY) + script_dir = os.path.dirname(os.path.abspath(__file__)) + prompt_file_path = os.path.join(script_dir, "prompts", "photographer_sequential_prompt.txt") + with open(prompt_file_path, "r", encoding="utf-8") as f: template = f.read() + + model_prompt = template.format(user_prompt=user_prompt, prompt_history=prompt_history_str) + img = Image.open(previous_image_path) + model = genai.GenerativeModel('gemini-2.0-flash') + response = model.generate_content([model_prompt, img]) + + try: + cleaned_response = response.text.strip().replace("```json", "").replace("```", "") + data = json.loads(cleaned_response) + return data.get("next_scene_prompt") + except Exception as e: + raise gr.Error(f"Fotógrafo Sequencial falhou: {e}. Resposta: {response.text}") + +def get_motion_prompt_for_pair(user_prompt: str, start_image_path: str, end_image_path: str): + genai.configure(api_key=GEMINI_API_KEY) + script_dir = os.path.dirname(os.path.abspath(__file__)) + prompt_file_path = os.path.join(script_dir, "prompts", "director_sequential_prompt.txt") + with open(prompt_file_path, "r", encoding="utf-8") as f: template = f.read() + + model_prompt = template.format(user_prompt=user_prompt) + img1 = Image.open(start_image_path) + img2 = Image.open(end_image_path) + model = genai.GenerativeModel('gemini-2.0-flash') + response = model.generate_content([model_prompt, img1, img2]) + + try: + cleaned_response = response.text.strip().replace("```json", "").replace("```", "") + data = json.loads(cleaned_response) + return data.get("motion_prompt") + except Exception as e: + raise gr.Error(f"Diretor Sequencial falhou: {e}. Resposta: {response.text}") + +def run_ltx_animation(current_fragment_index, motion_prompt, conditioning_items_data, width, height, seed, cfg, progress=gr.Progress()): + progress(0, desc=f"[Animador LTX] Gerando Cena {current_fragment_index}..."); + output_path = os.path.join(WORKSPACE_DIR, f"fragment_{current_fragment_index}.mp4"); + target_device = 'cuda' if torch.cuda.is_available() else 'cpu' + + result_tensor = None + try: + pipeline_instance.to(target_device) + + conditioning_items = [] + for (path, start_frame, strength) in conditioning_items_data: + tensor = load_image_to_tensor_with_resize_and_crop(path, height, width) + conditioning_items.append(ConditioningItem(tensor.to(target_device), start_frame, strength)) + + n_val = round((float(VIDEO_TOTAL_FRAMES) - 1.0) / 8.0) + actual_num_frames = int(n_val * 8 + 1) + padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32 + padding_vals = calculate_padding(height, width, padded_h, padded_w) + for cond_item in conditioning_items: cond_item.media_item = torch.nn.functional.pad(cond_item.media_item, padding_vals) + + first_pass_config = PIPELINE_CONFIG_YAML.get("first_pass", {}) + + kwargs = { + "prompt": motion_prompt, "negative_prompt": "blurry, distorted, bad quality, artifacts", + "height": padded_h, "width": padded_w, "num_frames": actual_num_frames, "frame_rate": VIDEO_FPS, + "generator": torch.Generator(device=target_device).manual_seed(int(seed) + current_fragment_index), + "output_type": "pt", "guidance_scale": float(cfg), "timesteps": first_pass_config.get("timesteps"), + "stg_scale": first_pass_config.get("stg_scale"), "rescaling_scale": first_pass_config.get("rescaling_scale"), + "skip_block_list": first_pass_config.get("skip_block_list"), "conditioning_items": conditioning_items, + "decode_timestep": PIPELINE_CONFIG_YAML.get("decode_timestep"), "decode_noise_scale": PIPELINE_CONFIG_YAML.get("decode_noise_scale"), + "stochastic_sampling": PIPELINE_CONFIG_YAML.get("stochastic_sampling"), "image_cond_noise_scale": 0.15, + "is_video": True, "vae_per_channel_normalize": True, + "mixed_precision": (PIPELINE_CONFIG_YAML.get("precision") == "mixed_precision"), "offload_to_cpu": False, "enhance_prompt": False + } + + result_tensor = pipeline_instance(**kwargs).images + + pad_l, pad_r, pad_t, pad_b = map(int, padding_vals) + slice_h = -pad_b if pad_b > 0 else None + slice_w = -pad_r if pad_r > 0 else None + + cropped_tensor = result_tensor[:, :, :VIDEO_TOTAL_FRAMES, pad_t:slice_h, pad_l:slice_w] + video_np = (cropped_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy() * 255).astype(np.uint8) + + with imageio.get_writer(output_path, fps=VIDEO_FPS, codec='libx264', quality=8) as writer: + for i, frame in enumerate(video_np): + progress(i / len(video_np), desc=f"Renderizando frame {i+1}/{len(video_np)}...") + writer.append_data(frame) + return output_path + finally: + pipeline_instance.to('cpu') + + if result_tensor is not None: del result_tensor + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + print("Memória do Animador LTX liberada.") + +def concatenate_masterpiece(fragment_paths: list, progress=gr.Progress()): + if not fragment_paths: return None + progress(0.5, desc="Montando a obra-prima final..."); + list_file_path = os.path.join(WORKSPACE_DIR, "concat_list.txt") + final_output_path = os.path.join(WORKSPACE_DIR, "obra_prima_final.mp4") + with open(list_file_path, "w") as f: + for path in fragment_paths: f.write(f"file '{os.path.abspath(path)}'\n") + command = f"ffmpeg -y -f concat -safe 0 -i {list_file_path} -c copy {final_output_path}" + try: + subprocess.run(command, shell=True, check=True, capture_output=True, text=True) + return final_output_path + except subprocess.CalledProcessError as e: + raise gr.Error(f"FFmpeg falhou ao unir os vídeos: {e.stderr}") + +def editor_magic(video_path: str, fragment_index: int): + print(f"--- [ADUC-SDR] Editor (FFmpeg) trabalhando no Fragmento {fragment_index}... ---") + output_image_path = os.path.join(WORKSPACE_DIR, f"last_frame_frag_{fragment_index}.jpg") + + if not video_path or not os.path.exists(video_path): + raise gr.Error(f"Erro Interno: O vídeo do fragmento {fragment_index} não foi encontrado para extrair o frame.") + + try: + command_probe = f"ffprobe -v error -count_frames -select_streams v:0 -show_entries stream=nb_read_frames -of default=noprint_wrappers=1:nokey=1 \"{video_path}\"" + result_probe = subprocess.run(command_probe, shell=True, check=True, capture_output=True, text=True) + total_frames = int(result_probe.stdout.strip()) + last_frame_index = total_frames - 1 + + if last_frame_index < 0: + raise gr.Error("FFprobe retornou um número de frames inválido.") + + command_extract = f"ffmpeg -y -i \"{video_path}\" -vf \"select='eq(n,{last_frame_index})'\" -vsync vfr -frames:v 1 \"{output_image_path}\"" + subprocess.run(command_extract, shell=True, check=True, capture_output=True, text=True) + + print(f"Último frame ({last_frame_index}) extraído com sucesso para: {output_image_path}") + return output_image_path + except (subprocess.CalledProcessError, ValueError) as e: + error_message = f"FFmpeg/FFprobe falhou ao extrair último frame: {e}" + if hasattr(e, 'stderr'): + error_message += f"\nDetalhes: {e.stderr}" + raise gr.Error(error_message) + +def run_sequential_production(num_fragments, user_prompt, ref_image_path, seed, cfg, progress=gr.Progress()): + if not ref_image_path: raise gr.Error("Por favor, forneça uma imagem de referência.") + + video_fragments = [] + log_history = "Iniciando Produção Sequencial com Memória Contextual...\n" + + prompt_history = [] + image_anterior_path = ref_image_path + + for i in range(int(num_fragments)): + progress(i / num_fragments, desc=f"Gerando Fragmento {i+1}/{num_fragments}") + log_history += f"\n--- FRAGMENTO {i+1} ---\n" + yield log_history, None, image_anterior_path, None + + log_history += "Fotógrafo (Gemini) criando prompt da próxima cena (com memória)...\n" + yield log_history, None, image_anterior_path, None + + prompt_history_str = "\n".join([f"- Cena {idx+1}: {p}" for idx, p in enumerate(prompt_history)]) + if not prompt_history_str: + prompt_history_str = "Esta é a primeira cena." + + prompt_proxima_cena = get_next_scene_prompt(user_prompt, prompt_history_str, image_anterior_path) + prompt_history.append(prompt_proxima_cena) + + log_history += f"Pintor (DreamO) renderizando a próxima cena: '{prompt_proxima_cena}'...\n" + yield log_history, None, image_anterior_path, None + + image_atual_path = os.path.join(WORKSPACE_DIR, f"keyframe_{i+1}.png") + with Image.open(image_anterior_path) as img: width, height = img.size + width, height = (width // 32) * 32, (height // 32) * 32 + + dreamo_generator_singleton.to_gpu() + try: + image_atual = dreamo_generator_singleton.generate_image_with_gpu_management( + ref_image1_np=np.array(Image.open(image_anterior_path).convert("RGB")), ref_task1="style", + ref_image2_np=np.array(Image.open(image_anterior_path).convert("RGB")), ref_task2="ip", + prompt=prompt_proxima_cena, width=width, height=height + ) + image_atual.save(image_atual_path) + log_history += "Nova imagem de keyframe gerada.\n" + yield log_history, None, image_anterior_path, image_atual_path + finally: + dreamo_generator_singleton.to_cpu() + + log_history += "Diretor de Cena (Gemini) criando prompt de movimento...\n" + yield log_history, None, image_anterior_path, image_atual_path + prompt_movimento = get_motion_prompt_for_pair(user_prompt, image_anterior_path, image_atual_path) + + log_history += f"Animador (LTX) gerando vídeo: '{prompt_movimento}'...\n" + yield log_history, None, image_anterior_path, image_atual_path + + n_val = round((float(VIDEO_TOTAL_FRAMES) - 1.0) / 8.0) + actual_num_frames = int(n_val * 8 + 1) + end_frame_index = actual_num_frames - 1 + conditioning_items_data = [(image_anterior_path, 0, 1.0), (image_atual_path, end_frame_index, 1.0)] + + fragment_path = run_ltx_animation(i + 1, prompt_movimento, conditioning_items_data, width, height, seed, cfg, progress) + video_fragments.append(fragment_path) + + log_history += "Editor (FFmpeg) extraindo último frame para continuidade...\n" + yield log_history, None, image_anterior_path, image_atual_path + image_anterior_path = editor_magic(fragment_path, i + 1) + + log_history += "\nConcatenando vídeo final...\n" + yield log_history, None, None, None + final_video_path = concatenate_masterpiece(video_fragments, progress) + + log_history += "\nProdução Concluída! Vídeo final pronto." + yield log_history, final_video_path, None, None + +# --- Ato 4: A Interface com o Mundo (Gradio UI) --- +with gr.Blocks(theme=gr.themes.Soft()) as demo: + gr.Markdown("# LTX Video - ADUC-SDR v4.0 (Compilação Corrigida)\n*By Carlex & Gemini & DreamO*") + + if os.path.exists(WORKSPACE_DIR): shutil.rmtree(WORKSPACE_DIR) + os.makedirs(WORKSPACE_DIR) + + with gr.Row(): + with gr.Column(scale=1): + num_fragments_input = gr.Slider(1, 10, 4, step=1, label="Número de Fragmentos a Gerar") + prompt_input = gr.Textbox(label="Ideia Geral (Prompt)") + image_input = gr.Image(type="filepath", label="Imagem de Referência Inicial") + seed_number = gr.Number(42, label="Seed") + cfg_slider = gr.Slider(1.0, 10.0, 2.5, step=0.1, label="CFG") + run_button = gr.Button("▶️ Gerar Vídeo Completo", variant="primary") + with gr.Column(scale=2): + with gr.Row(): + start_keyframe_display = gr.Image(label="Keyframe Inicial da Animação", interactive=False) + end_keyframe_display = gr.Image(label="Keyframe Final da Animação", interactive=False) + log_output = gr.Textbox(label="Diário de Bordo da Produção", lines=10, interactive=False) + video_output = gr.Video(label="Vídeo Final") + + run_button.click( + fn=run_sequential_production, + inputs=[num_fragments_input, prompt_input, image_input, seed_number, cfg_slider], + outputs=[log_output, video_output, start_keyframe_display, end_keyframe_display] + ) + +if __name__ == "__main__": + demo.queue().launch(server_name="0.0.0.0", share=True) \ No newline at end of file diff --git a/configs/ltxv-13b-0.9.7-dev.yaml b/configs/ltxv-13b-0.9.7-dev.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ae548253526c1de5804bb430407850573305cd14 --- /dev/null +++ b/configs/ltxv-13b-0.9.7-dev.yaml @@ -0,0 +1,34 @@ +pipeline_type: multi-scale +checkpoint_path: "ltxv-13b-0.9.7-dev.safetensors" +downscale_factor: 0.6666666 +spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors" +stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" +precision: "bfloat16" +sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" +prompt_enhancement_words_threshold: 120 +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +stochastic_sampling: false + +first_pass: + guidance_scale: [1, 1, 6, 8, 6, 1, 1] + stg_scale: [0, 0, 4, 4, 4, 2, 1] + rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1] + guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180] + skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]] + num_inference_steps: 30 + skip_final_inference_steps: 3 + cfg_star_rescale: true + +second_pass: + guidance_scale: [1] + stg_scale: [1] + rescaling_scale: [1] + guidance_timesteps: [1.0] + skip_block_list: [27] + num_inference_steps: 30 + skip_initial_inference_steps: 17 + cfg_star_rescale: true \ No newline at end of file diff --git a/configs/ltxv-13b-0.9.7-distilled.yaml b/configs/ltxv-13b-0.9.7-distilled.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9df17bb001b39d6d12c7013cb823c44b85d28aea --- /dev/null +++ b/configs/ltxv-13b-0.9.7-distilled.yaml @@ -0,0 +1,28 @@ +pipeline_type: multi-scale +checkpoint_path: "ltxv-13b-0.9.7-distilled.safetensors" +downscale_factor: 0.6666666 +spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors" +stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" +precision: "bfloat16" +sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" +prompt_enhancement_words_threshold: 120 +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +stochastic_sampling: false + +first_pass: + timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250] + guidance_scale: 1 + stg_scale: 0 + rescaling_scale: 1 + skip_block_list: [42] + +second_pass: + timesteps: [0.9094, 0.7250, 0.4219] + guidance_scale: 1 + stg_scale: 0 + rescaling_scale: 1 + skip_block_list: [42] diff --git a/configs/ltxv-13b-0.9.8-dev-fp8.yaml b/configs/ltxv-13b-0.9.8-dev-fp8.yaml new file mode 100644 index 0000000000000000000000000000000000000000..76b25f1373061a873a3134d471b927b66c37aa54 --- /dev/null +++ b/configs/ltxv-13b-0.9.8-dev-fp8.yaml @@ -0,0 +1,34 @@ +pipeline_type: multi-scale +checkpoint_path: "ltxv-13b-0.9.8-dev-fp8.safetensors" +downscale_factor: 0.6666666 +spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors" +stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" +precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision" +sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" +prompt_enhancement_words_threshold: 120 +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +stochastic_sampling: false + +first_pass: + guidance_scale: [1, 1, 6, 8, 6, 1, 1] + stg_scale: [0, 0, 4, 4, 4, 2, 1] + rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1] + guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180] + skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]] + num_inference_steps: 30 + skip_final_inference_steps: 3 + cfg_star_rescale: true + +second_pass: + guidance_scale: [1] + stg_scale: [1] + rescaling_scale: [1] + guidance_timesteps: [1.0] + skip_block_list: [27] + num_inference_steps: 30 + skip_initial_inference_steps: 17 + cfg_star_rescale: true diff --git a/configs/ltxv-13b-0.9.8-dev.yaml b/configs/ltxv-13b-0.9.8-dev.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0c22e9e5b3704146d521e7c60a841c043373c66e --- /dev/null +++ b/configs/ltxv-13b-0.9.8-dev.yaml @@ -0,0 +1,34 @@ +pipeline_type: multi-scale +checkpoint_path: "ltxv-13b-0.9.8-dev.safetensors" +downscale_factor: 0.6666666 +spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors" +stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" +precision: "bfloat16" +sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" +prompt_enhancement_words_threshold: 120 +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +stochastic_sampling: false + +first_pass: + guidance_scale: [1, 1, 6, 8, 6, 1, 1] + stg_scale: [0, 0, 4, 4, 4, 2, 1] + rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1] + guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180] + skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]] + num_inference_steps: 30 + skip_final_inference_steps: 3 + cfg_star_rescale: true + +second_pass: + guidance_scale: [1] + stg_scale: [1] + rescaling_scale: [1] + guidance_timesteps: [1.0] + skip_block_list: [27] + num_inference_steps: 30 + skip_initial_inference_steps: 17 + cfg_star_rescale: true \ No newline at end of file diff --git a/configs/ltxv-13b-0.9.8-distilled-fp8.yaml b/configs/ltxv-13b-0.9.8-distilled-fp8.yaml new file mode 100644 index 0000000000000000000000000000000000000000..444718bacbaa698c6b3df9cff6c89c9a2f95923c --- /dev/null +++ b/configs/ltxv-13b-0.9.8-distilled-fp8.yaml @@ -0,0 +1,29 @@ +pipeline_type: multi-scale +checkpoint_path: "ltxv-13b-0.9.8-distilled-fp8.safetensors" +downscale_factor: 0.6666666 +spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors" +stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" +precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision" +sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" +prompt_enhancement_words_threshold: 120 +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +stochastic_sampling: false + +first_pass: + timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250] + guidance_scale: 1 + stg_scale: 0 + rescaling_scale: 1 + skip_block_list: [42] + +second_pass: + timesteps: [0.9094, 0.7250, 0.4219] + guidance_scale: 1 + stg_scale: 0 + rescaling_scale: 1 + skip_block_list: [42] + tone_map_compression_ratio: 0.6 diff --git a/configs/ltxv-13b-0.9.8-distilled.yaml b/configs/ltxv-13b-0.9.8-distilled.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a1ac7239f3c3ecf0a8e4e03c3a1415a8b257dbf0 --- /dev/null +++ b/configs/ltxv-13b-0.9.8-distilled.yaml @@ -0,0 +1,29 @@ +pipeline_type: multi-scale +checkpoint_path: "ltxv-13b-0.9.8-distilled.safetensors" +downscale_factor: 0.6666666 +spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors" +stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" +precision: "bfloat16" +sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" +prompt_enhancement_words_threshold: 120 +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +stochastic_sampling: false + +first_pass: + timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250] + guidance_scale: 1 + stg_scale: 0 + rescaling_scale: 1 + skip_block_list: [42] + +second_pass: + timesteps: [0.9094, 0.7250, 0.4219] + guidance_scale: 1 + stg_scale: 0 + rescaling_scale: 1 + skip_block_list: [42] + tone_map_compression_ratio: 0.6 diff --git a/configs/ltxv-2b-0.9.1.yaml b/configs/ltxv-2b-0.9.1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6e888de3fb5ff258cd4caf52453eb707a3941761 --- /dev/null +++ b/configs/ltxv-2b-0.9.1.yaml @@ -0,0 +1,17 @@ +pipeline_type: base +checkpoint_path: "ltx-video-2b-v0.9.1.safetensors" +guidance_scale: 3 +stg_scale: 1 +rescaling_scale: 0.7 +skip_block_list: [19] +num_inference_steps: 40 +stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" +precision: "bfloat16" +sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" +prompt_enhancement_words_threshold: 120 +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +stochastic_sampling: false \ No newline at end of file diff --git a/configs/ltxv-2b-0.9.5.yaml b/configs/ltxv-2b-0.9.5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5998c6040bdbc3b4b0f6838bb7b61b58d0b58b5d --- /dev/null +++ b/configs/ltxv-2b-0.9.5.yaml @@ -0,0 +1,17 @@ +pipeline_type: base +checkpoint_path: "ltx-video-2b-v0.9.5.safetensors" +guidance_scale: 3 +stg_scale: 1 +rescaling_scale: 0.7 +skip_block_list: [19] +num_inference_steps: 40 +stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" +precision: "bfloat16" +sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" +prompt_enhancement_words_threshold: 120 +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +stochastic_sampling: false \ No newline at end of file diff --git a/configs/ltxv-2b-0.9.6-dev.yaml b/configs/ltxv-2b-0.9.6-dev.yaml new file mode 100644 index 0000000000000000000000000000000000000000..487f99708e0672dd17b5bd78424f25261163f7dc --- /dev/null +++ b/configs/ltxv-2b-0.9.6-dev.yaml @@ -0,0 +1,17 @@ +pipeline_type: base +checkpoint_path: "ltxv-2b-0.9.6-dev-04-25.safetensors" +guidance_scale: 3 +stg_scale: 1 +rescaling_scale: 0.7 +skip_block_list: [19] +num_inference_steps: 40 +stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" +precision: "bfloat16" +sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" +prompt_enhancement_words_threshold: 120 +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +stochastic_sampling: false \ No newline at end of file diff --git a/configs/ltxv-2b-0.9.6-distilled.yaml b/configs/ltxv-2b-0.9.6-distilled.yaml new file mode 100644 index 0000000000000000000000000000000000000000..328d9291613f16ba191cb56f97340f3bfa4d341d --- /dev/null +++ b/configs/ltxv-2b-0.9.6-distilled.yaml @@ -0,0 +1,16 @@ +pipeline_type: base +checkpoint_path: "ltxv-2b-0.9.6-distilled-04-25.safetensors" +guidance_scale: 1 +stg_scale: 0 +rescaling_scale: 1 +num_inference_steps: 8 +stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" +precision: "bfloat16" +sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" +prompt_enhancement_words_threshold: 120 +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +stochastic_sampling: true \ No newline at end of file diff --git a/configs/ltxv-2b-0.9.8-distilled-fp8.yaml b/configs/ltxv-2b-0.9.8-distilled-fp8.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c02b2057cb2050ea8f277697a3d741ce1ed03403 --- /dev/null +++ b/configs/ltxv-2b-0.9.8-distilled-fp8.yaml @@ -0,0 +1,28 @@ +pipeline_type: multi-scale +checkpoint_path: "ltxv-2b-0.9.8-distilled-fp8.safetensors" +downscale_factor: 0.6666666 +spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors" +stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" +precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision" +sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" +prompt_enhancement_words_threshold: 120 +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +stochastic_sampling: false + +first_pass: + timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250] + guidance_scale: 1 + stg_scale: 0 + rescaling_scale: 1 + skip_block_list: [42] + +second_pass: + timesteps: [0.9094, 0.7250, 0.4219] + guidance_scale: 1 + stg_scale: 0 + rescaling_scale: 1 + skip_block_list: [42] diff --git a/configs/ltxv-2b-0.9.8-distilled.yaml b/configs/ltxv-2b-0.9.8-distilled.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9e24b0eb46b7113e2fe52b3d86d8f0eb4adae8de --- /dev/null +++ b/configs/ltxv-2b-0.9.8-distilled.yaml @@ -0,0 +1,28 @@ +pipeline_type: multi-scale +checkpoint_path: "ltxv-2b-0.9.8-distilled.safetensors" +downscale_factor: 0.6666666 +spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors" +stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" +precision: "bfloat16" +sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" +prompt_enhancement_words_threshold: 120 +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +stochastic_sampling: false + +first_pass: + timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250] + guidance_scale: 1 + stg_scale: 0 + rescaling_scale: 1 + skip_block_list: [42] + +second_pass: + timesteps: [0.9094, 0.7250, 0.4219] + guidance_scale: 1 + stg_scale: 0 + rescaling_scale: 1 + skip_block_list: [42] diff --git a/configs/ltxv-2b-0.9.yaml b/configs/ltxv-2b-0.9.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f501ca62c24085192cebe10c87261fba38c930bc --- /dev/null +++ b/configs/ltxv-2b-0.9.yaml @@ -0,0 +1,17 @@ +pipeline_type: base +checkpoint_path: "ltx-video-2b-v0.9.safetensors" +guidance_scale: 3 +stg_scale: 1 +rescaling_scale: 0.7 +skip_block_list: [19] +num_inference_steps: 40 +stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" +decode_timestep: 0.05 +decode_noise_scale: 0.025 +text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" +precision: "bfloat16" +sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" +prompt_enhancement_words_threshold: 120 +prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" +prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" +stochastic_sampling: false \ No newline at end of file diff --git a/dreamo_helpers.py b/dreamo_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..b9cf6d6881cd5d86a506efd2b1823cf67b7b0b6b --- /dev/null +++ b/dreamo_helpers.py @@ -0,0 +1,123 @@ +# dreamo_helpers.py +# Módulo de serviço para o DreamO, com gestão de memória e aceitando uma lista dinâmica de referências. + +import os +import cv2 +import torch +import numpy as np +from PIL import Image +import huggingface_hub +import gc +from facexlib.utils.face_restoration_helper import FaceRestoreHelper +from torchvision.transforms.functional import normalize +from dreamo.dreamo_pipeline import DreamOPipeline +from dreamo.utils import img2tensor, tensor2img +from tools import BEN2 + +class Generator: + def __init__(self): + self.cpu_device = torch.device('cpu') + self.gpu_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + print("Carregando modelos DreamO para a CPU...") + model_root = 'black-forest-labs/FLUX.1-dev' + self.dreamo_pipeline = DreamOPipeline.from_pretrained(model_root, torch_dtype=torch.bfloat16) + self.dreamo_pipeline.load_dreamo_model(self.cpu_device, use_turbo=True) + + self.bg_rm_model = BEN2.BEN_Base().to(self.cpu_device).eval() + huggingface_hub.hf_hub_download(repo_id='PramaLLC/BEN2', filename='BEN2_Base.pth', local_dir='models') + self.bg_rm_model.loadcheckpoints('models/BEN2_Base.pth') + + self.face_helper = FaceRestoreHelper( + upscale_factor=1, face_size=512, crop_ratio=(1, 1), + det_model='retinaface_resnet50', save_ext='png', device=self.cpu_device, + ) + print("Modelos DreamO prontos (na CPU).") + + def to_gpu(self): + if self.gpu_device.type == 'cpu': return + print("Movendo modelos DreamO para a GPU...") + self.dreamo_pipeline.to(self.gpu_device) + self.bg_rm_model.to(self.gpu_device) + self.face_helper.device = self.gpu_device + self.dreamo_pipeline.t5_embedding.to(self.gpu_device) + self.dreamo_pipeline.task_embedding.to(self.gpu_device) + self.dreamo_pipeline.idx_embedding.to(self.gpu_device) + if hasattr(self.face_helper, 'face_parse'): self.face_helper.face_parse.to(self.gpu_device) + if hasattr(self.face_helper, 'face_det'): self.face_helper.face_det.to(self.gpu_device) + print("Modelos DreamO na GPU.") + + def to_cpu(self): + if self.gpu_device.type == 'cpu': return + print("Descarregando modelos DreamO da GPU...") + self.dreamo_pipeline.to(self.cpu_device) + self.bg_rm_model.to(self.cpu_device) + self.face_helper.device = self.cpu_device + self.dreamo_pipeline.t5_embedding.to(self.cpu_device) + self.dreamo_pipeline.task_embedding.to(self.cpu_device) + self.dreamo_pipeline.idx_embedding.to(self.cpu_device) + if hasattr(self.face_helper, 'face_det'): self.face_helper.face_det.to(self.cpu_device) + if hasattr(self.face_helper, 'face_parse'): self.face_helper.face_parse.to(self.cpu_device) + gc.collect() + if torch.cuda.is_available(): torch.cuda.empty_cache() + + @torch.inference_mode() + # <<<<< MODIFICAÇÃO PRINCIPAL: Aceita uma lista de dicionários de referência >>>>> + def generate_image_with_gpu_management(self, reference_items, prompt, width, height): + ref_conds = [] + + for idx, item in enumerate(reference_items): + ref_image_np = item.get('image_np') + ref_task = item.get('task') + + if ref_image_np is not None: + if ref_task == "id": + ref_image = self.get_align_face(ref_image_np) + elif ref_task != "style": + ref_image = self.bg_rm_model.inference(Image.fromarray(ref_image_np)) + else: # Style usa a imagem original + ref_image = ref_image_np + + ref_image_tensor = img2tensor(np.array(ref_image), bgr2rgb=False).unsqueeze(0) / 255.0 + ref_image_tensor = (2 * ref_image_tensor - 1.0).to(self.gpu_device, dtype=torch.bfloat16) + + # O modelo DreamO espera o índice começando em 1 + ref_conds.append({'img': ref_image_tensor, 'task': ref_task, 'idx': idx + 1}) + + image = self.dreamo_pipeline( + prompt=prompt, + width=width, + height=height, + num_inference_steps=12, + guidance_scale=4.5, + ref_conds=ref_conds, + generator=torch.Generator(device="cpu").manual_seed(42) + ).images[0] + return image + + @torch.no_grad() + def get_align_face(self, img): + # ... (lógica inalterada) + self.face_helper.clean_all() + image_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + self.face_helper.read_image(image_bgr) + self.face_helper.get_face_landmarks_5(only_center_face=True) + self.face_helper.align_warp_face() + if len(self.face_helper.cropped_faces) == 0: return None + align_face = self.face_helper.cropped_faces[0] + input_tensor = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 + input_tensor = input_tensor.to(self.gpu_device) + parsing_out = self.face_helper.face_parse(normalize(input_tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0] + parsing_out = parsing_out.argmax(dim=1, keepdim=True) + bg_label = [0, 16, 18, 7, 8, 9, 14, 15] + bg = sum(parsing_out == i for i in bg_label).bool() + white_image = torch.ones_like(input_tensor) + face_features_image = torch.where(bg, white_image, input_tensor) + return tensor2img(face_features_image, rgb2bgr=False) + +# --- Instância Singleton --- +print("Inicializando o Pintor de Cenas (DreamO Helper)...") +hf_token = os.getenv('HF_TOKEN') +if hf_token: huggingface_hub.login(token=hf_token) +dreamo_generator_singleton = Generator() +print("Pintor de Cenas (DreamO Helper) pronto.") \ No newline at end of file diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..0be9213f5016d6559ca3c04e435388069235c88d --- /dev/null +++ b/inference.py @@ -0,0 +1,774 @@ +import argparse +import os +import random +from datetime import datetime +from pathlib import Path +from diffusers.utils import logging +from typing import Optional, List, Union +import yaml + +import imageio +import json +import numpy as np +import torch +import cv2 +from safetensors import safe_open +from PIL import Image +from transformers import ( + T5EncoderModel, + T5Tokenizer, + AutoModelForCausalLM, + AutoProcessor, + AutoTokenizer, +) +from huggingface_hub import hf_hub_download + +from ltx_video.models.autoencoders.causal_video_autoencoder import ( + CausalVideoAutoencoder, +) +from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier +from ltx_video.models.transformers.transformer3d import Transformer3DModel +from ltx_video.pipelines.pipeline_ltx_video import ( + ConditioningItem, + LTXVideoPipeline, + LTXMultiScalePipeline, +) +from ltx_video.schedulers.rf import RectifiedFlowScheduler +from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy +from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler +import ltx_video.pipelines.crf_compressor as crf_compressor + +MAX_HEIGHT = 720 +MAX_WIDTH = 1280 +MAX_NUM_FRAMES = 257 + +logger = logging.get_logger("LTX-Video") + + +def get_total_gpu_memory(): + if torch.cuda.is_available(): + total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) + return total_memory + return 44 + + +def get_device(): + if torch.cuda.is_available(): + return "cuda" + elif torch.backends.mps.is_available(): + return "mps" + return "cuda" + + +def load_image_to_tensor_with_resize_and_crop( + image_input: Union[str, Image.Image], + target_height: int = 512, + target_width: int = 768, + just_crop: bool = False, +) -> torch.Tensor: + """Load and process an image into a tensor. + + Args: + image_input: Either a file path (str) or a PIL Image object + target_height: Desired height of output tensor + target_width: Desired width of output tensor + just_crop: If True, only crop the image to the target size without resizing + """ + if isinstance(image_input, str): + image = Image.open(image_input).convert("RGB") + elif isinstance(image_input, Image.Image): + image = image_input + else: + raise ValueError("image_input must be either a file path or a PIL Image object") + + input_width, input_height = image.size + aspect_ratio_target = target_width / target_height + aspect_ratio_frame = input_width / input_height + if aspect_ratio_frame > aspect_ratio_target: + new_width = int(input_height * aspect_ratio_target) + new_height = input_height + x_start = (input_width - new_width) // 2 + y_start = 0 + else: + new_width = input_width + new_height = int(input_width / aspect_ratio_target) + x_start = 0 + y_start = (input_height - new_height) // 2 + + image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height)) + if not just_crop: + image = image.resize((target_width, target_height)) + + image = np.array(image) + image = cv2.GaussianBlur(image, (3, 3), 0) + frame_tensor = torch.from_numpy(image).float() + frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0 + frame_tensor = frame_tensor.permute(2, 0, 1) + frame_tensor = (frame_tensor / 127.5) - 1.0 + # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width) + return frame_tensor.unsqueeze(0).unsqueeze(2) + + +def calculate_padding( + source_height: int, source_width: int, target_height: int, target_width: int +) -> tuple[int, int, int, int]: + + # Calculate total padding needed + pad_height = target_height - source_height + pad_width = target_width - source_width + + # Calculate padding for each side + pad_top = pad_height // 2 + pad_bottom = pad_height - pad_top # Handles odd padding + pad_left = pad_width // 2 + pad_right = pad_width - pad_left # Handles odd padding + + # Return padded tensor + # Padding format is (left, right, top, bottom) + padding = (pad_left, pad_right, pad_top, pad_bottom) + return padding + + +def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: + # Remove non-letters and convert to lowercase + clean_text = "".join( + char.lower() for char in text if char.isalpha() or char.isspace() + ) + + # Split into words + words = clean_text.split() + + # Build result string keeping track of length + result = [] + current_length = 0 + + for word in words: + # Add word length plus 1 for underscore (except for first word) + new_length = current_length + len(word) + + if new_length <= max_len: + result.append(word) + current_length += len(word) + else: + break + + return "-".join(result) + + +# Generate output video name +def get_unique_filename( + base: str, + ext: str, + prompt: str, + seed: int, + resolution: tuple[int, int, int], + dir: Path, + endswith=None, + index_range=1000, +) -> Path: + base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}" + for i in range(index_range): + filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}" + if not os.path.exists(filename): + return filename + raise FileExistsError( + f"Could not find a unique filename after {index_range} attempts." + ) + + +def seed_everething(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + if torch.backends.mps.is_available(): + torch.mps.manual_seed(seed) + + +def main(): + parser = argparse.ArgumentParser( + description="Load models from separate directories and run the pipeline." + ) + + # Directories + parser.add_argument( + "--output_path", + type=str, + default=None, + help="Path to the folder to save output video, if None will save in outputs/ directory.", + ) + parser.add_argument("--seed", type=int, default="171198") + + # Pipeline parameters + parser.add_argument( + "--num_images_per_prompt", + type=int, + default=1, + help="Number of images per prompt", + ) + parser.add_argument( + "--image_cond_noise_scale", + type=float, + default=0.15, + help="Amount of noise to add to the conditioned image", + ) + parser.add_argument( + "--height", + type=int, + default=704, + help="Height of the output video frames. Optional if an input image provided.", + ) + parser.add_argument( + "--width", + type=int, + default=1216, + help="Width of the output video frames. If None will infer from input image.", + ) + parser.add_argument( + "--num_frames", + type=int, + default=121, + help="Number of frames to generate in the output video", + ) + parser.add_argument( + "--frame_rate", type=int, default=30, help="Frame rate for the output video" + ) + parser.add_argument( + "--device", + default=None, + help="Device to run inference on. If not specified, will automatically detect and use CUDA or MPS if available, else CPU.", + ) + parser.add_argument( + "--pipeline_config", + type=str, + default="configs/ltxv-13b-0.9.7-dev.yaml", + help="The path to the config file for the pipeline, which contains the parameters for the pipeline", + ) + + # Prompts + parser.add_argument( + "--prompt", + type=str, + help="Text prompt to guide generation", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default="worst quality, inconsistent motion, blurry, jittery, distorted", + help="Negative prompt for undesired features", + ) + + parser.add_argument( + "--offload_to_cpu", + action="store_true", + help="Offloading unnecessary computations to CPU.", + ) + + # video-to-video arguments: + parser.add_argument( + "--input_media_path", + type=str, + default=None, + help="Path to the input video (or imaage) to be modified using the video-to-video pipeline", + ) + + # Conditioning arguments + parser.add_argument( + "--conditioning_media_paths", + type=str, + nargs="*", + help="List of paths to conditioning media (images or videos). Each path will be used as a conditioning item.", + ) + parser.add_argument( + "--conditioning_strengths", + type=float, + nargs="*", + help="List of conditioning strengths (between 0 and 1) for each conditioning item. Must match the number of conditioning items.", + ) + parser.add_argument( + "--conditioning_start_frames", + type=int, + nargs="*", + help="List of frame indices where each conditioning item should be applied. Must match the number of conditioning items.", + ) + + args = parser.parse_args() + logger.warning(f"Running generation with arguments: {args}") + infer(**vars(args)) + + +def create_ltx_video_pipeline( + ckpt_path: str, + precision: str, + text_encoder_model_name_or_path: str, + sampler: Optional[str] = None, + device: Optional[str] = None, + enhance_prompt: bool = False, + prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None, + prompt_enhancer_llm_model_name_or_path: Optional[str] = None, +) -> LTXVideoPipeline: + ckpt_path = Path(ckpt_path) + assert os.path.exists( + ckpt_path + ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist" + + with safe_open(ckpt_path, framework="pt") as f: + metadata = f.metadata() + config_str = metadata.get("config") + configs = json.loads(config_str) + allowed_inference_steps = configs.get("allowed_inference_steps", None) + + vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) + transformer = Transformer3DModel.from_pretrained(ckpt_path) + + # Use constructor if sampler is specified, otherwise use from_pretrained + if sampler == "from_checkpoint" or not sampler: + scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path) + else: + scheduler = RectifiedFlowScheduler( + sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic") + ) + + text_encoder = T5EncoderModel.from_pretrained( + text_encoder_model_name_or_path, subfolder="text_encoder" + ) + patchifier = SymmetricPatchifier(patch_size=1) + tokenizer = T5Tokenizer.from_pretrained( + text_encoder_model_name_or_path, subfolder="tokenizer" + ) + + transformer = transformer.to(device) + vae = vae.to(device) + text_encoder = text_encoder.to(device) + + if enhance_prompt: + prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( + prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True + ) + prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( + prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True + ) + prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained( + prompt_enhancer_llm_model_name_or_path, + torch_dtype="bfloat16", + ) + prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained( + prompt_enhancer_llm_model_name_or_path, + ) + else: + prompt_enhancer_image_caption_model = None + prompt_enhancer_image_caption_processor = None + prompt_enhancer_llm_model = None + prompt_enhancer_llm_tokenizer = None + + vae = vae.to(torch.bfloat16) + if precision == "bfloat16" and transformer.dtype != torch.bfloat16: + transformer = transformer.to(torch.bfloat16) + text_encoder = text_encoder.to(torch.bfloat16) + + # Use submodels for the pipeline + submodel_dict = { + "transformer": transformer, + "patchifier": patchifier, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "scheduler": scheduler, + "vae": vae, + "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model, + "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor, + "prompt_enhancer_llm_model": prompt_enhancer_llm_model, + "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer, + "allowed_inference_steps": allowed_inference_steps, + } + + pipeline = LTXVideoPipeline(**submodel_dict) + pipeline = pipeline.to(device) + return pipeline + + +def create_latent_upsampler(latent_upsampler_model_path: str, device: str): + latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path) + latent_upsampler.to(device) + latent_upsampler.eval() + return latent_upsampler + + +def infer( + output_path: Optional[str], + seed: int, + pipeline_config: str, + image_cond_noise_scale: float, + height: Optional[int], + width: Optional[int], + num_frames: int, + frame_rate: int, + prompt: str, + negative_prompt: str, + offload_to_cpu: bool, + input_media_path: Optional[str] = None, + conditioning_media_paths: Optional[List[str]] = None, + conditioning_strengths: Optional[List[float]] = None, + conditioning_start_frames: Optional[List[int]] = None, + device: Optional[str] = None, + **kwargs, +): + # check if pipeline_config is a file + if not os.path.isfile(pipeline_config): + raise ValueError(f"Pipeline config file {pipeline_config} does not exist") + with open(pipeline_config, "r") as f: + pipeline_config = yaml.safe_load(f) + + models_dir = "MODEL_DIR" + + ltxv_model_name_or_path = pipeline_config["checkpoint_path"] + if not os.path.isfile(ltxv_model_name_or_path): + ltxv_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=ltxv_model_name_or_path, + local_dir=models_dir, + repo_type="model", + ) + else: + ltxv_model_path = ltxv_model_name_or_path + + spatial_upscaler_model_name_or_path = pipeline_config.get( + "spatial_upscaler_model_path" + ) + if spatial_upscaler_model_name_or_path and not os.path.isfile( + spatial_upscaler_model_name_or_path + ): + spatial_upscaler_model_path = hf_hub_download( + repo_id="Lightricks/LTX-Video", + filename=spatial_upscaler_model_name_or_path, + local_dir=models_dir, + repo_type="model", + ) + else: + spatial_upscaler_model_path = spatial_upscaler_model_name_or_path + + if kwargs.get("input_image_path", None): + logger.warning( + "Please use conditioning_media_paths instead of input_image_path." + ) + assert not conditioning_media_paths and not conditioning_start_frames + conditioning_media_paths = [kwargs["input_image_path"]] + conditioning_start_frames = [0] + + # Validate conditioning arguments + if conditioning_media_paths: + # Use default strengths of 1.0 + if not conditioning_strengths: + conditioning_strengths = [1.0] * len(conditioning_media_paths) + if not conditioning_start_frames: + raise ValueError( + "If `conditioning_media_paths` is provided, " + "`conditioning_start_frames` must also be provided" + ) + if len(conditioning_media_paths) != len(conditioning_strengths) or len( + conditioning_media_paths + ) != len(conditioning_start_frames): + raise ValueError( + "`conditioning_media_paths`, `conditioning_strengths`, " + "and `conditioning_start_frames` must have the same length" + ) + if any(s < 0 or s > 1 for s in conditioning_strengths): + raise ValueError("All conditioning strengths must be between 0 and 1") + if any(f < 0 or f >= num_frames for f in conditioning_start_frames): + raise ValueError( + f"All conditioning start frames must be between 0 and {num_frames-1}" + ) + + seed_everething(seed) + if offload_to_cpu and not torch.cuda.is_available(): + logger.warning( + "offload_to_cpu is set to True, but offloading will not occur since the model is already running on CPU." + ) + offload_to_cpu = False + else: + offload_to_cpu = offload_to_cpu and get_total_gpu_memory() < 30 + + output_dir = ( + Path(output_path) + if output_path + else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}") + ) + output_dir.mkdir(parents=True, exist_ok=True) + + # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1) + height_padded = ((height - 1) // 32 + 1) * 32 + width_padded = ((width - 1) // 32 + 1) * 32 + num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1 + + padding = calculate_padding(height, width, height_padded, width_padded) + + logger.warning( + f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}" + ) + + prompt_enhancement_words_threshold = pipeline_config[ + "prompt_enhancement_words_threshold" + ] + + prompt_word_count = len(prompt.split()) + enhance_prompt = ( + prompt_enhancement_words_threshold > 0 + and prompt_word_count < prompt_enhancement_words_threshold + ) + + if prompt_enhancement_words_threshold > 0 and not enhance_prompt: + logger.info( + f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled." + ) + + precision = pipeline_config["precision"] + text_encoder_model_name_or_path = pipeline_config["text_encoder_model_name_or_path"] + sampler = pipeline_config["sampler"] + prompt_enhancer_image_caption_model_name_or_path = pipeline_config[ + "prompt_enhancer_image_caption_model_name_or_path" + ] + prompt_enhancer_llm_model_name_or_path = pipeline_config[ + "prompt_enhancer_llm_model_name_or_path" + ] + + pipeline = create_ltx_video_pipeline( + ckpt_path=ltxv_model_path, + precision=precision, + text_encoder_model_name_or_path=text_encoder_model_name_or_path, + sampler=sampler, + device=kwargs.get("device", get_device()), + enhance_prompt=enhance_prompt, + prompt_enhancer_image_caption_model_name_or_path=prompt_enhancer_image_caption_model_name_or_path, + prompt_enhancer_llm_model_name_or_path=prompt_enhancer_llm_model_name_or_path, + ) + + if pipeline_config.get("pipeline_type", None) == "multi-scale": + if not spatial_upscaler_model_path: + raise ValueError( + "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" + ) + latent_upsampler = create_latent_upsampler( + spatial_upscaler_model_path, pipeline.device + ) + pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler) + + media_item = None + if input_media_path: + media_item = load_media_file( + media_path=input_media_path, + height=height, + width=width, + max_frames=num_frames_padded, + padding=padding, + ) + + conditioning_items = ( + prepare_conditioning( + conditioning_media_paths=conditioning_media_paths, + conditioning_strengths=conditioning_strengths, + conditioning_start_frames=conditioning_start_frames, + height=height, + width=width, + num_frames=num_frames, + padding=padding, + pipeline=pipeline, + ) + if conditioning_media_paths + else None + ) + + stg_mode = pipeline_config.get("stg_mode", "attention_values") + del pipeline_config["stg_mode"] + if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values": + skip_layer_strategy = SkipLayerStrategy.AttentionValues + elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip": + skip_layer_strategy = SkipLayerStrategy.AttentionSkip + elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual": + skip_layer_strategy = SkipLayerStrategy.Residual + elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block": + skip_layer_strategy = SkipLayerStrategy.TransformerBlock + else: + raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}") + + # Prepare input for the pipeline + sample = { + "prompt": prompt, + "prompt_attention_mask": None, + "negative_prompt": negative_prompt, + "negative_prompt_attention_mask": None, + } + + device = device or get_device() + generator = torch.Generator(device=device).manual_seed(seed) + + images = pipeline( + **pipeline_config, + skip_layer_strategy=skip_layer_strategy, + generator=generator, + output_type="pt", + callback_on_step_end=None, + height=height_padded, + width=width_padded, + num_frames=num_frames_padded, + frame_rate=frame_rate, + **sample, + media_items=media_item, + conditioning_items=conditioning_items, + is_video=True, + vae_per_channel_normalize=True, + image_cond_noise_scale=image_cond_noise_scale, + mixed_precision=(precision == "mixed_precision"), + offload_to_cpu=offload_to_cpu, + device=device, + enhance_prompt=enhance_prompt, + ).images + + # Crop the padded images to the desired resolution and number of frames + (pad_left, pad_right, pad_top, pad_bottom) = padding + pad_bottom = -pad_bottom + pad_right = -pad_right + if pad_bottom == 0: + pad_bottom = images.shape[3] + if pad_right == 0: + pad_right = images.shape[4] + images = images[:, :, :num_frames, pad_top:pad_bottom, pad_left:pad_right] + + for i in range(images.shape[0]): + # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C + video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy() + # Unnormalizing images to [0, 255] range + video_np = (video_np * 255).astype(np.uint8) + fps = frame_rate + height, width = video_np.shape[1:3] + # In case a single image is generated + if video_np.shape[0] == 1: + output_filename = get_unique_filename( + f"image_output_{i}", + ".png", + prompt=prompt, + seed=seed, + resolution=(height, width, num_frames), + dir=output_dir, + ) + imageio.imwrite(output_filename, video_np[0]) + else: + output_filename = get_unique_filename( + f"video_output_{i}", + ".mp4", + prompt=prompt, + seed=seed, + resolution=(height, width, num_frames), + dir=output_dir, + ) + + # Write video + with imageio.get_writer(output_filename, fps=fps) as video: + for frame in video_np: + video.append_data(frame) + + logger.warning(f"Output saved to {output_filename}") + + +def prepare_conditioning( + conditioning_media_paths: List[str], + conditioning_strengths: List[float], + conditioning_start_frames: List[int], + height: int, + width: int, + num_frames: int, + padding: tuple[int, int, int, int], + pipeline: LTXVideoPipeline, +) -> Optional[List[ConditioningItem]]: + """Prepare conditioning items based on input media paths and their parameters. + + Args: + conditioning_media_paths: List of paths to conditioning media (images or videos) + conditioning_strengths: List of conditioning strengths for each media item + conditioning_start_frames: List of frame indices where each item should be applied + height: Height of the output frames + width: Width of the output frames + num_frames: Number of frames in the output video + padding: Padding to apply to the frames + pipeline: LTXVideoPipeline object used for condition video trimming + + Returns: + A list of ConditioningItem objects. + """ + conditioning_items = [] + for path, strength, start_frame in zip( + conditioning_media_paths, conditioning_strengths, conditioning_start_frames + ): + num_input_frames = orig_num_input_frames = get_media_num_frames(path) + if hasattr(pipeline, "trim_conditioning_sequence") and callable( + getattr(pipeline, "trim_conditioning_sequence") + ): + num_input_frames = pipeline.trim_conditioning_sequence( + start_frame, orig_num_input_frames, num_frames + ) + if num_input_frames < orig_num_input_frames: + logger.warning( + f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames." + ) + + media_tensor = load_media_file( + media_path=path, + height=height, + width=width, + max_frames=num_input_frames, + padding=padding, + just_crop=True, + ) + conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength)) + return conditioning_items + + +def get_media_num_frames(media_path: str) -> int: + is_video = any( + media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"] + ) + num_frames = 1 + if is_video: + reader = imageio.get_reader(media_path) + num_frames = reader.count_frames() + reader.close() + return num_frames + + +def load_media_file( + media_path: str, + height: int, + width: int, + max_frames: int, + padding: tuple[int, int, int, int], + just_crop: bool = False, +) -> torch.Tensor: + is_video = any( + media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"] + ) + if is_video: + reader = imageio.get_reader(media_path) + num_input_frames = min(reader.count_frames(), max_frames) + + # Read and preprocess the relevant frames from the video file. + frames = [] + for i in range(num_input_frames): + frame = Image.fromarray(reader.get_data(i)) + frame_tensor = load_image_to_tensor_with_resize_and_crop( + frame, height, width, just_crop=just_crop + ) + frame_tensor = torch.nn.functional.pad(frame_tensor, padding) + frames.append(frame_tensor) + reader.close() + + # Stack frames along the temporal dimension + media_tensor = torch.cat(frames, dim=2) + else: # Input image + media_tensor = load_image_to_tensor_with_resize_and_crop( + media_path, height, width, just_crop=just_crop + ) + media_tensor = torch.nn.functional.pad(media_tensor, padding) + return media_tensor + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/ltx_video/__init__.py b/ltx_video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ltx_video/models/__init__.py b/ltx_video/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ltx_video/models/autoencoders/__init__.py b/ltx_video/models/autoencoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ltx_video/models/autoencoders/causal_conv3d.py b/ltx_video/models/autoencoders/causal_conv3d.py new file mode 100644 index 0000000000000000000000000000000000000000..98249c2f5ffe52eead83b38476e034c4f03bdccd --- /dev/null +++ b/ltx_video/models/autoencoders/causal_conv3d.py @@ -0,0 +1,63 @@ +from typing import Tuple, Union + +import torch +import torch.nn as nn + + +class CausalConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + stride: Union[int, Tuple[int]] = 1, + dilation: int = 1, + groups: int = 1, + spatial_padding_mode: str = "zeros", + **kwargs, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + kernel_size = (kernel_size, kernel_size, kernel_size) + self.time_kernel_size = kernel_size[0] + + dilation = (dilation, 1, 1) + + height_pad = kernel_size[1] // 2 + width_pad = kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + padding_mode=spatial_padding_mode, + groups=groups, + ) + + def forward(self, x, causal: bool = True): + if causal: + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.time_kernel_size - 1, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x), dim=2) + else: + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) + ) + last_frame_pad = x[:, :, -1:, :, :].repeat( + (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) + x = self.conv(x) + return x + + @property + def weight(self): + return self.conv.weight diff --git a/ltx_video/models/autoencoders/causal_video_autoencoder.py b/ltx_video/models/autoencoders/causal_video_autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..736c96a3c65e22a7ada0bb20535e0e15bc47b123 --- /dev/null +++ b/ltx_video/models/autoencoders/causal_video_autoencoder.py @@ -0,0 +1,1398 @@ +import json +import os +from functools import partial +from types import SimpleNamespace +from typing import Any, Mapping, Optional, Tuple, Union, List +from pathlib import Path + +import torch +import numpy as np +from einops import rearrange +from torch import nn +from diffusers.utils import logging +import torch.nn.functional as F +from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings +from safetensors import safe_open + + +from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd +from ltx_video.models.autoencoders.pixel_norm import PixelNorm +from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND +from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper +from ltx_video.models.transformers.attention import Attention +from ltx_video.utils.diffusers_config_mapping import ( + diffusers_and_ours_config_mapping, + make_hashable_key, + VAE_KEYS_RENAME_DICT, +) + +PER_CHANNEL_STATISTICS_PREFIX = "per_channel_statistics." +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CausalVideoAutoencoder(AutoencoderKLWrapper): + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + if ( + pretrained_model_name_or_path.is_dir() + and (pretrained_model_name_or_path / "autoencoder.pth").exists() + ): + config_local_path = pretrained_model_name_or_path / "config.json" + config = cls.load_config(config_local_path, **kwargs) + + model_local_path = pretrained_model_name_or_path / "autoencoder.pth" + state_dict = torch.load(model_local_path, map_location=torch.device("cpu")) + + statistics_local_path = ( + pretrained_model_name_or_path / "per_channel_statistics.json" + ) + if statistics_local_path.exists(): + with open(statistics_local_path, "r") as file: + data = json.load(file) + transposed_data = list(zip(*data["data"])) + data_dict = { + col: torch.tensor(vals) + for col, vals in zip(data["columns"], transposed_data) + } + std_of_means = data_dict["std-of-means"] + mean_of_means = data_dict.get( + "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) + ) + state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = ( + std_of_means + ) + state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = ( + mean_of_means + ) + + elif pretrained_model_name_or_path.is_dir(): + config_path = pretrained_model_name_or_path / "vae" / "config.json" + with open(config_path, "r") as f: + config = make_hashable_key(json.load(f)) + + assert config in diffusers_and_ours_config_mapping, ( + "Provided diffusers checkpoint config for VAE is not suppported. " + "We only support diffusers configs found in Lightricks/LTX-Video." + ) + + config = diffusers_and_ours_config_mapping[config] + + state_dict_path = ( + pretrained_model_name_or_path + / "vae" + / "diffusion_pytorch_model.safetensors" + ) + + state_dict = {} + with safe_open(state_dict_path, framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + for key in list(state_dict.keys()): + new_key = key + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + + state_dict[new_key] = state_dict.pop(key) + + elif pretrained_model_name_or_path.is_file() and str( + pretrained_model_name_or_path + ).endswith(".safetensors"): + state_dict = {} + with safe_open( + pretrained_model_name_or_path, framework="pt", device="cpu" + ) as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + config = configs["vae"] + + video_vae = cls.from_config(config) + if "torch_dtype" in kwargs: + video_vae.to(kwargs["torch_dtype"]) + video_vae.load_state_dict(state_dict) + return video_vae + + @staticmethod + def from_config(config): + assert ( + config["_class_name"] == "CausalVideoAutoencoder" + ), "config must have _class_name=CausalVideoAutoencoder" + if isinstance(config["dims"], list): + config["dims"] = tuple(config["dims"]) + + assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" + + double_z = config.get("double_z", True) + latent_log_var = config.get( + "latent_log_var", "per_channel" if double_z else "none" + ) + use_quant_conv = config.get("use_quant_conv", True) + normalize_latent_channels = config.get("normalize_latent_channels", False) + + if use_quant_conv and latent_log_var in ["uniform", "constant"]: + raise ValueError( + f"latent_log_var={latent_log_var} requires use_quant_conv=False" + ) + + encoder = Encoder( + dims=config["dims"], + in_channels=config.get("in_channels", 3), + out_channels=config["latent_channels"], + blocks=config.get("encoder_blocks", config.get("blocks")), + patch_size=config.get("patch_size", 1), + latent_log_var=latent_log_var, + norm_layer=config.get("norm_layer", "group_norm"), + base_channels=config.get("encoder_base_channels", 128), + spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + ) + + decoder = Decoder( + dims=config["dims"], + in_channels=config["latent_channels"], + out_channels=config.get("out_channels", 3), + blocks=config.get("decoder_blocks", config.get("blocks")), + patch_size=config.get("patch_size", 1), + norm_layer=config.get("norm_layer", "group_norm"), + causal=config.get("causal_decoder", False), + timestep_conditioning=config.get("timestep_conditioning", False), + base_channels=config.get("decoder_base_channels", 128), + spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + ) + + dims = config["dims"] + return CausalVideoAutoencoder( + encoder=encoder, + decoder=decoder, + latent_channels=config["latent_channels"], + dims=dims, + use_quant_conv=use_quant_conv, + normalize_latent_channels=normalize_latent_channels, + ) + + @property + def config(self): + return SimpleNamespace( + _class_name="CausalVideoAutoencoder", + dims=self.dims, + in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2, + out_channels=self.decoder.conv_out.out_channels + // self.decoder.patch_size**2, + latent_channels=self.decoder.conv_in.in_channels, + encoder_blocks=self.encoder.blocks_desc, + decoder_blocks=self.decoder.blocks_desc, + scaling_factor=1.0, + norm_layer=self.encoder.norm_layer, + patch_size=self.encoder.patch_size, + latent_log_var=self.encoder.latent_log_var, + use_quant_conv=self.use_quant_conv, + causal_decoder=self.decoder.causal, + timestep_conditioning=self.decoder.timestep_conditioning, + normalize_latent_channels=self.normalize_latent_channels, + ) + + @property + def is_video_supported(self): + """ + Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. + """ + return self.dims != 2 + + @property + def spatial_downscale_factor(self): + return ( + 2 + ** len( + [ + block + for block in self.encoder.blocks_desc + if block[0] + in [ + "compress_space", + "compress_all", + "compress_all_res", + "compress_space_res", + ] + ] + ) + * self.encoder.patch_size + ) + + @property + def temporal_downscale_factor(self): + return 2 ** len( + [ + block + for block in self.encoder.blocks_desc + if block[0] + in [ + "compress_time", + "compress_all", + "compress_all_res", + "compress_time_res", + ] + ] + ) + + def to_json_string(self) -> str: + import json + + return json.dumps(self.config.__dict__) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + if any([key.startswith("vae.") for key in state_dict.keys()]): + state_dict = { + key.replace("vae.", ""): value + for key, value in state_dict.items() + if key.startswith("vae.") + } + ckpt_state_dict = { + key: value + for key, value in state_dict.items() + if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX) + } + + model_keys = set(name for name, _ in self.named_modules()) + + key_mapping = { + ".resnets.": ".res_blocks.", + "downsamplers.0": "downsample", + "upsamplers.0": "upsample", + } + converted_state_dict = {} + for key, value in ckpt_state_dict.items(): + for k, v in key_mapping.items(): + key = key.replace(k, v) + + key_prefix = ".".join(key.split(".")[:-1]) + if "norm" in key and key_prefix not in model_keys: + logger.info( + f"Removing key {key} from state_dict as it is not present in the model" + ) + continue + + converted_state_dict[key] = value + + super().load_state_dict(converted_state_dict, strict=strict) + + data_dict = { + key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value + for key, value in state_dict.items() + if key.startswith(PER_CHANNEL_STATISTICS_PREFIX) + } + if len(data_dict) > 0: + self.register_buffer("std_of_means", data_dict["std-of-means"]) + self.register_buffer( + "mean_of_means", + data_dict.get( + "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) + ), + ) + + def last_layer(self): + if hasattr(self.decoder, "conv_out"): + if isinstance(self.decoder.conv_out, nn.Sequential): + last_layer = self.decoder.conv_out[-1] + else: + last_layer = self.decoder.conv_out + else: + last_layer = self.decoder.layers[-1] + return last_layer + + def set_use_tpu_flash_attention(self): + for block in self.decoder.up_blocks: + if isinstance(block, UNetMidBlock3D) and block.attention_blocks: + for attention_block in block.attention_blocks: + attention_block.set_use_tpu_flash_attention() + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): + The number of dimensions to use in convolutions. + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): + The blocks to use. Each block is a tuple of the block name and the number of layers. + base_channels (`int`, *optional*, defaults to 128): + The number of output channels for the first convolutional layer. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var (`str`, *optional*, defaults to `per_channel`): + The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]] = 3, + in_channels: int = 3, + out_channels: int = 3, + blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], + base_channels: int = 128, + norm_num_groups: int = 32, + patch_size: Union[int, Tuple[int]] = 1, + norm_layer: str = "group_norm", # group_norm, pixel_norm + latent_log_var: str = "per_channel", + spatial_padding_mode: str = "zeros", + ): + super().__init__() + self.patch_size = patch_size + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + self.blocks_desc = blocks + + in_channels = in_channels * patch_size**2 + output_channel = base_channels + + self.conv_in = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.down_blocks = nn.ModuleList([]) + + for block_name, block_params in blocks: + input_channel = output_channel + if isinstance(block_params, int): + block_params = {"num_layers": block_params} + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + output_channel = block_params.get("multiplier", 2) * output_channel + block = ResnetBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 1, 1), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(1, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_x_y": + output_channel = block_params.get("multiplier", 2) * output_channel + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time_res": + output_channel = block_params.get("multiplier", 2) * output_channel + block = SpaceToDepthDownsample( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown block: {block_name}") + + self.down_blocks.append(block) + + # out + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + elif norm_layer == "layer_norm": + self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) + + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == "per_channel": + conv_out_channels *= 2 + elif latent_log_var == "uniform": + conv_out_channels += 1 + elif latent_log_var == "constant": + conv_out_channels += 1 + elif latent_log_var != "none": + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + self.conv_out = make_conv_nd( + dims, + output_channel, + conv_out_channels, + 3, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + sample = self.conv_in(sample) + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + for down_block in self.down_blocks: + sample = checkpoint_fn(down_block)(sample) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == "uniform": + last_channel = sample[:, -1:, ...] + num_dims = sample.dim() + + if num_dims == 4: + # For shape (B, C, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + elif num_dims == 5: + # For shape (B, C, F, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + else: + raise ValueError(f"Invalid input shape: {sample.shape}") + elif self.latent_log_var == "constant": + sample = sample[:, :-1, ...] + approx_ln_0 = ( + -30 + ) # this is the minimal clamp value in DiagonalGaussianDistribution objects + sample = torch.cat( + [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0], + dim=1, + ) + + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): + The number of dimensions to use in convolutions. + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): + The blocks to use. Each block is a tuple of the block name and the number of layers. + base_channels (`int`, *optional*, defaults to 128): + The number of output channels for the first convolutional layer. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + causal (`bool`, *optional*, defaults to `True`): + Whether to use causal convolutions or not. + """ + + def __init__( + self, + dims, + in_channels: int = 3, + out_channels: int = 3, + blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], + base_channels: int = 128, + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: int = 1, + norm_layer: str = "group_norm", + causal: bool = True, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + self.patch_size = patch_size + self.layers_per_block = layers_per_block + out_channels = out_channels * patch_size**2 + self.causal = causal + self.blocks_desc = blocks + + # Compute output channel to be product of all channel-multiplier blocks + output_channel = base_channels + for block_name, block_params in list(reversed(blocks)): + block_params = block_params if isinstance(block_params, dict) else {} + if block_name == "res_x_y": + output_channel = output_channel * block_params.get("multiplier", 2) + if block_name.startswith("compress"): + output_channel = output_channel * block_params.get("multiplier", 1) + + self.conv_in = make_conv_nd( + dims, + in_channels, + output_channel, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.up_blocks = nn.ModuleList([]) + + for block_name, block_params in list(reversed(blocks)): + input_channel = output_channel + if isinstance(block_params, int): + block_params = {"num_layers": block_params} + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "attn_res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + attention_head_dim=block_params["attention_head_dim"], + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + output_channel = output_channel // block_params.get("multiplier", 2) + block = ResnetBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=False, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + output_channel = output_channel // block_params.get("multiplier", 1) + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(2, 2, 2), + residual=block_params.get("residual", False), + out_channels_reduction_factor=block_params.get("multiplier", 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown layer: {block_name}") + + self.up_blocks.append(block) + + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + elif norm_layer == "layer_norm": + self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims, + output_channel, + out_channels, + 3, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + self.gradient_checkpointing = False + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter( + torch.tensor(1000.0, dtype=torch.float32) + ) + self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + output_channel * 2, 0 + ) + self.last_scale_shift_table = nn.Parameter( + torch.randn(2, output_channel) / output_channel**0.5 + ) + + def forward( + self, + sample: torch.FloatTensor, + target_shape, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + assert target_shape is not None, "target_shape must be provided" + batch_size = sample.shape[0] + + sample = self.conv_in(sample, causal=self.causal) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + sample = sample.to(upscale_dtype) + + if self.timestep_conditioning: + assert ( + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + scaled_timestep = timestep * self.timestep_scale_multiplier + + for up_block in self.up_blocks: + if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): + sample = checkpoint_fn(up_block)( + sample, causal=self.causal, timestep=scaled_timestep + ) + else: + sample = checkpoint_fn(up_block)(sample, causal=self.causal) + + sample = self.conv_norm_out(sample) + + if self.timestep_conditioning: + embedded_timestep = self.last_time_embedder( + timestep=scaled_timestep.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=sample.shape[0], + hidden_dtype=sample.dtype, + ) + embedded_timestep = embedded_timestep.view( + batch_size, embedded_timestep.shape[-1], 1, 1, 1 + ) + ada_values = self.last_scale_shift_table[ + None, ..., None, None, None + ] + embedded_timestep.reshape( + batch_size, + 2, + -1, + embedded_timestep.shape[-3], + embedded_timestep.shape[-2], + embedded_timestep.shape[-1], + ) + shift, scale = ada_values.unbind(dim=1) + sample = sample * (1 + scale) + shift + + sample = self.conv_act(sample) + sample = self.conv_out(sample, causal=self.causal) + + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + + return sample + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + inject_noise (`bool`, *optional*, defaults to `False`): + Whether to inject noise into the hidden states. + timestep_conditioning (`bool`, *optional*, defaults to `False`): + Whether to condition the hidden states on the timestep. + attention_head_dim (`int`, *optional*, defaults to -1): + The dimension of the attention head. If -1, no attention is used. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: str = "group_norm", + inject_noise: bool = False, + timestep_conditioning: bool = False, + attention_head_dim: int = -1, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + in_channels * 4, 0 + ) + + self.res_blocks = nn.ModuleList( + [ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + for _ in range(num_layers) + ] + ) + + self.attention_blocks = None + + if attention_head_dim > 0: + if attention_head_dim > in_channels: + raise ValueError( + "attention_head_dim must be less than or equal to in_channels" + ) + + self.attention_blocks = nn.ModuleList( + [ + Attention( + query_dim=in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + bias=True, + out_bias=True, + qk_norm="rms_norm", + residual_connection=True, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + timestep_embed = None + if self.timestep_conditioning: + assert ( + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + batch_size = hidden_states.shape[0] + timestep_embed = self.time_embedder( + timestep=timestep.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + timestep_embed = timestep_embed.view( + batch_size, timestep_embed.shape[-1], 1, 1, 1 + ) + + if self.attention_blocks: + for resnet, attention in zip(self.res_blocks, self.attention_blocks): + hidden_states = resnet( + hidden_states, causal=causal, timestep=timestep_embed + ) + + # Reshape the hidden states to be (batch_size, frames * height * width, channel) + batch_size, channel, frames, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, frames * height * width + ).transpose(1, 2) + + if attention.use_tpu_flash_attention: + # Pad the second dimension to be divisible by block_k_major (block in flash attention) + seq_len = hidden_states.shape[1] + block_k_major = 512 + pad_len = (block_k_major - seq_len % block_k_major) % block_k_major + if pad_len > 0: + hidden_states = F.pad( + hidden_states, (0, 0, 0, pad_len), "constant", 0 + ) + + # Create a mask with ones for the original sequence length and zeros for the padded indexes + mask = torch.ones( + (hidden_states.shape[0], seq_len), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + if pad_len > 0: + mask = F.pad(mask, (0, pad_len), "constant", 0) + + hidden_states = attention( + hidden_states, + attention_mask=( + None if not attention.use_tpu_flash_attention else mask + ), + ) + + if attention.use_tpu_flash_attention: + # Remove the padding + if pad_len > 0: + hidden_states = hidden_states[:, :-pad_len, :] + + # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel) + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, frames, height, width + ) + else: + for resnet in self.res_blocks: + hidden_states = resnet( + hidden_states, causal=causal, timestep=timestep_embed + ) + + return hidden_states + + +class SpaceToDepthDownsample(nn.Module): + def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode): + super().__init__() + self.stride = stride + self.group_size = in_channels * np.prod(stride) // out_channels + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels // np.prod(stride), + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward(self, x, causal: bool = True): + if self.stride[0] == 2: + x = torch.cat( + [x[:, :, :1, :, :], x], dim=2 + ) # duplicate first frames for padding + + # skip connection + x_in = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size) + x_in = x_in.mean(dim=2) + + # conv + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + x = x + x_in + + return x + + +class DepthToSpaceUpsample(nn.Module): + def __init__( + self, + dims, + in_channels, + stride, + residual=False, + out_channels_reduction_factor=1, + spatial_padding_mode="zeros", + ): + super().__init__() + self.stride = stride + self.out_channels = ( + np.prod(stride) * in_channels // out_channels_reduction_factor + ) + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + self.pixel_shuffle = PixelShuffleND(dims=dims, upscale_factors=stride) + self.residual = residual + self.out_channels_reduction_factor = out_channels_reduction_factor + + def forward(self, x, causal: bool = True): + if self.residual: + # Reshape and duplicate the input to match the output shape + x_in = self.pixel_shuffle(x) + num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor + x_in = x_in.repeat(1, num_repeat, 1, 1, 1) + if self.stride[0] == 2: + x_in = x_in[:, :, 1:, :, :] + x = self.conv(x, causal=causal) + x = self.pixel_shuffle(x) + if self.stride[0] == 2: + x = x[:, :, 1:, :, :] + if self.residual: + x = x + x_in + return x + + +class LayerNorm(nn.Module): + def __init__(self, dim, eps, elementwise_affine=True) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward(self, x): + x = rearrange(x, "b c d h w -> b d h w c") + x = self.norm(x) + x = rearrange(x, "b d h w c -> b c d h w") + return x + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: str = "group_norm", + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.inject_noise = inject_noise + + if norm_layer == "group_norm": + self.norm1 = nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm1 = PixelNorm() + elif norm_layer == "layer_norm": + self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True) + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + if norm_layer == "group_norm": + self.norm2 = nn.GroupNorm( + num_groups=groups, num_channels=out_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm2 = PixelNorm() + elif norm_layer == "layer_norm": + self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True) + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + self.conv_shortcut = ( + make_linear_nd( + dims=dims, in_channels=in_channels, out_channels=out_channels + ) + if in_channels != out_channels + else nn.Identity() + ) + + self.norm3 = ( + LayerNorm(in_channels, eps=eps, elementwise_affine=True) + if in_channels != out_channels + else nn.Identity() + ) + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.scale_shift_table = nn.Parameter( + torch.randn(4, in_channels) / in_channels**0.5 + ) + + def _feed_spatial_noise( + self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor + ) -> torch.FloatTensor: + spatial_shape = hidden_states.shape[-2:] + device = hidden_states.device + dtype = hidden_states.dtype + + # similar to the "explicit noise inputs" method in style-gan + spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None] + scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] + hidden_states = hidden_states + scaled_noise + + return hidden_states + + def forward( + self, + input_tensor: torch.FloatTensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + hidden_states = input_tensor + batch_size = hidden_states.shape[0] + + hidden_states = self.norm1(hidden_states) + if self.timestep_conditioning: + assert ( + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + ada_values = self.scale_shift_table[ + None, ..., None, None, None + ] + timestep.reshape( + batch_size, + 4, + -1, + timestep.shape[-3], + timestep.shape[-2], + timestep.shape[-1], + ) + shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) + + hidden_states = hidden_states * (1 + scale1) + shift1 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, self.per_channel_scale1 + ) + + hidden_states = self.norm2(hidden_states) + + if self.timestep_conditioning: + hidden_states = hidden_states * (1 + scale2) + shift2 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = self.conv2(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, self.per_channel_scale2 + ) + + input_tensor = self.norm3(input_tensor) + + batch_size = input_tensor.shape[0] + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +def patchify(x, patch_size_hw, patch_size_t=1): + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size_hw, patch_size_t=1): + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + + +def create_video_autoencoder_demo_config( + latent_channels: int = 64, +): + encoder_blocks = [ + ("res_x", {"num_layers": 2}), + ("compress_space_res", {"multiplier": 2}), + ("compress_time_res", {"multiplier": 2}), + ("compress_all_res", {"multiplier": 2}), + ("compress_all_res", {"multiplier": 2}), + ("res_x", {"num_layers": 1}), + ] + decoder_blocks = [ + ("res_x", {"num_layers": 2, "inject_noise": False}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("compress_all", {"residual": True, "multiplier": 2}), + ("res_x", {"num_layers": 2, "inject_noise": False}), + ] + return { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "encoder_blocks": encoder_blocks, + "decoder_blocks": decoder_blocks, + "latent_channels": latent_channels, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, + "timestep_conditioning": True, + "spatial_padding_mode": "replicate", + } + + +def test_vae_patchify_unpatchify(): + import torch + + x = torch.randn(2, 3, 8, 64, 64) + x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) + x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) + assert torch.allclose(x, x_unpatched) + + +def demo_video_autoencoder_forward_backward(): + # Configuration for the VideoAutoencoder + config = create_video_autoencoder_demo_config() + + # Instantiate the VideoAutoencoder with the specified configuration + video_autoencoder = CausalVideoAutoencoder.from_config(config) + + print(video_autoencoder) + video_autoencoder.eval() + # Print the total number of parameters in the video autoencoder + total_params = sum(p.numel() for p in video_autoencoder.parameters()) + print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") + + # Create a mock input tensor simulating a batch of videos + # Shape: (batch_size, channels, depth, height, width) + # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame + input_videos = torch.randn(2, 3, 17, 64, 64) + + # Forward pass: encode and decode the input videos + latent = video_autoencoder.encode(input_videos).latent_dist.mode() + print(f"input shape={input_videos.shape}") + print(f"latent shape={latent.shape}") + + timestep = torch.ones(input_videos.shape[0]) * 0.1 + reconstructed_videos = video_autoencoder.decode( + latent, target_shape=input_videos.shape, timestep=timestep + ).sample + + print(f"reconstructed shape={reconstructed_videos.shape}") + + # Validate that single image gets treated the same way as first frame + input_image = input_videos[:, :, :1, :, :] + image_latent = video_autoencoder.encode(input_image).latent_dist.mode() + _ = video_autoencoder.decode( + image_latent, target_shape=image_latent.shape, timestep=timestep + ).sample + + first_frame_latent = latent[:, :, :1, :, :] + + assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) + # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6) + # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) + # assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all() + + # Calculate the loss (e.g., mean squared error) + loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) + + # Perform backward pass + loss.backward() + + print(f"Demo completed with loss: {loss.item()}") + + +# Ensure to call the demo function to execute the forward and backward pass +if __name__ == "__main__": + demo_video_autoencoder_forward_backward() diff --git a/ltx_video/models/autoencoders/conv_nd_factory.py b/ltx_video/models/autoencoders/conv_nd_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..718c69befd959c7466c4a57d71e46bb80bfe9fba --- /dev/null +++ b/ltx_video/models/autoencoders/conv_nd_factory.py @@ -0,0 +1,90 @@ +from typing import Tuple, Union + +import torch + +from ltx_video.models.autoencoders.dual_conv3d import DualConv3d +from ltx_video.models.autoencoders.causal_conv3d import CausalConv3d + + +def make_conv_nd( + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + kernel_size: int, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + causal=False, + spatial_padding_mode="zeros", + temporal_padding_mode="zeros", +): + if not (spatial_padding_mode == temporal_padding_mode or causal): + raise NotImplementedError("spatial and temporal padding modes must be equal") + if dims == 2: + return torch.nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode, + ) + elif dims == 3: + if causal: + return CausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + spatial_padding_mode=spatial_padding_mode, + ) + return torch.nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode, + ) + elif dims == (2, 1): + return DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def make_linear_nd( + dims: int, + in_channels: int, + out_channels: int, + bias=True, +): + if dims == 2: + return torch.nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias + ) + elif dims == 3 or dims == (2, 1): + return torch.nn.Conv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") diff --git a/ltx_video/models/autoencoders/dual_conv3d.py b/ltx_video/models/autoencoders/dual_conv3d.py new file mode 100644 index 0000000000000000000000000000000000000000..dcf889296750d3d7e553af37ecf77d1b10245af3 --- /dev/null +++ b/ltx_video/models/autoencoders/dual_conv3d.py @@ -0,0 +1,217 @@ +import math +from typing import Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class DualConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups=1, + bias=True, + padding_mode="zeros", + ): + super(DualConv3d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.padding_mode = padding_mode + # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if kernel_size == (1, 1, 1): + raise ValueError( + "kernel_size must be greater than 1. Use make_linear_nd instead." + ) + if isinstance(stride, int): + stride = (stride, stride, stride) + if isinstance(padding, int): + padding = (padding, padding, padding) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + + # Set parameters for convolutions + self.groups = groups + self.bias = bias + + # Define the size of the channels after the first convolution + intermediate_channels = ( + out_channels if in_channels < out_channels else in_channels + ) + + # Define parameters for the first convolution + self.weight1 = nn.Parameter( + torch.Tensor( + intermediate_channels, + in_channels // groups, + 1, + kernel_size[1], + kernel_size[2], + ) + ) + self.stride1 = (1, stride[1], stride[2]) + self.padding1 = (0, padding[1], padding[2]) + self.dilation1 = (1, dilation[1], dilation[2]) + if bias: + self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) + else: + self.register_parameter("bias1", None) + + # Define parameters for the second convolution + self.weight2 = nn.Parameter( + torch.Tensor( + out_channels, intermediate_channels // groups, kernel_size[0], 1, 1 + ) + ) + self.stride2 = (stride[0], 1, 1) + self.padding2 = (padding[0], 0, 0) + self.dilation2 = (dilation[0], 1, 1) + if bias: + self.bias2 = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias2", None) + + # Initialize weights and biases + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5)) + if self.bias: + fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) + bound1 = 1 / math.sqrt(fan_in1) + nn.init.uniform_(self.bias1, -bound1, bound1) + fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) + bound2 = 1 / math.sqrt(fan_in2) + nn.init.uniform_(self.bias2, -bound2, bound2) + + def forward(self, x, use_conv3d=False, skip_time_conv=False): + if use_conv3d: + return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) + else: + return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) + + def forward_with_3d(self, x, skip_time_conv): + # First convolution + x = F.conv3d( + x, + self.weight1, + self.bias1, + self.stride1, + self.padding1, + self.dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + if skip_time_conv: + return x + + # Second convolution + x = F.conv3d( + x, + self.weight2, + self.bias2, + self.stride2, + self.padding2, + self.dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + + return x + + def forward_with_2d(self, x, skip_time_conv): + b, c, d, h, w = x.shape + + # First 2D convolution + x = rearrange(x, "b c d h w -> (b d) c h w") + # Squeeze the depth dimension out of weight1 since it's 1 + weight1 = self.weight1.squeeze(2) + # Select stride, padding, and dilation for the 2D convolution + stride1 = (self.stride1[1], self.stride1[2]) + padding1 = (self.padding1[1], self.padding1[2]) + dilation1 = (self.dilation1[1], self.dilation1[2]) + x = F.conv2d( + x, + weight1, + self.bias1, + stride1, + padding1, + dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + _, _, h, w = x.shape + + if skip_time_conv: + x = rearrange(x, "(b d) c h w -> b c d h w", b=b) + return x + + # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) + + # Reshape weight2 to match the expected dimensions for conv1d + weight2 = self.weight2.squeeze(-1).squeeze(-1) + # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution + stride2 = self.stride2[0] + padding2 = self.padding2[0] + dilation2 = self.dilation2[0] + x = F.conv1d( + x, + weight2, + self.bias2, + stride2, + padding2, + dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) + + return x + + @property + def weight(self): + return self.weight2 + + +def test_dual_conv3d_consistency(): + # Initialize parameters + in_channels = 3 + out_channels = 5 + kernel_size = (3, 3, 3) + stride = (2, 2, 2) + padding = (1, 1, 1) + + # Create an instance of the DualConv3d class + dual_conv3d = DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=True, + ) + + # Example input tensor + test_input = torch.randn(1, 3, 10, 10, 10) + + # Perform forward passes with both 3D and 2D settings + output_conv3d = dual_conv3d(test_input, use_conv3d=True) + output_2d = dual_conv3d(test_input, use_conv3d=False) + + # Assert that the outputs from both methods are sufficiently close + assert torch.allclose( + output_conv3d, output_2d, atol=1e-6 + ), "Outputs are not consistent between 3D and 2D convolutions." diff --git a/ltx_video/models/autoencoders/latent_upsampler.py b/ltx_video/models/autoencoders/latent_upsampler.py new file mode 100644 index 0000000000000000000000000000000000000000..4a76bc21d1a503d61dec673cf5cb980bb6d703fd --- /dev/null +++ b/ltx_video/models/autoencoders/latent_upsampler.py @@ -0,0 +1,203 @@ +from typing import Optional, Union +from pathlib import Path +import os +import json + +import torch +import torch.nn as nn +from einops import rearrange +from diffusers import ConfigMixin, ModelMixin +from safetensors.torch import safe_open + +from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND + + +class ResBlock(nn.Module): + def __init__( + self, channels: int, mid_channels: Optional[int] = None, dims: int = 3 + ): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = nn.GroupNorm(32, channels) + self.activation = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.activation(x + residual) + return x + + +class LatentUpsampler(ModelMixin, ConfigMixin): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`): Number of channels in the input latent + mid_channels (`int`): Number of channels in the middle layers + num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`): Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`): Whether to spatially upsample the latent + temporal_upsample (`bool`): Whether to temporally upsample the latent + """ + + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 512, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = nn.GroupNorm(32, mid_channels) + self.initial_activation = nn.SiLU() + + self.res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + if spatial_upsample and temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + self.upsampler = nn.Sequential( + nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError( + "Either spatial_upsample or temporal_upsample must be True" + ) + + self.post_upsample_res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + b, c, f, h, w = latent.shape + + if self.dims == 2: + x = rearrange(latent, "b c f h w -> (b f) c h w") + x = self.initial_conv(x) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + x = self.upsampler(x) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + else: + x = self.initial_conv(latent) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + if self.temporal_upsample: + x = self.upsampler(x) + x = x[:, :, 1:, :, :] + else: + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.upsampler(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + + return x + + @classmethod + def from_config(cls, config): + return cls( + in_channels=config.get("in_channels", 4), + mid_channels=config.get("mid_channels", 128), + num_blocks_per_stage=config.get("num_blocks_per_stage", 4), + dims=config.get("dims", 2), + spatial_upsample=config.get("spatial_upsample", True), + temporal_upsample=config.get("temporal_upsample", False), + ) + + def config(self): + return { + "_class_name": "LatentUpsampler", + "in_channels": self.in_channels, + "mid_channels": self.mid_channels, + "num_blocks_per_stage": self.num_blocks_per_stage, + "dims": self.dims, + "spatial_upsample": self.spatial_upsample, + "temporal_upsample": self.temporal_upsample, + } + + @classmethod + def from_pretrained( + cls, + pretrained_model_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_path = Path(pretrained_model_path) + if pretrained_model_path.is_file() and str(pretrained_model_path).endswith( + ".safetensors" + ): + state_dict = {} + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + config = json.loads(metadata["config"]) + with torch.device("meta"): + latent_upsampler = LatentUpsampler.from_config(config) + latent_upsampler.load_state_dict(state_dict, assign=True) + return latent_upsampler + + +if __name__ == "__main__": + latent_upsampler = LatentUpsampler(num_blocks_per_stage=4, dims=3) + print(latent_upsampler) + total_params = sum(p.numel() for p in latent_upsampler.parameters()) + print(f"Total number of parameters: {total_params:,}") + latent = torch.randn(1, 128, 9, 16, 16) + upsampled_latent = latent_upsampler(latent) + print(f"Upsampled latent shape: {upsampled_latent.shape}") diff --git a/ltx_video/models/autoencoders/pixel_norm.py b/ltx_video/models/autoencoders/pixel_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9bc3ea60e8a6453e7e12a7fb5aca4de3958a2567 --- /dev/null +++ b/ltx_video/models/autoencoders/pixel_norm.py @@ -0,0 +1,12 @@ +import torch +from torch import nn + + +class PixelNorm(nn.Module): + def __init__(self, dim=1, eps=1e-8): + super(PixelNorm, self).__init__() + self.dim = dim + self.eps = eps + + def forward(self, x): + return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps) diff --git a/ltx_video/models/autoencoders/pixel_shuffle.py b/ltx_video/models/autoencoders/pixel_shuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..4e79ae28483d5ad684ea68092bc955ef025722e6 --- /dev/null +++ b/ltx_video/models/autoencoders/pixel_shuffle.py @@ -0,0 +1,33 @@ +import torch.nn as nn +from einops import rearrange + + +class PixelShuffleND(nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + assert dims in [1, 2, 3], "dims must be 1, 2, or 3" + self.dims = dims + self.upscale_factors = upscale_factors + + def forward(self, x): + if self.dims == 3: + return rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + p3=self.upscale_factors[2], + ) + elif self.dims == 2: + return rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + ) + elif self.dims == 1: + return rearrange( + x, + "b (c p1) f h w -> b c (f p1) h w", + p1=self.upscale_factors[0], + ) diff --git a/ltx_video/models/autoencoders/vae.py b/ltx_video/models/autoencoders/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..5b22217c158eb26bca45b2b6a5e475e8a71b8181 --- /dev/null +++ b/ltx_video/models/autoencoders/vae.py @@ -0,0 +1,380 @@ +from typing import Optional, Union + +import torch +import inspect +import math +import torch.nn as nn +from diffusers import ConfigMixin, ModelMixin +from diffusers.models.autoencoders.vae import ( + DecoderOutput, + DiagonalGaussianDistribution, +) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd + + +class AutoencoderKLWrapper(ModelMixin, ConfigMixin): + """Variational Autoencoder (VAE) model with KL loss. + + VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling. + This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss. + + Args: + encoder (`nn.Module`): + Encoder module. + decoder (`nn.Module`): + Decoder module. + latent_channels (`int`, *optional*, defaults to 4): + Number of latent channels. + """ + + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + latent_channels: int = 4, + dims: int = 2, + sample_size=512, + use_quant_conv: bool = True, + normalize_latent_channels: bool = False, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = encoder + self.use_quant_conv = use_quant_conv + self.normalize_latent_channels = normalize_latent_channels + + # pass init params to Decoder + quant_dims = 2 if dims == 2 else 3 + self.decoder = decoder + if use_quant_conv: + self.quant_conv = make_conv_nd( + quant_dims, 2 * latent_channels, 2 * latent_channels, 1 + ) + self.post_quant_conv = make_conv_nd( + quant_dims, latent_channels, latent_channels, 1 + ) + else: + self.quant_conv = nn.Identity() + self.post_quant_conv = nn.Identity() + + if normalize_latent_channels: + if dims == 2: + self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False) + else: + self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False) + else: + self.latent_norm_out = nn.Identity() + self.use_z_tiling = False + self.use_hw_tiling = False + self.dims = dims + self.z_sample_size = 1 + + self.decoder_params = inspect.signature(self.decoder.forward).parameters + + # only relevant if vae tiling is enabled + self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25) + + def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25): + self.tile_sample_min_size = sample_size + num_blocks = len(self.encoder.down_blocks) + self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1))) + self.tile_overlap_factor = overlap_factor + + def enable_z_tiling(self, z_sample_size: int = 8): + r""" + Enable tiling during VAE decoding. + + When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_z_tiling = z_sample_size > 1 + self.z_sample_size = z_sample_size + assert ( + z_sample_size % 8 == 0 or z_sample_size == 1 + ), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}." + + def disable_z_tiling(self): + r""" + Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing + decoding in one step. + """ + self.use_z_tiling = False + + def enable_hw_tiling(self): + r""" + Enable tiling during VAE decoding along the height and width dimension. + """ + self.use_hw_tiling = True + + def disable_hw_tiling(self): + r""" + Disable tiling during VAE decoding along the height and width dimension. + """ + self.use_hw_tiling = False + + def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True): + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[3], overlap_size): + row = [] + for j in range(0, x.shape[4], overlap_size): + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_size, + j : j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + moments = torch.cat(result_rows, dim=3) + return moments + + def blend_z( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for z in range(blend_extent): + b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * ( + 1 - z / blend_extent + ) + b[:, :, z, :, :] * (z / blend_extent) + return b + + def blend_v( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * ( + 1 - y / blend_extent + ) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * ( + 1 - x / blend_extent + ) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape): + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + tile_target_shape = ( + *target_shape[:3], + self.tile_sample_min_size, + self.tile_sample_min_size, + ) + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[3], overlap_size): + row = [] + for j in range(0, z.shape[4], overlap_size): + tile = z[ + :, + :, + :, + i : i + self.tile_latent_min_size, + j : j + self.tile_latent_min_size, + ] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, target_shape=tile_target_shape) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + return dec + + def encode( + self, z: torch.FloatTensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: + num_splits = z.shape[2] // self.z_sample_size + sizes = [self.z_sample_size] * num_splits + sizes = ( + sizes + [z.shape[2] - sum(sizes)] + if z.shape[2] - sum(sizes) > 0 + else sizes + ) + tiles = z.split(sizes, dim=2) + moments_tiles = [ + ( + self._hw_tiled_encode(z_tile, return_dict) + if self.use_hw_tiling + else self._encode(z_tile) + ) + for z_tile in tiles + ] + moments = torch.cat(moments_tiles, dim=2) + + else: + moments = ( + self._hw_tiled_encode(z, return_dict) + if self.use_hw_tiling + else self._encode(z) + ) + + posterior = DiagonalGaussianDistribution(moments) + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: + if isinstance(self.latent_norm_out, nn.BatchNorm3d): + _, c, _, _, _ = z.shape + z = torch.cat( + [ + self.latent_norm_out(z[:, : c // 2, :, :, :]), + z[:, c // 2 :, :, :, :], + ], + dim=1, + ) + elif isinstance(self.latent_norm_out, nn.BatchNorm2d): + raise NotImplementedError("BatchNorm2d not supported") + return z + + def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: + if isinstance(self.latent_norm_out, nn.BatchNorm3d): + running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1) + running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1) + eps = self.latent_norm_out.eps + + z = z * torch.sqrt(running_var + eps) + running_mean + elif isinstance(self.latent_norm_out, nn.BatchNorm3d): + raise NotImplementedError("BatchNorm2d not supported") + return z + + def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput: + h = self.encoder(x) + moments = self.quant_conv(h) + moments = self._normalize_latent_channels(moments) + return moments + + def _decode( + self, + z: torch.FloatTensor, + target_shape=None, + timestep: Optional[torch.Tensor] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + z = self._unnormalize_latent_channels(z) + z = self.post_quant_conv(z) + if "timestep" in self.decoder_params: + dec = self.decoder(z, target_shape=target_shape, timestep=timestep) + else: + dec = self.decoder(z, target_shape=target_shape) + return dec + + def decode( + self, + z: torch.FloatTensor, + return_dict: bool = True, + target_shape=None, + timestep: Optional[torch.Tensor] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + assert target_shape is not None, "target_shape must be provided for decoding" + if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: + reduction_factor = int( + self.encoder.patch_size_t + * 2 + ** ( + len(self.encoder.down_blocks) + - 1 + - math.sqrt(self.encoder.patch_size) + ) + ) + split_size = self.z_sample_size // reduction_factor + num_splits = z.shape[2] // split_size + + # copy target shape, and divide frame dimension (=2) by the context size + target_shape_split = list(target_shape) + target_shape_split[2] = target_shape[2] // num_splits + + decoded_tiles = [ + ( + self._hw_tiled_decode(z_tile, target_shape_split) + if self.use_hw_tiling + else self._decode(z_tile, target_shape=target_shape_split) + ) + for z_tile in torch.tensor_split(z, num_splits, dim=2) + ] + decoded = torch.cat(decoded_tiles, dim=2) + else: + decoded = ( + self._hw_tiled_decode(z, target_shape) + if self.use_hw_tiling + else self._decode(z, target_shape=target_shape, timestep=timestep) + ) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`DecoderOutput`] instead of a plain tuple. + generator (`torch.Generator`, *optional*): + Generator used to sample from the posterior. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, target_shape=sample.shape).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/ltx_video/models/autoencoders/vae_encode.py b/ltx_video/models/autoencoders/vae_encode.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc97f6720ecbef51711cb47cd759532d8813128 --- /dev/null +++ b/ltx_video/models/autoencoders/vae_encode.py @@ -0,0 +1,247 @@ +from typing import Tuple +import torch +from diffusers import AutoencoderKL +from einops import rearrange +from torch import Tensor + + +from ltx_video.models.autoencoders.causal_video_autoencoder import ( + CausalVideoAutoencoder, +) +from ltx_video.models.autoencoders.video_autoencoder import ( + Downsample3D, + VideoAutoencoder, +) + +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + + +def vae_encode( + media_items: Tensor, + vae: AutoencoderKL, + split_size: int = 1, + vae_per_channel_normalize=False, +) -> Tensor: + """ + Encodes media items (images or videos) into latent representations using a specified VAE model. + The function supports processing batches of images or video frames and can handle the processing + in smaller sub-batches if needed. + + Args: + media_items (Tensor): A torch Tensor containing the media items to encode. The expected + shape is (batch_size, channels, height, width) for images or (batch_size, channels, + frames, height, width) for videos. + vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library, + pre-configured and loaded with the appropriate model weights. + split_size (int, optional): The number of sub-batches to split the input batch into for encoding. + If set to more than 1, the input media items are processed in smaller batches according to + this value. Defaults to 1, which processes all items in a single batch. + + Returns: + Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted + to match the input shape, scaled by the model's configuration. + + Examples: + >>> import torch + >>> from diffusers import AutoencoderKL + >>> vae = AutoencoderKL.from_pretrained('your-model-name') + >>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames. + >>> latents = vae_encode(images, vae) + >>> print(latents.shape) # Output shape will depend on the model's latent configuration. + + Note: + In case of a video, the function encodes the media item frame-by frame. + """ + is_video_shaped = media_items.dim() == 5 + batch_size, channels = media_items.shape[0:2] + + if channels != 3: + raise ValueError(f"Expects tensors with 3 channels, got {channels}.") + + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + media_items = rearrange(media_items, "b c n h w -> (b n) c h w") + if split_size > 1: + if len(media_items) % split_size != 0: + raise ValueError( + "Error: The batch size must be divisible by 'train.vae_bs_split" + ) + encode_bs = len(media_items) // split_size + # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)] + latents = [] + if media_items.device.type == "xla": + xm.mark_step() + for image_batch in media_items.split(encode_bs): + latents.append(vae.encode(image_batch).latent_dist.sample()) + if media_items.device.type == "xla": + xm.mark_step() + latents = torch.cat(latents, dim=0) + else: + latents = vae.encode(media_items).latent_dist.sample() + + latents = normalize_latents(latents, vae, vae_per_channel_normalize) + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size) + return latents + + +def vae_decode( + latents: Tensor, + vae: AutoencoderKL, + is_video: bool = True, + split_size: int = 1, + vae_per_channel_normalize=False, + timestep=None, +) -> Tensor: + is_video_shaped = latents.dim() == 5 + batch_size = latents.shape[0] + + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + latents = rearrange(latents, "b c n h w -> (b n) c h w") + if split_size > 1: + if len(latents) % split_size != 0: + raise ValueError( + "Error: The batch size must be divisible by 'train.vae_bs_split" + ) + encode_bs = len(latents) // split_size + image_batch = [ + _run_decoder( + latent_batch, vae, is_video, vae_per_channel_normalize, timestep + ) + for latent_batch in latents.split(encode_bs) + ] + images = torch.cat(image_batch, dim=0) + else: + images = _run_decoder( + latents, vae, is_video, vae_per_channel_normalize, timestep + ) + + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size) + return images + + +def _run_decoder( + latents: Tensor, + vae: AutoencoderKL, + is_video: bool, + vae_per_channel_normalize=False, + timestep=None, +) -> Tensor: + if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): + *_, fl, hl, wl = latents.shape + temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae) + latents = latents.to(vae.dtype) + vae_decode_kwargs = {} + if timestep is not None: + vae_decode_kwargs["timestep"] = timestep + image = vae.decode( + un_normalize_latents(latents, vae, vae_per_channel_normalize), + return_dict=False, + target_shape=( + 1, + 3, + fl * temporal_scale if is_video else 1, + hl * spatial_scale, + wl * spatial_scale, + ), + **vae_decode_kwargs, + )[0] + else: + image = vae.decode( + un_normalize_latents(latents, vae, vae_per_channel_normalize), + return_dict=False, + )[0] + return image + + +def get_vae_size_scale_factor(vae: AutoencoderKL) -> float: + if isinstance(vae, CausalVideoAutoencoder): + spatial = vae.spatial_downscale_factor + temporal = vae.temporal_downscale_factor + else: + down_blocks = len( + [ + block + for block in vae.encoder.down_blocks + if isinstance(block.downsample, Downsample3D) + ] + ) + spatial = vae.config.patch_size * 2**down_blocks + temporal = ( + vae.config.patch_size_t * 2**down_blocks + if isinstance(vae, VideoAutoencoder) + else 1 + ) + + return (temporal, spatial, spatial) + + +def latent_to_pixel_coords( + latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False +) -> Tensor: + """ + Converts latent coordinates to pixel coordinates by scaling them according to the VAE's + configuration. + + Args: + latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents] + containing the latent corner coordinates of each token. + vae (AutoencoderKL): The VAE model + causal_fix (bool): Whether to take into account the different temporal scale + of the first frame. Default = False for backwards compatibility. + Returns: + Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. + """ + + scale_factors = get_vae_size_scale_factor(vae) + causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix + pixel_coords = latent_to_pixel_coords_from_factors( + latent_coords, scale_factors, causal_fix + ) + return pixel_coords + + +def latent_to_pixel_coords_from_factors( + latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False +) -> Tensor: + pixel_coords = ( + latent_coords + * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] + ) + if causal_fix: + # Fix temporal scale for first frame to 1 due to causality + pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) + return pixel_coords + + +def normalize_latents( + latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False +) -> Tensor: + return ( + (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)) + / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + if vae_per_channel_normalize + else latents * vae.config.scaling_factor + ) + + +def un_normalize_latents( + latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False +) -> Tensor: + return ( + latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + if vae_per_channel_normalize + else latents / vae.config.scaling_factor + ) diff --git a/ltx_video/models/autoencoders/video_autoencoder.py b/ltx_video/models/autoencoders/video_autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3c7926c1d3afb8188221b2e569aaaf89f7271bce --- /dev/null +++ b/ltx_video/models/autoencoders/video_autoencoder.py @@ -0,0 +1,1045 @@ +import json +import os +from functools import partial +from types import SimpleNamespace +from typing import Any, Mapping, Optional, Tuple, Union + +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional + +from diffusers.utils import logging + +from ltx_video.utils.torch_utils import Identity +from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd +from ltx_video.models.autoencoders.pixel_norm import PixelNorm +from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper + +logger = logging.get_logger(__name__) + + +class VideoAutoencoder(AutoencoderKLWrapper): + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + config_local_path = pretrained_model_name_or_path / "config.json" + config = cls.load_config(config_local_path, **kwargs) + video_vae = cls.from_config(config) + video_vae.to(kwargs["torch_dtype"]) + + model_local_path = pretrained_model_name_or_path / "autoencoder.pth" + ckpt_state_dict = torch.load(model_local_path) + video_vae.load_state_dict(ckpt_state_dict) + + statistics_local_path = ( + pretrained_model_name_or_path / "per_channel_statistics.json" + ) + if statistics_local_path.exists(): + with open(statistics_local_path, "r") as file: + data = json.load(file) + transposed_data = list(zip(*data["data"])) + data_dict = { + col: torch.tensor(vals) + for col, vals in zip(data["columns"], transposed_data) + } + video_vae.register_buffer("std_of_means", data_dict["std-of-means"]) + video_vae.register_buffer( + "mean_of_means", + data_dict.get( + "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) + ), + ) + + return video_vae + + @staticmethod + def from_config(config): + assert ( + config["_class_name"] == "VideoAutoencoder" + ), "config must have _class_name=VideoAutoencoder" + if isinstance(config["dims"], list): + config["dims"] = tuple(config["dims"]) + + assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" + + double_z = config.get("double_z", True) + latent_log_var = config.get( + "latent_log_var", "per_channel" if double_z else "none" + ) + use_quant_conv = config.get("use_quant_conv", True) + + if use_quant_conv and latent_log_var == "uniform": + raise ValueError("uniform latent_log_var requires use_quant_conv=False") + + encoder = Encoder( + dims=config["dims"], + in_channels=config.get("in_channels", 3), + out_channels=config["latent_channels"], + block_out_channels=config["block_out_channels"], + patch_size=config.get("patch_size", 1), + latent_log_var=latent_log_var, + norm_layer=config.get("norm_layer", "group_norm"), + patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), + add_channel_padding=config.get("add_channel_padding", False), + ) + + decoder = Decoder( + dims=config["dims"], + in_channels=config["latent_channels"], + out_channels=config.get("out_channels", 3), + block_out_channels=config["block_out_channels"], + patch_size=config.get("patch_size", 1), + norm_layer=config.get("norm_layer", "group_norm"), + patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), + add_channel_padding=config.get("add_channel_padding", False), + ) + + dims = config["dims"] + return VideoAutoencoder( + encoder=encoder, + decoder=decoder, + latent_channels=config["latent_channels"], + dims=dims, + use_quant_conv=use_quant_conv, + ) + + @property + def config(self): + return SimpleNamespace( + _class_name="VideoAutoencoder", + dims=self.dims, + in_channels=self.encoder.conv_in.in_channels + // (self.encoder.patch_size_t * self.encoder.patch_size**2), + out_channels=self.decoder.conv_out.out_channels + // (self.decoder.patch_size_t * self.decoder.patch_size**2), + latent_channels=self.decoder.conv_in.in_channels, + block_out_channels=[ + self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels + for i in range(len(self.encoder.down_blocks)) + ], + scaling_factor=1.0, + norm_layer=self.encoder.norm_layer, + patch_size=self.encoder.patch_size, + latent_log_var=self.encoder.latent_log_var, + use_quant_conv=self.use_quant_conv, + patch_size_t=self.encoder.patch_size_t, + add_channel_padding=self.encoder.add_channel_padding, + ) + + @property + def is_video_supported(self): + """ + Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. + """ + return self.dims != 2 + + @property + def downscale_factor(self): + return self.encoder.downsample_factor + + def to_json_string(self) -> str: + import json + + return json.dumps(self.config.__dict__) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + model_keys = set(name for name, _ in self.named_parameters()) + + key_mapping = { + ".resnets.": ".res_blocks.", + "downsamplers.0": "downsample", + "upsamplers.0": "upsample", + } + + converted_state_dict = {} + for key, value in state_dict.items(): + for k, v in key_mapping.items(): + key = key.replace(k, v) + + if "norm" in key and key not in model_keys: + logger.info( + f"Removing key {key} from state_dict as it is not present in the model" + ) + continue + + converted_state_dict[key] = value + + super().load_state_dict(converted_state_dict, strict=strict) + + def last_layer(self): + if hasattr(self.decoder, "conv_out"): + if isinstance(self.decoder.conv_out, nn.Sequential): + last_layer = self.decoder.conv_out[-1] + else: + last_layer = self.decoder.conv_out + else: + last_layer = self.decoder.layers[-1] + return last_layer + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var (`str`, *optional*, defaults to `per_channel`): + The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]] = 3, + in_channels: int = 3, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: Union[int, Tuple[int]] = 1, + norm_layer: str = "group_norm", # group_norm, pixel_norm + latent_log_var: str = "per_channel", + patch_size_t: Optional[int] = None, + add_channel_padding: Optional[bool] = False, + ): + super().__init__() + self.patch_size = patch_size + self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size + self.add_channel_padding = add_channel_padding + self.layers_per_block = layers_per_block + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + if add_channel_padding: + in_channels = in_channels * self.patch_size**3 + else: + in_channels = in_channels * self.patch_size_t * self.patch_size**2 + self.in_channels = in_channels + output_channel = block_out_channels[0] + + self.conv_in = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + padding=1, + ) + + self.down_blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels)): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = DownEncoderBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + num_layers=self.layers_per_block, + add_downsample=not is_final_block and 2**i >= patch_size, + resnet_eps=1e-6, + downsample_padding=0, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + self.down_blocks.append(down_block) + + self.mid_block = UNetMidBlock3D( + dims=dims, + in_channels=block_out_channels[-1], + num_layers=self.layers_per_block, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + + # out + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[-1], + num_groups=norm_num_groups, + eps=1e-6, + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == "per_channel": + conv_out_channels *= 2 + elif latent_log_var == "uniform": + conv_out_channels += 1 + elif latent_log_var != "none": + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + self.conv_out = make_conv_nd( + dims, block_out_channels[-1], conv_out_channels, 3, padding=1 + ) + + self.gradient_checkpointing = False + + @property + def downscale_factor(self): + return ( + 2 + ** len( + [ + block + for block in self.down_blocks + if isinstance(block.downsample, Downsample3D) + ] + ) + * self.patch_size + ) + + def forward( + self, sample: torch.FloatTensor, return_features=False + ) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + downsample_in_time = sample.shape[2] != 1 + + # patchify + patch_size_t = self.patch_size_t if downsample_in_time else 1 + sample = patchify( + sample, + patch_size_hw=self.patch_size, + patch_size_t=patch_size_t, + add_channel_padding=self.add_channel_padding, + ) + + sample = self.conv_in(sample) + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + if return_features: + features = [] + for down_block in self.down_blocks: + sample = checkpoint_fn(down_block)( + sample, downsample_in_time=downsample_in_time + ) + if return_features: + features.append(sample) + + sample = checkpoint_fn(self.mid_block)(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == "uniform": + last_channel = sample[:, -1:, ...] + num_dims = sample.dim() + + if num_dims == 4: + # For shape (B, C, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + elif num_dims == 5: + # For shape (B, C, F, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + else: + raise ValueError(f"Invalid input shape: {sample.shape}") + + if return_features: + features.append(sample[:, : self.latent_channels, ...]) + return sample, features + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + """ + + def __init__( + self, + dims, + in_channels: int = 3, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: int = 1, + norm_layer: str = "group_norm", + patch_size_t: Optional[int] = None, + add_channel_padding: Optional[bool] = False, + ): + super().__init__() + self.patch_size = patch_size + self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size + self.add_channel_padding = add_channel_padding + self.layers_per_block = layers_per_block + if add_channel_padding: + out_channels = out_channels * self.patch_size**3 + else: + out_channels = out_channels * self.patch_size_t * self.patch_size**2 + self.out_channels = out_channels + + self.conv_in = make_conv_nd( + dims, + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + self.mid_block = UNetMidBlock3D( + dims=dims, + in_channels=block_out_channels[-1], + num_layers=self.layers_per_block, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = UpDecoderBlock3D( + dims=dims, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block + and 2 ** (len(block_out_channels) - i - 1) > patch_size, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + self.up_blocks.append(up_block) + + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims, block_out_channels[0], out_channels, 3, padding=1 + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + assert target_shape is not None, "target_shape must be provided" + upsample_in_time = sample.shape[2] < target_shape[2] + + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + sample = checkpoint_fn(self.mid_block)(sample) + sample = sample.to(upscale_dtype) + + for up_block in self.up_blocks: + sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # un-patchify + patch_size_t = self.patch_size_t if upsample_in_time else 1 + sample = unpatchify( + sample, + patch_size_hw=self.patch_size, + patch_size_t=patch_size_t, + add_channel_padding=self.add_channel_padding, + ) + + return sample + + +class DownEncoderBlock3D(nn.Module): + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + add_downsample: bool = True, + downsample_padding: int = 1, + norm_layer: str = "group_norm", + ): + super().__init__() + res_blocks = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + res_blocks.append( + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + ) + + self.res_blocks = nn.ModuleList(res_blocks) + + if add_downsample: + self.downsample = Downsample3D( + dims, + out_channels, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsample = Identity() + + def forward( + self, hidden_states: torch.FloatTensor, downsample_in_time + ) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) + + hidden_states = self.downsample( + hidden_states, downsample_in_time=downsample_in_time + ) + + return hidden_states + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: str = "group_norm", + ): + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + + self.res_blocks = nn.ModuleList( + [ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) + + return hidden_states + + +class UpDecoderBlock3D(nn.Module): + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + add_upsample: bool = True, + norm_layer: str = "group_norm", + ): + super().__init__() + res_blocks = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + res_blocks.append( + ResnetBlock3D( + dims=dims, + in_channels=input_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + ) + + self.res_blocks = nn.ModuleList(res_blocks) + + if add_upsample: + self.upsample = Upsample3D( + dims=dims, channels=out_channels, out_channels=out_channels + ) + else: + self.upsample = Identity() + + self.resolution_idx = resolution_idx + + def forward( + self, hidden_states: torch.FloatTensor, upsample_in_time=True + ) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) + + hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: str = "group_norm", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + if norm_layer == "group_norm": + self.norm1 = torch.nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm1 = PixelNorm() + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if norm_layer == "group_norm": + self.norm2 = torch.nn.GroupNorm( + num_groups=groups, num_channels=out_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm2 = PixelNorm() + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + self.conv_shortcut = ( + make_linear_nd( + dims=dims, in_channels=in_channels, out_channels=out_channels + ) + if in_channels != out_channels + else nn.Identity() + ) + + def forward( + self, + input_tensor: torch.FloatTensor, + ) -> torch.FloatTensor: + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = self.conv2(hidden_states) + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class Downsample3D(nn.Module): + def __init__( + self, + dims, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + padding: int = 1, + ): + super().__init__() + stride: int = 2 + self.padding = padding + self.in_channels = in_channels + self.dims = dims + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + def forward(self, x, downsample_in_time=True): + conv = self.conv + if self.padding == 0: + if self.dims == 2: + padding = (0, 1, 0, 1) + else: + padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0) + + x = functional.pad(x, padding, mode="constant", value=0) + + if self.dims == (2, 1) and not downsample_in_time: + return conv(x, skip_time_conv=True) + + return conv(x) + + +class Upsample3D(nn.Module): + """ + An upsampling layer for 3D tensors of shape (B, C, D, H, W). + + :param channels: channels in the inputs and outputs. + """ + + def __init__(self, dims, channels, out_channels=None): + super().__init__() + self.dims = dims + self.channels = channels + self.out_channels = out_channels or channels + self.conv = make_conv_nd( + dims, channels, out_channels, kernel_size=3, padding=1, bias=True + ) + + def forward(self, x, upsample_in_time): + if self.dims == 2: + x = functional.interpolate( + x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest" + ) + else: + time_scale_factor = 2 if upsample_in_time else 1 + # print("before:", x.shape) + b, c, d, h, w = x.shape + x = rearrange(x, "b c d h w -> (b d) c h w") + # height and width interpolate + x = functional.interpolate( + x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest" + ) + _, _, h, w = x.shape + + if not upsample_in_time and self.dims == (2, 1): + x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w) + return self.conv(x, skip_time_conv=True) + + # Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b) + + # (b h w) c 1 d + new_d = x.shape[-1] * time_scale_factor + x = functional.interpolate(x, (1, new_d), mode="nearest") + # (b h w) c 1 new_d + x = rearrange( + x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d + ) + # b c d h w + + # x = functional.interpolate( + # x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + # ) + # print("after:", x.shape) + + return self.conv(x) + + +def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + if ( + (x.dim() == 5) + and (patch_size_hw > patch_size_t) + and (patch_size_t > 1 or add_channel_padding) + ): + channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1] + padding_zeros = torch.zeros( + x.shape[0], + channels_to_pad, + x.shape[2], + x.shape[3], + x.shape[4], + device=x.device, + dtype=x.dtype, + ) + x = torch.cat([padding_zeros, x], dim=1) + + return x + + +def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if ( + (x.dim() == 5) + and (patch_size_hw > patch_size_t) + and (patch_size_t > 1 or add_channel_padding) + ): + channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw)) + x = x[:, :channels_to_keep, :, :, :] + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + + +def create_video_autoencoder_config( + latent_channels: int = 4, +): + config = { + "_class_name": "VideoAutoencoder", + "dims": ( + 2, + 1, + ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [ + 128, + 256, + 512, + 512, + ], # Number of output channels of each encoder / decoder inner block + "patch_size": 1, + } + + return config + + +def create_video_autoencoder_pathify4x4x4_config( + latent_channels: int = 4, +): + config = { + "_class_name": "VideoAutoencoder", + "dims": ( + 2, + 1, + ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [512] + * 4, # Number of output channels of each encoder / decoder inner block + "patch_size": 4, + "latent_log_var": "uniform", + } + + return config + + +def create_video_autoencoder_pathify4x4_config( + latent_channels: int = 4, +): + config = { + "_class_name": "VideoAutoencoder", + "dims": 2, # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [512] + * 4, # Number of output channels of each encoder / decoder inner block + "patch_size": 4, + "norm_layer": "pixel_norm", + } + + return config + + +def test_vae_patchify_unpatchify(): + import torch + + x = torch.randn(2, 3, 8, 64, 64) + x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) + x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) + assert torch.allclose(x, x_unpatched) + + +def demo_video_autoencoder_forward_backward(): + # Configuration for the VideoAutoencoder + config = create_video_autoencoder_pathify4x4x4_config() + + # Instantiate the VideoAutoencoder with the specified configuration + video_autoencoder = VideoAutoencoder.from_config(config) + + print(video_autoencoder) + + # Print the total number of parameters in the video autoencoder + total_params = sum(p.numel() for p in video_autoencoder.parameters()) + print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") + + # Create a mock input tensor simulating a batch of videos + # Shape: (batch_size, channels, depth, height, width) + # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame + input_videos = torch.randn(2, 3, 8, 64, 64) + + # Forward pass: encode and decode the input videos + latent = video_autoencoder.encode(input_videos).latent_dist.mode() + print(f"input shape={input_videos.shape}") + print(f"latent shape={latent.shape}") + reconstructed_videos = video_autoencoder.decode( + latent, target_shape=input_videos.shape + ).sample + + print(f"reconstructed shape={reconstructed_videos.shape}") + + # Calculate the loss (e.g., mean squared error) + loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) + + # Perform backward pass + loss.backward() + + print(f"Demo completed with loss: {loss.item()}") + + +# Ensure to call the demo function to execute the forward and backward pass +if __name__ == "__main__": + demo_video_autoencoder_forward_backward() diff --git a/ltx_video/models/transformers/__init__.py b/ltx_video/models/transformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ltx_video/models/transformers/attention.py b/ltx_video/models/transformers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..bee0839ad78bfc33d2e940818edec2701ece99c7 --- /dev/null +++ b/ltx_video/models/transformers/attention.py @@ -0,0 +1,1264 @@ +import inspect +from importlib import import_module +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU +from diffusers.models.attention import _chunked_feed_forward +from diffusers.models.attention_processor import ( + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + SpatialNorm, +) +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.models.normalization import RMSNorm +from diffusers.utils import deprecate, logging +from diffusers.utils.torch_utils import maybe_allow_in_graph +from einops import rearrange +from torch import nn + +from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy + +try: + from torch_xla.experimental.custom_kernel import flash_attention +except ImportError: + # workaround for automatic tests. Currently this function is manually patched + # to the torch_xla lib on setup of container + pass + +# code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py + +logger = logging.get_logger(__name__) + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + adaptive_norm (`str`, *optional*, defaults to `"single_scale_shift"`): + The type of adaptive norm to use. Can be `"single_scale_shift"`, `"single_scale"` or "none". + standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): + The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, # pylint: disable=unused-argument + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + adaptive_norm: str = "single_scale_shift", # 'single_scale_shift', 'single_scale' or 'none' + standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm' + norm_eps: float = 1e-5, + qk_norm: Optional[str] = None, + final_dropout: bool = False, + attention_type: str = "default", # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + use_tpu_flash_attention: bool = False, + use_rope: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_tpu_flash_attention = use_tpu_flash_attention + self.adaptive_norm = adaptive_norm + + assert standardization_norm in ["layer_norm", "rms_norm"] + assert adaptive_norm in ["single_scale_shift", "single_scale", "none"] + + make_norm_layer = ( + nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = make_norm_layer( + dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps + ) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=use_rope, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=( + cross_attention_dim if not double_self_attention else None + ), + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=use_rope, + ) # is self-attn if encoder_hidden_states is none + + if adaptive_norm == "none": + self.attn2_norm = make_norm_layer( + dim, norm_eps, norm_elementwise_affine + ) + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(dim, norm_eps, norm_elementwise_affine) + + # 3. Feed-forward + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 5. Scale-shift for PixArt-Alpha. + if adaptive_norm != "none": + num_ada_params = 4 if adaptive_norm == "single_scale" else 6 + self.scale_shift_table = nn.Parameter( + torch.randn(num_ada_params, dim) / dim**0.5 + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU + attention kernel. + """ + self.use_tpu_flash_attention = True + self.attn1.set_use_tpu_flash_attention() + self.attn2.set_use_tpu_flash_attention() + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> torch.FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored." + ) + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + original_hidden_states = hidden_states + + norm_hidden_states = self.norm1(hidden_states) + + # Apply ada_norm_single + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None] + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + ada_values.unbind(dim=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + norm_hidden_states = norm_hidden_states.squeeze( + 1 + ) # TODO: Check if this is needed + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = ( + cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + ) + + attn_output = self.attn1( + norm_hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=( + encoder_hidden_states if self.only_cross_attention else None + ), + attention_mask=attention_mask, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **cross_attention_kwargs, + ) + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.adaptive_norm == "none": + attn_input = self.attn2_norm(hidden_states) + else: + attn_input = hidden_states + attn_output = self.attn2( + attn_input, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm2(hidden_states) + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward( + self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size + ) + else: + ff_output = self.ff(norm_hidden_states) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.TransformerBlock + ): + skip_layer_mask = skip_layer_mask.view(-1, 1, 1) + hidden_states = hidden_states * skip_layer_mask + original_hidden_states * ( + 1.0 - skip_layer_mask + ) + + return hidden_states + + +@maybe_allow_in_graph +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + qk_norm: Optional[str] = None, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + use_tpu_flash_attention: bool = False, + use_rope: bool = False, + ): + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = ( + cross_attention_dim if cross_attention_dim is not None else query_dim + ) + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.use_tpu_flash_attention = use_tpu_flash_attention + self.use_rope = use_rope + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + if qk_norm is None: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + elif qk_norm == "rms_norm": + self.q_norm = RMSNorm(dim_head * heads, eps=1e-5) + self.k_norm = RMSNorm(dim_head * heads, eps=1e-5) + elif qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(dim_head * heads, eps=1e-5) + self.k_norm = nn.LayerNorm(dim_head * heads, eps=1e-5) + else: + raise ValueError(f"Unsupported qk_norm method: {qk_norm}") + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm( + num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True + ) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm( + f_channels=query_dim, zq_channels=spatial_norm_dim + ) + else: + self.spatial_norm = None + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, + num_groups=cross_attention_norm_num_groups, + eps=1e-5, + affine=True, + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + linear_cls = nn.Linear + + self.linear_cls = linear_cls + self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = AttnProcessor2_0() + self.set_processor(processor) + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel. + """ + self.use_tpu_flash_attention = True + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info( + f"You are removing possibly trained weights of {self.processor} with {processor}" + ) + self._modules.pop("processor") + + self.processor = processor + + def get_processor( + self, return_deprecated_lora: bool = False + ) -> "AttentionProcessor": # noqa: F821 + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible + # serialization format for LoRA Attention Processors. It should be deleted once the integration + # with PEFT is completed. + is_lora_activated = { + name: module.lora_layer is not None + for name, module in self.named_modules() + if hasattr(module, "lora_layer") + } + + # 1. if no layer has a LoRA activated we can return the processor as usual + if not any(is_lora_activated.values()): + return self.processor + + # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` + is_lora_activated.pop("add_k_proj", None) + is_lora_activated.pop("add_v_proj", None) + # 2. else it is not posssible that only some layers have LoRA activated + if not all(is_lora_activated.values()): + raise ValueError( + f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" + ) + + # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor + non_lora_processor_cls_name = self.processor.__class__.__name__ + lora_processor_cls = getattr( + import_module(__name__), "LoRA" + non_lora_processor_cls_name + ) + + hidden_size = self.inner_dim + + # now create a LoRA attention processor from the LoRA layers + if lora_processor_cls in [ + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + ]: + kwargs = { + "cross_attention_dim": self.cross_attention_dim, + "rank": self.to_q.lora_layer.rank, + "network_alpha": self.to_q.lora_layer.network_alpha, + "q_rank": self.to_q.lora_layer.rank, + "q_hidden_size": self.to_q.lora_layer.out_features, + "k_rank": self.to_k.lora_layer.rank, + "k_hidden_size": self.to_k.lora_layer.out_features, + "v_rank": self.to_v.lora_layer.rank, + "v_hidden_size": self.to_v.lora_layer.out_features, + "out_rank": self.to_out[0].lora_layer.rank, + "out_hidden_size": self.to_out[0].lora_layer.out_features, + } + + if hasattr(self.processor, "attention_op"): + kwargs["attention_op"] = self.processor.attention_op + + lora_processor = lora_processor_cls(hidden_size, **kwargs) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict( + self.to_out[0].lora_layer.state_dict() + ) + elif lora_processor_cls == LoRAAttnAddedKVProcessor: + lora_processor = lora_processor_cls( + hidden_size, + cross_attention_dim=self.add_k_proj.weight.shape[0], + rank=self.to_q.lora_layer.rank, + network_alpha=self.to_q.lora_layer.network_alpha, + ) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict( + self.to_out[0].lora_layer.state_dict() + ) + + # only save if used + if self.add_k_proj.lora_layer is not None: + lora_processor.add_k_proj_lora.load_state_dict( + self.add_k_proj.lora_layer.state_dict() + ) + lora_processor.add_v_proj_lora.load_state_dict( + self.add_v_proj.lora_layer.state_dict() + ) + else: + lora_processor.add_k_proj_lora = None + lora_processor.add_v_proj_lora = None + else: + raise ValueError(f"{lora_processor_cls} does not exist.") + + return lora_processor + + def forward( + self, + hidden_states: torch.FloatTensor, + freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + skip_layer_mask (`torch.Tensor`, *optional*): + The skip layer mask to use. If `None`, no mask is applied. + skip_layer_strategy (`SkipLayerStrategy`, *optional*, defaults to `None`): + Controls which layers to skip for spatiotemporal guidance. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set( + inspect.signature(self.processor.__call__).parameters.keys() + ) + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by" + f" {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = { + k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters + } + + return self.processor( + self, + hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size // head_size, seq_len, dim * head_size + ) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape( + batch_size, seq_len * extra_dim, head_size, dim // head_size + ) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape( + batch_size * head_size, seq_len * extra_dim, dim // head_size + ) + + return tensor + + def get_attention_scores( + self, + query: torch.Tensor, + key: torch.Tensor, + attention_mask: torch.Tensor = None, + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.device, + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, + attention_mask: torch.Tensor, + target_length: int, + batch_size: int, + out_dim: int = 3, + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = ( + attention_mask.shape[0], + attention_mask.shape[1], + target_length, + ) + padding = torch.zeros( + padding_shape, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states( + self, encoder_hidden_states: torch.Tensor + ) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert ( + self.norm_cross is not None + ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + @staticmethod + def apply_rotary_emb( + input_tensor: torch.Tensor, + freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + cos_freqs = freqs_cis[0] + sin_freqs = freqs_cis[1] + + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor], + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + skip_layer_mask: Optional[torch.FloatTensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if skip_layer_mask is not None: + skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1) + + if (attention_mask is not None) and (not attn.use_tpu_flash_attention): + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states) + query = attn.q_norm(query) + + if encoder_hidden_states is not None: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + key = attn.to_k(encoder_hidden_states) + key = attn.k_norm(key) + else: # if no context provided do self-attention + encoder_hidden_states = hidden_states + key = attn.to_k(hidden_states) + key = attn.k_norm(key) + if attn.use_rope: + key = attn.apply_rotary_emb(key, freqs_cis) + query = attn.apply_rotary_emb(query, freqs_cis) + + value = attn.to_v(encoder_hidden_states) + value_for_stg = value + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + + if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention' + q_segment_indexes = None + if ( + attention_mask is not None + ): # if mask is required need to tune both segmenIds fields + # attention_mask = torch.squeeze(attention_mask).to(torch.float32) + attention_mask = attention_mask.to(torch.float32) + q_segment_indexes = torch.ones( + batch_size, query.shape[2], device=query.device, dtype=torch.float32 + ) + assert ( + attention_mask.shape[1] == key.shape[2] + ), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]" + + assert ( + query.shape[2] % 128 == 0 + ), f"ERROR: QUERY SHAPE must be divisible by 128 (TPU limitation) [{query.shape[2]}]" + assert ( + key.shape[2] % 128 == 0 + ), f"ERROR: KEY SHAPE must be divisible by 128 (TPU limitation) [{key.shape[2]}]" + + # run the TPU kernel implemented in jax with pallas + hidden_states_a = flash_attention( + q=query, + k=key, + v=value, + q_segment_ids=q_segment_indexes, + kv_segment_ids=attention_mask, + sm_scale=attn.scale, + ) + else: + hidden_states_a = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ) + + hidden_states_a = hidden_states_a.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states_a = hidden_states_a.to(query.dtype) + + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.AttentionSkip + ): + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * ( + 1.0 - skip_layer_mask + ) + elif ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.AttentionValues + ): + hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * ( + 1.0 - skip_layer_mask + ) + else: + hidden_states = hidden_states_a + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.Residual + ): + skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1, 1) + + if attn.residual_connection: + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.Residual + ): + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + query = attn.q_norm(query) + key = attn.k_norm(key) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + elif activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + else: + raise ValueError(f"Unsupported activation function: {activation_fn}") + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + compatible_cls = (GEGLU, LoRACompatibleLinear) + for module in self.net: + if isinstance(module, compatible_cls): + hidden_states = module(hidden_states, scale) + else: + hidden_states = module(hidden_states) + return hidden_states diff --git a/ltx_video/models/transformers/embeddings.py b/ltx_video/models/transformers/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..a30d6be16b4f3fe709cf24465e06eb798889ba66 --- /dev/null +++ b/ltx_video/models/transformers/embeddings.py @@ -0,0 +1,129 @@ +# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py +import math + +import numpy as np +import torch +from einops import rearrange +from torch import nn + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w) + grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w) + grid = grid.reshape([3, 1, w, h, f]) + pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid) + pos_embed = pos_embed.transpose(1, 0, 2, 3) + return rearrange(pos_embed, "h w f c -> (f h w) c") + + +def get_3d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 3 != 0: + raise ValueError("embed_dim must be divisible by 3") + + # use half of dimensions to encode grid_h + emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3) + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3) + + emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos_shape = pos.shape + + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + out = out.reshape([*pos_shape, -1])[0] + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D) + return emb + + +class SinusoidalPositionalEmbedding(nn.Module): + """Apply positional information to a sequence of embeddings. + + Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to + them + + Args: + embed_dim: (int): Dimension of the positional embedding. + max_seq_length: Maximum sequence length to apply positional embeddings + + """ + + def __init__(self, embed_dim: int, max_seq_length: int = 32): + super().__init__() + position = torch.arange(max_seq_length).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim) + ) + pe = torch.zeros(1, max_seq_length, embed_dim) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x): + _, seq_length, _ = x.shape + x = x + self.pe[:, :seq_length] + return x diff --git a/ltx_video/models/transformers/symmetric_patchifier.py b/ltx_video/models/transformers/symmetric_patchifier.py new file mode 100644 index 0000000000000000000000000000000000000000..2eca32033eef03c0dbffd7a25cca993bbda57ded --- /dev/null +++ b/ltx_video/models/transformers/symmetric_patchifier.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +import torch +from diffusers.configuration_utils import ConfigMixin +from einops import rearrange +from torch import Tensor + + +class Patchifier(ConfigMixin, ABC): + def __init__(self, patch_size: int): + super().__init__() + self._patch_size = (1, patch_size, patch_size) + + @abstractmethod + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + raise NotImplementedError("Patchify method not implemented") + + @abstractmethod + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + pass + + @property + def patch_size(self): + return self._patch_size + + def get_latent_coords( + self, latent_num_frames, latent_height, latent_width, batch_size, device + ): + """ + Return a tensor of shape [batch_size, 3, num_patches] containing the + top-left corner latent coordinates of each latent patch. + The tensor is repeated for each batch element. + """ + latent_sample_coords = torch.meshgrid( + torch.arange(0, latent_num_frames, self._patch_size[0], device=device), + torch.arange(0, latent_height, self._patch_size[1], device=device), + torch.arange(0, latent_width, self._patch_size[2], device=device), + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = rearrange( + latent_coords, "b c f h w -> b c (f h w)", b=batch_size + ) + return latent_coords + + +class SymmetricPatchifier(Patchifier): + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + b, _, f, h, w = latents.shape + latent_coords = self.get_latent_coords(f, h, w, b, latents.device) + latents = rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + return latents, latent_coords + + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + output_height = output_height // self._patch_size[1] + output_width = output_width // self._patch_size[2] + latents = rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + h=output_height, + w=output_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + return latents diff --git a/ltx_video/models/transformers/transformer3d.py b/ltx_video/models/transformers/transformer3d.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc08d8e3f1669287bca04135fd63498385d014d --- /dev/null +++ b/ltx_video/models/transformers/transformer3d.py @@ -0,0 +1,507 @@ +# Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union +import os +import json +import glob +from pathlib import Path + +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import PixArtAlphaTextProjection +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormSingle +from diffusers.utils import BaseOutput, is_torch_version +from diffusers.utils import logging +from torch import nn +from safetensors import safe_open + + +from ltx_video.models.transformers.attention import BasicTransformerBlock +from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy + +from ltx_video.utils.diffusers_config_mapping import ( + diffusers_and_ours_config_mapping, + make_hashable_key, + TRANSFORMER_KEYS_RENAME_DICT, +) + + +logger = logging.get_logger(__name__) + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer3DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + adaptive_norm: str = "single_scale_shift", # 'single_scale_shift' or 'single_scale' + standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + use_tpu_flash_attention: bool = False, # if True uses the TPU attention offload ('flash attention') + qk_norm: Optional[str] = None, + positional_embedding_type: str = "rope", + positional_embedding_theta: Optional[float] = None, + positional_embedding_max_pos: Optional[List[int]] = None, + timestep_scale_multiplier: Optional[float] = None, + causal_temporal_positioning: bool = False, # For backward compatibility, will be deprecated + ): + super().__init__() + self.use_tpu_flash_attention = ( + use_tpu_flash_attention # FIXME: push config down to the attention modules + ) + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim + self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True) + self.positional_embedding_type = positional_embedding_type + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.use_rope = self.positional_embedding_type == "rope" + self.timestep_scale_multiplier = timestep_scale_multiplier + + if self.positional_embedding_type == "absolute": + raise ValueError("Absolute positional embedding is no longer supported") + elif self.positional_embedding_type == "rope": + if positional_embedding_theta is None: + raise ValueError( + "If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined" + ) + if positional_embedding_max_pos is None: + raise ValueError( + "If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined" + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + adaptive_norm=adaptive_norm, + standardization_norm=standardization_norm, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=self.use_rope, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter( + torch.randn(2, inner_dim) / inner_dim**0.5 + ) + self.proj_out = nn.Linear(inner_dim, self.out_channels) + + self.adaln_single = AdaLayerNormSingle( + inner_dim, use_additional_conditions=False + ) + if adaptive_norm == "single_scale": + self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=inner_dim + ) + + self.gradient_checkpointing = False + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU + attention kernel. + """ + logger.info("ENABLE TPU FLASH ATTENTION -> TRUE") + self.use_tpu_flash_attention = True + # push config down to the attention modules + for block in self.transformer_blocks: + block.set_use_tpu_flash_attention() + + def create_skip_layer_mask( + self, + batch_size: int, + num_conds: int, + ptb_index: int, + skip_block_list: Optional[List[int]] = None, + ): + if skip_block_list is None or len(skip_block_list) == 0: + return None + num_layers = len(self.transformer_blocks) + mask = torch.ones( + (num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype + ) + for block_idx in skip_block_list: + mask[block_idx, ptb_index::num_conds] = 0 + return mask + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def get_fractional_positions(self, indices_grid): + fractional_positions = torch.stack( + [ + indices_grid[:, i] / self.positional_embedding_max_pos[i] + for i in range(3) + ], + dim=-1, + ) + return fractional_positions + + def precompute_freqs_cis(self, indices_grid, spacing="exp"): + dtype = torch.float32 # We need full precision in the freqs_cis computation. + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + device = fractional_positions.device + if spacing == "exp": + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + dim // 6, + device=device, + dtype=dtype, + ) + ) + indices = indices.to(dtype=dtype) + elif spacing == "exp_2": + indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim) + indices = indices.to(dtype=dtype) + elif spacing == "linear": + indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype) + elif spacing == "sqrt": + indices = torch.linspace( + start**2, end**2, dim // 6, device=device, dtype=dtype + ).sqrt() + + indices = indices * math.pi / 2 + + if spacing == "exp_2": + freqs = ( + (indices * fractional_positions.unsqueeze(-1)) + .transpose(-1, -2) + .flatten(2) + ) + else: + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if dim % 6 != 0: + cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq.to(self.dtype), sin_freq.to(self.dtype) + + def load_state_dict( + self, + state_dict: Dict, + *args, + **kwargs, + ): + if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]): + state_dict = { + key.replace("model.diffusion_model.", ""): value + for key, value in state_dict.items() + if key.startswith("model.diffusion_model.") + } + super().load_state_dict(state_dict, *args, **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_path = Path(pretrained_model_path) + if pretrained_model_path.is_dir(): + config_path = pretrained_model_path / "transformer" / "config.json" + with open(config_path, "r") as f: + config = make_hashable_key(json.load(f)) + + assert config in diffusers_and_ours_config_mapping, ( + "Provided diffusers checkpoint config for transformer is not suppported. " + "We only support diffusers configs found in Lightricks/LTX-Video." + ) + + config = diffusers_and_ours_config_mapping[config] + state_dict = {} + ckpt_paths = ( + pretrained_model_path + / "transformer" + / "diffusion_pytorch_model*.safetensors" + ) + dict_list = glob.glob(str(ckpt_paths)) + for dict_path in dict_list: + part_dict = {} + with safe_open(dict_path, framework="pt", device="cpu") as f: + for k in f.keys(): + part_dict[k] = f.get_tensor(k) + state_dict.update(part_dict) + + for key in list(state_dict.keys()): + new_key = key + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + state_dict[new_key] = state_dict.pop(key) + + with torch.device("meta"): + transformer = cls.from_config(config) + transformer.load_state_dict(state_dict, assign=True, strict=True) + elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith( + ".safetensors" + ): + comfy_single_file_state_dict = {} + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + comfy_single_file_state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + transformer_config = configs["transformer"] + with torch.device("meta"): + transformer = Transformer3DModel.from_config(transformer_config) + transformer.load_state_dict(comfy_single_file_state_dict, assign=True) + return transformer + + def forward( + self, + hidden_states: torch.Tensor, + indices_grid: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`): + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + skip_layer_mask ( `torch.Tensor`, *optional*): + A mask of shape `(num_layers, batch)` that indicates which layers to skip. `0` at position + `layer, batch_idx` indicates that the layer should be skipped for the corresponding batch index. + skip_layer_strategy ( `SkipLayerStrategy`, *optional*, defaults to `None`): + Controls which layers are skipped when calculating a perturbed latent for spatiotemporal guidance. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # for tpu attention offload 2d token masks are used. No need to transform. + if not self.use_tpu_flash_attention: + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(hidden_states.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + hidden_states = self.patchify_proj(hidden_states) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + freqs_cis = self.precompute_freqs_cis(indices_grid) + + batch_size = hidden_states.shape[0] + timestep, embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view( + batch_size, -1, embedded_timestep.shape[-1] + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view( + batch_size, -1, hidden_states.shape[-1] + ) + + for block_idx, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + freqs_cis, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + ( + skip_layer_mask[block_idx] + if skip_layer_mask is not None + else None + ), + skip_layer_strategy, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + freqs_cis=freqs_cis, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + skip_layer_mask=( + skip_layer_mask[block_idx] + if skip_layer_mask is not None + else None + ), + skip_layer_strategy=skip_layer_strategy, + ) + + # 3. Output + scale_shift_values = ( + self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if not return_dict: + return (hidden_states,) + + return Transformer3DModelOutput(sample=hidden_states) diff --git a/ltx_video/pipelines/__init__.py b/ltx_video/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ltx_video/pipelines/ai_studio_code (11).py b/ltx_video/pipelines/ai_studio_code (11).py new file mode 100644 index 0000000000000000000000000000000000000000..6318f6a37a9fb906d37db4d27b1f540c7f334426 --- /dev/null +++ b/ltx_video/pipelines/ai_studio_code (11).py @@ -0,0 +1,157 @@ +# Adaptado de: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +# (e com a nossa modificação pela ciência!) + +import copy +import inspect +import math +import re +from contextlib import nullcontext +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKL +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.schedulers import DPMSolverMultistepScheduler +from diffusers.utils import deprecate, logging +from diffusers.utils.torch_utils import randn_tensor +from einops import rearrange +from transformers import ( + T5EncoderModel, + T5Tokenizer, + AutoModelForCausalLM, + AutoProcessor, + AutoTokenizer, +) + +from ltx_video.models.autoencoders.causal_video_autoencoder import ( + CausalVideoAutoencoder, +) +from ltx_video.models.autoencoders.vae_encode import ( + get_vae_size_scale_factor, + latent_to_pixel_coords, + vae_decode, + vae_encode, +) +from ltx_video.models.transformers.symmetric_patchifier import Patchifier +from ltx_video.models.transformers.transformer3d import Transformer3DModel +from ltx_video.schedulers.rf import TimestepShifter +from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy +from ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt +from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler +from ltx_video.models.autoencoders.vae_encode import ( + un_normalize_latents, + normalize_latents, +) + +# ... (Todo o código inicial do arquivo permanece o mesmo, incluindo ASPECT_RATIO_BINS, retrieve_timesteps, ConditioningItem, etc.) +# ... (Vou pular para a classe LTXVideoPipeline para manter a resposta focada) + +class LTXVideoPipeline(DiffusionPipeline): + # ... (O __init__ e outras funções como encode_prompt, check_inputs, etc., permanecem as mesmas) + # ... (Pulando para a função __call__ onde faremos a nossa modificação) + + @torch.no_grad() + def __call__( + self, + height: int, + width: int, + num_frames: int, + frame_rate: float, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + timesteps: List[int] = None, + guidance_scale: Union[float, List[float]] = 4.5, + cfg_star_rescale: bool = False, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, + stg_scale: Union[float, List[float]] = 1.0, + rescaling_scale: Union[float, List[float]] = 0.7, + guidance_timesteps: Optional[List[int]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + conditioning_items: Optional[List[ConditioningItem]] = None, + decode_timestep: Union[List[float], float] = 0.0, + decode_noise_scale: Optional[List[float]] = None, + mixed_precision: bool = False, + offload_to_cpu: bool = False, + enhance_prompt: bool = False, + text_encoder_max_tokens: int = 256, + stochastic_sampling: bool = False, + media_items: Optional[torch.Tensor] = None, + tone_map_compression_ratio: float = 0.0, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + + # --- [NOSSA MODIFICAÇÃO] Captura o prompt original para logging --- + original_prompt_for_logging = prompt + + # ... (O resto do código inicial da função __call__ permanece o mesmo) ... + # ... (check_inputs, default height/width, etc.) + + if enhance_prompt: + self.prompt_enhancer_image_caption_model = ( + self.prompt_enhancer_image_caption_model.to(self._execution_device) + ) + self.prompt_enhancer_llm_model = self.prompt_enhancer_llm_model.to( + self._execution_device + ) + + # A chamada para o Diretor Assistente + enhanced_prompt = generate_cinematic_prompt( + self.prompt_enhancer_image_caption_model, + self.prompt_enhancer_image_caption_processor, + self.prompt_enhancer_llm_model, + self.prompt_enhancer_llm_tokenizer, + prompt, + conditioning_items, + max_new_tokens=text_encoder_max_tokens, + ) + + # --- [NOSSA ESCUTA SECRETA PELA CIÊNCIA!] --- + print("\n" + "="*50) + print("--- [LOG DO DIRETOR ASSISTENTE (PROMPT ENHANCER)] ---") + print(f"Prompt Original do Maestro: {original_prompt_for_logging}") + print(f"PROMPT FINAL APERFEIÇOADO (enviado para o LTX): {enhanced_prompt}") + print("--- [FIM DO LOG DO DIRETOR ASSISTENTE] ---") + print("="*50 + "\n") + # --- [FIM DA ESCUTA] --- + + # Atualiza o prompt que será usado pelo resto da função + prompt = enhanced_prompt + + # ... (O resto da função __call__ continua a partir daqui, usando o `prompt` novo ou o original) + # ... (encode_prompt, prepare_latents, denoising loop, etc.) + + # 3. Encode input prompt + if self.text_encoder is not None: + self.text_encoder = self.text_encoder.to(self._execution_device) + + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + True, + negative_prompt=negative_prompt, + # ... (resto dos parâmetros) + ) + + # ... (todo o resto do arquivo, sem mais nenhuma modificação) ... + # ... (denoising_step, prepare_conditioning, etc.) \ No newline at end of file diff --git a/ltx_video/pipelines/crf_compressor.py b/ltx_video/pipelines/crf_compressor.py new file mode 100644 index 0000000000000000000000000000000000000000..9b9380afb7f92e0a2379c9db4cf5ce9f5a20942c --- /dev/null +++ b/ltx_video/pipelines/crf_compressor.py @@ -0,0 +1,50 @@ +import av +import torch +import io +import numpy as np + + +def _encode_single_frame(output_file, image_array: np.ndarray, crf): + container = av.open(output_file, "w", format="mp4") + try: + stream = container.add_stream( + "libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"} + ) + stream.height = image_array.shape[0] + stream.width = image_array.shape[1] + av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat( + format="yuv420p" + ) + container.mux(stream.encode(av_frame)) + container.mux(stream.encode()) + finally: + container.close() + + +def _decode_single_frame(video_file): + container = av.open(video_file) + try: + stream = next(s for s in container.streams if s.type == "video") + frame = next(container.decode(stream)) + finally: + container.close() + return frame.to_ndarray(format="rgb24") + + +def compress(image: torch.Tensor, crf=29): + if crf == 0: + return image + + image_array = ( + (image[: (image.shape[0] // 2) * 2, : (image.shape[1] // 2) * 2] * 255.0) + .byte() + .cpu() + .numpy() + ) + with io.BytesIO() as output_file: + _encode_single_frame(output_file, image_array, crf) + video_bytes = output_file.getvalue() + with io.BytesIO(video_bytes) as video_file: + image_array = _decode_single_frame(video_file) + tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0 + return tensor diff --git a/ltx_video/pipelines/pipeline_ltx_video.py b/ltx_video/pipelines/pipeline_ltx_video.py new file mode 100644 index 0000000000000000000000000000000000000000..4381d5be25ab2b2929e41fbca9fb8e541002f451 --- /dev/null +++ b/ltx_video/pipelines/pipeline_ltx_video.py @@ -0,0 +1,1903 @@ +# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +import copy +import inspect +import math +import re +from contextlib import nullcontext +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKL +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.schedulers import DPMSolverMultistepScheduler +from diffusers.utils import deprecate, logging +from diffusers.utils.torch_utils import randn_tensor +from einops import rearrange +from transformers import ( + T5EncoderModel, + T5Tokenizer, + AutoModelForCausalLM, + AutoProcessor, + AutoTokenizer, +) + +from ltx_video.models.autoencoders.causal_video_autoencoder import ( + CausalVideoAutoencoder, +) +from ltx_video.models.autoencoders.vae_encode import ( + get_vae_size_scale_factor, + latent_to_pixel_coords, + vae_decode, + vae_encode, +) +from ltx_video.models.transformers.symmetric_patchifier import Patchifier +from ltx_video.models.transformers.transformer3d import Transformer3DModel +from ltx_video.schedulers.rf import TimestepShifter +from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy +from ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt +from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler +from ltx_video.models.autoencoders.vae_encode import ( + un_normalize_latents, + normalize_latents, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +ASPECT_RATIO_1024_BIN = { + "0.25": [512.0, 2048.0], + "0.28": [512.0, 1856.0], + "0.32": [576.0, 1792.0], + "0.33": [576.0, 1728.0], + "0.35": [576.0, 1664.0], + "0.4": [640.0, 1600.0], + "0.42": [640.0, 1536.0], + "0.48": [704.0, 1472.0], + "0.5": [704.0, 1408.0], + "0.52": [704.0, 1344.0], + "0.57": [768.0, 1344.0], + "0.6": [768.0, 1280.0], + "0.68": [832.0, 1216.0], + "0.72": [832.0, 1152.0], + "0.78": [896.0, 1152.0], + "0.82": [896.0, 1088.0], + "0.88": [960.0, 1088.0], + "0.94": [960.0, 1024.0], + "1.0": [1024.0, 1024.0], + "1.07": [1024.0, 960.0], + "1.13": [1088.0, 960.0], + "1.21": [1088.0, 896.0], + "1.29": [1152.0, 896.0], + "1.38": [1152.0, 832.0], + "1.46": [1216.0, 832.0], + "1.67": [1280.0, 768.0], + "1.75": [1344.0, 768.0], + "2.0": [1408.0, 704.0], + "2.09": [1472.0, 704.0], + "2.4": [1536.0, 640.0], + "2.5": [1600.0, 640.0], + "3.0": [1728.0, 576.0], + "4.0": [2048.0, 512.0], +} + +ASPECT_RATIO_512_BIN = { + "0.25": [256.0, 1024.0], + "0.28": [256.0, 928.0], + "0.32": [288.0, 896.0], + "0.33": [288.0, 864.0], + "0.35": [288.0, 832.0], + "0.4": [320.0, 800.0], + "0.42": [320.0, 768.0], + "0.48": [352.0, 736.0], + "0.5": [352.0, 704.0], + "0.52": [352.0, 672.0], + "0.57": [384.0, 672.0], + "0.6": [384.0, 640.0], + "0.68": [416.0, 608.0], + "0.72": [416.0, 576.0], + "0.78": [448.0, 576.0], + "0.82": [448.0, 544.0], + "0.88": [480.0, 544.0], + "0.94": [480.0, 512.0], + "1.0": [512.0, 512.0], + "1.07": [512.0, 480.0], + "1.13": [544.0, 480.0], + "1.21": [544.0, 448.0], + "1.29": [576.0, 448.0], + "1.38": [576.0, 416.0], + "1.46": [608.0, 416.0], + "1.67": [640.0, 384.0], + "1.75": [672.0, 384.0], + "2.0": [704.0, 352.0], + "2.09": [736.0, 352.0], + "2.4": [768.0, 320.0], + "2.5": [800.0, 320.0], + "3.0": [864.0, 288.0], + "4.0": [1024.0, 256.0], +} + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + max_timestep ('float', *optional*, defaults to 1.0): + The initial noising level for image-to-image/video-to-video. The list if timestamps will be + truncated to start with a timestamp greater or equal to this. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + + if ( + skip_initial_inference_steps < 0 + or skip_final_inference_steps < 0 + or skip_initial_inference_steps + skip_final_inference_steps + >= num_inference_steps + ): + raise ValueError( + "invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps" + ) + + timesteps = timesteps[ + skip_initial_inference_steps : len(timesteps) - skip_final_inference_steps + ] + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + num_inference_steps = len(timesteps) + + return timesteps, num_inference_steps + + +@dataclass +class ConditioningItem: + """ + Defines a single frame-conditioning item - a single frame or a sequence of frames. + + Attributes: + media_item (torch.Tensor): shape=(b, 3, f, h, w). The media item to condition on. + media_frame_number (int): The start-frame number of the media item in the generated video. + conditioning_strength (float): The strength of the conditioning (1.0 = full conditioning). + media_x (Optional[int]): Optional left x coordinate of the media item in the generated frame. + media_y (Optional[int]): Optional top y coordinate of the media item in the generated frame. + """ + + media_item: torch.Tensor + media_frame_number: int + conditioning_strength: float + media_x: Optional[int] = None + media_y: Optional[int] = None + + +class LTXVideoPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using LTX-Video. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. This uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`Transformer2DModel`]): + A text conditioned `Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = [ + "tokenizer", + "text_encoder", + "prompt_enhancer_image_caption_model", + "prompt_enhancer_image_caption_processor", + "prompt_enhancer_llm_model", + "prompt_enhancer_llm_tokenizer", + ] + model_cpu_offload_seq = "prompt_enhancer_image_caption_model->prompt_enhancer_llm_model->text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: Transformer3DModel, + scheduler: DPMSolverMultistepScheduler, + patchifier: Patchifier, + prompt_enhancer_image_caption_model: AutoModelForCausalLM, + prompt_enhancer_image_caption_processor: AutoProcessor, + prompt_enhancer_llm_model: AutoModelForCausalLM, + prompt_enhancer_llm_tokenizer: AutoTokenizer, + allowed_inference_steps: Optional[List[float]] = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + patchifier=patchifier, + prompt_enhancer_image_caption_model=prompt_enhancer_image_caption_model, + prompt_enhancer_image_caption_processor=prompt_enhancer_image_caption_processor, + prompt_enhancer_llm_model=prompt_enhancer_llm_model, + prompt_enhancer_llm_tokenizer=prompt_enhancer_llm_tokenizer, + ) + + self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor( + self.vae + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.allowed_inference_steps = allowed_inference_steps + + def mask_text_embeddings(self, emb, mask): + if emb.shape[0] == 1: + keep_index = mask.sum().item() + return emb[:, :, :keep_index, :], keep_index + else: + masked_feature = emb * mask[:, None, :, None] + return masked_feature, emb.shape[2] + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + text_encoder_max_tokens: int = 256, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + This should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. + """ + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # See Section 3.1. of the paper. + max_length = ( + text_encoder_max_tokens # TPU supports only lengths multiple of 128 + ) + if prompt_embeds is None: + assert ( + self.text_encoder is not None + ), "You should provide either prompt_embeds or self.text_encoder should not be None," + text_enc_device = next(self.text_encoder.parameters()).device + prompt = self._text_preprocessing(prompt) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(text_enc_device) + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(text_enc_device), attention_mask=prompt_attention_mask + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt) + prompt_attention_mask = prompt_attention_mask.view( + bs_embed * num_images_per_prompt, -1 + ) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = self._text_preprocessing(negative_prompt) + uncond_tokens = uncond_tokens * batch_size + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to( + text_enc_device + ) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(text_enc_device), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=dtype, device=device + ) + + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat( + 1, num_images_per_prompt + ) + negative_prompt_attention_mask = negative_prompt_attention_mask.view( + bs_embed * num_images_per_prompt, -1 + ) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + enhance_prompt=False, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError( + "Must provide `prompt_attention_mask` when specifying `prompt_embeds`." + ) + + if ( + negative_prompt_embeds is not None + and negative_prompt_attention_mask is None + ): + raise ValueError( + "Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + if enhance_prompt: + assert ( + self.prompt_enhancer_image_caption_model is not None + ), "Image caption model must be initialized if enhance_prompt is True" + assert ( + self.prompt_enhancer_image_caption_processor is not None + ), "Image caption processor must be initialized if enhance_prompt is True" + assert ( + self.prompt_enhancer_llm_model is not None + ), "Text prompt enhancer model must be initialized if enhance_prompt is True" + assert ( + self.prompt_enhancer_llm_tokenizer is not None + ), "Text prompt enhancer tokenizer must be initialized if enhance_prompt is True" + + def _text_preprocessing(self, text): + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + text = text.strip() + return text + + return [process(t) for t in text] + + @staticmethod + def add_noise_to_image_conditioning_latents( + t: float, + init_latents: torch.Tensor, + latents: torch.Tensor, + noise_scale: float, + conditioning_mask: torch.Tensor, + generator, + eps=1e-6, + ): + """ + Add timestep-dependent noise to the hard-conditioning latents. + This helps with motion continuity, especially when conditioned on a single frame. + """ + noise = randn_tensor( + latents.shape, + generator=generator, + device=latents.device, + dtype=latents.dtype, + ) + # Add noise only to hard-conditioning latents (conditioning_mask = 1.0) + need_to_noise = (conditioning_mask > 1.0 - eps).unsqueeze(-1) + noised_latents = init_latents + noise_scale * noise * (t**2) + latents = torch.where(need_to_noise, noised_latents, latents) + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents( + self, + latents: torch.Tensor | None, + media_items: torch.Tensor | None, + timestep: float, + latent_shape: torch.Size | Tuple[Any, ...], + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | List[torch.Generator], + vae_per_channel_normalize: bool = True, + ): + """ + Prepare the initial latent tensor to be denoised. + The latents are either pure noise or a noised version of the encoded media items. + Args: + latents (`torch.FloatTensor` or `None`): + The latents to use (provided by the user) or `None` to create new latents. + media_items (`torch.FloatTensor` or `None`): + An image or video to be updated using img2img or vid2vid. The media item is encoded and noised. + timestep (`float`): + The timestep to noise the encoded media_items to. + latent_shape (`torch.Size`): + The target latent shape. + dtype (`torch.dtype`): + The target dtype. + device (`torch.device`): + The target device. + generator (`torch.Generator` or `List[torch.Generator]`): + Generator(s) to be used for the noising process. + vae_per_channel_normalize ('bool'): + When encoding the media_items, whether to normalize the latents per-channel. + Returns: + `torch.FloatTensor`: The latents to be used for the denoising process. This is a tensor of shape + (batch_size, num_channels, height, width). + """ + if isinstance(generator, list) and len(generator) != latent_shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {latent_shape[0]}. Make sure the batch size matches the length of the generators." + ) + + # Initialize the latents with the given latents or encoded media item, if provided + assert ( + latents is None or media_items is None + ), "Cannot provide both latents and media_items. Please provide only one of the two." + + assert ( + latents is None and media_items is None or timestep < 1.0 + ), "Input media_item or latents are provided, but they will be replaced with noise." + + if media_items is not None: + latents = vae_encode( + media_items.to(dtype=self.vae.dtype, device=self.vae.device), + self.vae, + vae_per_channel_normalize=vae_per_channel_normalize, + ) + if latents is not None: + assert ( + latents.shape == latent_shape + ), f"Latents have to be of shape {latent_shape} but are {latents.shape}." + latents = latents.to(device=device, dtype=dtype) + + # For backward compatibility, generate in the "patchified" shape and rearrange + b, c, f, h, w = latent_shape + noise = randn_tensor( + (b, f * h * w, c), generator=generator, device=device, dtype=dtype + ) + noise = rearrange(noise, "b (f h w) c -> b c f h w", f=f, h=h, w=w) + + # scale the initial noise by the standard deviation required by the scheduler + noise = noise * self.scheduler.init_noise_sigma + + if latents is None: + latents = noise + else: + # Noise the latents to the required (first) timestep + latents = timestep * noise + (1 - timestep) * latents + + return latents + + @staticmethod + def classify_height_width_bin( + height: int, width: int, ratios: dict + ) -> Tuple[int, int]: + """Returns binned height and width.""" + ar = float(height / width) + closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar)) + default_hw = ratios[closest_ratio] + return int(default_hw[0]), int(default_hw[1]) + + @staticmethod + def resize_and_crop_tensor( + samples: torch.Tensor, new_width: int, new_height: int + ) -> torch.Tensor: + n_frames, orig_height, orig_width = samples.shape[-3:] + + # Check if resizing is needed + if orig_height != new_height or orig_width != new_width: + ratio = max(new_height / orig_height, new_width / orig_width) + resized_width = int(orig_width * ratio) + resized_height = int(orig_height * ratio) + + # Resize + samples = LTXVideoPipeline.resize_tensor( + samples, resized_height, resized_width + ) + + # Center Crop + start_x = (resized_width - new_width) // 2 + end_x = start_x + new_width + start_y = (resized_height - new_height) // 2 + end_y = start_y + new_height + samples = samples[..., start_y:end_y, start_x:end_x] + + return samples + + @staticmethod + def resize_tensor(media_items, height, width): + n_frames = media_items.shape[2] + if media_items.shape[-2:] != (height, width): + media_items = rearrange(media_items, "b c n h w -> (b n) c h w") + media_items = F.interpolate( + media_items, + size=(height, width), + mode="bilinear", + align_corners=False, + ) + media_items = rearrange(media_items, "(b n) c h w -> b c n h w", n=n_frames) + return media_items + + @torch.no_grad() + def __call__( + self, + height: int, + width: int, + num_frames: int, + frame_rate: float, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + skip_initial_inference_steps: int = 0, + skip_final_inference_steps: int = 0, + timesteps: List[int] = None, + guidance_scale: Union[float, List[float]] = 4.5, + cfg_star_rescale: bool = False, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, + stg_scale: Union[float, List[float]] = 1.0, + rescaling_scale: Union[float, List[float]] = 0.7, + guidance_timesteps: Optional[List[int]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + conditioning_items: Optional[List[ConditioningItem]] = None, + decode_timestep: Union[List[float], float] = 0.0, + decode_noise_scale: Optional[List[float]] = None, + mixed_precision: bool = False, + offload_to_cpu: bool = False, + enhance_prompt: bool = False, + text_encoder_max_tokens: int = 256, + stochastic_sampling: bool = False, + media_items: Optional[torch.Tensor] = None, + tone_map_compression_ratio: float = 0.0, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. If `timesteps` is provided, this parameter is ignored. + skip_initial_inference_steps (`int`, *optional*, defaults to 0): + The number of initial timesteps to skip. After calculating the timesteps, this number of timesteps will + be removed from the beginning of the timesteps list. Meaning the highest-timesteps values will not run. + skip_final_inference_steps (`int`, *optional*, defaults to 0): + The number of final timesteps to skip. After calculating the timesteps, this number of timesteps will + be removed from the end of the timesteps list. Meaning the lowest-timesteps values will not run. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_star_rescale (`bool`, *optional*, defaults to `False`): + If set to `True`, applies the CFG star rescale. Scales the negative prediction according to dot + product between positive and negative. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. This negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + enhance_prompt (`bool`, *optional*, defaults to `False`): + If set to `True`, the prompt is enhanced using a LLM model. + text_encoder_max_tokens (`int`, *optional*, defaults to `256`): + The maximum number of tokens to use for the text encoder. + stochastic_sampling (`bool`, *optional*, defaults to `False`): + If set to `True`, the sampling is stochastic. If set to `False`, the sampling is deterministic. + media_items ('torch.Tensor', *optional*): + The input media item used for image-to-image / video-to-video. + tone_map_compression_ratio: compression ratio for tone mapping, defaults to 0.0. + If set to 0.0, no tone mapping is applied. If set to 1.0 - full compression is applied. + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + is_video = kwargs.get("is_video", False) + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + self.video_scale_factor = self.video_scale_factor if is_video else 1 + vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", True) + image_cond_noise_scale = kwargs.get("image_cond_noise_scale", 0.0) + + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + latent_num_frames = num_frames // self.video_scale_factor + if isinstance(self.vae, CausalVideoAutoencoder) and is_video: + latent_num_frames += 1 + latent_shape = ( + batch_size * num_images_per_prompt, + self.transformer.config.in_channels, + latent_num_frames, + latent_height, + latent_width, + ) + + # Prepare the list of denoising time-steps + + retrieve_timesteps_kwargs = {} + if isinstance(self.scheduler, TimestepShifter): + retrieve_timesteps_kwargs["samples_shape"] = latent_shape + + assert ( + skip_initial_inference_steps == 0 + or latents is not None + or media_items is not None + ), ( + f"skip_initial_inference_steps ({skip_initial_inference_steps}) is used for image-to-image/video-to-video - " + "media_item or latents should be provided." + ) + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + skip_initial_inference_steps=skip_initial_inference_steps, + skip_final_inference_steps=skip_final_inference_steps, + **retrieve_timesteps_kwargs, + ) + + if self.allowed_inference_steps is not None: + for timestep in [round(x, 4) for x in timesteps.tolist()]: + assert ( + timestep in self.allowed_inference_steps + ), f"Invalid inference timestep {timestep}. Allowed timesteps are {self.allowed_inference_steps}." + + if guidance_timesteps: + guidance_mapping = [] + for timestep in timesteps: + indices = [ + i for i, val in enumerate(guidance_timesteps) if val <= timestep + ] + # assert len(indices) > 0, f"No guidance timestep found for {timestep}" + guidance_mapping.append( + indices[0] if len(indices) > 0 else (len(guidance_timesteps) - 1) + ) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + if not isinstance(guidance_scale, List): + guidance_scale = [guidance_scale] * len(timesteps) + else: + guidance_scale = [ + guidance_scale[guidance_mapping[i]] for i in range(len(timesteps)) + ] + + if not isinstance(stg_scale, List): + stg_scale = [stg_scale] * len(timesteps) + else: + stg_scale = [stg_scale[guidance_mapping[i]] for i in range(len(timesteps))] + + if not isinstance(rescaling_scale, List): + rescaling_scale = [rescaling_scale] * len(timesteps) + else: + rescaling_scale = [ + rescaling_scale[guidance_mapping[i]] for i in range(len(timesteps)) + ] + + # Normalize skip_block_list to always be None or a list of lists matching timesteps + if skip_block_list is not None: + # Convert single list to list of lists if needed + if len(skip_block_list) == 0 or not isinstance(skip_block_list[0], list): + skip_block_list = [skip_block_list] * len(timesteps) + else: + new_skip_block_list = [] + for i, timestep in enumerate(timesteps): + new_skip_block_list.append(skip_block_list[guidance_mapping[i]]) + skip_block_list = new_skip_block_list + + if enhance_prompt: + self.prompt_enhancer_image_caption_model = ( + self.prompt_enhancer_image_caption_model.to(self._execution_device) + ) + self.prompt_enhancer_llm_model = self.prompt_enhancer_llm_model.to( + self._execution_device + ) + + prompt = generate_cinematic_prompt( + self.prompt_enhancer_image_caption_model, + self.prompt_enhancer_image_caption_processor, + self.prompt_enhancer_llm_model, + self.prompt_enhancer_llm_tokenizer, + prompt, + conditioning_items, + max_new_tokens=text_encoder_max_tokens, + ) + + # --- [NOSSA ESCUTA SECRETA AQUI] --- + print("--- [LOG DO DIRETOR ASSISTENTE (PROMPT ENHANCER)] ---") + print("Prompt Original do Maestro:", kwargs.get("original_prompt_for_logging", "N/A")) # Precisamos passar isso + print("PROMPT FINAL APERFEIÇOADO (enviado para o LTX):", prompt) + print("--- [FIM DO LOG DO DIRETOR ASSISTENTE] ---") + # --- [FIM DA ESCUTA] --- + + + # 3. Encode input prompt + if self.text_encoder is not None: + self.text_encoder = self.text_encoder.to(self._execution_device) + + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + True, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + text_encoder_max_tokens=text_encoder_max_tokens, + ) + + if offload_to_cpu and self.text_encoder is not None: + self.text_encoder = self.text_encoder.cpu() + + self.transformer = self.transformer.to(self._execution_device) + + prompt_embeds_batch = prompt_embeds + prompt_attention_mask_batch = prompt_attention_mask + negative_prompt_embeds = ( + torch.zeros_like(prompt_embeds) + if negative_prompt_embeds is None + else negative_prompt_embeds + ) + negative_prompt_attention_mask = ( + torch.zeros_like(prompt_attention_mask) + if negative_prompt_attention_mask is None + else negative_prompt_attention_mask + ) + + prompt_embeds_batch = torch.cat( + [negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0 + ) + prompt_attention_mask_batch = torch.cat( + [ + negative_prompt_attention_mask, + prompt_attention_mask, + prompt_attention_mask, + ], + dim=0, + ) + # 4. Prepare the initial latents using the provided media and conditioning items + + # Prepare the initial latents tensor, shape = (b, c, f, h, w) + latents = self.prepare_latents( + latents=latents, + media_items=media_items, + timestep=timesteps[0], + latent_shape=latent_shape, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + vae_per_channel_normalize=vae_per_channel_normalize, + ) + + # Update the latents with the conditioning items and patchify them into (b, n, c) + latents, pixel_coords, conditioning_mask, num_cond_latents = ( + self.prepare_conditioning( + conditioning_items=conditioning_items, + init_latents=latents, + num_frames=num_frames, + height=height, + width=width, + vae_per_channel_normalize=vae_per_channel_normalize, + generator=generator, + ) + ) + init_latents = latents.clone() # Used for image_cond_noise_update + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) + + orig_conditioning_mask = conditioning_mask + + # Befor compiling this code please be aware: + # This code might generate different input shapes if some timesteps have no STG or CFG. + # This means that the codes might need to be compiled mutliple times. + # To avoid that, use the same STG and CFG values for all timesteps. + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + do_classifier_free_guidance = guidance_scale[i] > 1.0 + do_spatio_temporal_guidance = stg_scale[i] > 0 + do_rescaling = rescaling_scale[i] != 1.0 + + num_conds = 1 + if do_classifier_free_guidance: + num_conds += 1 + if do_spatio_temporal_guidance: + num_conds += 1 + + if do_classifier_free_guidance and do_spatio_temporal_guidance: + indices = slice(batch_size * 0, batch_size * 3) + elif do_classifier_free_guidance: + indices = slice(batch_size * 0, batch_size * 2) + elif do_spatio_temporal_guidance: + indices = slice(batch_size * 1, batch_size * 3) + else: + indices = slice(batch_size * 1, batch_size * 2) + + # Prepare skip layer masks + skip_layer_mask: Optional[torch.Tensor] = None + if do_spatio_temporal_guidance: + if skip_block_list is not None: + skip_layer_mask = self.transformer.create_skip_layer_mask( + batch_size, num_conds, num_conds - 1, skip_block_list[i] + ) + + batch_pixel_coords = torch.cat([pixel_coords] * num_conds) + conditioning_mask = orig_conditioning_mask + if conditioning_mask is not None and is_video: + assert num_images_per_prompt == 1 + conditioning_mask = torch.cat([conditioning_mask] * num_conds) + fractional_coords = batch_pixel_coords.to(torch.float32) + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + + if conditioning_mask is not None and image_cond_noise_scale > 0.0: + latents = self.add_noise_to_image_conditioning_latents( + t, + init_latents, + latents, + image_cond_noise_scale, + orig_conditioning_mask, + generator, + ) + + latent_model_input = ( + torch.cat([latents] * num_conds) if num_conds > 1 else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor( + [current_timestep], + dtype=dtype, + device=latent_model_input.device, + ) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to( + latent_model_input.device + ) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand( + latent_model_input.shape[0] + ).unsqueeze(-1) + + if conditioning_mask is not None: + # Conditioning latents have an initial timestep and noising level of (1.0 - conditioning_mask) + # and will start to be denoised when the current timestep is lower than their conditioning timestep. + current_timestep = torch.min( + current_timestep, 1.0 - conditioning_mask + ) + + # Choose the appropriate context manager based on `mixed_precision` + if mixed_precision: + context_manager = torch.autocast(device.type, dtype=torch.bfloat16) + else: + context_manager = nullcontext() # Dummy context manager + + # predict noise model_output + with context_manager: + noise_pred = self.transformer( + latent_model_input.to(self.transformer.dtype), + indices_grid=fractional_coords, + encoder_hidden_states=prompt_embeds_batch[indices].to( + self.transformer.dtype + ), + encoder_attention_mask=prompt_attention_mask_batch[indices], + timestep=current_timestep, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + return_dict=False, + )[0] + + # perform guidance + if do_spatio_temporal_guidance: + noise_pred_text, noise_pred_text_perturb = noise_pred.chunk( + num_conds + )[-2:] + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_conds)[:2] + + if cfg_star_rescale: + # Rescales the unconditional noise prediction using the projection of the conditional prediction onto it: + # α = (⟨ε_text, ε_uncond⟩ / ||ε_uncond||²), then ε_uncond ← α * ε_uncond + # where ε_text is the conditional noise prediction and ε_uncond is the unconditional one. + positive_flat = noise_pred_text.view(batch_size, -1) + negative_flat = noise_pred_uncond.view(batch_size, -1) + dot_product = torch.sum( + positive_flat * negative_flat, dim=1, keepdim=True + ) + squared_norm = ( + torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 + ) + alpha = dot_product / squared_norm + noise_pred_uncond = alpha * noise_pred_uncond + + noise_pred = noise_pred_uncond + guidance_scale[i] * ( + noise_pred_text - noise_pred_uncond + ) + elif do_spatio_temporal_guidance: + noise_pred = noise_pred_text + if do_spatio_temporal_guidance: + noise_pred = noise_pred + stg_scale[i] * ( + noise_pred_text - noise_pred_text_perturb + ) + if do_rescaling and stg_scale[i] > 0.0: + noise_pred_text_std = noise_pred_text.view(batch_size, -1).std( + dim=1, keepdim=True + ) + noise_pred_std = noise_pred.view(batch_size, -1).std( + dim=1, keepdim=True + ) + + factor = noise_pred_text_std / noise_pred_std + factor = rescaling_scale[i] * factor + (1 - rescaling_scale[i]) + + noise_pred = noise_pred * factor.view(batch_size, 1, 1) + + current_timestep = current_timestep[:1] + # learned sigma + if ( + self.transformer.config.out_channels // 2 + == self.transformer.config.in_channels + ): + noise_pred = noise_pred.chunk(2, dim=1)[0] + + # compute previous image: x_t -> x_t-1 + latents = self.denoising_step( + latents, + noise_pred, + current_timestep, + orig_conditioning_mask, + t, + extra_step_kwargs, + stochastic_sampling=stochastic_sampling, + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if callback_on_step_end is not None: + callback_on_step_end(self, i, t, {}) + + if offload_to_cpu: + self.transformer = self.transformer.cpu() + if self._execution_device == "cuda": + torch.cuda.empty_cache() + + # Remove the added conditioning latents + latents = latents[:, num_cond_latents:] + + latents = self.patchifier.unpatchify( + latents=latents, + output_height=latent_height, + output_width=latent_width, + out_channels=self.transformer.in_channels + // math.prod(self.patchifier.patch_size), + ) + if output_type != "latent": + if self.vae.decoder.timestep_conditioning: + noise = torch.randn_like(latents) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * latents.shape[0] + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * latents.shape[0] + + decode_timestep = torch.tensor(decode_timestep).to(latents.device) + decode_noise_scale = torch.tensor(decode_noise_scale).to( + latents.device + )[:, None, None, None, None] + latents = ( + latents * (1 - decode_noise_scale) + noise * decode_noise_scale + ) + else: + decode_timestep = None + latents = self.tone_map_latents(latents, tone_map_compression_ratio) + image = vae_decode( + latents, + self.vae, + is_video, + vae_per_channel_normalize=kwargs["vae_per_channel_normalize"], + timestep=decode_timestep, + ) + + image = self.image_processor.postprocess(image, output_type=output_type) + + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + def denoising_step( + self, + latents: torch.Tensor, + noise_pred: torch.Tensor, + current_timestep: torch.Tensor, + conditioning_mask: torch.Tensor, + t: float, + extra_step_kwargs, + t_eps=1e-6, + stochastic_sampling=False, + ): + """ + Perform the denoising step for the required tokens, based on the current timestep and + conditioning mask: + Conditioning latents have an initial timestep and noising level of (1.0 - conditioning_mask) + and will start to be denoised when the current timestep is equal or lower than their + conditioning timestep. + (hard-conditioning latents with conditioning_mask = 1.0 are never denoised) + """ + # Denoise the latents using the scheduler + denoised_latents = self.scheduler.step( + noise_pred, + t if current_timestep is None else current_timestep, + latents, + **extra_step_kwargs, + return_dict=False, + stochastic_sampling=stochastic_sampling, + )[0] + + if conditioning_mask is None: + return denoised_latents + + tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).unsqueeze(-1) + return torch.where(tokens_to_denoise_mask, denoised_latents, latents) + + def prepare_conditioning( + self, + conditioning_items: Optional[List[ConditioningItem]], + init_latents: torch.Tensor, + num_frames: int, + height: int, + width: int, + vae_per_channel_normalize: bool = False, + generator=None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + """ + Prepare conditioning tokens based on the provided conditioning items. + + This method encodes provided conditioning items (video frames or single frames) into latents + and integrates them with the initial latent tensor. It also calculates corresponding pixel + coordinates, a mask indicating the influence of conditioning latents, and the total number of + conditioning latents. + + Args: + conditioning_items (Optional[List[ConditioningItem]]): A list of ConditioningItem objects. + init_latents (torch.Tensor): The initial latent tensor of shape (b, c, f_l, h_l, w_l), where + `f_l` is the number of latent frames, and `h_l` and `w_l` are latent spatial dimensions. + num_frames, height, width: The dimensions of the generated video. + vae_per_channel_normalize (bool, optional): Whether to normalize channels during VAE encoding. + Defaults to `False`. + generator: The random generator + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + - `init_latents` (torch.Tensor): The updated latent tensor including conditioning latents, + patchified into (b, n, c) shape. + - `init_pixel_coords` (torch.Tensor): The pixel coordinates corresponding to the updated + latent tensor. + - `conditioning_mask` (torch.Tensor): A mask indicating the conditioning-strength of each + latent token. + - `num_cond_latents` (int): The total number of latent tokens added from conditioning items. + + Raises: + AssertionError: If input shapes, dimensions, or conditions for applying conditioning are invalid. + """ + assert isinstance(self.vae, CausalVideoAutoencoder) + + if conditioning_items: + batch_size, _, num_latent_frames = init_latents.shape[:3] + + init_conditioning_mask = torch.zeros( + init_latents[:, 0, :, :, :].shape, + dtype=torch.float32, + device=init_latents.device, + ) + + extra_conditioning_latents = [] + extra_conditioning_pixel_coords = [] + extra_conditioning_mask = [] + extra_conditioning_num_latents = 0 # Number of extra conditioning latents added (should be removed before decoding) + + # Process each conditioning item + for conditioning_item in conditioning_items: + conditioning_item = self._resize_conditioning_item( + conditioning_item, height, width + ) + media_item = conditioning_item.media_item + media_frame_number = conditioning_item.media_frame_number + strength = conditioning_item.conditioning_strength + assert media_item.ndim == 5 # (b, c, f, h, w) + b, c, n_frames, h, w = media_item.shape + assert ( + height == h and width == w + ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0" + assert n_frames % 8 == 1 + assert ( + media_frame_number >= 0 + and media_frame_number + n_frames <= num_frames + ) + + # Encode the provided conditioning media item + media_item_latents = vae_encode( + media_item.to(dtype=self.vae.dtype, device=self.vae.device), + self.vae, + vae_per_channel_normalize=vae_per_channel_normalize, + ).to(dtype=init_latents.dtype) + + # Handle the different conditioning cases + if media_frame_number == 0: + # Get the target spatial position of the latent conditioning item + media_item_latents, l_x, l_y = self._get_latent_spatial_position( + media_item_latents, + conditioning_item, + height, + width, + strip_latent_border=True, + ) + b, c_l, f_l, h_l, w_l = media_item_latents.shape + + # First frame or sequence - just update the initial noise latents and the mask + init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = ( + torch.lerp( + init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], + media_item_latents, + strength, + ) + ) + init_conditioning_mask[ + :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l + ] = strength + else: + # Non-first frame or sequence + if n_frames > 1: + # Handle non-first sequence. + # Encoded latents are either fully consumed, or the prefix is handled separately below. + ( + init_latents, + init_conditioning_mask, + media_item_latents, + ) = self._handle_non_first_conditioning_sequence( + init_latents, + init_conditioning_mask, + media_item_latents, + media_frame_number, + strength, + ) + + # Single frame or sequence-prefix latents + if media_item_latents is not None: + noise = randn_tensor( + media_item_latents.shape, + generator=generator, + device=media_item_latents.device, + dtype=media_item_latents.dtype, + ) + + media_item_latents = torch.lerp( + noise, media_item_latents, strength + ) + + # Patchify the extra conditioning latents and calculate their pixel coordinates + media_item_latents, latent_coords = self.patchifier.patchify( + latents=media_item_latents + ) + pixel_coords = latent_to_pixel_coords( + latent_coords, + self.vae, + causal_fix=self.transformer.config.causal_temporal_positioning, + ) + + # Update the frame numbers to match the target frame number + pixel_coords[:, 0] += media_frame_number + extra_conditioning_num_latents += media_item_latents.shape[1] + + conditioning_mask = torch.full( + media_item_latents.shape[:2], + strength, + dtype=torch.float32, + device=init_latents.device, + ) + + extra_conditioning_latents.append(media_item_latents) + extra_conditioning_pixel_coords.append(pixel_coords) + extra_conditioning_mask.append(conditioning_mask) + + # Patchify the updated latents and calculate their pixel coordinates + init_latents, init_latent_coords = self.patchifier.patchify( + latents=init_latents + ) + init_pixel_coords = latent_to_pixel_coords( + init_latent_coords, + self.vae, + causal_fix=self.transformer.config.causal_temporal_positioning, + ) + + if not conditioning_items: + return init_latents, init_pixel_coords, None, 0 + + init_conditioning_mask, _ = self.patchifier.patchify( + latents=init_conditioning_mask.unsqueeze(1) + ) + init_conditioning_mask = init_conditioning_mask.squeeze(-1) + + if extra_conditioning_latents: + # Stack the extra conditioning latents, pixel coordinates and mask + init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1) + init_pixel_coords = torch.cat( + [*extra_conditioning_pixel_coords, init_pixel_coords], dim=2 + ) + init_conditioning_mask = torch.cat( + [*extra_conditioning_mask, init_conditioning_mask], dim=1 + ) + + if self.transformer.use_tpu_flash_attention: + # When flash attention is used, keep the original number of tokens by removing + # tokens from the end. + init_latents = init_latents[:, :-extra_conditioning_num_latents] + init_pixel_coords = init_pixel_coords[ + :, :, :-extra_conditioning_num_latents + ] + init_conditioning_mask = init_conditioning_mask[ + :, :-extra_conditioning_num_latents + ] + + return ( + init_latents, + init_pixel_coords, + init_conditioning_mask, + extra_conditioning_num_latents, + ) + + @staticmethod + def _resize_conditioning_item( + conditioning_item: ConditioningItem, + height: int, + width: int, + ): + if conditioning_item.media_x or conditioning_item.media_y: + raise ValueError( + "Provide media_item in the target size for spatial conditioning." + ) + new_conditioning_item = copy.copy(conditioning_item) + new_conditioning_item.media_item = LTXVideoPipeline.resize_tensor( + conditioning_item.media_item, height, width + ) + return new_conditioning_item + + def _get_latent_spatial_position( + self, + latents: torch.Tensor, + conditioning_item: ConditioningItem, + height: int, + width: int, + strip_latent_border, + ): + """ + Get the spatial position of the conditioning item in the latent space. + If requested, strip the conditioning latent borders that do not align with target borders. + (border latents look different then other latents and might confuse the model) + """ + scale = self.vae_scale_factor + h, w = conditioning_item.media_item.shape[-2:] + assert ( + h <= height and w <= width + ), f"Conditioning item size {h}x{w} is larger than target size {height}x{width}" + assert h % scale == 0 and w % scale == 0 + + # Compute the start and end spatial positions of the media item + x_start, y_start = conditioning_item.media_x, conditioning_item.media_y + x_start = (width - w) // 2 if x_start is None else x_start + y_start = (height - h) // 2 if y_start is None else y_start + x_end, y_end = x_start + w, y_start + h + assert ( + x_end <= width and y_end <= height + ), f"Conditioning item {x_start}:{x_end}x{y_start}:{y_end} is out of bounds for target size {width}x{height}" + + if strip_latent_border: + # Strip one latent from left/right and/or top/bottom, update x, y accordingly + if x_start > 0: + x_start += scale + latents = latents[:, :, :, :, 1:] + + if y_start > 0: + y_start += scale + latents = latents[:, :, :, 1:, :] + + if x_end < width: + latents = latents[:, :, :, :, :-1] + + if y_end < height: + latents = latents[:, :, :, :-1, :] + + return latents, x_start // scale, y_start // scale + + @staticmethod + def _handle_non_first_conditioning_sequence( + init_latents: torch.Tensor, + init_conditioning_mask: torch.Tensor, + latents: torch.Tensor, + media_frame_number: int, + strength: float, + num_prefix_latent_frames: int = 2, + prefix_latents_mode: str = "concat", + prefix_soft_conditioning_strength: float = 0.15, + ): + """ + Special handling for a conditioning sequence that does not start on the first frame. + The special handling is required to allow a short encoded video to be used as middle + (or last) sequence in a longer video. + Args: + init_latents (torch.Tensor): The initial noise latents to be updated. + init_conditioning_mask (torch.Tensor): The initial conditioning mask to be updated. + latents (torch.Tensor): The encoded conditioning item. + media_frame_number (int): The target frame number of the first frame in the conditioning sequence. + strength (float): The conditioning strength for the conditioning latents. + num_prefix_latent_frames (int, optional): The length of the sequence prefix, to be handled + separately. Defaults to 2. + prefix_latents_mode (str, optional): Special treatment for prefix (boundary) latents. + - "drop": Drop the prefix latents. + - "soft": Use the prefix latents, but with soft-conditioning + - "concat": Add the prefix latents as extra tokens (like single frames) + prefix_soft_conditioning_strength (float, optional): The strength of the soft-conditioning for + the prefix latents, relevant if `prefix_latents_mode` is "soft". Defaults to 0.1. + + """ + f_l = latents.shape[2] + f_l_p = num_prefix_latent_frames + assert f_l >= f_l_p + assert media_frame_number % 8 == 0 + if f_l > f_l_p: + # Insert the conditioning latents **excluding the prefix** into the sequence + f_l_start = media_frame_number // 8 + f_l_p + f_l_end = f_l_start + f_l - f_l_p + init_latents[:, :, f_l_start:f_l_end] = torch.lerp( + init_latents[:, :, f_l_start:f_l_end], + latents[:, :, f_l_p:], + strength, + ) + # Mark these latent frames as conditioning latents + init_conditioning_mask[:, f_l_start:f_l_end] = strength + + # Handle the prefix-latents + if prefix_latents_mode == "soft": + if f_l_p > 1: + # Drop the first (single-frame) latent and soft-condition the remaining prefix + f_l_start = media_frame_number // 8 + 1 + f_l_end = f_l_start + f_l_p - 1 + strength = min(prefix_soft_conditioning_strength, strength) + init_latents[:, :, f_l_start:f_l_end] = torch.lerp( + init_latents[:, :, f_l_start:f_l_end], + latents[:, :, 1:f_l_p], + strength, + ) + # Mark these latent frames as conditioning latents + init_conditioning_mask[:, f_l_start:f_l_end] = strength + latents = None # No more latents to handle + elif prefix_latents_mode == "drop": + # Drop the prefix latents + latents = None + elif prefix_latents_mode == "concat": + # Pass-on the prefix latents to be handled as extra conditioning frames + latents = latents[:, :, :f_l_p] + else: + raise ValueError(f"Invalid prefix_latents_mode: {prefix_latents_mode}") + return ( + init_latents, + init_conditioning_mask, + latents, + ) + + def trim_conditioning_sequence( + self, start_frame: int, sequence_num_frames: int, target_num_frames: int + ): + """ + Trim a conditioning sequence to the allowed number of frames. + + Args: + start_frame (int): The target frame number of the first frame in the sequence. + sequence_num_frames (int): The number of frames in the sequence. + target_num_frames (int): The target number of frames in the generated video. + + Returns: + int: updated sequence length + """ + scale_factor = self.video_scale_factor + num_frames = min(sequence_num_frames, target_num_frames - start_frame) + # Trim down to a multiple of temporal_scale_factor frames plus 1 + num_frames = (num_frames - 1) // scale_factor * scale_factor + 1 + return num_frames + + @staticmethod + def tone_map_latents( + latents: torch.Tensor, + compression: float, + ) -> torch.Tensor: + """ + Applies a non-linear tone-mapping function to latent values to reduce their dynamic range + in a perceptually smooth way using a sigmoid-based compression. + + This is useful for regularizing high-variance latents or for conditioning outputs + during generation, especially when controlling dynamic behavior with a `compression` factor. + + Parameters: + ---------- + latents : torch.Tensor + Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range. + compression : float + Compression strength in the range [0, 1]. + - 0.0: No tone-mapping (identity transform) + - 1.0: Full compression effect + + Returns: + ------- + torch.Tensor + The tone-mapped latent tensor of the same shape as input. + """ + if not (0 <= compression <= 1): + raise ValueError("Compression must be in the range [0, 1]") + + # Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot + scale_factor = compression * 0.75 + abs_latents = torch.abs(latents) + + # Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0 + # When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect + sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0)) + scales = 1.0 - 0.8 * scale_factor * sigmoid_term + + filtered = latents * scales + return filtered + + +def adain_filter_latent( + latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0 +): + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on + statistics from a reference latent tensor. + + Args: + latent (torch.Tensor): Input latents to normalize + reference_latent (torch.Tensor): The reference latents providing style statistics. + factor (float): Blending factor between original and transformed latent. + Range: -10.0 to 10.0, Default: 1.0 + + Returns: + torch.Tensor: The transformed latent tensor + """ + result = latents.clone() + + for i in range(latents.size(0)): + for c in range(latents.size(1)): + r_sd, r_mean = torch.std_mean( + reference_latents[i, c], dim=None + ) # index by original dim order + i_sd, i_mean = torch.std_mean(result[i, c], dim=None) + + result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + + result = torch.lerp(latents, result, factor) + return result + + +class LTXMultiScalePipeline: + def _upsample_latents( + self, latest_upsampler: LatentUpsampler, latents: torch.Tensor + ): + assert latents.device == latest_upsampler.device + + latents = un_normalize_latents( + latents, self.vae, vae_per_channel_normalize=True + ) + upsampled_latents = latest_upsampler(latents) + upsampled_latents = normalize_latents( + upsampled_latents, self.vae, vae_per_channel_normalize=True + ) + return upsampled_latents + + def __init__( + self, video_pipeline: LTXVideoPipeline, latent_upsampler: LatentUpsampler + ): + self.video_pipeline = video_pipeline + self.vae = video_pipeline.vae + self.latent_upsampler = latent_upsampler + + def __call__( + self, + downscale_factor: float, + first_pass: dict, + second_pass: dict, + *args: Any, + **kwargs: Any, + ) -> Any: + original_kwargs = kwargs.copy() + original_output_type = kwargs["output_type"] + original_width = kwargs["width"] + original_height = kwargs["height"] + + x_width = int(kwargs["width"] * downscale_factor) + downscaled_width = x_width - (x_width % self.video_pipeline.vae_scale_factor) + x_height = int(kwargs["height"] * downscale_factor) + downscaled_height = x_height - (x_height % self.video_pipeline.vae_scale_factor) + + kwargs["output_type"] = "latent" + kwargs["width"] = downscaled_width + kwargs["height"] = downscaled_height + kwargs.update(**first_pass) + result = self.video_pipeline(*args, **kwargs) + latents = result.images + + upsampled_latents = self._upsample_latents(self.latent_upsampler, latents) + upsampled_latents = adain_filter_latent( + latents=upsampled_latents, reference_latents=latents + ) + + kwargs = original_kwargs + + kwargs["latents"] = upsampled_latents + kwargs["output_type"] = original_output_type + kwargs["width"] = downscaled_width * 2 + kwargs["height"] = downscaled_height * 2 + kwargs.update(**second_pass) + + result = self.video_pipeline(*args, **kwargs) + if original_output_type != "latent": + num_frames = result.images.shape[2] + videos = rearrange(result.images, "b c f h w -> (b f) c h w") + + videos = F.interpolate( + videos, + size=(original_height, original_width), + mode="bilinear", + align_corners=False, + ) + videos = rearrange(videos, "(b f) c h w -> b c f h w", f=num_frames) + result.images = videos + + return result diff --git a/ltx_video/schedulers/__init__.py b/ltx_video/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ltx_video/schedulers/rf.py b/ltx_video/schedulers/rf.py new file mode 100644 index 0000000000000000000000000000000000000000..c7d2ab3426645941efa71ec0c5d866d9ea9c90d4 --- /dev/null +++ b/ltx_video/schedulers/rf.py @@ -0,0 +1,386 @@ +import math +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Optional, Tuple, Union +import json +import os +from pathlib import Path + +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput +from torch import Tensor +from safetensors import safe_open + + +from ltx_video.utils.torch_utils import append_dims + +from ltx_video.utils.diffusers_config_mapping import ( + diffusers_and_ours_config_mapping, + make_hashable_key, +) + + +def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_steps=None): + if num_steps == 1: + return torch.tensor([1.0]) + if linear_steps is None: + linear_steps = num_steps // 2 + linear_sigma_schedule = [ + i * threshold_noise / linear_steps for i in range(linear_steps) + ] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / ( + quadratic_steps**2 + ) + const = quadratic_coef * (linear_steps**2) + quadratic_sigma_schedule = [ + quadratic_coef * (i**2) + linear_coef * i + const + for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] + sigma_schedule = [1.0 - x for x in sigma_schedule] + return torch.tensor(sigma_schedule[:-1]) + + +def simple_diffusion_resolution_dependent_timestep_shift( + samples_shape: torch.Size, + timesteps: Tensor, + n: int = 32 * 32, +) -> Tensor: + if len(samples_shape) == 3: + _, m, _ = samples_shape + elif len(samples_shape) in [4, 5]: + m = math.prod(samples_shape[2:]) + else: + raise ValueError( + "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)" + ) + snr = (timesteps / (1 - timesteps)) ** 2 + shift_snr = torch.log(snr) + 2 * math.log(m / n) + shifted_timesteps = torch.sigmoid(0.5 * shift_snr) + + return shifted_timesteps + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_normal_shift( + n_tokens: int, + min_tokens: int = 1024, + max_tokens: int = 4096, + min_shift: float = 0.95, + max_shift: float = 2.05, +) -> Callable[[float], float]: + m = (max_shift - min_shift) / (max_tokens - min_tokens) + b = min_shift - m * min_tokens + return m * n_tokens + b + + +def strech_shifts_to_terminal(shifts: Tensor, terminal=0.1): + """ + Stretch a function (given as sampled shifts) so that its final value matches the given terminal value + using the provided formula. + + Parameters: + - shifts (Tensor): The samples of the function to be stretched (PyTorch Tensor). + - terminal (float): The desired terminal value (value at the last sample). + + Returns: + - Tensor: The stretched shifts such that the final value equals `terminal`. + """ + if shifts.numel() == 0: + raise ValueError("The 'shifts' tensor must not be empty.") + + # Ensure terminal value is valid + if terminal <= 0 or terminal >= 1: + raise ValueError("The terminal value must be between 0 and 1 (exclusive).") + + # Transform the shifts using the given formula + one_minus_z = 1 - shifts + scale_factor = one_minus_z[-1] / (1 - terminal) + stretched_shifts = 1 - (one_minus_z / scale_factor) + + return stretched_shifts + + +def sd3_resolution_dependent_timestep_shift( + samples_shape: torch.Size, + timesteps: Tensor, + target_shift_terminal: Optional[float] = None, +) -> Tensor: + """ + Shifts the timestep schedule as a function of the generated resolution. + + In the SD3 paper, the authors empirically how to shift the timesteps based on the resolution of the target images. + For more details: https://arxiv.org/pdf/2403.03206 + + In Flux they later propose a more dynamic resolution dependent timestep shift, see: + https://github.com/black-forest-labs/flux/blob/87f6fff727a377ea1c378af692afb41ae84cbe04/src/flux/sampling.py#L66 + + + Args: + samples_shape (torch.Size): The samples batch shape (batch_size, channels, height, width) or + (batch_size, channels, frame, height, width). + timesteps (Tensor): A batch of timesteps with shape (batch_size,). + target_shift_terminal (float): The target terminal value for the shifted timesteps. + + Returns: + Tensor: The shifted timesteps. + """ + if len(samples_shape) == 3: + _, m, _ = samples_shape + elif len(samples_shape) in [4, 5]: + m = math.prod(samples_shape[2:]) + else: + raise ValueError( + "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)" + ) + + shift = get_normal_shift(m) + time_shifts = time_shift(shift, 1, timesteps) + if target_shift_terminal is not None: # Stretch the shifts to the target terminal + time_shifts = strech_shifts_to_terminal(time_shifts, target_shift_terminal) + return time_shifts + + +class TimestepShifter(ABC): + @abstractmethod + def shift_timesteps(self, samples_shape: torch.Size, timesteps: Tensor) -> Tensor: + pass + + +@dataclass +class RectifiedFlowSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter): + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps=1000, + shifting: Optional[str] = None, + base_resolution: int = 32**2, + target_shift_terminal: Optional[float] = None, + sampler: Optional[str] = "Uniform", + shift: Optional[float] = None, + ): + super().__init__() + self.init_noise_sigma = 1.0 + self.num_inference_steps = None + self.sampler = sampler + self.shifting = shifting + self.base_resolution = base_resolution + self.target_shift_terminal = target_shift_terminal + self.timesteps = self.sigmas = self.get_initial_timesteps( + num_train_timesteps, shift=shift + ) + self.shift = shift + + def get_initial_timesteps( + self, num_timesteps: int, shift: Optional[float] = None + ) -> Tensor: + if self.sampler == "Uniform": + return torch.linspace(1, 1 / num_timesteps, num_timesteps) + elif self.sampler == "LinearQuadratic": + return linear_quadratic_schedule(num_timesteps) + elif self.sampler == "Constant": + assert ( + shift is not None + ), "Shift must be provided for constant time shift sampler." + return time_shift( + shift, 1, torch.linspace(1, 1 / num_timesteps, num_timesteps) + ) + + def shift_timesteps(self, samples_shape: torch.Size, timesteps: Tensor) -> Tensor: + if self.shifting == "SD3": + return sd3_resolution_dependent_timestep_shift( + samples_shape, timesteps, self.target_shift_terminal + ) + elif self.shifting == "SimpleDiffusion": + return simple_diffusion_resolution_dependent_timestep_shift( + samples_shape, timesteps, self.base_resolution + ) + return timesteps + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + samples_shape: Optional[torch.Size] = None, + timesteps: Optional[Tensor] = None, + device: Union[str, torch.device] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + If `timesteps` are provided, they will be used instead of the scheduled timesteps. + + Args: + num_inference_steps (`int` *optional*): The number of diffusion steps used when generating samples. + samples_shape (`torch.Size` *optional*): The samples batch shape, used for shifting. + timesteps ('torch.Tensor' *optional*): Specific timesteps to use instead of scheduled timesteps. + device (`Union[str, torch.device]`, *optional*): The device to which the timesteps tensor will be moved. + """ + if timesteps is not None and num_inference_steps is not None: + raise ValueError( + "You cannot provide both `timesteps` and `num_inference_steps`." + ) + if timesteps is None: + num_inference_steps = min( + self.config.num_train_timesteps, num_inference_steps + ) + timesteps = self.get_initial_timesteps( + num_inference_steps, shift=self.shift + ).to(device) + timesteps = self.shift_timesteps(samples_shape, timesteps) + else: + timesteps = torch.Tensor(timesteps).to(device) + num_inference_steps = len(timesteps) + self.timesteps = timesteps + self.num_inference_steps = num_inference_steps + self.sigmas = self.timesteps + + @staticmethod + def from_pretrained(pretrained_model_path: Union[str, os.PathLike]): + pretrained_model_path = Path(pretrained_model_path) + if pretrained_model_path.is_file(): + comfy_single_file_state_dict = {} + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + comfy_single_file_state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + config = configs["scheduler"] + del comfy_single_file_state_dict + + elif pretrained_model_path.is_dir(): + diffusers_noise_scheduler_config_path = ( + pretrained_model_path / "scheduler" / "scheduler_config.json" + ) + + with open(diffusers_noise_scheduler_config_path, "r") as f: + scheduler_config = json.load(f) + hashable_config = make_hashable_key(scheduler_config) + if hashable_config in diffusers_and_ours_config_mapping: + config = diffusers_and_ours_config_mapping[hashable_config] + return RectifiedFlowScheduler.from_config(config) + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Optional[int] = None + ) -> torch.FloatTensor: + # pylint: disable=unused-argument + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def step( + self, + model_output: torch.FloatTensor, + timestep: torch.FloatTensor, + sample: torch.FloatTensor, + return_dict: bool = True, + stochastic_sampling: Optional[bool] = False, + **kwargs, + ) -> Union[RectifiedFlowSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + z_{t_1} = z_t - Delta_t * v + The method finds the next timestep that is lower than the input timestep(s) and denoises the latents + to that level. The input timestep(s) are not required to be one of the predefined timesteps. + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model - the velocity, + timestep (`float`): + The current discrete timestep in the diffusion chain (global or per-token). + sample (`torch.FloatTensor`): + A current latent tokens to be de-noised. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. + stochastic_sampling (`bool`, *optional*, defaults to `False`): + Whether to use stochastic sampling for the sampling process. + + Returns: + [`~schedulers.scheduling_utils.RectifiedFlowSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.rf_scheduler.RectifiedFlowSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + t_eps = 1e-6 # Small epsilon to avoid numerical issues in timestep values + + timesteps_padded = torch.cat( + [self.timesteps, torch.zeros(1, device=self.timesteps.device)] + ) + + # Find the next lower timestep(s) and compute the dt from the current timestep(s) + if timestep.ndim == 0: + # Global timestep case + lower_mask = timesteps_padded < timestep - t_eps + lower_timestep = timesteps_padded[lower_mask][0] # Closest lower timestep + dt = timestep - lower_timestep + + else: + # Per-token case + assert timestep.ndim == 2 + lower_mask = timesteps_padded[:, None, None] < timestep[None] - t_eps + lower_timestep = lower_mask * timesteps_padded[:, None, None] + lower_timestep, _ = lower_timestep.max(dim=0) + dt = (timestep - lower_timestep)[..., None] + + # Compute previous sample + if stochastic_sampling: + x0 = sample - timestep[..., None] * model_output + next_timestep = timestep[..., None] - dt + prev_sample = self.add_noise(x0, torch.randn_like(sample), next_timestep) + else: + prev_sample = sample - dt * model_output + + if not return_dict: + return (prev_sample,) + + return RectifiedFlowSchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + sigmas = timesteps + sigmas = append_dims(sigmas, original_samples.ndim) + alphas = 1 - sigmas + noisy_samples = alphas * original_samples + sigmas * noise + return noisy_samples diff --git a/ltx_video/utils/__init__.py b/ltx_video/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ltx_video/utils/diffusers_config_mapping.py b/ltx_video/utils/diffusers_config_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..53c0082d182617f6f84eab9c849f7ef0224becb8 --- /dev/null +++ b/ltx_video/utils/diffusers_config_mapping.py @@ -0,0 +1,174 @@ +def make_hashable_key(dict_key): + def convert_value(value): + if isinstance(value, list): + return tuple(value) + elif isinstance(value, dict): + return tuple(sorted((k, convert_value(v)) for k, v in value.items())) + else: + return value + + return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items())) + + +DIFFUSERS_SCHEDULER_CONFIG = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "_diffusers_version": "0.32.0.dev0", + "base_image_seq_len": 1024, + "base_shift": 0.95, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 2.05, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": 0.1, + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False, +} +DIFFUSERS_TRANSFORMER_CONFIG = { + "_class_name": "LTXVideoTransformer3DModel", + "_diffusers_version": "0.32.0.dev0", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_out_bias": True, + "caption_channels": 4096, + "cross_attention_dim": 2048, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_layers": 28, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "qk_norm": "rms_norm_across_heads", +} +DIFFUSERS_VAE_CONFIG = { + "_class_name": "AutoencoderKLLTXVideo", + "_diffusers_version": "0.32.0.dev0", + "block_out_channels": [128, 256, 512, 512], + "decoder_causal": False, + "encoder_causal": True, + "in_channels": 3, + "latent_channels": 128, + "layers_per_block": [4, 3, 3, 3, 4], + "out_channels": 3, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-06, + "scaling_factor": 1.0, + "spatio_temporal_scaling": [True, True, True, False], +} + +OURS_SCHEDULER_CONFIG = { + "_class_name": "RectifiedFlowScheduler", + "_diffusers_version": "0.25.1", + "num_train_timesteps": 1000, + "shifting": "SD3", + "base_resolution": None, + "target_shift_terminal": 0.1, +} + +OURS_TRANSFORMER_CONFIG = { + "_class_name": "Transformer3DModel", + "_diffusers_version": "0.25.1", + "_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_type": "default", + "caption_channels": 4096, + "cross_attention_dim": 2048, + "double_self_attention": False, + "dropout": 0.0, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "norm_num_groups": 32, + "num_attention_heads": 32, + "num_embeds_ada_norm": 1000, + "num_layers": 28, + "num_vector_embeds": None, + "only_cross_attention": False, + "out_channels": 128, + "project_to_2d_pos": True, + "upcast_attention": False, + "use_linear_projection": False, + "qk_norm": "rms_norm", + "standardization_norm": "rms_norm", + "positional_embedding_type": "rope", + "positional_embedding_theta": 10000.0, + "positional_embedding_max_pos": [20, 2048, 2048], + "timestep_scale_multiplier": 1000, +} +OURS_VAE_CONFIG = { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "blocks": [ + ["res_x", 4], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x", 3], + ["res_x", 4], + ], + "scaling_factor": 1.0, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, +} + + +diffusers_and_ours_config_mapping = { + make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG, + make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG, + make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG, +} + + +TRANSFORMER_KEYS_RENAME_DICT = { + "proj_in": "patchify_proj", + "time_embed": "adaln_single", + "norm_q": "q_norm", + "norm_k": "k_norm", +} + + +VAE_KEYS_RENAME_DICT = { + "decoder.up_blocks.3.conv_in": "decoder.up_blocks.7", + "decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8", + "decoder.up_blocks.3": "decoder.up_blocks.9", + "decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5", + "decoder.up_blocks.2.conv_in": "decoder.up_blocks.4", + "decoder.up_blocks.2": "decoder.up_blocks.6", + "decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2", + "decoder.up_blocks.1": "decoder.up_blocks.3", + "decoder.up_blocks.0": "decoder.up_blocks.1", + "decoder.mid_block": "decoder.up_blocks.0", + "encoder.down_blocks.3": "encoder.down_blocks.8", + "encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7", + "encoder.down_blocks.2": "encoder.down_blocks.6", + "encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4", + "encoder.down_blocks.1.conv_out": "encoder.down_blocks.5", + "encoder.down_blocks.1": "encoder.down_blocks.3", + "encoder.down_blocks.0.conv_out": "encoder.down_blocks.2", + "encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1", + "encoder.down_blocks.0": "encoder.down_blocks.0", + "encoder.mid_block": "encoder.down_blocks.9", + "conv_shortcut.conv": "conv_shortcut", + "resnets": "res_blocks", + "norm3": "norm3.norm", + "latents_mean": "per_channel_statistics.mean-of-means", + "latents_std": "per_channel_statistics.std-of-means", +} diff --git a/ltx_video/utils/prompt_enhance_utils.py b/ltx_video/utils/prompt_enhance_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9010517282925f8f3d2343829347f309e5c0e41a --- /dev/null +++ b/ltx_video/utils/prompt_enhance_utils.py @@ -0,0 +1,226 @@ +import logging +from typing import Union, List, Optional + +import torch +from PIL import Image + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + +T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. +Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. +Start directly with the action, and keep descriptions literal and precise. +Think like a cinematographer describing a shot list. +Do not change the user input intent, just enhance it. +Keep within 150 words. +For best results, build your prompts using this structure: +Start with main action in a single sentence +Add specific details about movements and gestures +Describe character/object appearances precisely +Include background and environment details +Specify camera angles and movements +Describe lighting and colors +Note any changes or sudden events +Do not exceed the 150 word limit! +Output the enhanced prompt only. +""" + +I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. +Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. +Start directly with the action, and keep descriptions literal and precise. +Think like a cinematographer describing a shot list. +Keep within 150 words. +For best results, build your prompts using this structure: +Describe the image first and then add the user input. Image description should be in first priority! Align to the image caption if it contradicts the user text input. +Start with main action in a single sentence +Add specific details about movements and gestures +Describe character/object appearances precisely +Include background and environment details +Specify camera angles and movements +Describe lighting and colors +Note any changes or sudden events +Align to the image caption if it contradicts the user text input. +Do not exceed the 150 word limit! +Output the enhanced prompt only. +""" + + +def tensor_to_pil(tensor): + # Ensure tensor is in range [-1, 1] + assert tensor.min() >= -1 and tensor.max() <= 1 + + # Convert from [-1, 1] to [0, 1] + tensor = (tensor + 1) / 2 + + # Rearrange from [C, H, W] to [H, W, C] + tensor = tensor.permute(1, 2, 0) + + # Convert to numpy array and then to uint8 range [0, 255] + numpy_image = (tensor.cpu().numpy() * 255).astype("uint8") + + # Convert to PIL Image + return Image.fromarray(numpy_image) + + +def generate_cinematic_prompt( + image_caption_model, + image_caption_processor, + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompt: Union[str, List[str]], + conditioning_items: Optional[List] = None, + max_new_tokens: int = 256, +) -> List[str]: + prompts = [prompt] if isinstance(prompt, str) else prompt + + if conditioning_items is None: + prompts = _generate_t2v_prompt( + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts, + max_new_tokens, + T2V_CINEMATIC_PROMPT, + ) + else: + if len(conditioning_items) > 1 or conditioning_items[0].media_frame_number != 0: + logger.warning( + "prompt enhancement does only support unconditional or first frame of conditioning items, returning original prompts" + ) + return prompts + + first_frame_conditioning_item = conditioning_items[0] + first_frames = _get_first_frames_from_conditioning_item( + first_frame_conditioning_item + ) + + assert len(first_frames) == len( + prompts + ), "Number of conditioning frames must match number of prompts" + + prompts = _generate_i2v_prompt( + image_caption_model, + image_caption_processor, + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts, + first_frames, + max_new_tokens, + I2V_CINEMATIC_PROMPT, + ) + + return prompts + + +def _get_first_frames_from_conditioning_item(conditioning_item) -> List[Image.Image]: + frames_tensor = conditioning_item.media_item + return [ + tensor_to_pil(frames_tensor[i, :, 0, :, :]) + for i in range(frames_tensor.shape[0]) + ] + + +def _generate_t2v_prompt( + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts: List[str], + max_new_tokens: int, + system_prompt: str, +) -> List[str]: + messages = [ + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"user_prompt: {p}"}, + ] + for p in prompts + ] + + texts = [ + prompt_enhancer_tokenizer.apply_chat_template( + m, tokenize=False, add_generation_prompt=True + ) + for m in messages + ] + model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( + prompt_enhancer_model.device + ) + + return _generate_and_decode_prompts( + prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens + ) + + +def _generate_i2v_prompt( + image_caption_model, + image_caption_processor, + prompt_enhancer_model, + prompt_enhancer_tokenizer, + prompts: List[str], + first_frames: List[Image.Image], + max_new_tokens: int, + system_prompt: str, +) -> List[str]: + image_captions = _generate_image_captions( + image_caption_model, image_caption_processor, first_frames + ) + + messages = [ + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"}, + ] + for p, c in zip(prompts, image_captions) + ] + + texts = [ + prompt_enhancer_tokenizer.apply_chat_template( + m, tokenize=False, add_generation_prompt=True + ) + for m in messages + ] + model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( + prompt_enhancer_model.device + ) + + return _generate_and_decode_prompts( + prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens + ) + + +def _generate_image_captions( + image_caption_model, + image_caption_processor, + images: List[Image.Image], + system_prompt: str = "", +) -> List[str]: + image_caption_prompts = [system_prompt] * len(images) + inputs = image_caption_processor( + image_caption_prompts, images, return_tensors="pt" + ).to(image_caption_model.device) + + with torch.inference_mode(): + generated_ids = image_caption_model.generate( + input_ids=inputs["input_ids"], + pixel_values=inputs["pixel_values"], + max_new_tokens=1024, + do_sample=False, + num_beams=3, + ) + + return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True) + + +def _generate_and_decode_prompts( + prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int +) -> List[str]: + with torch.inference_mode(): + outputs = prompt_enhancer_model.generate( + **model_inputs, max_new_tokens=max_new_tokens + ) + generated_ids = [ + output_ids[len(input_ids) :] + for input_ids, output_ids in zip(model_inputs.input_ids, outputs) + ] + decoded_prompts = prompt_enhancer_tokenizer.batch_decode( + generated_ids, skip_special_tokens=True + ) + + return decoded_prompts diff --git a/ltx_video/utils/skip_layer_strategy.py b/ltx_video/utils/skip_layer_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..30f9016e1cf2abbe62360775e914fa63876e4cf7 --- /dev/null +++ b/ltx_video/utils/skip_layer_strategy.py @@ -0,0 +1,8 @@ +from enum import Enum, auto + + +class SkipLayerStrategy(Enum): + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() diff --git a/ltx_video/utils/torch_utils.py b/ltx_video/utils/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..991b07c36269ef4dafb88a85834f2596647ba816 --- /dev/null +++ b/ltx_video/utils/torch_utils.py @@ -0,0 +1,25 @@ +import torch +from torch import nn + + +def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + elif dims_to_append == 0: + return x + return x[(...,) + (None,) * dims_to_append] + + +class Identity(nn.Module): + """A placeholder identity operator that is argument-insensitive.""" + + def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument + super().__init__() + + # pylint: disable=unused-argument + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return x diff --git a/prompts/director_motion_prompt.txt b/prompts/director_motion_prompt.txt new file mode 100644 index 0000000000000000000000000000000000000000..c5cc539b8a161731ef938a6b02983092396cdc04 --- /dev/null +++ b/prompts/director_motion_prompt.txt @@ -0,0 +1,33 @@ +You are a world-class Animation Director and VFX Supervisor. Your task is to look at a sequence of static keyframe images and a general story idea, and then write the specific ANIMATION COMMANDS that will bring these images to life, creating fluid transitions between them. + +GOLDEN RULES: +1. **CONNECT THE DOTS:** Your main job is to describe the MOVEMENT that happens *between* each pair of keyframe images. Each command corresponds to a 3-second video clip that starts at one keyframe and ends at the next. +2. **FOCUS ON DYNAMICS:** Describe camera movements (zoom, pan, dolly), character actions, environmental changes (wind, water), and lighting shifts that create a seamless and engaging transition. +3. **RESPECT THE ANCHORS:** The generated video will start exactly like the first image in a pair and end exactly like the second. Your description should logically connect these two visual states. +4. **FOLLOW THE STORY:** Your animation commands must be consistent with the user's general idea. + +CONTEXT INPUT: +- General Idea (Story Theme): "{user_prompt}" +- Sequence of Keyframe Images (Visual Anchors): (attached) + +YOUR TASK: +Based on the sequence of {num_fragments} images, create a storyboard of `{num_fragments}` ANIMATION COMMANDS. + +- **Command for Image 1 to 2:** Describe the animation that transitions from Keyframe 1 to Keyframe 2. +- **Command for Image 2 to 3:** Describe the animation that transitions from Keyframe 2 to Keyframe 3. +- ... +- **Final Command (from last Keyframe):** For the very last image, describe a "free animation" that continues the motion from that final static pose. + +RESPONSE FORMAT: +Return a single JSON object with the key `"motion_storyboard"`, containing an array of strings (the animation commands). + +EXAMPLE: +- General Idea: "A character puts on a hoodie and smiles." +- Keyframes: [Image of character looking neutral, Image of character with hoodie on and smiling] +- Expected Response: +{{ + "motion_storyboard": [ + "The camera holds steady as the character smoothly raises their arms, pulling the hoodie over their head. As the hood settles, a gentle smile begins to form, and their eyes crinkle slightly.", + "Holding the final smiling pose, a gentle breeze rustles the fabric of the hoodie, and the background subtly shifts out of focus, bringing all attention to their happy expression." + ] +}} \ No newline at end of file diff --git a/prompts/director_sequential_prompt.txt b/prompts/director_sequential_prompt.txt new file mode 100644 index 0000000000000000000000000000000000000000..62314df149a61a8edb835174b1c00e6f031bb6bb --- /dev/null +++ b/prompts/director_sequential_prompt.txt @@ -0,0 +1,34 @@ +You are a pragmatic Director of Photography and a Physicist of Motion. Your task is to analyze a START frame and an END frame and write a single, direct animation command that describes the logical transition between them. + +GOLDEN RULES (CRITICAL): + +1. **ACTION AND REACTION:** Your command must describe a linear cause-and-effect motion. The START frame is the "cause". The END frame is the "effect". Your prompt must describe the physical action that connects them. + * **GOOD:** START: "in the air" -> END: "hitting the water". Your prompt describes the *descent and splash*. + * **BAD:** START: "jumping" -> END: "jumping higher". This is not a reaction, it's a repetitive action. Avoid this. + +2. **FOCUS ON CAMERA AND VISIBLE MOTION:** Describe the physical movement of the SUBJECT and the CAMERA. Be literal. + +3. **USE TECHNICAL LANGUAGE:** Use clear, cinematic keywords: "slow pan left," "camera zooms in," "wide tracking shot," "galloping," "flying," "swimming," "leaping." + +4. **BE DIRECT:** Start the sentence with the subject and its main action. Describe what is seen, not what is felt. + +5. **CONCISENESS (CRITICAL):** Keep your command under 70 tokens (approximately 50 words). + +CONTEXT INPUT: +- General Story Idea: "{user_prompt}" +- START Frame (The "Cause"): (attached) +- END Frame (The "Effect"): (attached) + +YOUR TASK: +Write a single, direct, and technical animation prompt based on the principle of action and reaction that connects the START frame to the END frame. + +RESPONSE FORMAT: +Return a single JSON object with the key `"motion_prompt"`. + +EXAMPLE: +- START Frame: A salmon at the peak of its leap. +- END Frame: The salmon splashing back into the river. +- Expected Response: +{{ + "motion_prompt": "Following its arc, the salmon descends powerfully towards the river, hitting the surface in a dramatic slow-motion splash. The camera tracks the descent closely." +}} \ No newline at end of file diff --git a/prompts/photographer_prompt.txt b/prompts/photographer_prompt.txt new file mode 100644 index 0000000000000000000000000000000000000000..269776f398c0794f9bb08ace37abc01a1e35f090 --- /dev/null +++ b/prompts/photographer_prompt.txt @@ -0,0 +1,34 @@ +You are a creative Photographer with a powerful computational imagination. Your task is to analyze a reference image and a story idea, and then describe a sequence of distinct, static photographs that tell this story. + +GOLDEN RULES: +1. **STATIC SCENES ONLY:** Describe each image as a completely still photograph. DO NOT describe movement, actions over time, or camera motion. +2. **VISUAL STORYTELLING:** Each photo description should logically follow the previous one, creating a clear narrative sequence through different compositions, character expressions, and settings. +3. **MAINTAIN CONSISTENCY:** The visual style (e.g., character appearance, clothing, overall mood) should remain consistent with the reference image, unless the story idea explicitly calls for a change. +4. **FOCUS ON DETAILS:** Describe composition, lighting, character pose, facial expression, and background elements for each individual photograph. + +CONTEXT INPUT: +- General Idea (Story): "{user_prompt}" +- Number of Photos (Fragments): {num_fragments} +- Reference Image (Visual Style Guide): (attached) + +YOUR TASK: +Create a "photo album" of `{num_fragments}` detailed descriptions. + +- **Photo 1 Description:** Describe a static scene that sets up the story, inspired by the reference image. +- **Subsequent Photo Descriptions:** Describe the next static scenes that continue the narrative. + +RESPONSE FORMAT: +Return a single JSON object with the key `"scene_storyboard"`, containing an array of strings (each representing one photo description). + + +EXAMPLE: +- General Idea: "A woman explores a forest and finds a glowing flower." +- Number of Photos: 3 +- Expected Response: +{{ + "scene_storyboard": [ + "A full-body shot of the woman standing at the edge of a dense, misty forest, looking inward with a curious expression. The lighting is soft and diffused.", + "A medium shot of the woman deeper in the woods, kneeling down to look at something just out of frame. Her face is illuminated by a soft, warm light from below.", + "A close-up shot of the woman's hands gently cupping a single, brilliantly glowing flower. Her face, seen just above, is filled with awe and wonder." + ] +}} \ No newline at end of file diff --git a/prompts/photographer_sequential_prompt.txt b/prompts/photographer_sequential_prompt.txt new file mode 100644 index 0000000000000000000000000000000000000000..e2f84161ae9e5b10e6ff8ba738cc3c97af52085c --- /dev/null +++ b/prompts/photographer_sequential_prompt.txt @@ -0,0 +1,33 @@ +**MISSÃO: DIRETOR DE FOTOGRAFIA COM MEMÓRIA CONTEXTUAL** + +Você é um diretor de fotografia criativo. Sua tarefa é analisar o histórico da narrativa, a ideia geral e a imagem da cena anterior para criar o prompt para o **próximo quadro-chave (keyframe)**. + +Sua decisão deve garantir que a nova cena seja uma **evolução lógica e criativa** da história, mantendo a coerência com tudo o que aconteceu antes. + +**1. IDEIA GERAL DA HISTÓRIA (do Usuário):** +{user_prompt} + +**2. HISTÓRICO DOS PROMPTS DAS CENAS ANTERIORES (Sua Memória):** +{prompt_history} + +**3. IMAGEM DE REFERÊNCIA (O final da última cena gerada):** +[A imagem será fornecida aqui] + +**SUA TAREFA:** + +1. **Analise o Histórico:** Leia o histórico de prompts para entender a trajetória da narrativa até agora. +2. **Analise a Última Imagem:** Observe a ação e a composição do último quadro-chave. +3. **Continue a História:** Com base na "Ideia Geral" e no "Histórico", decida qual é o próximo passo lógico e visual da narrativa. O que aconteceria a seguir? +4. **Crie o Prompt para a PRÓXIMA IMAGEM:** Escreva um prompt claro e descritivo para um gerador de imagens criar o **keyframe final** desta nova cena. + +**REGRAS CRÍTICAS PARA O PROMPT:** +* **EVOLUÇÃO E COERÊNCIA:** O prompt deve avançar a história, evitando repetições, mas mantendo a consistência com o histórico. +* **DESCRIÇÃO VISUAL:** Descreva a composição, iluminação e a ação principal da nova cena. +* **SEJA CONCISO:** Mantenha o prompt com 40-80 tokens. + +YOUR TASK: +Write the prompt for the very next static scene in the story. + +RESPONSE FORMAT: +Return a single JSON object with the key `"next_scene_prompt"`. + diff --git a/requirements.txt b/requirements.txt index f6f6b7762ae5f051ef8d05b8887bd27d2c321d16..b69e9d4e1f7412830d84392b61116c2d090a0b35 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,25 @@ -torch +#accelerate +gradio +google-generativeai +Pillow +ffmpeg-python +sentence_transformers +#transformers +sentencepiece +numpy torchvision +huggingface_hub>=0.20.0 +spaces +opencv-python +imageio +imageio-ffmpeg einops timm +av +#git+https://github.com/huggingface/diffusers.git@main +torch +peft diffusers==0.31.0 transformers==4.45.2 -sentencepiece -spaces -huggingface_hub accelerate==0.32.0 -peft git+https://github.com/ToTheBeginning/facexlib.git \ No newline at end of file diff --git a/workspaces/tmp b/workspaces/tmp new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391