import argparse import importlib import os import random from datetime import date from shutil import copyfile import cv2 as cv import numpy as np import torch import torch.backends.cudnn import admin.settings as ws_settings def run_sampling(train_module, train_name, seed, name, cudnn_benchmark=True, corruption=False): """Run a sampling scripts in train_settings. args: train_module: Name of module in the "train_settings/" folder. train_name: Name of the train settings file. cudnn_benchmark: Use cudnn benchmark or not (default is True). """ # This is needed to avoid strange crashes related to opencv cv.setNumThreads(0) torch.backends.cudnn.benchmark = cudnn_benchmark # dd/mm/YY today = date.today() d1 = today.strftime("%d/%m/%Y") print('Sampling: {} {}\nDate: {}'.format(train_module, train_name, d1)) settings = ws_settings.Settings() settings.module_name = train_module settings.script_name = train_name settings.project_path = 'train_settings/{}/{}'.format(train_module, train_name) settings.seed = seed settings.name = name # will save the checkpoints there save_dir = os.path.join(settings.env.workspace_dir, settings.project_path) if not os.path.exists(save_dir): os.makedirs(save_dir) copyfile(settings.project_path + '.py', os.path.join(save_dir, settings.script_name + '.py')) expr_module = importlib.import_module('train_settings.{}.{}'.format(train_module.replace('/', '.'), train_name.replace('/', '.'))) expr_func = getattr(expr_module, 'run') if corruption: for severity in [5]: settings.severity = severity for corruption_number in range(0, 15): # [0, 18]; useful for easy looping; 15, 16, 17, 18 are validation corruption numbers settings.corruption_number = corruption_number expr_func(settings) else: settings.severity = 0 settings.corruption_number = 0 expr_func(settings) def main(): parser = argparse.ArgumentParser(description='Run a sampling scripts in train_settings.') parser.add_argument('--train_module', type=str, help='Name of module in the "train_settings/" folder.') parser.add_argument('--train_name', type=str, help='Name of the train settings file.') parser.add_argument('--cudnn_benchmark', type=bool, default=True, help='Set cudnn benchmark on (1) or off (0) (default is on).') parser.add_argument('--seed', type=int, default=1992, help='Pseudo-RNG seed') parser.add_argument('--name', type=str, default="Default", help='Name of the experiment') parser.add_argument('--corruption', action='store_true') args = parser.parse_args() args.seed = random.randint(0, 3000000) args.seed = torch.initial_seed() & (2 ** 32 - 1) print('Seed is {}'.format(args.seed)) random.seed(int(args.seed)) np.random.seed(args.seed) run_sampling( args.train_module, args.train_name, cudnn_benchmark=args.cudnn_benchmark, seed=args.seed, name=args.name, corruption=args.corruption) if __name__ == '__main__': main()