sagar007 commited on
Commit
9fc75bd
·
verified ·
1 Parent(s): cb26575

Upload folder using huggingface_hub

Browse files
app.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio UI for Multimodal Gemma Model - Hugging Face Space Version
4
+ """
5
+ import sys
6
+ import torch
7
+ import gradio as gr
8
+ from pathlib import Path
9
+ from PIL import Image
10
+ import io
11
+ import time
12
+ import logging
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ # Model imports
16
+ from src.models import MultimodalGemmaLightning
17
+ from src.utils.config import load_config, merge_configs
18
+
19
+ # Global model variable
20
+ model = None
21
+ config = None
22
+
23
+ def download_and_load_model():
24
+ """Download and load the trained multimodal model from HF"""
25
+ global model, config
26
+
27
+ if model is not None:
28
+ return "✅ Model already loaded!"
29
+
30
+ try:
31
+ print("🔄 Downloading multimodal Gemma model from HF...")
32
+
33
+ # Download model checkpoint
34
+ checkpoint_path = hf_hub_download(
35
+ repo_id="sagar007/multimodal-gemma-270m-llava",
36
+ filename="final_model.ckpt",
37
+ cache_dir="./model_cache"
38
+ )
39
+
40
+ # Download config files
41
+ model_config_path = hf_hub_download(
42
+ repo_id="sagar007/multimodal-gemma-270m-llava",
43
+ filename="configs/model_config.yaml",
44
+ cache_dir="./model_cache"
45
+ )
46
+ training_config_path = hf_hub_download(
47
+ repo_id="sagar007/multimodal-gemma-270m-llava",
48
+ filename="configs/training_config.yaml",
49
+ cache_dir="./model_cache"
50
+ )
51
+ data_config_path = hf_hub_download(
52
+ repo_id="sagar007/multimodal-gemma-270m-llava",
53
+ filename="configs/data_config.yaml",
54
+ cache_dir="./model_cache"
55
+ )
56
+
57
+ # Load configs
58
+ model_config = load_config(model_config_path)
59
+ training_config = load_config(training_config_path)
60
+ data_config = load_config(data_config_path)
61
+ config = merge_configs([model_config, training_config, data_config])
62
+
63
+ print("📁 Loading model from checkpoint...")
64
+ model = MultimodalGemmaLightning.load_from_checkpoint(
65
+ checkpoint_path,
66
+ config=config,
67
+ strict=False,
68
+ map_location="cuda" if torch.cuda.is_available() else "cpu"
69
+ )
70
+ model.eval()
71
+
72
+ # Move to appropriate device
73
+ device = "cuda" if torch.cuda.is_available() else "cpu"
74
+ model = model.to(device)
75
+
76
+ print(f"✅ Model loaded successfully on {device}!")
77
+ return f"✅ Model loaded successfully on {device}!"
78
+
79
+ except Exception as e:
80
+ error_msg = f"❌ Error loading model: {str(e)}"
81
+ print(error_msg)
82
+ return error_msg
83
+
84
+ def predict_with_image(image, question, max_tokens=100, temperature=0.7):
85
+ """Generate response for image + text input"""
86
+ global model, config
87
+
88
+ if model is None:
89
+ return "❌ Please load the model first using the 'Load Model' button!"
90
+
91
+ if image is None:
92
+ return "❌ Please upload an image!"
93
+
94
+ if not question.strip():
95
+ question = "What do you see in this image?"
96
+
97
+ try:
98
+ # Get device
99
+ device = next(model.parameters()).device
100
+
101
+ # Process image
102
+ if isinstance(image, str):
103
+ image = Image.open(image).convert('RGB')
104
+ elif not isinstance(image, Image.Image):
105
+ image = Image.fromarray(image).convert('RGB')
106
+
107
+ # Prepare image for model
108
+ vision_inputs = model.model.vision_processor(
109
+ images=[image],
110
+ return_tensors="pt"
111
+ )
112
+ pixel_values = vision_inputs["pixel_values"].to(device)
113
+
114
+ # Prepare text prompt
115
+ prompt = f"<image>\\nHuman: {question}\\nAssistant:"
116
+
117
+ # Tokenize text
118
+ text_inputs = model.model.tokenizer(
119
+ prompt,
120
+ return_tensors="pt",
121
+ padding=True,
122
+ truncation=True,
123
+ max_length=256
124
+ )
125
+
126
+ input_ids = text_inputs["input_ids"].to(device)
127
+ attention_mask = text_inputs["attention_mask"].to(device)
128
+
129
+ # Generate response
130
+ with torch.no_grad():
131
+ # Use the full multimodal model with image inputs
132
+ outputs = model.model.generate(
133
+ input_ids=input_ids,
134
+ attention_mask=attention_mask,
135
+ images=pixel_values,
136
+ max_new_tokens=min(max_tokens, 150),
137
+ temperature=min(max(temperature, 0.1), 2.0),
138
+ do_sample=temperature > 0.1,
139
+ repetition_penalty=1.1
140
+ )
141
+
142
+ # Decode response
143
+ input_length = input_ids.shape[1]
144
+ generated_tokens = outputs[0][input_length:]
145
+ response = model.model.tokenizer.decode(generated_tokens, skip_special_tokens=True)
146
+
147
+ # Clean up response
148
+ response = response.strip()
149
+ if not response:
150
+ response = "I can see the image, but I'm having trouble generating a detailed response."
151
+
152
+ return response
153
+
154
+ except Exception as e:
155
+ error_msg = f"❌ Error during inference: {str(e)}"
156
+ print(error_msg)
157
+ return error_msg
158
+
159
+ def chat_with_image(image, question, history, max_tokens, temperature):
160
+ """Chat interface function"""
161
+ if model is None:
162
+ response = "❌ Please load the model first!"
163
+ else:
164
+ response = predict_with_image(image, question, max_tokens, temperature)
165
+
166
+ # Add to history - using messages format
167
+ history.append({"role": "user", "content": question})
168
+ history.append({"role": "assistant", "content": response})
169
+ return history, ""
170
+
171
+ def create_gradio_interface():
172
+ """Create the Gradio interface"""
173
+
174
+ # Custom CSS for better styling
175
+ css = """
176
+ .container {
177
+ max-width: 1200px;
178
+ margin: auto;
179
+ padding: 20px;
180
+ }
181
+ .header {
182
+ text-align: center;
183
+ margin-bottom: 30px;
184
+ }
185
+ .model-info {
186
+ background-color: #f0f8ff;
187
+ padding: 15px;
188
+ border-radius: 10px;
189
+ margin-bottom: 20px;
190
+ }
191
+ """
192
+
193
+ with gr.Blocks(css=css, title="Multimodal Gemma Chat") as demo:
194
+ gr.HTML("""
195
+ <div class="header">
196
+ <h1>🎉 Multimodal Gemma-270M Chat</h1>
197
+ <p>Upload an image and chat with your trained vision-language model!</p>
198
+ <p><a href="https://huggingface.co/sagar007/multimodal-gemma-270m-llava">🤗 Model</a></p>
199
+ </div>
200
+ """)
201
+
202
+ # Model status section
203
+ with gr.Row():
204
+ with gr.Column():
205
+ gr.HTML("""
206
+ <div class="model-info">
207
+ <h3>📊 Model Info</h3>
208
+ <ul>
209
+ <li><strong>Base Model:</strong> Google Gemma-270M</li>
210
+ <li><strong>Vision:</strong> CLIP ViT-Large</li>
211
+ <li><strong>Training:</strong> LLaVA-150K + COCO Images</li>
212
+ <li><strong>Parameters:</strong> 18.6M trainable / 539M total</li>
213
+ </ul>
214
+ </div>
215
+ """)
216
+
217
+ # Model loading
218
+ load_btn = gr.Button("🚀 Load Model", variant="primary", size="lg")
219
+ model_status = gr.Textbox(
220
+ label="Model Status",
221
+ value="Click 'Load Model' to start",
222
+ interactive=False
223
+ )
224
+
225
+ gr.HTML("<hr>")
226
+
227
+ # Main interface
228
+ with gr.Row():
229
+ # Left column - Image and controls
230
+ with gr.Column(scale=1):
231
+ image_input = gr.Image(
232
+ label="📸 Upload Image",
233
+ type="pil",
234
+ height=300
235
+ )
236
+
237
+ # Example images
238
+ gr.HTML("<p><strong>💡 Tip:</strong> Upload any image and ask questions about it</p>")
239
+
240
+ # Generation settings
241
+ with gr.Accordion("⚙️ Generation Settings", open=False):
242
+ max_tokens = gr.Slider(
243
+ minimum=10,
244
+ maximum=200,
245
+ value=100,
246
+ step=10,
247
+ label="Max Tokens"
248
+ )
249
+ temperature = gr.Slider(
250
+ minimum=0.1,
251
+ maximum=2.0,
252
+ value=0.7,
253
+ step=0.1,
254
+ label="Temperature"
255
+ )
256
+
257
+ # Right column - Chat interface
258
+ with gr.Column(scale=2):
259
+ chatbot = gr.Chatbot(
260
+ label="💬 Chat with Image",
261
+ height=400,
262
+ show_label=True,
263
+ type="messages"
264
+ )
265
+
266
+ question_input = gr.Textbox(
267
+ label="❓ Ask about the image",
268
+ placeholder="What do you see in this image?",
269
+ lines=2
270
+ )
271
+
272
+ with gr.Row():
273
+ submit_btn = gr.Button("💬 Send", variant="primary")
274
+ clear_btn = gr.Button("🗑️ Clear Chat")
275
+
276
+ # Example prompts
277
+ with gr.Row():
278
+ gr.HTML("<h3>💡 Example Questions:</h3>")
279
+
280
+ with gr.Row():
281
+ example_questions = [
282
+ "What do you see in this image?",
283
+ "Describe the main objects in the picture.",
284
+ "What colors are prominent in this image?",
285
+ "Are there any people in the image?",
286
+ "What's the setting or location?",
287
+ "What objects are in the foreground?"
288
+ ]
289
+
290
+ for i, question in enumerate(example_questions):
291
+ if i % 3 == 0:
292
+ with gr.Row():
293
+ pass
294
+ gr.Button(
295
+ question,
296
+ size="sm"
297
+ ).click(
298
+ lambda x=question: x,
299
+ outputs=question_input
300
+ )
301
+
302
+ # Footer
303
+ gr.HTML("""
304
+ <hr>
305
+ <div style="text-align: center; margin-top: 20px;">
306
+ <p><strong>🎯 Your Multimodal Gemma Model</strong></p>
307
+ <p>Text-only → Vision-Language Model using LLaVA Architecture</p>
308
+ <p>Model: <a href="https://huggingface.co/sagar007/multimodal-gemma-270m-llava">sagar007/multimodal-gemma-270m-llava</a></p>
309
+ </div>
310
+ """)
311
+
312
+ # Event handlers
313
+ load_btn.click(
314
+ fn=download_and_load_model,
315
+ outputs=model_status
316
+ )
317
+
318
+ submit_btn.click(
319
+ fn=chat_with_image,
320
+ inputs=[image_input, question_input, chatbot, max_tokens, temperature],
321
+ outputs=[chatbot, question_input]
322
+ )
323
+
324
+ question_input.submit(
325
+ fn=chat_with_image,
326
+ inputs=[image_input, question_input, chatbot, max_tokens, temperature],
327
+ outputs=[chatbot, question_input]
328
+ )
329
+
330
+ clear_btn.click(
331
+ fn=lambda: ([], ""),
332
+ outputs=[chatbot, question_input]
333
+ )
334
+
335
+ return demo
336
+
337
+ def main():
338
+ """Main function to launch the Gradio app"""
339
+ print("🚀 Starting Multimodal Gemma Gradio Space...")
340
+
341
+ # Create interface
342
+ demo = create_gradio_interface()
343
+
344
+ # Launch
345
+ print("🌐 Launching Gradio interface...")
346
+ demo.launch()
347
+
348
+ if __name__ == "__main__":
349
+ main()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision
3
+ transformers>=4.36.0
4
+ accelerate
5
+ bitsandbytes
6
+ peft>=0.6.0
7
+ lightning>=2.0.0
8
+ gradio>=4.0.0
9
+ pillow
10
+ huggingface-hub
11
+ pyyaml
12
+ omegaconf
src/__init__.py ADDED
File without changes
src/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .multimodal_gemma import MultimodalGemma
2
+ from .lightning_module import MultimodalGemmaLightning
3
+ from .projectors import VisionProjector, AudioProjector
4
+
5
+ __all__ = [
6
+ "MultimodalGemma",
7
+ "MultimodalGemmaLightning",
8
+ "VisionProjector",
9
+ "AudioProjector"
10
+ ]
src/models/multimodal_gemma.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multimodal Gemma model implementation
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import (
7
+ AutoTokenizer,
8
+ AutoModelForCausalLM,
9
+ CLIPVisionModel,
10
+ CLIPProcessor,
11
+ BitsAndBytesConfig
12
+ )
13
+ from peft import LoraConfig, get_peft_model, TaskType
14
+ from typing import Dict, Any, Optional, Tuple
15
+ import logging
16
+
17
+ from .projectors import VisionProjector
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class MultimodalGemma(nn.Module):
23
+ """Multimodal Gemma model with vision and audio capabilities"""
24
+
25
+ def __init__(self, config: Dict[str, Any]):
26
+ super().__init__()
27
+ self.config = config
28
+
29
+ # Initialize tokenizer first
30
+ self._setup_tokenizer()
31
+
32
+ # Initialize language model
33
+ self._setup_language_model()
34
+
35
+ # Initialize vision components
36
+ self._setup_vision_components()
37
+
38
+ # Initialize projectors
39
+ self._setup_projectors()
40
+
41
+ # Freeze encoders
42
+ self._freeze_encoders()
43
+
44
+ # Setup LoRA
45
+ self._setup_lora()
46
+
47
+ logger.info("MultimodalGemma model initialized successfully")
48
+
49
+ # Move projectors to the same device as the language model
50
+ self._move_to_device()
51
+
52
+ def _setup_tokenizer(self):
53
+ """Initialize and configure tokenizer"""
54
+ model_name = self.config["model"]["gemma_model_name"]
55
+ self.tokenizer = AutoTokenizer.from_pretrained(
56
+ model_name,
57
+ trust_remote_code=True,
58
+ use_fast=True
59
+ )
60
+
61
+ # Set padding token
62
+ if self.tokenizer.pad_token is None:
63
+ self.tokenizer.pad_token = self.tokenizer.eos_token
64
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
65
+
66
+ # Add special tokens
67
+ special_tokens = self.config.get("special_tokens", {})
68
+ new_tokens = []
69
+
70
+ for token_name, token_value in special_tokens.items():
71
+ if token_value not in self.tokenizer.get_vocab():
72
+ new_tokens.append(token_value)
73
+
74
+ if new_tokens:
75
+ self.tokenizer.add_special_tokens({"additional_special_tokens": new_tokens})
76
+ logger.info(f"Added special tokens: {new_tokens}")
77
+
78
+ def _setup_language_model(self):
79
+ """Initialize language model with quantization if specified"""
80
+ model_name = self.config["model"]["gemma_model_name"]
81
+
82
+ # Setup quantization config
83
+ quantization_config = None
84
+ if self.config["model"].get("use_4bit", False):
85
+ quantization_config = BitsAndBytesConfig(
86
+ load_in_4bit=True,
87
+ bnb_4bit_compute_dtype=getattr(torch, self.config["model"]["bnb_4bit_compute_dtype"]),
88
+ bnb_4bit_quant_type=self.config["model"]["bnb_4bit_quant_type"],
89
+ bnb_4bit_use_double_quant=self.config["model"]["use_nested_quant"]
90
+ )
91
+
92
+ self.language_model = AutoModelForCausalLM.from_pretrained(
93
+ model_name,
94
+ quantization_config=quantization_config,
95
+ torch_dtype=torch.bfloat16,
96
+ device_map=None, # Lightning handles device placement
97
+ trust_remote_code=True,
98
+ attn_implementation="eager" # Use eager attention (flash_attn not required)
99
+ )
100
+
101
+ # Resize embeddings if we added special tokens
102
+ if len(self.tokenizer) > self.language_model.config.vocab_size:
103
+ self.language_model.resize_token_embeddings(len(self.tokenizer))
104
+ logger.info(f"Resized embeddings to {len(self.tokenizer)}")
105
+
106
+ # Store image token ID for later use
107
+ self.image_token_id = self.tokenizer.convert_tokens_to_ids(
108
+ self.config.get("special_tokens", {}).get("image_token", "<image>")
109
+ )
110
+
111
+ def _setup_vision_components(self):
112
+ """Initialize vision encoder and processor"""
113
+ vision_model_name = self.config["model"]["vision_model_name"]
114
+
115
+ self.vision_encoder = CLIPVisionModel.from_pretrained(
116
+ vision_model_name,
117
+ torch_dtype=torch.bfloat16
118
+ )
119
+ self.vision_processor = CLIPProcessor.from_pretrained(vision_model_name)
120
+
121
+ logger.info(f"Loaded vision model: {vision_model_name}")
122
+
123
+
124
+ def _setup_projectors(self):
125
+ """Initialize projection layers"""
126
+ vision_dim = self.vision_encoder.config.hidden_size
127
+ language_dim = self.language_model.config.hidden_size
128
+
129
+ # Vision projector
130
+ self.vision_projector = VisionProjector(
131
+ vision_dim=vision_dim,
132
+ language_dim=language_dim,
133
+ hidden_dim=self.config["model"].get("projector_hidden_dim", language_dim)
134
+ ).to(torch.bfloat16) # Match the model dtype
135
+
136
+ logger.info("Initialized vision projection layer")
137
+
138
+ def _freeze_encoders(self):
139
+ """Freeze vision encoder"""
140
+ # Freeze vision encoder
141
+ for param in self.vision_encoder.parameters():
142
+ param.requires_grad = False
143
+
144
+ logger.info("Froze vision encoder parameters")
145
+
146
+ def _setup_lora(self):
147
+ """Setup LoRA for the language model"""
148
+ lora_config = LoraConfig(
149
+ r=self.config["model"]["lora"]["r"],
150
+ lora_alpha=self.config["model"]["lora"]["alpha"],
151
+ target_modules=self.config["model"]["lora"]["target_modules"],
152
+ lora_dropout=self.config["model"]["lora"]["dropout"],
153
+ bias="none",
154
+ task_type=TaskType.CAUSAL_LM,
155
+ )
156
+
157
+ self.language_model = get_peft_model(self.language_model, lora_config)
158
+ self.language_model.print_trainable_parameters()
159
+
160
+ logger.info("Setup LoRA adapters")
161
+
162
+ def _move_to_device(self):
163
+ """Move all components to the same device as the language model"""
164
+ device = next(self.language_model.parameters()).device
165
+
166
+ # Move vision components
167
+ self.vision_encoder = self.vision_encoder.to(device)
168
+ self.vision_projector = self.vision_projector.to(device)
169
+
170
+ logger.info(f"Moved vision components to device: {device}")
171
+
172
+ def encode_images(self, images: torch.Tensor) -> torch.Tensor:
173
+ """
174
+ Encode images using CLIP and project to language space
175
+
176
+ Args:
177
+ images: [batch_size, 3, height, width]
178
+ Returns:
179
+ projected_features: [batch_size, language_dim]
180
+ """
181
+ with torch.no_grad():
182
+ vision_outputs = self.vision_encoder(pixel_values=images)
183
+ # Use the pooled output (CLS token equivalent)
184
+ image_features = vision_outputs.pooler_output
185
+
186
+ # Project to language model space
187
+ projected_features = self.vision_projector(image_features)
188
+ return projected_features
189
+
190
+
191
+ def forward(
192
+ self,
193
+ input_ids: torch.Tensor,
194
+ attention_mask: Optional[torch.Tensor] = None,
195
+ images: Optional[torch.Tensor] = None,
196
+ labels: Optional[torch.Tensor] = None,
197
+ ) -> Dict[str, torch.Tensor]:
198
+ """
199
+ Forward pass with multimodal inputs
200
+
201
+ Args:
202
+ input_ids: [batch_size, seq_len]
203
+ attention_mask: [batch_size, seq_len]
204
+ images: [batch_size, 3, height, width] or None
205
+ labels: [batch_size, seq_len] or None
206
+
207
+ Returns:
208
+ Dictionary with loss and logits
209
+ """
210
+ if images is not None:
211
+ # Encode images and project to language space
212
+ image_features = self.encode_images(images) # [batch_size, language_dim]
213
+
214
+ # Replace <image> tokens with actual image features
215
+ input_embeds, attention_mask, labels = self._merge_image_features(
216
+ input_ids, image_features, attention_mask, labels
217
+ )
218
+
219
+ # Forward through language model with merged embeddings
220
+ outputs = self.language_model(
221
+ inputs_embeds=input_embeds,
222
+ attention_mask=attention_mask,
223
+ labels=labels,
224
+ )
225
+ else:
226
+ # Standard text-only forward pass
227
+ outputs = self.language_model(
228
+ input_ids=input_ids,
229
+ attention_mask=attention_mask,
230
+ labels=labels,
231
+ )
232
+
233
+ return {
234
+ "loss": outputs.loss,
235
+ "logits": outputs.logits,
236
+ }
237
+
238
+ def _merge_image_features(
239
+ self,
240
+ input_ids: torch.Tensor,
241
+ image_features: torch.Tensor,
242
+ attention_mask: Optional[torch.Tensor] = None,
243
+ labels: Optional[torch.Tensor] = None,
244
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
245
+ """
246
+ Merge image features with text embeddings at <image> token positions
247
+
248
+ Args:
249
+ input_ids: [batch_size, seq_len]
250
+ image_features: [batch_size, language_dim]
251
+ attention_mask: [batch_size, seq_len]
252
+ labels: [batch_size, seq_len]
253
+
254
+ Returns:
255
+ input_embeds: [batch_size, seq_len, hidden_size]
256
+ attention_mask: [batch_size, seq_len]
257
+ labels: [batch_size, seq_len]
258
+ """
259
+ batch_size, seq_len = input_ids.shape
260
+
261
+ # Get text embeddings
262
+ text_embeds = self.language_model.get_input_embeddings()(input_ids)
263
+
264
+ # Find positions of <image> tokens
265
+ image_token_mask = (input_ids == self.image_token_id)
266
+
267
+ # Replace <image> token embeddings with projected image features
268
+ for batch_idx in range(batch_size):
269
+ image_positions = torch.where(image_token_mask[batch_idx])[0]
270
+
271
+ if len(image_positions) > 0:
272
+ # Use the first <image> token position (assuming one image per sample)
273
+ img_pos = image_positions[0]
274
+ text_embeds[batch_idx, img_pos] = image_features[batch_idx]
275
+
276
+ return text_embeds, attention_mask, labels
277
+
278
+ def generate(
279
+ self,
280
+ input_ids: torch.Tensor,
281
+ attention_mask: Optional[torch.Tensor] = None,
282
+ images: Optional[torch.Tensor] = None,
283
+ max_new_tokens: int = 150,
284
+ temperature: float = 0.7,
285
+ do_sample: bool = True,
286
+ **kwargs
287
+ ) -> torch.Tensor:
288
+ """Generate text with multimodal context"""
289
+
290
+ if images is not None:
291
+ # Encode images and merge with text embeddings
292
+ image_features = self.encode_images(images)
293
+ input_embeds, attention_mask, _ = self._merge_image_features(
294
+ input_ids, image_features, attention_mask, None
295
+ )
296
+
297
+ # Generate using language model with merged embeddings
298
+ with torch.no_grad():
299
+ outputs = self.language_model.generate(
300
+ inputs_embeds=input_embeds,
301
+ attention_mask=attention_mask,
302
+ max_new_tokens=max_new_tokens,
303
+ temperature=temperature,
304
+ do_sample=do_sample,
305
+ pad_token_id=self.tokenizer.pad_token_id,
306
+ eos_token_id=self.tokenizer.eos_token_id,
307
+ **kwargs
308
+ )
309
+ else:
310
+ # Standard text-only generation
311
+ with torch.no_grad():
312
+ outputs = self.language_model.generate(
313
+ input_ids=input_ids,
314
+ attention_mask=attention_mask,
315
+ max_new_tokens=max_new_tokens,
316
+ temperature=temperature,
317
+ do_sample=do_sample,
318
+ pad_token_id=self.tokenizer.pad_token_id,
319
+ eos_token_id=self.tokenizer.eos_token_id,
320
+ **kwargs
321
+ )
322
+
323
+ return outputs
src/utils/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .config import load_config, merge_configs
2
+ from .logging import setup_logging
3
+
4
+ __all__ = [
5
+ "load_config",
6
+ "merge_configs",
7
+ "setup_logging"
8
+ ]
src/utils/config.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration utilities
3
+ """
4
+ import yaml
5
+ from pathlib import Path
6
+ from typing import Dict, Any, List, Union
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def load_config(config_path: Union[str, Path]) -> Dict[str, Any]:
13
+ """Load configuration from YAML file"""
14
+ config_path = Path(config_path)
15
+
16
+ if not config_path.exists():
17
+ raise FileNotFoundError(f"Configuration file not found: {config_path}")
18
+
19
+ try:
20
+ with open(config_path, 'r', encoding='utf-8') as file:
21
+ config = yaml.safe_load(file)
22
+
23
+ logger.info(f"Loaded configuration from: {config_path}")
24
+ return config
25
+
26
+ except Exception as e:
27
+ logger.error(f"Failed to load configuration from {config_path}: {e}")
28
+ raise
29
+
30
+
31
+ def merge_configs(configs: List[Dict[str, Any]]) -> Dict[str, Any]:
32
+ """Merge multiple configuration dictionaries"""
33
+ merged = {}
34
+
35
+ for config in configs:
36
+ merged.update(config)
37
+
38
+ logger.info(f"Merged {len(configs)} configuration files")
39
+ return merged
40
+
41
+
42
+ def save_config(config: Dict[str, Any], save_path: Union[str, Path]) -> None:
43
+ """Save configuration to YAML file"""
44
+ save_path = Path(save_path)
45
+ save_path.parent.mkdir(parents=True, exist_ok=True)
46
+
47
+ try:
48
+ with open(save_path, 'w', encoding='utf-8') as file:
49
+ yaml.dump(config, file, default_flow_style=False, indent=2)
50
+
51
+ logger.info(f"Saved configuration to: {save_path}")
52
+
53
+ except Exception as e:
54
+ logger.error(f"Failed to save configuration to {save_path}: {e}")
55
+ raise
56
+
57
+
58
+ def validate_config(config: Dict[str, Any]) -> bool:
59
+ """Validate configuration structure"""
60
+ required_sections = ["model", "training", "data"]
61
+
62
+ for section in required_sections:
63
+ if section not in config:
64
+ logger.error(f"Missing required configuration section: {section}")
65
+ return False
66
+
67
+ # Validate model config
68
+ model_config = config["model"]
69
+ required_model_keys = ["gemma_model_name", "vision_model_name", "lora"]
70
+ for key in required_model_keys:
71
+ if key not in model_config:
72
+ logger.error(f"Missing required model config key: {key}")
73
+ return False
74
+
75
+ # Validate training config
76
+ training_config = config["training"]
77
+ required_training_keys = ["max_epochs", "batch_size", "lora_lr", "projector_lr"]
78
+ for key in required_training_keys:
79
+ if key not in training_config:
80
+ logger.error(f"Missing required training config key: {key}")
81
+ return False
82
+
83
+ # Validate data config
84
+ data_config = config["data"]
85
+ required_data_keys = ["dataset_name", "max_length", "image_size"]
86
+ for key in required_data_keys:
87
+ if key not in data_config:
88
+ logger.error(f"Missing required data config key: {key}")
89
+ return False
90
+
91
+ logger.info("Configuration validation passed")
92
+ return True
93
+
94
+
95
+ def update_config(config: Dict[str, Any], updates: Dict[str, Any]) -> Dict[str, Any]:
96
+ """Update configuration with new values"""
97
+ def deep_update(base_dict, update_dict):
98
+ """Recursively update nested dictionaries"""
99
+ for key, value in update_dict.items():
100
+ if isinstance(value, dict) and key in base_dict and isinstance(base_dict[key], dict):
101
+ deep_update(base_dict[key], value)
102
+ else:
103
+ base_dict[key] = value
104
+
105
+ import copy
106
+ updated_config = copy.deepcopy(config)
107
+ deep_update(updated_config, updates)
108
+
109
+ logger.info("Configuration updated")
110
+ return updated_config