File size: 1,820 Bytes
9954323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import torch.nn as nn


class Exp:
    """
    Configuration class for the page element model.
    """
    def __init__(self):
        self.name = "page-element-v3"
        self.ckpt = "weights.pth"
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

        # YOLOX architecture parameters
        self.act = "silu"
        self.depth = 1.00
        self.width = 1.00
        self.labels = ["table", "chart", "title", "infographic", "paragraph", "header_footer"]
        self.num_classes = len(self.labels)

        # Inference parameters
        self.size = (1024, 1024)
        self.min_bbox_size = 0
        self.normalize_boxes = True

        # NMS & thresholding. These can be updated
        self.conf_thresh = 0.01
        self.iou_thresh = 0.5
        self.class_agnostic = True

        self.thresholds_per_class = {
            "table": 0.1,
            "chart": 0.01,
            "infographic": 0.01,
            "title": 0.1,
            "paragraph": 0.1,
            "header_footer": 0.1,
        }

    def get_model(self):
        """
        Get the YOLOX model.
        """
        from yolox import YOLOX, YOLOPAFPN, YOLOXHead

        # Build model
        if getattr(self, "model", None) is None:
            in_channels = [256, 512, 1024]
            backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels, act=self.act)
            head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels, act=self.act)
            self.model = YOLOX(backbone, head)

        # Update batch-norm parameters
        def init_yolo(M):
            for m in M.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eps = 1e-3
                    m.momentum = 0.03
        self.model.apply(init_yolo)

        return self.model