Spaces:
Running
on
Zero
Running
on
Zero
| # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT | |
| # except for the third-party components listed below. | |
| # Hunyuan 3D does not impose any additional limitations beyond what is outlined | |
| # in the repsective licenses of these third-party components. | |
| # Users must comply with all terms and conditions of original licenses of these third-party | |
| # components and must ensure that the usage of the third party components adheres to | |
| # all relevant laws and regulations. | |
| # For avoidance of doubts, Hunyuan 3D means the large language models and | |
| # their software and algorithms, including trained model weights, parameters (including | |
| # optimizer states), machine-learning model code, inference-enabling code, training-enabling code, | |
| # fine-tuning enabling code and other elements of the foregoing made publicly available | |
| # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. | |
| import torch | |
| import os, sys | |
| import argparse | |
| import shutil | |
| import subprocess | |
| from omegaconf import OmegaConf | |
| from pytorch_lightning import seed_everything | |
| from pytorch_lightning.trainer import Trainer | |
| from pytorch_lightning.strategies import DDPStrategy | |
| from pytorch_lightning.callbacks import Callback | |
| from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn | |
| from src.utils.train_util import instantiate_from_config | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| from diffusers.utils import logging as diffusers_logging | |
| diffusers_logging.set_verbosity(50) | |
| def rank_zero_print(*args): | |
| print(*args) | |
| def get_parser(**parser_kwargs): | |
| def str2bool(v): | |
| if isinstance(v, bool): | |
| return v | |
| if v.lower() in ("yes", "true", "t", "y", "1"): | |
| return True | |
| elif v.lower() in ("no", "false", "f", "n", "0"): | |
| return False | |
| else: | |
| raise argparse.ArgumentTypeError("Boolean value expected.") | |
| parser = argparse.ArgumentParser(**parser_kwargs) | |
| parser.add_argument( | |
| "-r", | |
| "--resume", | |
| type=str, | |
| default=None, | |
| help="resume from checkpoint", | |
| ) | |
| parser.add_argument( | |
| "--resume_weights_only", | |
| action="store_true", | |
| help="only resume model weights", | |
| ) | |
| parser.add_argument( | |
| "-b", | |
| "--base", | |
| type=str, | |
| default="base_config.yaml", | |
| help="path to base configs", | |
| ) | |
| parser.add_argument( | |
| "-n", | |
| "--name", | |
| type=str, | |
| default="", | |
| help="experiment name", | |
| ) | |
| parser.add_argument( | |
| "--num_nodes", | |
| type=int, | |
| default=1, | |
| help="number of nodes to use", | |
| ) | |
| parser.add_argument( | |
| "--gpus", | |
| type=str, | |
| default="0,", | |
| help="gpu ids to use", | |
| ) | |
| parser.add_argument( | |
| "-s", | |
| "--seed", | |
| type=int, | |
| default=42, | |
| help="seed for seed_everything", | |
| ) | |
| parser.add_argument( | |
| "-l", | |
| "--logdir", | |
| type=str, | |
| default="logs", | |
| help="directory for logging data", | |
| ) | |
| return parser | |
| class SetupCallback(Callback): | |
| def __init__(self, resume, logdir, ckptdir, cfgdir, config): | |
| super().__init__() | |
| self.resume = resume | |
| self.logdir = logdir | |
| self.ckptdir = ckptdir | |
| self.cfgdir = cfgdir | |
| self.config = config | |
| def on_fit_start(self, trainer, pl_module): | |
| if trainer.global_rank == 0: | |
| # Create logdirs and save configs | |
| os.makedirs(self.logdir, exist_ok=True) | |
| os.makedirs(self.ckptdir, exist_ok=True) | |
| os.makedirs(self.cfgdir, exist_ok=True) | |
| rank_zero_print("Project config") | |
| rank_zero_print(OmegaConf.to_yaml(self.config)) | |
| OmegaConf.save(self.config, os.path.join(self.cfgdir, "project.yaml")) | |
| class CodeSnapshot(Callback): | |
| """ | |
| Modified from https://github.com/threestudio-project/threestudio/blob/main/threestudio/utils/callbacks.py#L60 | |
| """ | |
| def __init__(self, savedir): | |
| self.savedir = savedir | |
| def get_file_list(self): | |
| return [ | |
| b.decode() | |
| for b in set(subprocess.check_output('git ls-files -- ":!:configs/*"', shell=True).splitlines()) | |
| | set( # hard code, TODO: use config to exclude folders or files | |
| subprocess.check_output("git ls-files --others --exclude-standard", shell=True).splitlines() | |
| ) | |
| ] | |
| def save_code_snapshot(self): | |
| os.makedirs(self.savedir, exist_ok=True) | |
| # for f in self.get_file_list(): | |
| # if not os.path.exists(f) or os.path.isdir(f): | |
| # continue | |
| # os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) | |
| # shutil.copyfile(f, os.path.join(self.savedir, f)) | |
| def on_fit_start(self, trainer, pl_module): | |
| try: | |
| self.save_code_snapshot() | |
| except: | |
| rank_zero_warn( | |
| "Code snapshot is not saved. Please make sure you have git installed and are in a git repository." | |
| ) | |
| if __name__ == "__main__": | |
| # add cwd for convenience and to make classes in this file available when | |
| # running as `python main.py` | |
| sys.path.append(os.getcwd()) | |
| torch.set_float32_matmul_precision("medium") | |
| parser = get_parser() | |
| opt, unknown = parser.parse_known_args() | |
| cfg_fname = os.path.split(opt.base)[-1] | |
| cfg_name = os.path.splitext(cfg_fname)[0] | |
| exp_name = "-" + opt.name if opt.name != "" else "" | |
| logdir = os.path.join(opt.logdir, cfg_name + exp_name) | |
| # assert not os.path.exists(logdir) or 'test' in logdir, logdir | |
| if os.path.exists(logdir) and opt.resume is None: | |
| auto_resume_path = os.path.join(logdir, "checkpoints", "last.ckpt") | |
| if os.path.exists(auto_resume_path): | |
| opt.resume = auto_resume_path | |
| print(f"Auto set resume ckpt {opt.resume}") | |
| ckptdir = os.path.join(logdir, "checkpoints") | |
| cfgdir = os.path.join(logdir, "configs") | |
| codedir = os.path.join(logdir, "code") | |
| node_rank = int(os.environ.get("NODE_RANK", 0)) # 当前节点的编号 | |
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) # 当前节点上的 GPU 编号 | |
| num_gpus_per_node = torch.cuda.device_count() # 每个节点上的 GPU 数量 | |
| global_rank = node_rank * num_gpus_per_node + local_rank | |
| seed_everything(opt.seed + global_rank) | |
| # init configs | |
| config = OmegaConf.load(opt.base) | |
| lightning_config = config.lightning | |
| trainer_config = lightning_config.trainer | |
| trainer_config["accelerator"] = "gpu" | |
| rank_zero_print(f"Running on GPUs {opt.gpus}") | |
| try: | |
| ngpu = int(opt.gpus) | |
| except: | |
| ngpu = len(opt.gpus.strip(",").split(",")) | |
| trainer_config["devices"] = ngpu | |
| trainer_opt = argparse.Namespace(**trainer_config) | |
| lightning_config.trainer = trainer_config | |
| # model | |
| model = instantiate_from_config(config.model) | |
| model_unet = model.unet.unet | |
| model_unet_prefix = "unet.unet." | |
| if hasattr(model_unet, "unet"): | |
| model_unet = model_unet.unet | |
| model_unet_prefix += "unet." | |
| if getattr(config, "init_unet_from", None): | |
| unet_ckpt_path = config.init_unet_from | |
| sd = torch.load(unet_ckpt_path, map_location="cpu") | |
| model_unet.load_state_dict(sd, strict=True) | |
| if getattr(config, "init_vae_from", None): | |
| vae_ckpt_path = config.init_vae_from | |
| sd_vae = torch.load(vae_ckpt_path, map_location="cpu") | |
| def replace_key(key_str): | |
| replace_pairs = [("key", "to_k"), ("query", "to_q"), ("value", "to_v"), ("proj_attn", "to_out.0")] | |
| for replace_pair in replace_pairs: | |
| key_str = key_str.replace(replace_pair[0], replace_pair[1]) | |
| return key_str | |
| sd_vae = {replace_key(k): v for k, v in sd_vae.items()} | |
| model.pipeline.vae.load_state_dict(sd_vae, strict=True) | |
| if hasattr(model.unet, "controlnet"): | |
| if getattr(config, "init_control_from", None): | |
| unet_ckpt_path = config.init_control_from | |
| sd_control = torch.load(unet_ckpt_path, map_location="cpu") | |
| model.unet.controlnet.load(sd_control, strict=True) | |
| noise_in_channels = config.model.params.get("noise_in_channels", None) | |
| if noise_in_channels is not None: | |
| with torch.no_grad(): | |
| new_conv_in = torch.nn.Conv2d( | |
| noise_in_channels, | |
| model_unet.conv_in.out_channels, | |
| model_unet.conv_in.kernel_size, | |
| model_unet.conv_in.stride, | |
| model_unet.conv_in.padding, | |
| ) | |
| new_conv_in.weight.zero_() | |
| new_conv_in.weight[:, : model_unet.conv_in.in_channels, :, :].copy_(model_unet.conv_in.weight) | |
| new_conv_in.bias.zero_() | |
| new_conv_in.bias[: model_unet.conv_in.bias.size(0)].copy_(model_unet.conv_in.bias) | |
| model_unet.conv_in = new_conv_in | |
| if hasattr(model.unet, "controlnet"): | |
| if config.model.params.get("control_in_channels", None): | |
| control_in_channels = config.model.params.control_in_channels | |
| model.unet.controlnet.config["conditioning_channels"] = control_in_channels | |
| condition_conv_in = model.unet.controlnet.controlnet_cond_embedding.conv_in | |
| new_condition_conv_in = torch.nn.Conv2d( | |
| control_in_channels, | |
| condition_conv_in.out_channels, | |
| kernel_size=condition_conv_in.kernel_size, | |
| stride=condition_conv_in.stride, | |
| padding=condition_conv_in.padding, | |
| ) | |
| with torch.no_grad(): | |
| new_condition_conv_in.weight[:, : condition_conv_in.in_channels, :, :] = condition_conv_in.weight | |
| if condition_conv_in.bias is not None: | |
| new_condition_conv_in.bias = condition_conv_in.bias | |
| model.unet.controlnet.controlnet_cond_embedding.conv_in = new_condition_conv_in | |
| rank_zero_print(f"Loaded Init ...") | |
| if getattr(config, "resume_from", None): | |
| cnet_ckpt_path = config.resume_from | |
| sds = torch.load(cnet_ckpt_path, map_location="cpu")["state_dict"] | |
| sd0 = {k[len(model_unet_prefix) :]: v for k, v in sds.items() if model_unet_prefix in k} | |
| # model.unet.unet.unet.load_state_dict(sd0, strict=True) | |
| model_unet.load_state_dict(sd0, strict=True) | |
| if hasattr(model.unet, "controlnet"): | |
| sd1 = {k[16:]: v for k, v in sds.items() if "unet.controlnet." in k} | |
| model.unet.controlnet.load_state_dict(sd1, strict=True) | |
| rank_zero_print(f"Loaded {cnet_ckpt_path} ...") | |
| if opt.resume and opt.resume_weights_only: | |
| model = model.__class__.load_from_checkpoint(opt.resume, **config.model.params) | |
| model.logdir = logdir | |
| # trainer and callbacks | |
| trainer_kwargs = dict() | |
| # logger | |
| default_logger_cfg = { | |
| "target": "pytorch_lightning.loggers.TensorBoardLogger", | |
| "params": { | |
| "name": "tensorboard", | |
| "save_dir": logdir, | |
| "version": "0", | |
| }, | |
| } | |
| logger_cfg = OmegaConf.merge(default_logger_cfg) | |
| trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) | |
| # model checkpoint | |
| default_modelckpt_cfg = { | |
| "target": "pytorch_lightning.callbacks.ModelCheckpoint", | |
| "params": { | |
| "dirpath": ckptdir, | |
| "filename": "{step:08}", | |
| "verbose": True, | |
| "save_last": True, | |
| "every_n_train_steps": 5000, | |
| "save_top_k": -1, # save all checkpoints | |
| }, | |
| } | |
| if "modelcheckpoint" in lightning_config: | |
| modelckpt_cfg = lightning_config.modelcheckpoint | |
| else: | |
| modelckpt_cfg = OmegaConf.create() | |
| modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) | |
| # callbacks | |
| default_callbacks_cfg = { | |
| "setup_callback": { | |
| "target": "train.SetupCallback", | |
| "params": { | |
| "resume": opt.resume, | |
| "logdir": logdir, | |
| "ckptdir": ckptdir, | |
| "cfgdir": cfgdir, | |
| "config": config, | |
| }, | |
| }, | |
| "learning_rate_logger": { | |
| "target": "pytorch_lightning.callbacks.LearningRateMonitor", | |
| "params": { | |
| "logging_interval": "step", | |
| }, | |
| }, | |
| "code_snapshot": { | |
| "target": "train.CodeSnapshot", | |
| "params": { | |
| "savedir": codedir, | |
| }, | |
| }, | |
| } | |
| default_callbacks_cfg["checkpoint_callback"] = modelckpt_cfg | |
| if "callbacks" in lightning_config: | |
| callbacks_cfg = lightning_config.callbacks | |
| else: | |
| callbacks_cfg = OmegaConf.create() | |
| callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) | |
| trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] | |
| trainer_kwargs["precision"] = "bf16" | |
| trainer_kwargs["strategy"] = DDPStrategy(find_unused_parameters=False) | |
| # trainer | |
| trainer = Trainer(**trainer_config, **trainer_kwargs, num_nodes=opt.num_nodes, inference_mode=False) | |
| trainer.logdir = logdir | |
| # data | |
| data = instantiate_from_config(config.data) | |
| data.prepare_data() | |
| data.setup("fit") | |
| # configure learning rate | |
| base_lr = config.model.base_learning_rate | |
| if "accumulate_grad_batches" in lightning_config.trainer: | |
| accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches | |
| else: | |
| accumulate_grad_batches = 1 | |
| rank_zero_print(f"accumulate_grad_batches = {accumulate_grad_batches}") | |
| lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches | |
| model.learning_rate = base_lr | |
| rank_zero_print("++++ NOT USING LR SCALING ++++") | |
| rank_zero_print(f"Setting learning rate to {model.learning_rate:.2e}") | |
| # run training loop | |
| if opt.resume and not opt.resume_weights_only: | |
| trainer.fit(model, data, ckpt_path=opt.resume) | |
| else: | |
| trainer.fit(model, data) | |