Spaces:
Running
on
Zero
Running
on
Zero
| import pyrootutils | |
| root = pyrootutils.setup_root( | |
| search_from=__file__, | |
| indicator=[".git", "pyproject.toml"], | |
| pythonpath=True, | |
| dotenv=True, | |
| ) | |
| SEED = 32000 | |
| import collections | |
| import os | |
| import hydra | |
| from hydra.utils import instantiate | |
| from lightning.fabric import Fabric | |
| print(SEED) | |
| import random | |
| os.environ["PYTHONHASHSEED"] = str(SEED) | |
| import numpy as np | |
| import torch | |
| import tqdm | |
| import wandb | |
| from torch.optim.adamw import AdamW | |
| from torch.utils.data import DataLoader | |
| from ripe import utils | |
| from ripe.benchmarks.imw_2020 import IMW_2020_Benchmark | |
| from ripe.utils.utils import get_rewards | |
| from ripe.utils.wandb_utils import get_flattened_wandb_cfg | |
| log = utils.get_pylogger(__name__) | |
| from pathlib import Path | |
| torch.manual_seed(SEED) | |
| np.random.seed(SEED) | |
| random.seed(SEED) | |
| def unpack_batch(batch): | |
| src_image = batch["src_image"] | |
| trg_image = batch["trg_image"] | |
| trg_mask = batch["trg_mask"] | |
| src_mask = batch["src_mask"] | |
| label = batch["label"] | |
| H = batch["homography"] | |
| return src_image, trg_image, src_mask, trg_mask, H, label | |
| def train(cfg): | |
| """Main training function for the RIPE model.""" | |
| # Prepare model, data and hyperparms | |
| strategy = "ddp" if cfg.num_gpus > 1 else "auto" | |
| fabric = Fabric( | |
| accelerator="cuda", | |
| devices=cfg.num_gpus, | |
| precision=cfg.precision, | |
| strategy=strategy, | |
| ) | |
| fabric.launch() | |
| output_dir = Path(cfg.output_dir) | |
| experiment_name = output_dir.parent.parent.parent.name | |
| run_id = output_dir.parent.parent.name | |
| timestamp = output_dir.parent.name + "_" + output_dir.name | |
| experiment_name = run_id + " " + timestamp + " " + experiment_name | |
| # setup logger | |
| wandb_logger = wandb.init( | |
| project=cfg.project_name, | |
| name=experiment_name, | |
| config=get_flattened_wandb_cfg(cfg), | |
| dir=cfg.output_dir, | |
| mode=cfg.wandb_mode, | |
| ) | |
| min_nums_matches = {"homography": 4, "fundamental": 8, "fundamental_7pt": 7} | |
| min_num_matches = min_nums_matches[cfg.transformation_model] | |
| print(f"Minimum number of matches for {cfg.transformation_model} is {min_num_matches}") | |
| batch_size = cfg.batch_size | |
| steps = cfg.num_steps | |
| lr = cfg.lr | |
| num_grad_accs = ( | |
| cfg.num_grad_accs | |
| ) # this performs grad accumulation to simulate larger batch size, set to 1 to disable; | |
| # instantiate dataset | |
| ds = instantiate(cfg.data) | |
| # prepare dataloader | |
| dl = DataLoader( | |
| ds, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| drop_last=True, | |
| persistent_workers=False, | |
| num_workers=cfg.num_workers, | |
| ) | |
| dl = fabric.setup_dataloaders(dl) | |
| i_dl = iter(dl) | |
| # create matcher | |
| matcher = instantiate(cfg.matcher) | |
| if cfg.desc_loss_weight != 0.0: | |
| descriptor_loss = instantiate(cfg.descriptor_loss) | |
| else: | |
| log.warning( | |
| "Descriptor loss weight is 0.0, descriptor loss will not be used. 1x1 conv for descriptors will be deactivated!" | |
| ) | |
| descriptor_loss = None | |
| upsampler = instantiate(cfg.upsampler) if "upsampler" in cfg else None | |
| # create network | |
| net = instantiate(cfg.network)( | |
| net=instantiate(cfg.backbones), | |
| upsampler=upsampler, | |
| descriptor_dim=cfg.descriptor_dim if descriptor_loss is not None else None, | |
| device=fabric.device, | |
| ).train() | |
| # get num parameters | |
| num_params = sum(p.numel() for p in net.parameters() if p.requires_grad) | |
| log.info(f"Number of parameters: {num_params}") | |
| fp_penalty = cfg.fp_penalty # small penalty for not finding a match | |
| kp_penalty = cfg.kp_penalty # small penalty for low logprob keypoints | |
| opt_pi = AdamW(filter(lambda x: x.requires_grad, net.parameters()), lr=lr, weight_decay=1e-5) | |
| net, opt_pi = fabric.setup(net, opt_pi) | |
| if cfg.lr_scheduler: | |
| scheduler = instantiate(cfg.lr_scheduler)(optimizer=opt_pi, steps_init=0) | |
| else: | |
| scheduler = None | |
| val_benchmark = IMW_2020_Benchmark( | |
| use_predefined_subset=True, | |
| conf_inference=cfg.conf_inference, | |
| edge_input_divisible_by=None, | |
| ) | |
| # mean average of skipped batches | |
| # this is used to monitor how many batches were skipped due to not enough keypoints | |
| # this is useful to detect if the model is not learning anything -> should be zero | |
| ma_skipped_batches = collections.deque(maxlen=100) | |
| opt_pi.zero_grad() | |
| # initialize scheduler | |
| alpha_scheduler = instantiate(cfg.alpha_scheduler) | |
| beta_scheduler = instantiate(cfg.beta_scheduler) | |
| inl_th_scheduler = instantiate(cfg.inl_th) | |
| # ====== Training Loop ====== | |
| # check if the model is in training mode | |
| net.train() | |
| with tqdm.tqdm(total=steps) as pbar: | |
| for i_step in range(steps): | |
| alpha = alpha_scheduler(i_step) | |
| beta = beta_scheduler(i_step) | |
| inl_th = inl_th_scheduler(i_step) | |
| if scheduler: | |
| scheduler.step() | |
| # Initialize vars for current step | |
| # We need to handle batching because the description can have arbitrary number of keypoints | |
| sum_reward_batch = 0 | |
| sum_num_keypoints_1 = 0 | |
| sum_num_keypoints_2 = 0 | |
| loss = None | |
| loss_policy_stack = None | |
| loss_desc_stack = None | |
| loss_kp_stack = None | |
| try: | |
| batch = next(i_dl) | |
| except StopIteration: | |
| i_dl = iter(dl) | |
| batch = next(i_dl) | |
| p1, p2, mask_padding_1, mask_padding_2, Hs, label = unpack_batch(batch) | |
| ( | |
| kpts1, | |
| logprobs1, | |
| selected_mask1, | |
| mask_padding_grid_1, | |
| logits_selected_1, | |
| out1, | |
| ) = net(p1, mask_padding_1, training=True) | |
| ( | |
| kpts2, | |
| logprobs2, | |
| selected_mask2, | |
| mask_padding_grid_2, | |
| logits_selected_2, | |
| out2, | |
| ) = net(p2, mask_padding_2, training=True) | |
| # upsample coarse descriptors for all keypoints from the intermediate feature maps from the encoder | |
| desc_1 = net.get_descs(out1["coarse_descs"], p1, kpts1, p1.shape[2], p1.shape[3]) | |
| desc_2 = net.get_descs(out2["coarse_descs"], p2, kpts2, p2.shape[2], p2.shape[3]) | |
| if cfg.padding_filter_mode == "ignore": # remove keypoints that are in padding | |
| batch_mask_selection_for_matching_1 = selected_mask1 & mask_padding_grid_1 | |
| batch_mask_selection_for_matching_2 = selected_mask2 & mask_padding_grid_2 | |
| elif cfg.padding_filter_mode == "punish": | |
| batch_mask_selection_for_matching_1 = selected_mask1 # keep all keypoints | |
| batch_mask_selection_for_matching_2 = selected_mask2 # punish the keypoints in the padding area | |
| else: | |
| raise ValueError(f"Unknown padding filter mode: {cfg.padding_filter_mode}") | |
| ( | |
| batch_rel_idx_matches, | |
| batch_abs_idx_matches, | |
| batch_ransac_inliers, | |
| batch_Fm, | |
| ) = matcher( | |
| kpts1, | |
| kpts2, | |
| desc_1, | |
| desc_2, | |
| batch_mask_selection_for_matching_1, | |
| batch_mask_selection_for_matching_2, | |
| inl_th, | |
| label if cfg.no_filtering_negatives else None, | |
| ) | |
| for b in range(batch_size): | |
| # ignore if less than 16 keypoints have been detected | |
| if batch_rel_idx_matches[b] is None: | |
| ma_skipped_batches.append(1) | |
| continue | |
| else: | |
| ma_skipped_batches.append(0) | |
| mask_selection_for_matching_1 = batch_mask_selection_for_matching_1[b] | |
| mask_selection_for_matching_2 = batch_mask_selection_for_matching_2[b] | |
| rel_idx_matches = batch_rel_idx_matches[b] | |
| abs_idx_matches = batch_abs_idx_matches[b] | |
| ransac_inliers = batch_ransac_inliers[b] | |
| if cfg.selected_only: | |
| # every SELECTED keypoint with every other SELECTED keypoint | |
| dense_logprobs = logprobs1[b][mask_selection_for_matching_1].view(-1, 1) + logprobs2[b][ | |
| mask_selection_for_matching_2 | |
| ].view(1, -1) | |
| else: | |
| if cfg.padding_filter_mode == "ignore": | |
| # every keypoint with every other keypoint, but WITHOUT keypoint in the padding area | |
| dense_logprobs = logprobs1[b][mask_padding_grid_1[b]].view(-1, 1) + logprobs2[b][ | |
| mask_padding_grid_2[b] | |
| ].view(1, -1) | |
| elif cfg.padding_filter_mode == "punish": | |
| # every keypoint with every other keypoint, also WITH keypoints in the padding areas -> will be punished by the reward | |
| dense_logprobs = logprobs1[b].view(-1, 1) + logprobs2[b].view(1, -1) | |
| else: | |
| raise ValueError(f"Unknown padding filter mode: {cfg.padding_filter_mode}") | |
| reward = None | |
| if cfg.reward_type == "inlier": | |
| reward = ( | |
| 0.5 if cfg.no_filtering_negatives and not label[b] else 1.0 | |
| ) # reward is 1.0 if the pair is positive, 0.5 if negative and no filtering is applied | |
| elif cfg.reward_type == "inlier_ratio": | |
| ratio_inlier = ransac_inliers.sum() / len(abs_idx_matches) | |
| reward = ratio_inlier # reward is the ratio of inliers -> higher if more matches are inliers | |
| elif cfg.reward_type == "inlier+inlier_ratio": | |
| ratio_inlier = ransac_inliers.sum() / len(abs_idx_matches) | |
| reward = ( | |
| (1.0 - beta) * 1.0 + beta * ratio_inlier | |
| ) # reward is a combination of the ratio of inliers and the number of inliers -> gradually changes | |
| else: | |
| raise ValueError(f"Unknown reward type: {cfg.reward_type}") | |
| dense_rewards = get_rewards( | |
| reward, | |
| kpts1[b], | |
| kpts2[b], | |
| mask_selection_for_matching_1, | |
| mask_selection_for_matching_2, | |
| mask_padding_grid_1[b], | |
| mask_padding_grid_2[b], | |
| rel_idx_matches, | |
| abs_idx_matches, | |
| ransac_inliers, | |
| label[b], | |
| fp_penalty * alpha, | |
| use_whitening=cfg.use_whitening, | |
| selected_only=cfg.selected_only, | |
| filter_mode=cfg.padding_filter_mode, | |
| ) | |
| if descriptor_loss is not None: | |
| hard_loss = descriptor_loss( | |
| desc1=desc_1[b], | |
| desc2=desc_2[b], | |
| matches=abs_idx_matches, | |
| inliers=ransac_inliers, | |
| label=label[b], | |
| logits_1=None, | |
| logits_2=None, | |
| ) | |
| loss_desc_stack = ( | |
| hard_loss if loss_desc_stack is None else torch.hstack((loss_desc_stack, hard_loss)) | |
| ) | |
| sum_reward_batch += dense_rewards.sum() | |
| current_loss_policy = (dense_rewards * dense_logprobs).view(-1) | |
| loss_policy_stack = ( | |
| current_loss_policy | |
| if loss_policy_stack is None | |
| else torch.hstack((loss_policy_stack, current_loss_policy)) | |
| ) | |
| if kp_penalty != 0.0: | |
| # keypoints with low logprob are penalized | |
| # as they get large negative logprob values multiplying them with the penalty will make the loss larger | |
| loss_kp = ( | |
| logprobs1[b][mask_selection_for_matching_1] | |
| * torch.full_like( | |
| logprobs1[b][mask_selection_for_matching_1], | |
| kp_penalty * alpha, | |
| ) | |
| ).mean() + ( | |
| logprobs2[b][mask_selection_for_matching_2] | |
| * torch.full_like( | |
| logprobs2[b][mask_selection_for_matching_2], | |
| kp_penalty * alpha, | |
| ) | |
| ).mean() | |
| loss_kp_stack = loss_kp if loss_kp_stack is None else torch.hstack((loss_kp_stack, loss_kp)) | |
| sum_num_keypoints_1 += mask_selection_for_matching_1.sum() | |
| sum_num_keypoints_2 += mask_selection_for_matching_2.sum() | |
| loss = loss_policy_stack.mean() | |
| if loss_kp_stack is not None: | |
| loss += loss_kp_stack.mean() | |
| loss = -loss | |
| if descriptor_loss is not None: | |
| loss += cfg.desc_loss_weight * loss_desc_stack.mean() | |
| pbar.set_description( | |
| f"LP: {loss.item():.4f} - Det: ({sum_num_keypoints_1 / batch_size:.4f}, {sum_num_keypoints_2 / batch_size:.4f}), #mRwd: {sum_reward_batch / batch_size:.1f}" | |
| ) | |
| pbar.update() | |
| # backward pass | |
| loss /= num_grad_accs | |
| fabric.backward(loss) | |
| if i_step % num_grad_accs == 0: | |
| opt_pi.step() | |
| opt_pi.zero_grad() | |
| if i_step % cfg.log_interval == 0: | |
| wandb_logger.log( | |
| { | |
| # "loss": loss.item() if not use_amp else scaled_loss.item(), | |
| "loss": loss.item(), | |
| "loss_policy": -loss_policy_stack.mean().item(), | |
| "loss_kp": loss_kp_stack.mean().item() if loss_kp_stack is not None else 0.0, | |
| "loss_hard": (loss_desc_stack.mean().item() if loss_desc_stack is not None else 0.0), | |
| "mean_num_det_kpts1": sum_num_keypoints_1 / batch_size, | |
| "mean_num_det_kpts2": sum_num_keypoints_2 / batch_size, | |
| "mean_reward": sum_reward_batch / batch_size, | |
| "lr": opt_pi.param_groups[0]["lr"], | |
| "ma_skipped_batches": sum(ma_skipped_batches) / len(ma_skipped_batches), | |
| "inl_th": inl_th, | |
| }, | |
| step=i_step, | |
| ) | |
| if i_step % cfg.val_interval == 0: | |
| val_benchmark.evaluate(net, fabric.device, progress_bar=False) | |
| val_benchmark.log_results(logger=wandb_logger, step=i_step) | |
| # ensure that the model is in training mode again | |
| net.train() | |
| # save the model | |
| torch.save( | |
| net.state_dict(), | |
| output_dir / ("model" + "_" + str(i_step + 1) + "_final" + ".pth"), | |
| ) | |
| if __name__ == "__main__": | |
| train() | |