manu02 commited on
Commit
5fe125f
·
verified ·
1 Parent(s): 4722048

Update utils/complete_model.py

Browse files
Files changed (1) hide show
  1. utils/complete_model.py +490 -276
utils/complete_model.py CHANGED
@@ -1,276 +1,490 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- from transformers import AutoModel, GPT2Tokenizer
5
-
6
- from utils.modifiedGPT2 import create_decoder
7
-
8
- from utils.layer_mask import gaussian_layer_stack_pipeline
9
-
10
-
11
- class DINOEncoder(nn.Module):
12
- def __init__(self, model_id="facebook/dinov3-vits16-pretrain-lvd1689m", freeze=True):
13
- super().__init__()
14
- self.model = AutoModel.from_pretrained(model_id)
15
- if freeze:
16
- for p in self.model.parameters():
17
- p.requires_grad = False
18
- @torch.no_grad()
19
- def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
20
- """
21
- pixel_values: [B, C, H, W]
22
- returns patches: [B, Np, Cenc]
23
- """
24
- out = self.model(pixel_values=pixel_values)
25
- tokens = out.last_hidden_state # [B, 1+Np, Cenc] (CLS + patches) for ViT-like
26
- # Skip a few special tokens if your backbone adds them; adjust as needed.
27
- patches = tokens[:, 5:, :] # [B, Np, Cenc]
28
- return patches
29
-
30
- class DinoUNet(nn.Module):
31
- def __init__(self, model_name="facebook/dinov3-convnext-small-pretrain-lvd1689m", freeze=True):
32
- super().__init__()
33
- self.encoder = AutoModel.from_pretrained(model_name)
34
- # NOTE: confirm channels of the chosen hidden state; 768 is common for small convnext/dinov3
35
- self.channel_adapter = nn.Conv2d(768, 512, kernel_size=1)
36
- self.decoder = nn.Sequential(
37
- nn.Conv2d(512, 256, 3, padding=1), nn.ReLU(inplace=True),
38
- nn.ConvTranspose2d(256, 128, 2, stride=2), nn.ReLU(inplace=True),
39
- nn.ConvTranspose2d(128, 64, 2, stride=2), nn.ReLU(inplace=True),
40
- nn.Conv2d(64, 1, 1)
41
- )
42
- if freeze:
43
- for m in (self.encoder, self.channel_adapter, self.decoder):
44
- for p in m.parameters():
45
- p.requires_grad = False
46
-
47
- @torch.no_grad()
48
- def forward(self, x: torch.Tensor, num_layers: int) -> torch.Tensor:
49
- """
50
- x: [B, C, H, W]; returns mask: [B, 1, H', W'] (your upsampling stack defines H',W')
51
- """
52
- enc_feats = self.encoder(x, output_hidden_states=True, return_dict=True)
53
- # take the last 4D feature map from hidden_states
54
- feats = next(h for h in reversed(enc_feats.hidden_states) if isinstance(h, torch.Tensor) and h.ndim == 4)
55
- feats = self.channel_adapter(feats)
56
- pred = self.decoder(feats) # (B,1,h,w)
57
- _, _, segmentation_mask = gaussian_layer_stack_pipeline(pred, n_layers = num_layers)
58
- return segmentation_mask # [B, num_layers, h, w]
59
-
60
-
61
- class LinearProjection(nn.Module):
62
- def __init__(self, input_dim=384, output_dim=768, freeze=False):
63
- super().__init__()
64
- self.proj = nn.Linear(input_dim, output_dim)
65
- if freeze:
66
- for p in self.proj.parameters():
67
- p.requires_grad = False
68
-
69
- def forward(self, x: torch.Tensor) -> torch.Tensor:
70
- # x: [B, Np, input_dim] -> [B, Np, output_dim]
71
- return self.proj(x)
72
-
73
-
74
- class CustomModel(nn.Module):
75
- def __init__(
76
- self,
77
- device: str = "cuda",
78
- ENCODER_MODEL_PATH: str | None = "dino_encoder.pth",
79
- SEGMENTER_MODEL_PATH: str | None = "dino_segmenter.pth",
80
- DECODER_MODEL_PATH: str | None = "dino_decoder.pth",
81
- LINEAR_PROJECTION_PATH: str | None = "linear_projection.pth",
82
- freeze_encoder: bool = True,
83
- freeze_segmenter: bool = True,
84
- freeze_linear_projection: bool = False,
85
- freeze_decoder: bool = False,
86
- attention_implementation: str = "sdpa",
87
- ):
88
- super().__init__()
89
- self.device = torch.device(device)
90
-
91
- # Encoder
92
- self.encoder = DINOEncoder()
93
- if ENCODER_MODEL_PATH and os.path.exists(ENCODER_MODEL_PATH):
94
- self.encoder.load_state_dict(torch.load(ENCODER_MODEL_PATH, map_location="cpu"), strict=False)
95
- print("Loaded encoder weights from", ENCODER_MODEL_PATH)
96
- if freeze_encoder:
97
- self.encoder.eval()
98
-
99
- # Segmenter
100
- self.segmenter = DinoUNet()
101
- if SEGMENTER_MODEL_PATH and os.path.exists(SEGMENTER_MODEL_PATH):
102
- self.segmenter.load_state_dict(torch.load(SEGMENTER_MODEL_PATH, map_location="cpu"), strict=False)
103
- print("Loaded segmenter weights from", SEGMENTER_MODEL_PATH)
104
- if freeze_segmenter:
105
- self.segmenter.eval()
106
-
107
- # Decoder (modified GPT-2)
108
- self.decoder = create_decoder(attention=attention_implementation) # must expose .config.hidden_size & .config.num_hidden_layers
109
- if DECODER_MODEL_PATH and os.path.exists(DECODER_MODEL_PATH):
110
- self.decoder.load_state_dict(torch.load(DECODER_MODEL_PATH, map_location="cpu"), strict=False)
111
- print("Loaded decoder weights from", DECODER_MODEL_PATH)
112
- if freeze_decoder:
113
- self.decoder.eval()
114
-
115
- # Linear projection: DINO hidden -> GPT2 hidden
116
- enc_h = self.encoder.model.config.hidden_size
117
- dec_h = self.decoder.config.hidden_size
118
- self.linear_projection = LinearProjection(input_dim=enc_h, output_dim=dec_h)
119
- if LINEAR_PROJECTION_PATH and os.path.exists(LINEAR_PROJECTION_PATH):
120
- self.linear_projection.load_state_dict(torch.load(LINEAR_PROJECTION_PATH, map_location="cpu"), strict=False)
121
- print("Loaded linear projection weights from", LINEAR_PROJECTION_PATH)
122
- if freeze_linear_projection:
123
- self.linear_projection.eval()
124
-
125
- # Tokenizer (pad token for GPT-2)
126
- self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
127
- if self.tokenizer.pad_token_id is None:
128
- self.tokenizer.pad_token = self.tokenizer.eos_token
129
- self.pad_token_id = self.tokenizer.pad_token_id # ✅ use ID, not string
130
-
131
- self.num_layers = self.decoder.config.num_hidden_layers
132
-
133
- # move everything once
134
- self.to(self.device)
135
-
136
- def forward(self, pixel_values: torch.Tensor, tgt_ids: torch.Tensor | None = None, **kwargs) -> dict:
137
- """
138
- pixel_values: [B,C,H,W], float
139
- tgt_ids: [B,T], long (token IDs), padded with pad_token_id if any padding is present
140
- """
141
- pixel_values = pixel_values.to(self.device, non_blocking=True)
142
-
143
- # Visual path
144
- patches = self.encoder(pixel_values) # [B,Np,Cenc]
145
- projected_patches = self.linear_projection(patches) # [B,Np,n_embd]
146
-
147
- # Segmentation path per layer
148
- segmented_layers = self.segmenter(pixel_values, self.num_layers) # [B,n_layers,H,W] (per current decoder)
149
-
150
- # Text path (optional teacher-forced training)
151
- labels = None
152
- if tgt_ids is not None:
153
- if tgt_ids.dtype != torch.long:
154
- tgt_ids = tgt_ids.long()
155
- tgt_ids = tgt_ids.to(self.device, non_blocking=True) # [B,T]
156
- text_embeds = self.decoder.transformer.wte(tgt_ids) # [B,T,n_embd]
157
- inputs_embeds = torch.cat([projected_patches, text_embeds], dim=1) # [B,Np+T,n_embd]
158
-
159
- # Labels: ignore prefix tokens (vision) and PADs in text
160
- B, Np, _ = projected_patches.shape
161
- labels_prefix = torch.full((B, Np), -100, device=self.device, dtype=torch.long)
162
- text_labels = tgt_ids.clone()
163
- text_labels[text_labels == self.pad_token_id] = -100 # ✅ compare to ID
164
- labels = torch.cat([labels_prefix, text_labels], dim=1) # [B,Np+T]
165
- else:
166
- inputs_embeds = projected_patches
167
-
168
- # Decoder forward
169
- out = self.decoder(inputs_embeds=inputs_embeds, segmentation_mask=segmented_layers, labels=labels, **kwargs)
170
- return out
171
-
172
- @torch.inference_mode()
173
- def generate(
174
- self,
175
- pixel_values: torch.Tensor,
176
- max_new_tokens: int = 100,
177
- output_attentions: bool = False,
178
- ) -> torch.Tensor:
179
- """
180
- pixel_values: [B,C,H,W], float
181
- returns generated_ids: [B, T]
182
- """
183
- pixel_values = pixel_values.to(self.device, non_blocking=True)
184
-
185
- # Visual path
186
- patches = self.encoder(pixel_values) # [B,Np,Cenc]
187
- projected_patches = self.linear_projection(patches) # [B,Np,n_embd]
188
-
189
- # Segmentation path per layer
190
- segmented_layers = self.segmenter(pixel_values, self.num_layers) # [B,n_layers,H,W] (per current decoder)
191
-
192
- # Generate
193
- output = self.decoder.generate(
194
- inputs_embeds=projected_patches,
195
- max_new_tokens=max_new_tokens,
196
- do_sample=False,
197
- repetition_penalty=1.2,
198
- eos_token_id=self.tokenizer.eos_token_id,
199
- pad_token_id=self.pad_token_id,
200
- use_cache=True,
201
- segmentation_mask=segmented_layers,
202
- prefix_allowed_length=0,
203
- plot_attention_mask=False,
204
- plot_attention_mask_layer=[],
205
- plot_attention_map=False,
206
- plot_attention_map_layer=[],
207
- plot_attention_map_generation=0,
208
- output_attentions=output_attentions,
209
- return_dict_in_generate=True,
210
- )
211
- # Remove prefix tokens (vision)
212
- generated_ids = output.sequences#[:, projected_patches.shape[1]:] # [B,T]
213
- generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
214
- return generated_ids, generated_text, output.attentions if output_attentions else None
215
-
216
- def create_complete_model(device: str = "cuda", **kwargs) -> CustomModel:
217
- model = CustomModel(device=device, **kwargs)
218
- return model
219
-
220
- def save_complete_model(model: CustomModel, save_path: str, device: str = "cuda") -> None:
221
- # Ensure folder exists
222
- os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
223
-
224
- # Save on CPU to keep checkpoint portable
225
- orig_device = next(model.parameters()).device
226
- model.to("cpu")
227
- torch.save(model.state_dict(), save_path)
228
- print(f"Saved complete model weights to {save_path}")
229
-
230
- # Restore model device
231
- model.to(device if isinstance(device, str) else orig_device)
232
-
233
- def save_checkpoint(model: CustomModel, optimizer: torch.optim.Optimizer, save_path: str) -> None:
234
- # Ensure folder exists
235
- os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
236
-
237
- checkpoint = {
238
- "model_state_dict": model.state_dict(),
239
- "optimizer_state_dict": optimizer.state_dict(),
240
- }
241
- torch.save(checkpoint, save_path)
242
- print(f"Saved checkpoint to {save_path}")
243
-
244
- def load_complete_model(model: CustomModel, load_path: str, device: str = "cpu", strict: bool = True) -> CustomModel:
245
- if not os.path.exists(load_path):
246
- print(f"No weights found at {load_path}")
247
- model.to(device)
248
- return model
249
-
250
- # Load to CPU first, then move to target device
251
- state = torch.load(load_path, map_location="cpu")
252
- missing, unexpected = model.load_state_dict(state, strict=strict)
253
- if not strict:
254
- if missing:
255
- print(f"[load warning] Missing keys: {missing}")
256
- if unexpected:
257
- print(f"[load warning] Unexpected keys: {unexpected}")
258
-
259
- model.to(device)
260
- print(f"Loaded complete model weights from {load_path}")
261
- return model
262
-
263
- def load_checkpoint(model: CustomModel, optimizer: torch.optim.Optimizer, load_path: str, device: str = "cpu") -> tuple[CustomModel, torch.optim.Optimizer]:
264
- if not os.path.exists(load_path):
265
- print(f"No checkpoint found at {load_path}")
266
- model.to(device)
267
- return model, optimizer
268
-
269
- # Load to CPU first, then move to target device
270
- checkpoint = torch.load(load_path, map_location="cpu")
271
- model.load_state_dict(checkpoint["model_state_dict"])
272
- optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
273
-
274
- model.to(device)
275
- print(f"Loaded checkpoint from {load_path}")
276
- return model, optimizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import AutoModel, GPT2Tokenizer
5
+
6
+ from utils.modifiedGPT2 import create_decoder
7
+
8
+ from utils.layer_mask import gaussian_layer_stack_pipeline
9
+
10
+ class DINOEncoder(nn.Module):
11
+ def __init__(self, model_id="facebook/dinov3-vits16-pretrain-lvd1689m", freeze=True):
12
+ super().__init__()
13
+ self.model = AutoModel.from_pretrained(model_id)
14
+ if freeze:
15
+ for p in self.model.parameters():
16
+ p.requires_grad = False
17
+ @torch.no_grad()
18
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
19
+ """
20
+ pixel_values: [B, C, H, W]
21
+ returns patches: [B, Np, Cenc]
22
+ """
23
+ out = self.model(pixel_values=pixel_values)
24
+ tokens = out.last_hidden_state # [B, 1+Np, Cenc] (CLS + patches) for ViT-like
25
+ # Skip a few special tokens if your backbone adds them; adjust as needed.
26
+ patches = tokens[:, 5:, :] # [B, Np, Cenc]
27
+ return patches
28
+
29
+ class DinoUNet(nn.Module):
30
+ def __init__(self, model_name="facebook/dinov3-convnext-small-pretrain-lvd1689m", freeze=True):
31
+ super().__init__()
32
+ self.encoder = AutoModel.from_pretrained(model_name)
33
+ # NOTE: confirm channels of the chosen hidden state; 768 is common for small convnext/dinov3
34
+ self.channel_adapter = nn.Conv2d(768, 512, kernel_size=1)
35
+ self.decoder = nn.Sequential(
36
+ nn.Conv2d(512, 256, 3, padding=1), nn.ReLU(inplace=True),
37
+ nn.ConvTranspose2d(256, 128, 2, stride=2), nn.ReLU(inplace=True),
38
+ nn.ConvTranspose2d(128, 64, 2, stride=2), nn.ReLU(inplace=True),
39
+ nn.Conv2d(64, 1, 1)
40
+ )
41
+ if freeze:
42
+ for m in (self.encoder, self.channel_adapter, self.decoder):
43
+ for p in m.parameters():
44
+ p.requires_grad = False
45
+
46
+ @torch.no_grad()
47
+ def forward(self, x: torch.Tensor, num_layers: int) -> torch.Tensor:
48
+ """
49
+ x: [B, C, H, W]; returns mask: [B, 1, H', W'] (your upsampling stack defines H',W')
50
+ """
51
+ enc_feats = self.encoder(x, output_hidden_states=True, return_dict=True)
52
+ # take the last 4D feature map from hidden_states
53
+ feats = next(h for h in reversed(enc_feats.hidden_states) if isinstance(h, torch.Tensor) and h.ndim == 4)
54
+ feats = self.channel_adapter(feats)
55
+ pred = self.decoder(feats) # (B,1,h,w)
56
+ _, _, segmentation_mask = gaussian_layer_stack_pipeline(pred, n_layers = num_layers)
57
+ return segmentation_mask # [B, num_layers, h, w]
58
+
59
+
60
+ class LinearProjection(nn.Module):
61
+ def __init__(self, input_dim=384, output_dim=768, freeze=False):
62
+ super().__init__()
63
+ self.proj = nn.Linear(input_dim, output_dim)
64
+ if freeze:
65
+ for p in self.proj.parameters():
66
+ p.requires_grad = False
67
+
68
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
+ # x: [B, Np, input_dim] -> [B, Np, output_dim]
70
+ return self.proj(x)
71
+
72
+
73
+ class CustomModel(nn.Module):
74
+ def __init__(
75
+ self,
76
+ device: str = "cuda",
77
+ ENCODER_MODEL_PATH: str | None = "dino_encoder.pth",
78
+ SEGMENTER_MODEL_PATH: str | None = "dino_segmenter.pth",
79
+ DECODER_MODEL_PATH: str | None = "dino_decoder.pth",
80
+ LINEAR_PROJECTION_PATH: str | None = "linear_projection.pth",
81
+ freeze_encoder: bool = True,
82
+ freeze_segmenter: bool = True,
83
+ freeze_linear_projection: bool = False,
84
+ freeze_decoder: bool = False,
85
+ attention_implementation: str = "sdpa",
86
+ ):
87
+ super().__init__()
88
+ self.device = torch.device(device)
89
+
90
+ # Encoder
91
+ self.encoder = DINOEncoder()
92
+ if ENCODER_MODEL_PATH and os.path.exists(ENCODER_MODEL_PATH):
93
+ self.encoder.load_state_dict(torch.load(ENCODER_MODEL_PATH, map_location="cpu"), strict=False)
94
+ print("Loaded encoder weights from", ENCODER_MODEL_PATH)
95
+ if freeze_encoder:
96
+ self.encoder.eval()
97
+
98
+ # Segmenter
99
+ self.segmenter = DinoUNet()
100
+ if SEGMENTER_MODEL_PATH and os.path.exists(SEGMENTER_MODEL_PATH):
101
+ self.segmenter.load_state_dict(torch.load(SEGMENTER_MODEL_PATH, map_location="cpu"), strict=False)
102
+ print("Loaded segmenter weights from", SEGMENTER_MODEL_PATH)
103
+ if freeze_segmenter:
104
+ self.segmenter.eval()
105
+
106
+ # Decoder (modified GPT-2)
107
+ self.decoder = create_decoder(attention=attention_implementation) # must expose .config.hidden_size & .config.num_hidden_layers
108
+ if DECODER_MODEL_PATH and os.path.exists(DECODER_MODEL_PATH):
109
+ self.decoder.load_state_dict(torch.load(DECODER_MODEL_PATH, map_location="cpu"), strict=False)
110
+ print("Loaded decoder weights from", DECODER_MODEL_PATH)
111
+ if freeze_decoder:
112
+ self.decoder.eval()
113
+
114
+ # Linear projection: DINO hidden -> GPT2 hidden
115
+ enc_h = self.encoder.model.config.hidden_size
116
+ dec_h = self.decoder.config.hidden_size
117
+ self.linear_projection = LinearProjection(input_dim=enc_h, output_dim=dec_h)
118
+ if LINEAR_PROJECTION_PATH and os.path.exists(LINEAR_PROJECTION_PATH):
119
+ self.linear_projection.load_state_dict(torch.load(LINEAR_PROJECTION_PATH, map_location="cpu"), strict=False)
120
+ print("Loaded linear projection weights from", LINEAR_PROJECTION_PATH)
121
+ if freeze_linear_projection:
122
+ self.linear_projection.eval()
123
+
124
+ # Tokenizer (pad token for GPT-2)
125
+ self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
126
+ if self.tokenizer.pad_token_id is None:
127
+ self.tokenizer.pad_token = self.tokenizer.eos_token
128
+ self.pad_token_id = self.tokenizer.pad_token_id # ✅ use ID, not string
129
+
130
+ self.num_layers = self.decoder.config.num_hidden_layers
131
+
132
+ # move everything once
133
+ self.to(self.device)
134
+
135
+ def forward(self, pixel_values: torch.Tensor, tgt_ids: torch.Tensor | None = None, **kwargs) -> dict:
136
+ """
137
+ pixel_values: [B,C,H,W], float
138
+ tgt_ids: [B,T], long (token IDs), padded with pad_token_id if any padding is present
139
+ """
140
+ pixel_values = pixel_values.to(self.device, non_blocking=True)
141
+
142
+ # Visual path
143
+ patches = self.encoder(pixel_values) # [B,Np,Cenc]
144
+ projected_patches = self.linear_projection(patches) # [B,Np,n_embd]
145
+
146
+ # Segmentation path per layer
147
+ segmented_layers = self.segmenter(pixel_values, self.num_layers) # [B,n_layers,H,W] (per current decoder)
148
+
149
+ # Text path (optional teacher-forced training)
150
+ labels = None
151
+ if tgt_ids is not None:
152
+ if tgt_ids.dtype != torch.long:
153
+ tgt_ids = tgt_ids.long()
154
+ tgt_ids = tgt_ids.to(self.device, non_blocking=True) # [B,T]
155
+ text_embeds = self.decoder.transformer.wte(tgt_ids) # [B,T,n_embd]
156
+ inputs_embeds = torch.cat([projected_patches, text_embeds], dim=1) # [B,Np+T,n_embd]
157
+
158
+ # Labels: ignore prefix tokens (vision) and PADs in text
159
+ B, Np, _ = projected_patches.shape
160
+ labels_prefix = torch.full((B, Np), -100, device=self.device, dtype=torch.long)
161
+ text_labels = tgt_ids.clone()
162
+ text_labels[text_labels == self.pad_token_id] = -100 # ✅ compare to ID
163
+ labels = torch.cat([labels_prefix, text_labels], dim=1) # [B,Np+T]
164
+ else:
165
+ inputs_embeds = projected_patches
166
+
167
+ # Decoder forward
168
+ out = self.decoder(inputs_embeds=inputs_embeds, segmentation_mask=segmented_layers, labels=labels, **kwargs)
169
+ return out
170
+
171
+ @torch.inference_mode()
172
+ def generate(
173
+ self,
174
+ pixel_values: torch.Tensor,
175
+ max_new_tokens: int = 100,
176
+ output_attentions: bool = False,
177
+ ) -> torch.Tensor:
178
+ """
179
+ pixel_values: [B,C,H,W], float
180
+ returns generated_ids: [B, T]
181
+ """
182
+ pixel_values = pixel_values.to(self.device, non_blocking=True)
183
+
184
+ # Visual path
185
+ patches = self.encoder(pixel_values) # [B,Np,Cenc]
186
+ projected_patches = self.linear_projection(patches) # [B,Np,n_embd]
187
+
188
+ # Segmentation path per layer
189
+ segmented_layers = self.segmenter(pixel_values, self.num_layers) # [B,n_layers,H,W] (per current decoder)
190
+
191
+ # Generate
192
+ output = self.decoder.generate(
193
+ inputs_embeds=projected_patches,
194
+ max_new_tokens=max_new_tokens,
195
+ do_sample=False,
196
+ repetition_penalty=1.2,
197
+ eos_token_id=self.tokenizer.eos_token_id,
198
+ pad_token_id=self.pad_token_id,
199
+ use_cache=True,
200
+ segmentation_mask=segmented_layers,
201
+ prefix_allowed_length=0,
202
+ plot_attention_mask=False,
203
+ plot_attention_mask_layer=[],
204
+ plot_attention_map=False,
205
+ plot_attention_map_layer=[],
206
+ plot_attention_map_generation=0,
207
+ output_attentions=output_attentions,
208
+ return_dict_in_generate=True,
209
+ )
210
+ # Remove prefix tokens (vision)
211
+ generated_ids = output.sequences#[:, projected_patches.shape[1]:] # [B,T]
212
+ generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
213
+ return generated_ids, generated_text, output.attentions if output_attentions else None
214
+
215
+ def create_complete_model(device: str = "cuda", **kwargs) -> CustomModel:
216
+ model = CustomModel(device=device, **kwargs)
217
+ return model
218
+
219
+ def save_complete_model(model: CustomModel, save_path: str, device: str = "cuda") -> None:
220
+ # Ensure folder exists
221
+ os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
222
+
223
+ # Save on CPU to keep checkpoint portable
224
+ orig_device = next(model.parameters()).device
225
+ model.to("cpu")
226
+ torch.save(model.state_dict(), save_path)
227
+ print(f"Saved complete model weights to {save_path}")
228
+
229
+ # Restore model device
230
+ model.to(device if isinstance(device, str) else orig_device)
231
+
232
+ def save_checkpoint(model: CustomModel, optimizer: torch.optim.Optimizer, save_path: str) -> None:
233
+ # Ensure folder exists
234
+ os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
235
+
236
+ checkpoint = {
237
+ "model_state_dict": model.state_dict(),
238
+ "optimizer_state_dict": optimizer.state_dict(),
239
+ }
240
+ torch.save(checkpoint, save_path)
241
+ print(f"Saved checkpoint to {save_path}")
242
+
243
+ def load_complete_model(model: CustomModel, load_path: str, device: str = "cpu", strict: bool = True) -> CustomModel:
244
+ if not os.path.exists(load_path):
245
+ print(f"No weights found at {load_path}")
246
+ model.to(device)
247
+ return model
248
+
249
+ # Load to CPU first, then move to target device
250
+ state = torch.load(load_path, map_location="cpu")
251
+ missing, unexpected = model.load_state_dict(state, strict=strict)
252
+ if not strict:
253
+ if missing:
254
+ print(f"[load warning] Missing keys: {missing}")
255
+ if unexpected:
256
+ print(f"[load warning] Unexpected keys: {unexpected}")
257
+
258
+ model.to(device)
259
+ print(f"Loaded complete model weights from {load_path}")
260
+ return model
261
+
262
+ def load_checkpoint(model: CustomModel, optimizer: torch.optim.Optimizer, load_path: str, device: str = "cpu") -> tuple[CustomModel, torch.optim.Optimizer]:
263
+ if not os.path.exists(load_path):
264
+ print(f"No checkpoint found at {load_path}")
265
+ model.to(device)
266
+ return model, optimizer
267
+
268
+ # Load to CPU first, then move to target device
269
+ checkpoint = torch.load(load_path, map_location="cpu")
270
+ model.load_state_dict(checkpoint["model_state_dict"])
271
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
272
+
273
+ model.to(device)
274
+ print(f"Loaded checkpoint from {load_path}")
275
+ return model, optimizer
276
+
277
+ from transformers import AutoImageProcessor
278
+ from PIL import Image
279
+ import logging
280
+ import re
281
+
282
+ # Configure basic logging
283
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
284
+ logger = logging.getLogger(__name__)
285
+
286
+ # ==============================================================================
287
+ # 1. Architecture Definition (MLP)
288
+ # ==============================================================================
289
+ class EmbeddingClassifier(nn.Module):
290
+ """
291
+ Flexible MLP Classifier: Input Embeddings -> Hidden Layers -> Logits.
292
+ """
293
+ def __init__(self, embedding_dim, num_classes, custom_dims=(512, 256, 256),
294
+ activation="gelu", dropout=0.05, bn=False, use_layernorm=True):
295
+ super().__init__()
296
+ layers = []
297
+
298
+ # First layer: Embeddings -> First hidden dimension
299
+ layers.append(nn.Linear(embedding_dim, custom_dims[0]))
300
+ if use_layernorm: layers.append(nn.LayerNorm(custom_dims[0]))
301
+ elif bn: layers.append(nn.BatchNorm1d(custom_dims[0]))
302
+ layers.append(nn.GELU() if activation.lower() == "gelu" else nn.ReLU())
303
+ if dropout > 0: layers.append(nn.Dropout(dropout))
304
+
305
+ # Intermediate layers
306
+ for i in range(len(custom_dims) - 1):
307
+ layers.append(nn.Linear(custom_dims[i], custom_dims[i + 1]))
308
+ if use_layernorm: layers.append(nn.LayerNorm(custom_dims[i + 1]))
309
+ elif bn: layers.append(nn.BatchNorm1d(custom_dims[i + 1]))
310
+ layers.append(nn.GELU() if activation.lower() == "gelu" else nn.ReLU())
311
+ if dropout > 0: layers.append(nn.Dropout(dropout))
312
+
313
+ # Final layer: Last hidden dim -> Num classes (Logits)
314
+ layers.append(nn.Linear(custom_dims[-1], num_classes))
315
+ self.classifier = nn.Sequential(*layers)
316
+
317
+ def forward(self, embeddings):
318
+ return self.classifier(embeddings)
319
+
320
+
321
+ # ==============================================================================
322
+ # 2. Prediction Wrapper Class
323
+ # ==============================================================================
324
+ class ChestXrayPredictor:
325
+ """
326
+ Wrapper class responsible for receiving an image, processing it,
327
+ and returning class probabilities.
328
+ """
329
+
330
+ def __init__(self, base_model, classifier, processor, label_cols, device):
331
+ self.base_model = base_model
332
+ self.classifier = classifier
333
+ self.processor = processor
334
+ self.label_cols = label_cols
335
+ self.device = device
336
+
337
+ # Ensure models are in eval mode
338
+ self.base_model.eval()
339
+ self.classifier.eval()
340
+
341
+ def predict(self, image_source):
342
+ """
343
+ Runs inference on a single image.
344
+
345
+ Args:
346
+ image_source: File path (str) or PIL.Image object.
347
+
348
+ Returns:
349
+ dict: { "Class_Name": probability (0.0 - 1.0) }
350
+ """
351
+ try:
352
+ # 1. Flexible Input Handling (Path or Object)
353
+ if isinstance(image_source, str):
354
+ image = Image.open(image_source).convert('RGB')
355
+ else:
356
+ image = image_source.convert('RGB')
357
+
358
+ # 2. Preprocessing
359
+ inputs = self.processor(images=image, return_tensors="pt")
360
+ pixel_values = inputs['pixel_values'].to(self.device)
361
+
362
+ # 3. Inference
363
+ with torch.no_grad():
364
+ # A. Get Embeddings from DINO
365
+ outputs = self.base_model(pixel_values=pixel_values)
366
+
367
+ # Handle different transformer output formats
368
+ if hasattr(outputs, 'last_hidden_state'):
369
+ embeddings = outputs.last_hidden_state.mean(dim=1)
370
+ else:
371
+ embeddings = outputs[0].mean(dim=1)
372
+
373
+ # B. Classify Embeddings
374
+ logits = self.classifier(embeddings)
375
+
376
+ # Convert to standard Python float list for JSON serialization
377
+ probs = torch.sigmoid(logits).cpu().numpy()[0].tolist()
378
+
379
+ # 4. Format Output
380
+ return {
381
+ label: round(prob, 4)
382
+ for label, prob in zip(self.label_cols, probs)
383
+ }
384
+
385
+ except Exception as e:
386
+ logger.error(f"Error predicting image: {e}")
387
+ return {"error": str(e)}
388
+
389
+
390
+ # ==============================================================================
391
+ # 3. Factory Function (The "Builder")
392
+ # ==============================================================================
393
+ def create_classifier(checkpoint_path, model_id="facebook/dinov3-vits16-pretrain-lvd1689m", device=None):
394
+ """
395
+ Loads the checkpoint, reconstructs the specific architecture,
396
+ and returns a ready-to-use ChestXrayPredictor instance.
397
+
398
+ Args:
399
+ checkpoint_path (str): Path to the .pth file.
400
+ model_id (str): HuggingFace model ID for DINO.
401
+ device (str, optional): 'cuda' or 'cpu'. Auto-detects if None.
402
+
403
+ Returns:
404
+ ChestXrayPredictor: Initialized object ready for prediction.
405
+ """
406
+ device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
407
+ logger.info(f"🔄 Starting model initialization on: {device}")
408
+
409
+ try:
410
+ # A. Load Checkpoint
411
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
412
+ label_cols = checkpoint.get('label_cols', ["Class_1", "Class_2"]) # Fallback
413
+
414
+ # B. Load Base Model (DINO)
415
+ logger.info("🤖 Loading DINO backbone...")
416
+ base_model = AutoModel.from_pretrained(model_id).to(device)
417
+
418
+ # Load fine-tuned DINO weights if they exist in checkpoint
419
+ if 'base_model_state_dict' in checkpoint:
420
+ base_model.load_state_dict(checkpoint['base_model_state_dict'])
421
+ logger.info(" - Fine-tuned DINO weights loaded from checkpoint.")
422
+ else:
423
+ logger.info(" - Using default pre-trained DINO weights.")
424
+
425
+ processor = AutoImageProcessor.from_pretrained(model_id)
426
+
427
+ # C. Detect Embedding Dimension
428
+ if hasattr(base_model.config, 'hidden_size'):
429
+ embedding_dim = base_model.config.hidden_size
430
+ else:
431
+ # Dummy inference to detect output size
432
+ with torch.no_grad():
433
+ dummy = torch.randn(1, 3, 224, 224).to(device)
434
+ out = base_model(pixel_values=dummy)
435
+ embedding_dim = out.last_hidden_state.shape[-1]
436
+
437
+ # D. Reconstruct Classifier Architecture
438
+ logger.info("🏗️ Reconstructing classifier architecture...")
439
+ model_state = checkpoint['model_state_dict']
440
+ classifier = _build_mlp_from_state(model_state, embedding_dim)
441
+
442
+ # Load classifier weights
443
+ classifier.load_state_dict(model_state)
444
+ classifier.to(device)
445
+
446
+ logger.info("✅ Model created successfully.")
447
+
448
+ # E. Return the Wrapper Instance
449
+ return ChestXrayPredictor(base_model, classifier, processor, label_cols, device)
450
+
451
+ except Exception as e:
452
+ logger.error(f"❌ Fatal error creating the classifier: {e}")
453
+ raise e
454
+
455
+
456
+ def _build_mlp_from_state(model_state, embedding_dim):
457
+ """
458
+ Private helper function to inspect state_dict and rebuild the MLP architecture.
459
+ """
460
+ linear_layers = []
461
+ for key, val in model_state.items():
462
+ # Look for 2D weights (Linear layers) inside the classifier
463
+ if 'classifier' in key and key.endswith('.weight') and len(val.shape) == 2:
464
+ match = re.search(r'classifier\.(\d+)\.weight', key)
465
+ if match:
466
+ layer_idx = int(match.group(1))
467
+ linear_layers.append((layer_idx, val.shape[1], val.shape[0])) # idx, in_features, out_features
468
+
469
+ if not linear_layers:
470
+ raise ValueError("No linear layers found in checkpoint. Check architecture.")
471
+
472
+ # Sort by layer index to ensure correct order
473
+ linear_layers.sort(key=lambda x: x[0])
474
+
475
+ num_classes = linear_layers[-1][2]
476
+ hidden_dims = tuple([x[2] for x in linear_layers[:-1]])
477
+
478
+ # Detect Normalization types
479
+ uses_bn = any('running_mean' in k for k in model_state.keys())
480
+ has_norm = any(k.endswith('.weight') and len(model_state[k].shape) == 1 for k in model_state.keys() if 'classifier' in k)
481
+ uses_layernorm = has_norm and not uses_bn
482
+
483
+ return EmbeddingClassifier(
484
+ embedding_dim=embedding_dim,
485
+ num_classes=num_classes,
486
+ custom_dims=hidden_dims,
487
+ bn=uses_bn,
488
+ use_layernorm=uses_layernorm
489
+ )
490
+