nemotron-page-elements-v3 / page_element_v3.py
Theo Viel
add code
9954323
raw
history blame
1.82 kB
import torch
import torch.nn as nn
class Exp:
"""
Configuration class for the page element model.
"""
def __init__(self):
self.name = "page-element-v3"
self.ckpt = "weights.pth"
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
# YOLOX architecture parameters
self.act = "silu"
self.depth = 1.00
self.width = 1.00
self.labels = ["table", "chart", "title", "infographic", "paragraph", "header_footer"]
self.num_classes = len(self.labels)
# Inference parameters
self.size = (1024, 1024)
self.min_bbox_size = 0
self.normalize_boxes = True
# NMS & thresholding. These can be updated
self.conf_thresh = 0.01
self.iou_thresh = 0.5
self.class_agnostic = True
self.thresholds_per_class = {
"table": 0.1,
"chart": 0.01,
"infographic": 0.01,
"title": 0.1,
"paragraph": 0.1,
"header_footer": 0.1,
}
def get_model(self):
"""
Get the YOLOX model.
"""
from yolox import YOLOX, YOLOPAFPN, YOLOXHead
# Build model
if getattr(self, "model", None) is None:
in_channels = [256, 512, 1024]
backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels, act=self.act)
head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels, act=self.act)
self.model = YOLOX(backbone, head)
# Update batch-norm parameters
def init_yolo(M):
for m in M.modules():
if isinstance(m, nn.BatchNorm2d):
m.eps = 1e-3
m.momentum = 0.03
self.model.apply(init_yolo)
return self.model