| import argparse | |
| import random | |
| import math | |
| import time | |
| import os | |
| import numpy as np | |
| import torch | |
| import wandb | |
| from data.dataset import TextDataset, CollectionTextDataset | |
| from models.model import VATr | |
| from util.misc import EpochLossTracker, add_vatr_args, LinearScheduler | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--resume", action='store_true') | |
| parser = add_vatr_args(parser) | |
| args = parser.parse_args() | |
| rSeed(args.seed) | |
| dataset = CollectionTextDataset( | |
| args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples, | |
| collator_resolution=args.resolution, min_virtual_size=339, validation=False, debug=False, height=args.img_height | |
| ) | |
| datasetval = CollectionTextDataset( | |
| args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples, | |
| collator_resolution=args.resolution, min_virtual_size=161, validation=True, height=args.img_height | |
| ) | |
| args.num_writers = dataset.num_writers | |
| if args.dataset == 'IAM' or args.dataset == 'CVL': | |
| args.alphabet = 'Only thewigsofrcvdampbkuq.A-210xT5\'MDL,RYHJ"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%' | |
| else: | |
| args.alphabet = ''.join(sorted(set(dataset.alphabet + datasetval.alphabet))) | |
| args.special_alphabet = ''.join(c for c in args.special_alphabet if c not in dataset.alphabet) | |
| args.exp_name = f"{args.dataset}-{args.num_writers}-{args.num_examples}-LR{args.g_lr}-bs{args.batch_size}-{args.tag}" | |
| config = {k: v for k, v in args.__dict__.items() if isinstance(v, (bool, int, str, float))} | |
| args.wandb = args.wandb and (not torch.cuda.is_available() or torch.cuda.get_device_name(0) != 'Tesla K80') | |
| wandb_id = wandb.util.generate_id() | |
| MODEL_PATH = os.path.join(args.save_model_path, args.exp_name) | |
| os.makedirs(MODEL_PATH, exist_ok=True) | |
| train_loader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=args.num_workers, | |
| pin_memory=True, drop_last=True, | |
| collate_fn=dataset.collate_fn) | |
| val_loader = torch.utils.data.DataLoader( | |
| datasetval, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=args.num_workers, | |
| pin_memory=True, drop_last=True, | |
| collate_fn=datasetval.collate_fn) | |
| model = VATr(args) | |
| start_epoch = 0 | |
| del config['alphabet'] | |
| del config['special_alphabet'] | |
| wandb_params = { | |
| 'project': 'VATr', | |
| 'config': config, | |
| 'name': args.exp_name, | |
| 'id': wandb_id | |
| } | |
| checkpoint_path = os.path.join(MODEL_PATH, 'model.pth') | |
| loss_tracker = EpochLossTracker() | |
| if args.resume and os.path.exists(checkpoint_path): | |
| checkpoint = torch.load(checkpoint_path, map_location=args.device) | |
| model.load_state_dict(checkpoint['model']) | |
| start_epoch = checkpoint['epoch'] | |
| wandb_params['id'] = checkpoint['wandb_id'] | |
| wandb_params['resume'] = True | |
| print(checkpoint_path + ' : Model loaded Successfully') | |
| elif args.resume: | |
| raise FileNotFoundError(f'No model found at {checkpoint_path}') | |
| else: | |
| if args.feat_model_path is not None and args.feat_model_path.lower() != 'none': | |
| print('Loading...', args.feat_model_path) | |
| assert os.path.exists(args.feat_model_path) | |
| checkpoint = torch.load(args.feat_model_path, map_location=args.device) | |
| checkpoint['model']['conv1.weight'] = checkpoint['model']['conv1.weight'].mean(1).unsqueeze(1) | |
| del checkpoint['model']['fc.weight'] | |
| del checkpoint['model']['fc.bias'] | |
| miss, unexp = model.netG.Feat_Encoder.load_state_dict(checkpoint['model'], strict=False) | |
| if not os.path.isdir(MODEL_PATH): | |
| os.mkdir(MODEL_PATH) | |
| else: | |
| print(f'WARNING: No resume of Resnet-18, starting from scratch') | |
| if args.wandb: | |
| wandb.init(**wandb_params) | |
| wandb.watch(model) | |
| print(f"Starting training") | |
| for epoch in range(start_epoch, args.epochs): | |
| start_time = time.time() | |
| log_time = time.time() | |
| loss_tracker.reset() | |
| model.d_acc.update(0.0) | |
| if args.text_augment_strength > 0: | |
| model.set_text_aug_strength(args.text_augment_strength) | |
| for i, data in enumerate(train_loader): | |
| model.update_parameters(epoch) | |
| model._set_input(data) | |
| model.optimize_G_only() | |
| model.optimize_G_step() | |
| model.optimize_D_OCR() | |
| model.optimize_D_OCR_step() | |
| model.optimize_G_WL() | |
| model.optimize_G_step() | |
| model.optimize_D_WL() | |
| model.optimize_D_WL_step() | |
| if time.time() - log_time > 10: | |
| print( | |
| f'Epoch {epoch} {i / len(train_loader) * 100:.02f}% running, current time: {time.time() - start_time:.2f} s') | |
| log_time = time.time() | |
| batch_losses = model.get_current_losses() | |
| batch_losses['d_acc'] = model.d_acc.avg | |
| loss_tracker.add_batch(batch_losses) | |
| end_time = time.time() | |
| data_val = next(iter(val_loader)) | |
| losses = loss_tracker.get_epoch_loss() | |
| page = model._generate_page(model.sdata, model.input['swids']) | |
| page_val = model._generate_page(data_val['simg'].to(args.device), data_val['swids']) | |
| d_train, d_val, d_fake = model.compute_d_stats(train_loader, val_loader) | |
| if args.wandb: | |
| wandb.log({ | |
| 'loss-G': losses['G'], | |
| 'loss-D': losses['D'], | |
| 'loss-Dfake': losses['Dfake'], | |
| 'loss-Dreal': losses['Dreal'], | |
| 'loss-OCR_fake': losses['OCR_fake'], | |
| 'loss-OCR_real': losses['OCR_real'], | |
| 'loss-w_fake': losses['w_fake'], | |
| 'loss-w_real': losses['w_real'], | |
| 'd_acc': losses['d_acc'], | |
| 'd-rv': (d_train - d_val) / (d_train - d_fake), | |
| 'd-fake': d_fake, | |
| 'd-real': d_train, | |
| 'd-val': d_val, | |
| 'l_cycle': losses['cycle'], | |
| 'epoch': epoch, | |
| 'timeperepoch': end_time - start_time, | |
| 'result': [wandb.Image(page, caption="page"), wandb.Image(page_val, caption="page_val")], | |
| 'd-crop-size': model.netD.augmenter.get_current_width() if model.netD.crop else 0 | |
| }) | |
| print({'EPOCH': epoch, 'TIME': end_time - start_time, 'LOSSES': losses}) | |
| print(f"Text sample: {model.get_text_sample(10)}") | |
| checkpoint = { | |
| 'model': model.state_dict(), | |
| 'wandb_id': wandb_id, | |
| 'epoch': epoch | |
| } | |
| if epoch % args.save_model == 0: | |
| torch.save(checkpoint, os.path.join(MODEL_PATH, 'model.pth')) | |
| if epoch % args.save_model_history == 0: | |
| torch.save(checkpoint, os.path.join(MODEL_PATH, f'{epoch:04d}_model.pth')) | |
| def rSeed(sd): | |
| random.seed(sd) | |
| np.random.seed(sd) | |
| torch.manual_seed(sd) | |
| torch.cuda.manual_seed(sd) | |
| if __name__ == "__main__": | |
| print("Training Model") | |
| main() | |
| wandb.finish() | |