Spaces:
Runtime error
Runtime error
| import csv | |
| import datetime | |
| from collections import defaultdict | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| import wandb | |
| from termcolor import colored | |
| from torch.utils.tensorboard import SummaryWriter | |
| COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), | |
| ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), | |
| ('episode_reward', 'R', 'float'), | |
| ('fps', 'FPS', 'float'), ('total_time', 'T', 'time')] | |
| COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), | |
| ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), | |
| ('episode_reward', 'R', 'float'), | |
| ('total_time', 'T', 'time')] | |
| class AverageMeter(object): | |
| def __init__(self): | |
| self._sum = 0 | |
| self._count = 0 | |
| def update(self, value, n=1): | |
| self._sum += value | |
| self._count += n | |
| def value(self): | |
| return self._sum / max(1, self._count) | |
| class MetersGroup(object): | |
| def __init__(self, csv_file_name, formating, use_wandb): | |
| self._csv_file_name = csv_file_name | |
| self._formating = formating | |
| self._meters = defaultdict(AverageMeter) | |
| self._csv_file = None | |
| self._csv_writer = None | |
| self.use_wandb = use_wandb | |
| def log(self, key, value, n=1): | |
| self._meters[key].update(value, n) | |
| def _prime_meters(self): | |
| data = dict() | |
| for key, meter in self._meters.items(): | |
| if key.startswith('train'): | |
| key = key[len('train') + 1:] | |
| else: | |
| key = key[len('eval') + 1:] | |
| key = key.replace('/', '_') | |
| data[key] = meter.value() | |
| return data | |
| def _remove_old_entries(self, data): | |
| rows = [] | |
| with self._csv_file_name.open('r') as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| if 'episode' in row: | |
| # BUGFIX: covers weird cases where CSV are badly written | |
| if row['episode'] == '': | |
| rows.append(row) | |
| continue | |
| if type(row['episode']) == type(None): | |
| continue | |
| if float(row['episode']) >= data['episode']: | |
| break | |
| rows.append(row) | |
| with self._csv_file_name.open('w') as f: | |
| # To handle CSV that have more keys than new data | |
| keys = set(data.keys()) | |
| if len(rows) > 0: keys = keys | set(row.keys()) | |
| keys = sorted(list(keys)) | |
| # | |
| writer = csv.DictWriter(f, | |
| fieldnames=keys, | |
| restval=0.0) | |
| writer.writeheader() | |
| for row in rows: | |
| writer.writerow(row) | |
| def _dump_to_csv(self, data): | |
| if self._csv_writer is None: | |
| should_write_header = True | |
| if self._csv_file_name.exists(): | |
| self._remove_old_entries(data) | |
| should_write_header = False | |
| self._csv_file = self._csv_file_name.open('a') | |
| self._csv_writer = csv.DictWriter(self._csv_file, | |
| fieldnames=sorted(data.keys()), | |
| restval=0.0) | |
| if should_write_header: | |
| self._csv_writer.writeheader() | |
| # To handle components that start training later | |
| # (restval covers only when data has less keys than the CSV) | |
| if self._csv_writer.fieldnames != sorted(data.keys()) and \ | |
| len(self._csv_writer.fieldnames) < len(data.keys()): | |
| self._csv_file.close() | |
| self._csv_file = self._csv_file_name.open('r') | |
| dict_reader = csv.DictReader(self._csv_file) | |
| rows = [row for row in dict_reader] | |
| self._csv_file.close() | |
| self._csv_file = self._csv_file_name.open('w') | |
| self._csv_writer = csv.DictWriter(self._csv_file, | |
| fieldnames=sorted(data.keys()), | |
| restval=0.0) | |
| self._csv_writer.writeheader() | |
| for row in rows: | |
| self._csv_writer.writerow(row) | |
| self._csv_writer.writerow(data) | |
| self._csv_file.flush() | |
| def _format(self, key, value, ty): | |
| if ty == 'int': | |
| value = int(value) | |
| return f'{key}: {value}' | |
| elif ty == 'float': | |
| return f'{key}: {value:.04f}' | |
| elif ty == 'time': | |
| value = str(datetime.timedelta(seconds=int(value))) | |
| return f'{key}: {value}' | |
| else: | |
| raise f'invalid format type: {ty}' | |
| def _dump_to_console(self, data, prefix): | |
| prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') | |
| pieces = [f'| {prefix: <14}'] | |
| for key, disp_key, ty in self._formating: | |
| value = data.get(key, 0) | |
| pieces.append(self._format(disp_key, value, ty)) | |
| print(' | '.join(pieces)) | |
| def _dump_to_wandb(self, data): | |
| wandb.log(data) | |
| def dump(self, step, prefix): | |
| if len(self._meters) == 0: | |
| return | |
| data = self._prime_meters() | |
| data['frame'] = step | |
| if self.use_wandb: | |
| wandb_data = {prefix + '/' + key: val for key, val in data.items()} | |
| self._dump_to_wandb(data=wandb_data) | |
| # self._dump_to_csv(data) | |
| self._dump_to_console(data, prefix) | |
| self._meters.clear() | |
| class Logger(object): | |
| def __init__(self, log_dir, use_tb, use_wandb): | |
| self._log_dir = log_dir | |
| self._train_mg = MetersGroup(log_dir / 'train.csv', | |
| formating=COMMON_TRAIN_FORMAT, | |
| use_wandb=use_wandb) | |
| self._eval_mg = MetersGroup(log_dir / 'eval.csv', | |
| formating=COMMON_EVAL_FORMAT, | |
| use_wandb=use_wandb) | |
| if use_tb: | |
| self._sw = SummaryWriter(str(log_dir / 'tb')) | |
| else: | |
| self._sw = None | |
| self.use_wandb = use_wandb | |
| def _try_sw_log(self, key, value, step): | |
| if self._sw is not None: | |
| self._sw.add_scalar(key, value, step) | |
| def log(self, key, value, step): | |
| assert key.startswith('train') or key.startswith('eval') | |
| if type(value) == torch.Tensor: | |
| value = value.item() | |
| self._try_sw_log(key, value, step) | |
| mg = self._train_mg if key.startswith('train') else self._eval_mg | |
| mg.log(key, value) | |
| def log_metrics(self, metrics, step, ty): | |
| for key, value in metrics.items(): | |
| self.log(f'{ty}/{key}', value, step) | |
| def dump(self, step, ty=None): | |
| if ty is None or ty == 'eval': | |
| self._eval_mg.dump(step, 'eval') | |
| if ty is None or ty == 'train': | |
| self._train_mg.dump(step, 'train') | |
| def log_and_dump_ctx(self, step, ty): | |
| return LogAndDumpCtx(self, step, ty) | |
| def log_visual(self, data, step): | |
| if self._sw is not None: | |
| for k, v in data.items(): | |
| if len(v.shape) == 3: | |
| self._sw.add_image(k, v) | |
| else: | |
| if len(v.shape) == 4: | |
| v = np.expand_dims(v, axis=0) | |
| self._sw.add_video(k, v, global_step=step, fps=15) | |
| if self.use_wandb: | |
| for k, v in data.items(): | |
| if type(v) is not np.ndarray: | |
| v = v.cpu() | |
| if v.dtype not in [np.uint8]: | |
| v = v*255 | |
| v = np.uint8(v) | |
| if len(v.shape) == 3: | |
| if v.shape[0] == 3: | |
| v = v.transpose(1,2,0) | |
| # Note: defaulting to save only one image/video to save storage on wandb | |
| wandb.log({k: wandb.Image(v)},) | |
| else: | |
| # Note: defaulting to save only one image/video to save storage on wandb | |
| wandb.log({k: wandb.Video(v, fps=15, format="gif")},) | |
| class LogAndDumpCtx: | |
| def __init__(self, logger, step, ty): | |
| self._logger = logger | |
| self._step = step | |
| self._ty = ty | |
| def __enter__(self): | |
| return self | |
| def __call__(self, key, value): | |
| self._logger.log(f'{self._ty}/{key}', value, self._step) | |
| def __exit__(self, *args): | |
| self._logger.dump(self._step, self._ty) | |