File size: 1,820 Bytes
9954323 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch
import torch.nn as nn
class Exp:
"""
Configuration class for the page element model.
"""
def __init__(self):
self.name = "page-element-v3"
self.ckpt = "weights.pth"
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
# YOLOX architecture parameters
self.act = "silu"
self.depth = 1.00
self.width = 1.00
self.labels = ["table", "chart", "title", "infographic", "paragraph", "header_footer"]
self.num_classes = len(self.labels)
# Inference parameters
self.size = (1024, 1024)
self.min_bbox_size = 0
self.normalize_boxes = True
# NMS & thresholding. These can be updated
self.conf_thresh = 0.01
self.iou_thresh = 0.5
self.class_agnostic = True
self.thresholds_per_class = {
"table": 0.1,
"chart": 0.01,
"infographic": 0.01,
"title": 0.1,
"paragraph": 0.1,
"header_footer": 0.1,
}
def get_model(self):
"""
Get the YOLOX model.
"""
from yolox import YOLOX, YOLOPAFPN, YOLOXHead
# Build model
if getattr(self, "model", None) is None:
in_channels = [256, 512, 1024]
backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels, act=self.act)
head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels, act=self.act)
self.model = YOLOX(backbone, head)
# Update batch-norm parameters
def init_yolo(M):
for m in M.modules():
if isinstance(m, nn.BatchNorm2d):
m.eps = 1e-3
m.momentum = 0.03
self.model.apply(init_yolo)
return self.model
|