File size: 5,773 Bytes
0dce87a adfc8b9 512618e 0dce87a 512618e 0dce87a adfc8b9 0dce87a adfc8b9 0dce87a 512618e 0dce87a 512618e adfc8b9 0dce87a adfc8b9 0dce87a adfc8b9 0dce87a 512618e 0dce87a 512618e 0dce87a adfc8b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import math
import typing as t
from functools import partial
from pathlib import Path
import torch
import torch.nn as nn
from cucim import CuImage
from huggingface_hub import PyTorchModelHubMixin
from torchvision.transforms import functional as TF
from torchvision.transforms import v2 as T
from networks.vit import vit4k_base, vit_base, vit_global_base
from utils.tensor_utils import (
format_first_stg_act_as_second_stg_inp,
format_second_stg_act_as_third_stg_inp,
forward_with_batch_size_limit,
scale_and_normalize,
tile,
)
from utils.wsi_utils import load_slide_img, segment_tissue
if t.TYPE_CHECKING:
from _typeshed import StrPath
class Transform(T.Transform):
# For compatibility with torchvision <= 0.20
def _transform(self, inpt, params):
return self.transform(inpt, params)
class PadToDivisible(Transform):
def __init__(self, size: int, pad_value: float | None = None):
super().__init__()
self.size = size
self.pad_value = pad_value
def transform(self, inpt, params):
assert isinstance(inpt, torch.Tensor) and inpt.ndim >= 3
H, W = inpt.shape[-2:]
pad_h = (self.size - H % self.size) % self.size
pad_w = (self.size - W % self.size) % self.size
if pad_h > 0 or pad_w > 0:
inpt = torch.nn.functional.pad(
inpt, (0, pad_w, 0, pad_h), value=self.pad_value
)
return inpt
class EXAONEPathV20(nn.Module, PyTorchModelHubMixin):
def __init__(
self,
small_tile_size: int = 256,
large_tile_size: int = 4096,
):
super().__init__()
self.small_tile_size = small_tile_size
self.large_tile_size = large_tile_size
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model_first_stg = vit_base().to(self.device).eval()
self.model_second_stg = vit4k_base().to(self.device).eval()
self.model_third_stg = vit_global_base().to(self.device).eval()
def forward(
self,
svs_path: "StrPath",
target_mpp: float = 0.5,
first_stg_batch_size: int = 128,
):
small_tiles, is_tile_valid, padded_size, small_tile_size, large_tile_size = (
self._load_wsi(svs_path, target_mpp=target_mpp)
)
width, height = padded_size
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
with torch.no_grad():
act1 = forward_with_batch_size_limit(
self.model_first_stg,
small_tiles,
batch_size_on_gpu=first_stg_batch_size,
preproc_fn=partial(
_preproc,
small_tile_size_with_this_mpp=small_tile_size,
small_tile_size_with_target_mpp=self.small_tile_size,
),
device=self.device,
out_device="cpu",
dtype=torch.bfloat16,
)
act1 = act1.to(self.device)
act1_formatted = format_first_stg_act_as_second_stg_inp(
act1,
height=height,
width=width,
small_tile_size=small_tile_size,
large_tile_size=large_tile_size,
)
act2: torch.Tensor = self.model_second_stg(act1_formatted)
act2_formatted = format_second_stg_act_as_third_stg_inp(
act2,
height=height,
width=width,
large_tile_size=large_tile_size,
)
act3: torch.Tensor = self.model_third_stg(act2_formatted)
return act1[is_tile_valid], act2, act3
def _load_wsi(self, svs_path: "StrPath", target_mpp: float):
svs_path = str(svs_path)
# Load WSI tile
with CuImage(str(svs_path)) as wsi_obj:
try:
mpp = float(wsi_obj.metadata["aperio"]["MPP"])
except KeyError:
print(
f"Warning: MPP metadata not found, using default value of {target_mpp}"
)
mpp = target_mpp
img = load_slide_img(wsi_obj)
height, width = img.shape[:2]
mask_tensor = torch.from_numpy(
segment_tissue(Path(svs_path), seg_level=-1)[0]
)
mask_tensor = TF.resize(mask_tensor.unsqueeze(0), [height, width]).squeeze(
0
)
x: torch.Tensor = torch.from_numpy(img).permute(2, 0, 1)
small_tile_size = math.ceil(self.small_tile_size * (target_mpp / mpp))
large_tile_size = (
self.large_tile_size // self.small_tile_size
) * small_tile_size
pad_image = PadToDivisible(large_tile_size, 255)
pad_mask = PadToDivisible(large_tile_size, 0)
x = pad_image(x)
padded_size = (x.size(-1), x.size(-2))
x = tile(x, small_tile_size)
mask_padded = pad_mask(mask_tensor.unsqueeze(0))
mask_tile = tile(mask_padded, small_tile_size).squeeze(1)
is_tile_valid = mask_tile.sum(dim=(1, 2)) > 0
return x, is_tile_valid, padded_size, small_tile_size, large_tile_size
def _preproc(
x: torch.Tensor,
small_tile_size_with_this_mpp: int,
small_tile_size_with_target_mpp: int,
):
# Scale the input tensor to the target MPP
if small_tile_size_with_this_mpp != small_tile_size_with_target_mpp:
x = TF.resize(
x,
[small_tile_size_with_target_mpp, small_tile_size_with_target_mpp],
)
# Normalize the input tensor
x = scale_and_normalize(x)
return x
|