sagar007 commited on
Commit
05f961b
·
verified ·
1 Parent(s): 029a3b8

Upload src/models/lightning_module.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/models/lightning_module.py +237 -0
src/models/lightning_module.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch Lightning module for Multimodal Gemma training
3
+ """
4
+ import torch
5
+ import lightning as L
6
+ from typing import Dict, Any, Optional, List
7
+ from transformers import get_linear_schedule_with_warmup
8
+ import logging
9
+
10
+ from .multimodal_gemma import MultimodalGemma
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class MultimodalGemmaLightning(L.LightningModule):
16
+ """Lightning module for Multimodal Gemma training"""
17
+
18
+ def __init__(self, config: Dict[str, Any]):
19
+ super().__init__()
20
+ self.save_hyperparameters()
21
+ self.config = config
22
+
23
+ # Initialize model
24
+ self.model = MultimodalGemma(config)
25
+
26
+ # Training metrics tracking
27
+ self.training_step_outputs = []
28
+ self.validation_step_outputs = []
29
+
30
+ # Setup automatic optimization
31
+ self.automatic_optimization = True
32
+
33
+ logger.info("MultimodalGemmaLightning initialized")
34
+
35
+ def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
36
+ """Forward pass"""
37
+ return self.model(
38
+ input_ids=batch["input_ids"],
39
+ attention_mask=batch["attention_mask"],
40
+ images=batch.get("images"),
41
+ labels=batch["labels"]
42
+ )
43
+
44
+ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
45
+ """Training step"""
46
+ outputs = self(batch)
47
+ loss = outputs["loss"]
48
+
49
+ # Log metrics
50
+ self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
51
+ self.log("train/learning_rate", self.optimizers().param_groups[0]["lr"], on_step=True)
52
+
53
+ # Store outputs for epoch end
54
+ self.training_step_outputs.append(loss.detach())
55
+
56
+ return loss
57
+
58
+ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
59
+ """Validation step"""
60
+ outputs = self(batch)
61
+ loss = outputs["loss"]
62
+
63
+ # Log metrics
64
+ self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
65
+
66
+ # Store outputs for epoch end
67
+ self.validation_step_outputs.append(loss.detach())
68
+
69
+ return loss
70
+
71
+ def on_train_epoch_end(self) -> None:
72
+ """Called at the end of each training epoch"""
73
+ if self.training_step_outputs:
74
+ avg_loss = torch.stack(self.training_step_outputs).mean()
75
+ self.log("train/epoch_loss", avg_loss, prog_bar=False, sync_dist=True)
76
+ self.training_step_outputs.clear()
77
+
78
+ def on_validation_epoch_end(self) -> None:
79
+ """Called at the end of each validation epoch"""
80
+ if self.validation_step_outputs:
81
+ avg_loss = torch.stack(self.validation_step_outputs).mean()
82
+ self.log("val/epoch_loss", avg_loss, prog_bar=False, sync_dist=True)
83
+ self.validation_step_outputs.clear()
84
+
85
+ def configure_optimizers(self):
86
+ """Configure optimizer and scheduler"""
87
+ # Collect trainable parameters with different learning rates
88
+ param_groups = []
89
+
90
+ # Ensure learning rates are floats
91
+ projector_lr = float(self.config["training"]["projector_lr"])
92
+ lora_lr = float(self.config["training"]["lora_lr"])
93
+
94
+ # Vision projector parameters
95
+ vision_proj_params = list(self.model.vision_projector.parameters())
96
+ if vision_proj_params:
97
+ param_groups.append({
98
+ "params": vision_proj_params,
99
+ "lr": projector_lr,
100
+ "name": "vision_projector"
101
+ })
102
+
103
+ # Audio projector parameters (if enabled)
104
+ if hasattr(self.model, 'audio_projector'):
105
+ audio_proj_params = list(self.model.audio_projector.parameters())
106
+ if audio_proj_params:
107
+ param_groups.append({
108
+ "params": audio_proj_params,
109
+ "lr": projector_lr,
110
+ "name": "audio_projector"
111
+ })
112
+
113
+ # LoRA parameters from language model
114
+ lora_params = []
115
+ for name, param in self.model.language_model.named_parameters():
116
+ if param.requires_grad:
117
+ lora_params.append(param)
118
+
119
+ if lora_params:
120
+ param_groups.append({
121
+ "params": lora_params,
122
+ "lr": lora_lr,
123
+ "name": "lora_adapters"
124
+ })
125
+
126
+ if not param_groups:
127
+ raise ValueError("No trainable parameters found!")
128
+
129
+ # Log parameter counts
130
+ for group in param_groups:
131
+ param_count = sum(p.numel() for p in group["params"])
132
+ logger.info(f"{group['name']}: {param_count:,} parameters, lr={group['lr']}")
133
+
134
+ # Create optimizer
135
+ optimizer_class = torch.optim.AdamW
136
+ if self.config.get("optimization", {}).get("use_fused_adamw", False):
137
+ try:
138
+ optimizer_class = torch.optim.AdamW # Fused AdamW is default in recent PyTorch
139
+ except AttributeError:
140
+ logger.warning("Fused AdamW not available, using regular AdamW")
141
+
142
+ optimizer = optimizer_class(
143
+ param_groups,
144
+ weight_decay=self.config["training"]["weight_decay"],
145
+ eps=1e-8,
146
+ betas=(0.9, 0.999)
147
+ )
148
+
149
+ # Calculate total steps for scheduler
150
+ if self.trainer.datamodule is not None:
151
+ steps_per_epoch = len(self.trainer.datamodule.train_dataloader())
152
+ else:
153
+ # Fallback estimation
154
+ steps_per_epoch = self.config["training"].get("steps_per_epoch", 1000)
155
+
156
+ max_epochs = self.config["training"]["max_epochs"]
157
+ accumulate_grad_batches = self.config["training"].get("accumulate_grad_batches", 1)
158
+
159
+ total_steps = (steps_per_epoch // accumulate_grad_batches) * max_epochs
160
+ warmup_steps = int(total_steps * self.config["training"]["warmup_ratio"])
161
+
162
+ logger.info(f"Scheduler setup: {total_steps} total steps, {warmup_steps} warmup steps")
163
+
164
+ # Create scheduler
165
+ scheduler = get_linear_schedule_with_warmup(
166
+ optimizer,
167
+ num_warmup_steps=warmup_steps,
168
+ num_training_steps=total_steps
169
+ )
170
+
171
+ return {
172
+ "optimizer": optimizer,
173
+ "lr_scheduler": {
174
+ "scheduler": scheduler,
175
+ "interval": "step",
176
+ "frequency": 1,
177
+ "name": "learning_rate"
178
+ }
179
+ }
180
+
181
+ def lr_scheduler_step(self, scheduler, metric):
182
+ """Custom learning rate scheduler step"""
183
+ scheduler.step()
184
+
185
+ def on_before_optimizer_step(self, optimizer):
186
+ """Called before optimizer step"""
187
+ # Log gradient norms
188
+ if self.global_step % 100 == 0:
189
+ grad_norm = 0.0
190
+ param_count = 0
191
+ for param_group in optimizer.param_groups:
192
+ for param in param_group["params"]:
193
+ if param.grad is not None:
194
+ param_norm = param.grad.data.norm(2)
195
+ grad_norm += param_norm.item() ** 2
196
+ param_count += 1
197
+
198
+ if param_count > 0:
199
+ grad_norm = (grad_norm / param_count) ** 0.5
200
+ self.log("train/grad_norm", grad_norm, on_step=True, prog_bar=False)
201
+
202
+ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
203
+ """Called when saving checkpoint"""
204
+ # Save additional model components
205
+ checkpoint["model_config"] = self.config
206
+ checkpoint["tokenizer_vocab_size"] = len(self.model.tokenizer)
207
+
208
+ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
209
+ """Called when loading checkpoint"""
210
+ # Restore model configuration if needed
211
+ if "model_config" in checkpoint:
212
+ logger.info("Loaded model configuration from checkpoint")
213
+
214
+ def predict_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
215
+ """Prediction step for inference"""
216
+ outputs = self.model.generate(
217
+ input_ids=batch["input_ids"],
218
+ attention_mask=batch["attention_mask"],
219
+ images=batch.get("images"),
220
+ max_new_tokens=150,
221
+ temperature=0.7,
222
+ do_sample=True
223
+ )
224
+
225
+ # Decode generated text
226
+ generated_text = []
227
+ for i, output in enumerate(outputs):
228
+ # Remove input tokens from output
229
+ input_length = batch["input_ids"][i].shape[0]
230
+ generated_tokens = output[input_length:]
231
+ text = self.model.tokenizer.decode(generated_tokens, skip_special_tokens=True)
232
+ generated_text.append(text)
233
+
234
+ return {
235
+ "generated_text": generated_text,
236
+ "input_ids": batch["input_ids"],
237
+ }