File size: 5,773 Bytes
0dce87a
 
adfc8b9
512618e
0dce87a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512618e
0dce87a
 
 
 
 
adfc8b9
 
 
 
 
 
 
0dce87a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adfc8b9
 
0dce87a
 
 
 
512618e
0dce87a
 
512618e
adfc8b9
0dce87a
 
 
 
 
 
adfc8b9
0dce87a
 
 
 
 
 
 
adfc8b9
0dce87a
 
 
512618e
 
0dce87a
 
 
 
 
 
 
 
 
 
 
 
512618e
 
 
0dce87a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adfc8b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import math
import typing as t
from functools import partial
from pathlib import Path

import torch
import torch.nn as nn
from cucim import CuImage
from huggingface_hub import PyTorchModelHubMixin
from torchvision.transforms import functional as TF
from torchvision.transforms import v2 as T

from networks.vit import vit4k_base, vit_base, vit_global_base
from utils.tensor_utils import (
    format_first_stg_act_as_second_stg_inp,
    format_second_stg_act_as_third_stg_inp,
    forward_with_batch_size_limit,
    scale_and_normalize,
    tile,
)
from utils.wsi_utils import load_slide_img, segment_tissue

if t.TYPE_CHECKING:
    from _typeshed import StrPath


class Transform(T.Transform):
    # For compatibility with torchvision <= 0.20
    def _transform(self, inpt, params):
        return self.transform(inpt, params)


class PadToDivisible(Transform):
    def __init__(self, size: int, pad_value: float | None = None):
        super().__init__()
        self.size = size
        self.pad_value = pad_value

    def transform(self, inpt, params):
        assert isinstance(inpt, torch.Tensor) and inpt.ndim >= 3

        H, W = inpt.shape[-2:]

        pad_h = (self.size - H % self.size) % self.size
        pad_w = (self.size - W % self.size) % self.size

        if pad_h > 0 or pad_w > 0:
            inpt = torch.nn.functional.pad(
                inpt, (0, pad_w, 0, pad_h), value=self.pad_value
            )

        return inpt


class EXAONEPathV20(nn.Module, PyTorchModelHubMixin):
    def __init__(
        self,
        small_tile_size: int = 256,
        large_tile_size: int = 4096,
    ):
        super().__init__()

        self.small_tile_size = small_tile_size
        self.large_tile_size = large_tile_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.model_first_stg = vit_base().to(self.device).eval()
        self.model_second_stg = vit4k_base().to(self.device).eval()
        self.model_third_stg = vit_global_base().to(self.device).eval()

    def forward(
        self,
        svs_path: "StrPath",
        target_mpp: float = 0.5,
        first_stg_batch_size: int = 128,
    ):
        small_tiles, is_tile_valid, padded_size, small_tile_size, large_tile_size = (
            self._load_wsi(svs_path, target_mpp=target_mpp)
        )
        width, height = padded_size

        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            with torch.no_grad():
                act1 = forward_with_batch_size_limit(
                    self.model_first_stg,
                    small_tiles,
                    batch_size_on_gpu=first_stg_batch_size,
                    preproc_fn=partial(
                        _preproc,
                        small_tile_size_with_this_mpp=small_tile_size,
                        small_tile_size_with_target_mpp=self.small_tile_size,
                    ),
                    device=self.device,
                    out_device="cpu",
                    dtype=torch.bfloat16,
                )
                act1 = act1.to(self.device)
                act1_formatted = format_first_stg_act_as_second_stg_inp(
                    act1,
                    height=height,
                    width=width,
                    small_tile_size=small_tile_size,
                    large_tile_size=large_tile_size,
                )
                act2: torch.Tensor = self.model_second_stg(act1_formatted)
                act2_formatted = format_second_stg_act_as_third_stg_inp(
                    act2,
                    height=height,
                    width=width,
                    large_tile_size=large_tile_size,
                )
                act3: torch.Tensor = self.model_third_stg(act2_formatted)

        return act1[is_tile_valid], act2, act3

    def _load_wsi(self, svs_path: "StrPath", target_mpp: float):
        svs_path = str(svs_path)

        # Load WSI tile
        with CuImage(str(svs_path)) as wsi_obj:
            try:
                mpp = float(wsi_obj.metadata["aperio"]["MPP"])
            except KeyError:
                print(
                    f"Warning: MPP metadata not found, using default value of {target_mpp}"
                )
                mpp = target_mpp

            img = load_slide_img(wsi_obj)
            height, width = img.shape[:2]
            mask_tensor = torch.from_numpy(
                segment_tissue(Path(svs_path), seg_level=-1)[0]
            )
            mask_tensor = TF.resize(mask_tensor.unsqueeze(0), [height, width]).squeeze(
                0
            )
            x: torch.Tensor = torch.from_numpy(img).permute(2, 0, 1)

        small_tile_size = math.ceil(self.small_tile_size * (target_mpp / mpp))
        large_tile_size = (
            self.large_tile_size // self.small_tile_size
        ) * small_tile_size
        pad_image = PadToDivisible(large_tile_size, 255)
        pad_mask = PadToDivisible(large_tile_size, 0)

        x = pad_image(x)
        padded_size = (x.size(-1), x.size(-2))

        x = tile(x, small_tile_size)
        mask_padded = pad_mask(mask_tensor.unsqueeze(0))
        mask_tile = tile(mask_padded, small_tile_size).squeeze(1)
        is_tile_valid = mask_tile.sum(dim=(1, 2)) > 0

        return x, is_tile_valid, padded_size, small_tile_size, large_tile_size


def _preproc(
    x: torch.Tensor,
    small_tile_size_with_this_mpp: int,
    small_tile_size_with_target_mpp: int,
):
    # Scale the input tensor to the target MPP
    if small_tile_size_with_this_mpp != small_tile_size_with_target_mpp:
        x = TF.resize(
            x,
            [small_tile_size_with_target_mpp, small_tile_size_with_target_mpp],
        )

    # Normalize the input tensor
    x = scale_and_normalize(x)

    return x