Spaces:
Configuration error
Configuration error
| from lib.config import cfg, args | |
| from lib.networks import make_network | |
| from lib.train import make_trainer, make_optimizer, make_lr_scheduler, make_recorder, set_lr_scheduler | |
| from lib.datasets import make_data_loader | |
| from lib.utils.net_utils import load_model, save_model, load_network | |
| from lib.evaluators import make_evaluator | |
| import torch.multiprocessing | |
| import torch | |
| import torch.distributed as dist | |
| import os | |
| if cfg.fix_random: | |
| torch.manual_seed(0) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def train(cfg, network): | |
| trainer = make_trainer(cfg, network) | |
| optimizer = make_optimizer(cfg, network) | |
| scheduler = make_lr_scheduler(cfg, optimizer) | |
| recorder = make_recorder(cfg) | |
| evaluator = make_evaluator(cfg) | |
| begin_epoch = load_model(network, | |
| optimizer, | |
| scheduler, | |
| recorder, | |
| cfg.trained_model_dir, | |
| resume=cfg.resume) | |
| set_lr_scheduler(cfg, scheduler) | |
| train_loader = make_data_loader(cfg, | |
| is_train=True, | |
| is_distributed=cfg.distributed, | |
| max_iter=cfg.ep_iter) | |
| val_loader = make_data_loader(cfg, is_train=False) | |
| for epoch in range(begin_epoch, cfg.train.epoch): | |
| recorder.epoch = epoch | |
| if cfg.distributed: | |
| train_loader.batch_sampler.sampler.set_epoch(epoch) | |
| trainer.train(epoch, train_loader, optimizer, recorder) | |
| scheduler.step() | |
| if (epoch + 1) % cfg.save_ep == 0 and cfg.local_rank == 0: | |
| save_model(network, optimizer, scheduler, recorder, | |
| cfg.trained_model_dir, epoch) | |
| if (epoch + 1) % cfg.save_latest_ep == 0 and cfg.local_rank == 0: | |
| save_model(network, | |
| optimizer, | |
| scheduler, | |
| recorder, | |
| cfg.trained_model_dir, | |
| epoch, | |
| last=True) | |
| if (epoch + 1) % cfg.eval_ep == 0: | |
| trainer.val(epoch, val_loader, evaluator, recorder) | |
| return network | |
| def test(cfg, network): | |
| trainer = make_trainer(cfg, network) | |
| val_loader = make_data_loader(cfg, is_train=False) | |
| evaluator = make_evaluator(cfg) | |
| epoch = load_network(network, | |
| cfg.trained_model_dir, | |
| resume=cfg.resume, | |
| epoch=cfg.test.epoch) | |
| trainer.val(epoch, val_loader, evaluator) | |
| def synchronize(): | |
| """ | |
| Helper function to synchronize (barrier) among all processes when | |
| using distributed training | |
| """ | |
| if not dist.is_available(): | |
| return | |
| if not dist.is_initialized(): | |
| return | |
| world_size = dist.get_world_size() | |
| if world_size == 1: | |
| return | |
| dist.barrier() | |
| def main(): | |
| if cfg.distributed: | |
| cfg.local_rank = int(os.environ['RANK']) % torch.cuda.device_count() | |
| torch.cuda.set_device(cfg.local_rank) | |
| torch.distributed.init_process_group(backend="nccl", | |
| init_method="env://") | |
| synchronize() | |
| network = make_network(cfg) | |
| if args.test: | |
| test(cfg, network) | |
| else: | |
| train(cfg, network) | |
| if __name__ == "__main__": | |
| main() | |