import torch import torch.nn as nn class Exp: """ Configuration class for the page element model. """ def __init__(self): self.name = "page-element-v3" self.ckpt = "weights.pth" self.device = "cuda:0" if torch.cuda.is_available() else "cpu" # YOLOX architecture parameters self.act = "silu" self.depth = 1.00 self.width = 1.00 self.labels = ["table", "chart", "title", "infographic", "paragraph", "header_footer"] self.num_classes = len(self.labels) # Inference parameters self.size = (1024, 1024) self.min_bbox_size = 0 self.normalize_boxes = True # NMS & thresholding. These can be updated self.conf_thresh = 0.01 self.iou_thresh = 0.5 self.class_agnostic = True self.thresholds_per_class = { "table": 0.1, "chart": 0.01, "infographic": 0.01, "title": 0.1, "paragraph": 0.1, "header_footer": 0.1, } def get_model(self): """ Get the YOLOX model. """ from yolox import YOLOX, YOLOPAFPN, YOLOXHead # Build model if getattr(self, "model", None) is None: in_channels = [256, 512, 1024] backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels, act=self.act) head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels, act=self.act) self.model = YOLOX(backbone, head) # Update batch-norm parameters def init_yolo(M): for m in M.modules(): if isinstance(m, nn.BatchNorm2d): m.eps = 1e-3 m.momentum = 0.03 self.model.apply(init_yolo) return self.model