Spaces:
Running
Running
| import os | |
| from models.module import DiffGeolocalizer | |
| import hydra | |
| from os.path import join | |
| import torch | |
| from omegaconf import OmegaConf | |
| from omegaconf import open_dict | |
| from hydra.utils import instantiate | |
| from models.eval_best_model import EvalModule | |
| torch.set_float32_matmul_precision("high") | |
| # Registering the "eval" resolver allows for advanced config | |
| # interpolation with arithmetic operations in hydra: | |
| # https://omegaconf.readthedocs.io/en/2.3_branch/how_to_guides.html | |
| OmegaConf.register_new_resolver("eval", eval) | |
| def load_model(cfg, dict_config, wandb_id): | |
| logger = instantiate(cfg.logger, id=wandb_id, resume="allow") | |
| log_dict = {"model": dict_config["model"], "dataset": dict_config["dataset"]} | |
| logger._wandb_init.update({"config": log_dict}) | |
| model = EvalModule(cfg.model) | |
| trainer = instantiate( | |
| cfg.trainer, strategy=cfg.trainer.strategy | |
| ) # , logger=logger) | |
| return trainer, model | |
| def hydra_boilerplate(cfg): | |
| dict_config = OmegaConf.to_container(cfg, resolve=True) | |
| trainer, model = load_model(cfg, dict_config, cfg.wandb_id) | |
| return trainer, model | |
| import copy | |
| def init_datamodule(cfg): | |
| datamodule = instantiate(cfg.datamodule) | |
| return datamodule | |
| if __name__ == "__main__": | |
| import sys | |
| sys.argv = ( | |
| [sys.argv[0]] | |
| + ["+pt_model_path=${hydra:runtime.config_sources}"] | |
| + sys.argv[1:] | |
| ) | |
| def main(cfg): | |
| # print(hydra.runtime.config_sources) | |
| with open_dict(cfg): | |
| path = cfg.pt_model_path[1]["path"] | |
| cfg.wandb_id = join(path, "wandb_id.txt") | |
| cfg.checkpoint = join(path, "last.ckpt") | |
| cfg.computer.devices = 1 | |
| ( | |
| trainer, | |
| model, | |
| ) = hydra_boilerplate(cfg) | |
| datamodule = init_datamodule(cfg) | |
| trainer.test(model, datamodule=datamodule) | |
| main() | |