hanquansanren commited on
Commit
3a8784c
·
1 Parent(s): 05fb4ab

Add application file

Browse files
Files changed (2) hide show
  1. .gitignore +0 -1
  2. run_training.py +78 -0
.gitignore CHANGED
@@ -3,7 +3,6 @@ vis_hp
3
  assets
4
  images
5
  backup
6
- run_training.py
7
  run_gradio.py
8
  run_foward.py
9
 
 
3
  assets
4
  images
5
  backup
 
6
  run_gradio.py
7
  run_foward.py
8
 
run_training.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import importlib
3
+ import os
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = "7"
5
+ os.environ["HDF5_USE_FILE_LOCKING"] = "0"
6
+ import random
7
+ from datetime import date
8
+ from shutil import copyfile
9
+
10
+ import cv2 as cv
11
+ import numpy as np
12
+ import torch
13
+ import torch.backends.cudnn
14
+
15
+ import admin.settings as ws_settings
16
+
17
+
18
+ def run_training(train_module, train_name, seed, cudnn_benchmark=True):
19
+ """Run a train scripts in train_settings.
20
+ args:
21
+ train_module: Name of module in the "train_settings/" folder.
22
+ train_name: Name of the train settings file.
23
+ cudnn_benchmark: Use cudnn benchmark or not (default is True).
24
+ """
25
+
26
+ # This is needed to avoid strange crashes related to opencv
27
+ cv.setNumThreads(0)
28
+
29
+ torch.backends.cudnn.benchmark = cudnn_benchmark
30
+
31
+ # dd/mm/YY
32
+ today = date.today()
33
+ d1 = today.strftime("%d/%m/%Y")
34
+ print('Training: {} {}\nDate: {}'.format(train_module, train_name, d1))
35
+
36
+ settings = ws_settings.Settings()
37
+ settings.module_name = train_module
38
+ settings.script_name = train_name
39
+ settings.project_path = 'train_settings/{}/{}'.format(train_module, train_name)
40
+ settings.seed = seed
41
+
42
+ # will save the checkpoints there
43
+
44
+ save_dir = os.path.join(settings.env.workspace_dir, settings.project_path)
45
+ if not os.path.exists(save_dir):
46
+ os.makedirs(save_dir)
47
+ copyfile(settings.project_path + '.py', os.path.join(save_dir, settings.script_name + '.py'))
48
+
49
+ expr_module = importlib.import_module('train_settings.{}.{}'.format(train_module.replace('/', '.'),
50
+ train_name.replace('/', '.')))
51
+
52
+ expr_func = getattr(expr_module, 'run')
53
+
54
+ expr_func(settings)
55
+
56
+
57
+ def main():
58
+ parser = argparse.ArgumentParser(description='Run a train scripts in train_settings.')
59
+ parser.add_argument('--train_module', type=str, help='Name of module in the "train_settings/" folder.')
60
+ parser.add_argument('--train_name', type=str, help='Name of the train settings file.')
61
+ parser.add_argument('--cudnn_benchmark', type=bool, default=True,
62
+ help='Set cudnn benchmark on (1) or off (0) (default is on).')
63
+ parser.add_argument('--seed', type=int, default=1992, help='Pseudo-RNG seed')
64
+ args = parser.parse_args()
65
+
66
+ # args.seed = random.randint(0, 3000000)
67
+ args.seed = torch.initial_seed() & (2 ** 32 - 1)
68
+ print('Seed is {}'.format(args.seed))
69
+ random.seed(int(args.seed))
70
+ np.random.seed(args.seed)
71
+ torch.manual_seed(args.seed)
72
+ torch.cuda.manual_seed(args.seed)
73
+
74
+ run_training(args.train_module, args.train_name, cudnn_benchmark=args.cudnn_benchmark, seed=args.seed)
75
+
76
+
77
+ if __name__ == '__main__':
78
+ main()