Samin7479 commited on
Commit
4090512
·
1 Parent(s): 22d43f9

Initial commit: EN-BN Translation Project

Browse files
Files changed (3) hide show
  1. TESTAPI.py +57 -0
  2. app.py +1 -2
  3. project_2_mt_en_bn.ipynb +1207 -0
TESTAPI.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import os
4
+ import requests
5
+
6
+ BASE = os.getenv("ENBN_API_URL", "https://samin7479-en-bn-translator.hf.space")
7
+ HEADERS = {"Content-Type": "application/json"}
8
+
9
+ def greet():
10
+ try:
11
+ r = requests.get(f"{BASE}/greet", headers=HEADERS, timeout=20)
12
+ r.raise_for_status()
13
+ return r.json()
14
+ except Exception as e:
15
+ return {"error": str(e)}
16
+
17
+ def translate(text, max_new_tokens=128, num_beams=4):
18
+ try:
19
+ payload = {
20
+ "text": text,
21
+ "max_new_tokens": max_new_tokens,
22
+ "num_beams": num_beams,
23
+ "do_sample": False
24
+ }
25
+ r = requests.post(f"{BASE}/translate", json=payload, headers=HEADERS, timeout=60)
26
+ r.raise_for_status()
27
+ return r.json().get("translation")
28
+ except Exception as e:
29
+ return f"[error] {e}"
30
+
31
+ def translate_batch(texts, max_new_tokens=128, num_beams=4):
32
+ try:
33
+ payload = {
34
+ "texts": texts,
35
+ "max_new_tokens": max_new_tokens,
36
+ "num_beams": num_beams,
37
+ "do_sample": False
38
+ }
39
+ r = requests.post(f"{BASE}/translate_batch", json=payload, headers=HEADERS, timeout=120)
40
+ r.raise_for_status()
41
+ return r.json().get("translations", [])
42
+ except Exception as e:
43
+ return [f"[error] {e}"]
44
+
45
+ if __name__ == "__main__":
46
+ # quick smoke test
47
+ print("GREET:", greet())
48
+
49
+ en = "How are you today?"
50
+ bn = translate(en)
51
+ print(f"\nSingle:\nEN: {en}\nBN: {bn}")
52
+
53
+ batch = ["Good morning", "Where is the hospital?", "The weather is nice."]
54
+ outs = translate_batch(batch)
55
+ print("\nBatch:")
56
+ for e, b in zip(batch, outs):
57
+ print(f"EN: {e}\nBN: {b}\n")
app.py CHANGED
@@ -24,8 +24,7 @@ try:
24
  except Exception as e:
25
  raise RuntimeError(f"Failed to load model/tokenizer '{mt_pretrained_model_name}': {e}")
26
 
27
- # Optional: be gentle on CPU-only machines
28
- torch.set_num_threads(max(1, (os.cpu_count() or 1)))
29
 
30
  # -------------------------
31
  # FastAPI app + (optional) CORS
 
24
  except Exception as e:
25
  raise RuntimeError(f"Failed to load model/tokenizer '{mt_pretrained_model_name}': {e}")
26
 
27
+
 
28
 
29
  # -------------------------
30
  # FastAPI app + (optional) CORS
project_2_mt_en_bn.ipynb ADDED
@@ -0,0 +1,1207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "id": "initial_id",
6
+ "metadata": {
7
+ "collapsed": true,
8
+ "id": "initial_id",
9
+ "executionInfo": {
10
+ "status": "error",
11
+ "timestamp": 1757400199832,
12
+ "user_tz": -360,
13
+ "elapsed": 136,
14
+ "user": {
15
+ "displayName": "KARABI KUMARI MEDHA 1604062",
16
+ "userId": "02676772162340716864"
17
+ }
18
+ },
19
+ "outputId": "4ca67866-ac7a-4f35-9d5f-11d09460a5ef",
20
+ "colab": {
21
+ "base_uri": "https://localhost:8080/",
22
+ "height": 383
23
+ },
24
+ "ExecuteTime": {
25
+ "end_time": "2025-09-14T07:00:59.517452Z",
26
+ "start_time": "2025-09-14T07:00:50.959173Z"
27
+ }
28
+ },
29
+ "source": [
30
+ "from typing import Any\n",
31
+ "\n",
32
+ "from pytorch_lightning.utilities.types import STEP_OUTPUT\n",
33
+ "\n",
34
+ "\"\"\" Class 25 | Project 2 | Machine Translation using Pretrained Model\n",
35
+ "\n",
36
+ "Objectives:\n",
37
+ "1. End-to-end machine translation training pipeline\n",
38
+ "2. Fine-tune a pre-trained model for the custom dataset\n",
39
+ "\"\"\"\n",
40
+ "\n",
41
+ "import pytorch_lightning as pl\n",
42
+ "import torch\n",
43
+ "import torch.nn as nn\n",
44
+ "from torch.utils.data import Dataset, DataLoader\n",
45
+ "import pandas as pd\n",
46
+ "from torchmetrics.text import BLEUScore\n",
47
+ "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM"
48
+ ],
49
+ "outputs": [],
50
+ "execution_count": 2
51
+ },
52
+ {
53
+ "metadata": {
54
+ "id": "cd6712aad1b548d7",
55
+ "outputId": "824c43b6-2d5e-480e-da6d-35338820f1fe",
56
+ "colab": {
57
+ "base_uri": "https://localhost:8080/",
58
+ "height": 159
59
+ },
60
+ "executionInfo": {
61
+ "status": "error",
62
+ "timestamp": 1757187729883,
63
+ "user_tz": -360,
64
+ "elapsed": 187,
65
+ "user": {
66
+ "displayName": "Chironjit Banerjee",
67
+ "userId": "04428016465669976257"
68
+ }
69
+ },
70
+ "ExecuteTime": {
71
+ "end_time": "2025-09-14T07:00:59.576987Z",
72
+ "start_time": "2025-09-14T07:00:59.523970Z"
73
+ }
74
+ },
75
+ "cell_type": "code",
76
+ "source": [
77
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
78
+ "device"
79
+ ],
80
+ "id": "cd6712aad1b548d7",
81
+ "outputs": [
82
+ {
83
+ "data": {
84
+ "text/plain": [
85
+ "device(type='cuda')"
86
+ ]
87
+ },
88
+ "execution_count": 3,
89
+ "metadata": {},
90
+ "output_type": "execute_result"
91
+ }
92
+ ],
93
+ "execution_count": 3
94
+ },
95
+ {
96
+ "metadata": {
97
+ "id": "eaa67c3f07ec30e2",
98
+ "ExecuteTime": {
99
+ "end_time": "2025-09-14T07:00:59.585691Z",
100
+ "start_time": "2025-09-14T07:00:59.581345Z"
101
+ }
102
+ },
103
+ "cell_type": "code",
104
+ "source": [
105
+ "\"\"\"Task: English to Bangla \"\"\"\n",
106
+ "\n",
107
+ "mt_pretrained_model_name = \"shhossain/opus-mt-en-to-bn\""
108
+ ],
109
+ "id": "eaa67c3f07ec30e2",
110
+ "outputs": [],
111
+ "execution_count": 4
112
+ },
113
+ {
114
+ "metadata": {},
115
+ "cell_type": "code",
116
+ "outputs": [],
117
+ "execution_count": null,
118
+ "source": [
119
+ "ROOT_DIR = \"E:\\Projects\\DS & ML\"\n",
120
+ "DATA_DIR = os.path.join(ROOT_DIR, \"DS\")\n",
121
+ "dataset_file = os.path.join(DATA_DIR, \"digit_train.csv\")\n",
122
+ "\n",
123
+ "ARTIFACT_FOLDER_NAME = \"model\" # Directory to save models\n",
124
+ "SOURCE_CODE_PATH = os.path.join(\n",
125
+ " os.getcwd(),\n",
126
+ " \"project_2_mt_en_bn.ipynb\",\n",
127
+ ") # Our current notebook file path\n",
128
+ "\n",
129
+ "SOURCE_CODE_ARTIFACT = \"trainer.ipynb\"\n"
130
+ ],
131
+ "id": "eddf18ffeb5bab6c"
132
+ },
133
+ {
134
+ "metadata": {
135
+ "id": "a0d805fe4a8ab875",
136
+ "colab": {
137
+ "base_uri": "https://localhost:8080/",
138
+ "height": 193
139
+ },
140
+ "executionInfo": {
141
+ "status": "error",
142
+ "timestamp": 1757187729919,
143
+ "user_tz": -360,
144
+ "elapsed": 18,
145
+ "user": {
146
+ "displayName": "Chironjit Banerjee",
147
+ "userId": "04428016465669976257"
148
+ }
149
+ },
150
+ "outputId": "cae2f309-54cf-49f1-dee7-786aedc7622d",
151
+ "ExecuteTime": {
152
+ "end_time": "2025-09-14T07:01:02.538952Z",
153
+ "start_time": "2025-09-14T07:00:59.593206Z"
154
+ }
155
+ },
156
+ "cell_type": "code",
157
+ "source": [
158
+ "\"\"\" For NLP tasks, we basically need two entities:\n",
159
+ "1. Tokenizer\n",
160
+ "2. Model\n",
161
+ "\"\"\"\n",
162
+ "\n",
163
+ "tokenizer = AutoTokenizer.from_pretrained(mt_pretrained_model_name)\n",
164
+ "mt_pretrained_model = AutoModelForSeq2SeqLM.from_pretrained(mt_pretrained_model_name)"
165
+ ],
166
+ "id": "a0d805fe4a8ab875",
167
+ "outputs": [],
168
+ "execution_count": 5
169
+ },
170
+ {
171
+ "metadata": {},
172
+ "cell_type": "code",
173
+ "outputs": [],
174
+ "execution_count": null,
175
+ "source": [
176
+ "\"\"\" Hyperparameters: Parameters that are not for neural networks but use to train\n",
177
+ "models. \"\"\"\n",
178
+ "EPOCHS = 3\n",
179
+ "BATCH_SIZE = 32\n",
180
+ "LEARNING_RATE = 2e-5"
181
+ ],
182
+ "id": "adece7449c00450c"
183
+ },
184
+ {
185
+ "metadata": {
186
+ "id": "e93068cfc700f5f8"
187
+ },
188
+ "cell_type": "markdown",
189
+ "source": [
190
+ "# Data"
191
+ ],
192
+ "id": "e93068cfc700f5f8"
193
+ },
194
+ {
195
+ "metadata": {
196
+ "id": "89449c4bacc42140",
197
+ "colab": {
198
+ "base_uri": "https://localhost:8080/",
199
+ "height": 211
200
+ },
201
+ "executionInfo": {
202
+ "status": "error",
203
+ "timestamp": 1757187730045,
204
+ "user_tz": -360,
205
+ "elapsed": 29,
206
+ "user": {
207
+ "displayName": "Chironjit Banerjee",
208
+ "userId": "04428016465669976257"
209
+ }
210
+ },
211
+ "outputId": "a4dc0573-6e6d-4908-ac0f-8813ed1dc901",
212
+ "ExecuteTime": {
213
+ "end_time": "2025-09-14T07:01:02.555974Z",
214
+ "start_time": "2025-09-14T07:01:02.547968Z"
215
+ }
216
+ },
217
+ "cell_type": "code",
218
+ "source": [
219
+ "\"\"\"\n",
220
+ "Sentence: How are you, dude?\n",
221
+ "Tokens: 'How', 'are', 'you', 'dude?'\n",
222
+ "ids: 125, 14, 145, 78\n",
223
+ "max_length = 3\n",
224
+ "ids: [125, 14, 145]\n",
225
+ "\"\"\"\n",
226
+ "\n",
227
+ "class MTDataset(Dataset):\n",
228
+ " def __init__(self, csv_file):\n",
229
+ " self.data = pd.read_csv(csv_file)\n",
230
+ "\n",
231
+ " def __len__(self):\n",
232
+ " return len(self.data)\n",
233
+ "\n",
234
+ " def __getitem__(self, idx):\n",
235
+ " src_text = str(self.data.iloc[idx]['en'])\n",
236
+ " tgt_text = str(self.data.iloc[idx]['bn'])\n",
237
+ "\n",
238
+ " src_encoding = tokenizer(\n",
239
+ " src_text,\n",
240
+ " max_length=128,\n",
241
+ " padding='max_length',\n",
242
+ " truncation=True,\n",
243
+ " return_tensors='pt',\n",
244
+ " )\n",
245
+ "\n",
246
+ " tgt_encoding = tokenizer(\n",
247
+ " tgt_text,\n",
248
+ " max_length=128,\n",
249
+ " padding='max_length',\n",
250
+ " truncation=True,\n",
251
+ " return_tensors='pt'\n",
252
+ " )\n",
253
+ "\n",
254
+ " return {\n",
255
+ " 'src_input_ids': src_encoding['input_ids'].squeeze(),\n",
256
+ " 'src_attention_mask': src_encoding['attention_mask'].squeeze(),\n",
257
+ " 'tgt_input_ids': tgt_encoding['input_ids'].squeeze(),\n",
258
+ " 'tgt_attention_mask': tgt_encoding['attention_mask'].squeeze()\n",
259
+ " }\n",
260
+ "\n",
261
+ "\"\"\"\n",
262
+ "example: How are you, dude?\n",
263
+ "input_ids: 125, 14, 145, 78\n",
264
+ "max_length = 7\n",
265
+ "input_ids: [125, 14, 145, 147, 0, 0, 0]\n",
266
+ "attention_mask: [1, 1, 1, 1, 0, 0, 0]\n",
267
+ "\"\"\""
268
+ ],
269
+ "id": "89449c4bacc42140",
270
+ "outputs": [
271
+ {
272
+ "data": {
273
+ "text/plain": [
274
+ "'\\nexample: How are you, dude?\\ninput_ids: 125, 14, 145, 78\\nmax_length = 7\\ninput_ids: [125, 14, 145, 147, 0, 0, 0]\\nattention_mask: [1, 1, 1, 1, 0, 0, 0]\\n'"
275
+ ]
276
+ },
277
+ "execution_count": 6,
278
+ "metadata": {},
279
+ "output_type": "execute_result"
280
+ }
281
+ ],
282
+ "execution_count": 6
283
+ },
284
+ {
285
+ "metadata": {
286
+ "id": "7dec7cfe5693f5f1",
287
+ "ExecuteTime": {
288
+ "end_time": "2025-09-14T07:01:02.571975Z",
289
+ "start_time": "2025-09-14T07:01:02.567999Z"
290
+ }
291
+ },
292
+ "cell_type": "code",
293
+ "source": [
294
+ "class MTDataModule(pl.LightningDataModule):\n",
295
+ " def __init__(self, train_csv, val_csv, test_csv, batch_size=BATCH_SIZE):\n",
296
+ " super().__init__()\n",
297
+ " self.train_csv = train_csv\n",
298
+ " self.val_csv = val_csv\n",
299
+ " self.test_csv = test_csv\n",
300
+ " self.batch_size = BATCH_SIZE\n",
301
+ "\n",
302
+ " def setup(self, stage=None):\n",
303
+ " self.train_dataset = MTDataset(self.train_csv)\n",
304
+ " self.val_dataset = MTDataset(self.val_csv)\n",
305
+ " self.test_dataset = MTDataset(self.test_csv)\n",
306
+ "\n",
307
+ " def train_dataloader(self):\n",
308
+ " return DataLoader(\n",
309
+ " self.train_dataset,\n",
310
+ " batch_size=self.BATCH_SIZE,\n",
311
+ " shuffle=True\n",
312
+ " )\n",
313
+ "\n",
314
+ " def val_dataloader(self):\n",
315
+ " return DataLoader(\n",
316
+ " self.val_dataset,\n",
317
+ " batch_size=self.BATCH_SIZE,\n",
318
+ " shuffle=False\n",
319
+ " )\n",
320
+ "\n",
321
+ " def test_dataloader(self):\n",
322
+ " return DataLoader(\n",
323
+ " self.test_dataset,\n",
324
+ " batch_size=self.BATCH_SIZE,\n",
325
+ " shuffle=False\n",
326
+ " )"
327
+ ],
328
+ "id": "7dec7cfe5693f5f1",
329
+ "outputs": [],
330
+ "execution_count": 7
331
+ },
332
+ {
333
+ "metadata": {
334
+ "id": "ef2deed7494ec4b4",
335
+ "ExecuteTime": {
336
+ "end_time": "2025-09-14T07:01:02.583942Z",
337
+ "start_time": "2025-09-14T07:01:02.580979Z"
338
+ }
339
+ },
340
+ "cell_type": "code",
341
+ "source": [
342
+ "data_module = MTDataModule(\n",
343
+ " train_csv=r'E:\\Projects\\DS & ML\\EN to BN ML Project\\train.csv',\n",
344
+ " val_csv=r'E:\\Projects\\DS & ML\\EN to BN ML Project\\val.csv',\n",
345
+ " test_csv=r'E:\\Projects\\DS & ML\\EN to BN ML Project\\test.csv',\n",
346
+ " batch_size= BATCH_SIZE\n",
347
+ ")"
348
+ ],
349
+ "id": "ef2deed7494ec4b4",
350
+ "outputs": [],
351
+ "execution_count": 8
352
+ },
353
+ {
354
+ "metadata": {
355
+ "id": "86e90bfb5b63dafe"
356
+ },
357
+ "cell_type": "markdown",
358
+ "source": [
359
+ "# Model"
360
+ ],
361
+ "id": "86e90bfb5b63dafe"
362
+ },
363
+ {
364
+ "metadata": {
365
+ "id": "70ac9ff9786267a5",
366
+ "colab": {
367
+ "base_uri": "https://localhost:8080/",
368
+ "height": 211
369
+ },
370
+ "executionInfo": {
371
+ "status": "error",
372
+ "timestamp": 1757187730281,
373
+ "user_tz": -360,
374
+ "elapsed": 204,
375
+ "user": {
376
+ "displayName": "Chironjit Banerjee",
377
+ "userId": "04428016465669976257"
378
+ }
379
+ },
380
+ "outputId": "7a8e4dc0-2b1e-46ec-8c32-cef34f48ee96",
381
+ "ExecuteTime": {
382
+ "end_time": "2025-09-14T07:01:02.599947Z",
383
+ "start_time": "2025-09-14T07:01:02.591950Z"
384
+ }
385
+ },
386
+ "cell_type": "code",
387
+ "source": [
388
+ "class MTModel(pl.LightningModule):\n",
389
+ " def __init__(self):\n",
390
+ " super().__init__()\n",
391
+ " # load pretrained model\n",
392
+ " self.model = AutoModelForSeq2SeqLM.from_pretrained(mt_pretrained_model_name)\n",
393
+ " # load pretrained tokenizer\n",
394
+ " self.tokenizer = AutoTokenizer.from_pretrained(mt_pretrained_model_name)\n",
395
+ " # learning rate\n",
396
+ " self.learning_rate = 2e-5\n",
397
+ " # loss function\n",
398
+ " self.loss_fn = nn.CrossEntropyLoss(\n",
399
+ " ignore_index=self.tokenizer.pad_token_id\n",
400
+ " )\n",
401
+ " # evaluation metric\n",
402
+ " self.bleu = BLEUScore()\n",
403
+ "\n",
404
+ " def forward(self,\n",
405
+ " src_input_ids,\n",
406
+ " src_attention_mask,\n",
407
+ " tgt_input_ids,\n",
408
+ " tgt_attention_mask\n",
409
+ " ):\n",
410
+ " outputs = self.model(\n",
411
+ " input_ids=src_input_ids,\n",
412
+ " attention_mask=src_attention_mask,\n",
413
+ " decoder_input_ids=tgt_input_ids[:, :-1],\n",
414
+ " decoder_attention_mask=tgt_attention_mask[:, :-1]\n",
415
+ " )\n",
416
+ " return outputs\n",
417
+ "\n",
418
+ " def training_step(self, batch, batch_idx):\n",
419
+ " loss = self.compute_loss(batch, batch_idx, 'train')\n",
420
+ " self.log('train_loss', loss, prog_bar=True)\n",
421
+ " return loss\n",
422
+ "\n",
423
+ " def validation_step(self, batch, batch_idx):\n",
424
+ " loss = self.compute_loss(batch, batch_idx, 'val')\n",
425
+ " self.log('val_loss', loss, prog_bar=True)\n",
426
+ " return loss\n",
427
+ "\n",
428
+ " def test_step(self, batch, batch_idx):\n",
429
+ " loss = self.compute_loss(batch, batch_idx, 'test')\n",
430
+ " self.log('test_loss', loss, prog_bar=True)\n",
431
+ " return loss\n",
432
+ "\n",
433
+ " def configure_optimizers(self):\n",
434
+ " optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)\n",
435
+ " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
436
+ " optimizer,\n",
437
+ " T_max=10\n",
438
+ " )\n",
439
+ " return {'optimizer': optimizer, 'lr_scheduler': scheduler}\n",
440
+ "\n",
441
+ " def compute_loss(self, batch, batch_idx, stage):\n",
442
+ " src_input_ids = batch['src_input_ids']\n",
443
+ " src_attention_mask = batch['src_attention_mask']\n",
444
+ " tgt_input_ids = batch['tgt_input_ids']\n",
445
+ " tgt_attention_mask = batch['tgt_attention_mask']\n",
446
+ "\n",
447
+ " outputs = self(\n",
448
+ " src_input_ids,\n",
449
+ " src_attention_mask,\n",
450
+ " tgt_input_ids,\n",
451
+ " tgt_attention_mask\n",
452
+ " )\n",
453
+ " logits = outputs.logits\n",
454
+ " loss = self.loss_fn(\n",
455
+ " logits.view(-1, logits.size(-1)),\n",
456
+ " tgt_input_ids[:, 1:].contiguous().view(-1)\n",
457
+ " )\n",
458
+ "\n",
459
+ " if stage == 'val' or stage == 'test':\n",
460
+ " preds = torch.argmax(logits, dim=-1)\n",
461
+ " pred_texts = self.tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
462
+ " tgt_texts = self.tokenizer.batch_decode(tgt_input_ids[:, 1:], skip_special_tokens=True)\n",
463
+ " bleu_score = self.bleu(pred_texts, [[tgt] for tgt in tgt_texts])\n",
464
+ " self.log(f'{stage}_bleu', bleu_score, prog_bar=True)\n",
465
+ "\n",
466
+ " return loss\n"
467
+ ],
468
+ "id": "70ac9ff9786267a5",
469
+ "outputs": [],
470
+ "execution_count": 9
471
+ },
472
+ {
473
+ "metadata": {
474
+ "id": "76dccd8fe08376a4",
475
+ "ExecuteTime": {
476
+ "end_time": "2025-09-14T07:01:05.089548Z",
477
+ "start_time": "2025-09-14T07:01:02.605071Z"
478
+ }
479
+ },
480
+ "cell_type": "code",
481
+ "source": [
482
+ "model = MTModel()"
483
+ ],
484
+ "id": "76dccd8fe08376a4",
485
+ "outputs": [],
486
+ "execution_count": 10
487
+ },
488
+ {
489
+ "metadata": {},
490
+ "cell_type": "code",
491
+ "outputs": [],
492
+ "execution_count": null,
493
+ "source": [
494
+ "early_stopping = EarlyStopping(\n",
495
+ " monitor='val_loss', # Should match with the validation step log key\n",
496
+ " patience=2,\n",
497
+ " verbose=True,\n",
498
+ ")\n",
499
+ "\n",
500
+ "checkpoint_callback = ModelCheckpoint(\n",
501
+ " monitor='val_accuracy', # Should match with the validation step log key\n",
502
+ " save_top_k=1, # Saves top one model\n",
503
+ " mode='max', # top means max validation accuracy\n",
504
+ ")\n",
505
+ "\n",
506
+ "checkpoint_path = os.path.join(\n",
507
+ " os.getcwd(), \"checkpoints\", \"best_model.pth\"\n",
508
+ ")\n"
509
+ ],
510
+ "id": "b280d211a42ceeee"
511
+ },
512
+ {
513
+ "metadata": {
514
+ "id": "c037b19d321b93ff"
515
+ },
516
+ "cell_type": "markdown",
517
+ "source": [
518
+ "# Train"
519
+ ],
520
+ "id": "c037b19d321b93ff"
521
+ },
522
+ {
523
+ "metadata": {
524
+ "id": "1bd38416398d770a",
525
+ "ExecuteTime": {
526
+ "end_time": "2025-09-14T07:01:05.135876Z",
527
+ "start_time": "2025-09-14T07:01:05.096305Z"
528
+ }
529
+ },
530
+ "cell_type": "code",
531
+ "source": [
532
+ "trainer = pl.Trainer(\n",
533
+ " max_epochs=5,\n",
534
+ " accelerator='gpu' if torch.cuda.is_available() else 'cpu',\n",
535
+ " devices=1,\n",
536
+ " precision=\"16-mixed\",\n",
537
+ " log_every_n_steps=10,\n",
538
+ " val_check_interval=0.25\n",
539
+ ")"
540
+ ],
541
+ "id": "1bd38416398d770a",
542
+ "outputs": [
543
+ {
544
+ "name": "stderr",
545
+ "output_type": "stream",
546
+ "text": [
547
+ "Using 16bit Automatic Mixed Precision (AMP)\n",
548
+ "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
549
+ "GPU available: True (cuda), used: True\n",
550
+ "TPU available: False, using: 0 TPU cores\n",
551
+ "HPU available: False, using: 0 HPUs\n"
552
+ ]
553
+ }
554
+ ],
555
+ "execution_count": 11
556
+ },
557
+ {
558
+ "metadata": {
559
+ "id": "add377254e158c86",
560
+ "jupyter": {
561
+ "is_executing": true
562
+ },
563
+ "ExecuteTime": {
564
+ "start_time": "2025-09-14T07:01:05.145883Z"
565
+ }
566
+ },
567
+ "cell_type": "code",
568
+ "source": [
569
+ "trainer.fit(model, data_module)"
570
+ ],
571
+ "id": "add377254e158c86",
572
+ "outputs": [
573
+ {
574
+ "name": "stderr",
575
+ "output_type": "stream",
576
+ "text": [
577
+ "You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
578
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
579
+ "C:\\Users\\User\\PyCharmMiscProject\\.venv\\Lib\\site-packages\\pytorch_lightning\\utilities\\model_summary\\model_summary.py:231: Precision 16-mixed is not supported by the model summary. Estimated model size in MB will not be accurate. Using 32 bits instead.\n",
580
+ "\n",
581
+ " | Name | Type | Params | Mode \n",
582
+ "-----------------------------------------------------\n",
583
+ "0 | model | MarianMTModel | 76.3 M | eval \n",
584
+ "1 | loss_fn | CrossEntropyLoss | 0 | train\n",
585
+ "2 | bleu | BLEUScore | 0 | train\n",
586
+ "-----------------------------------------------------\n",
587
+ "75.8 M Trainable params\n",
588
+ "524 K Non-trainable params\n",
589
+ "76.3 M Total params\n",
590
+ "305.136 Total estimated model params size (MB)\n",
591
+ "2 Modules in train mode\n",
592
+ "178 Modules in eval mode\n"
593
+ ]
594
+ },
595
+ {
596
+ "data": {
597
+ "text/plain": [
598
+ "Sanity Checking: | | 0/? [00:00<?, ?it/s]"
599
+ ],
600
+ "application/vnd.jupyter.widget-view+json": {
601
+ "version_major": 2,
602
+ "version_minor": 0,
603
+ "model_id": "7653d36a0abd4e27bea488b14b89d42b"
604
+ }
605
+ },
606
+ "metadata": {},
607
+ "output_type": "display_data"
608
+ },
609
+ {
610
+ "name": "stderr",
611
+ "output_type": "stream",
612
+ "text": [
613
+ "C:\\Users\\User\\PyCharmMiscProject\\.venv\\Lib\\site-packages\\pytorch_lightning\\trainer\\connectors\\data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n",
614
+ "C:\\Users\\User\\PyCharmMiscProject\\.venv\\Lib\\site-packages\\pytorch_lightning\\trainer\\connectors\\data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
615
+ ]
616
+ },
617
+ {
618
+ "data": {
619
+ "text/plain": [
620
+ "Training: | | 0/? [00:00<?, ?it/s]"
621
+ ],
622
+ "application/vnd.jupyter.widget-view+json": {
623
+ "version_major": 2,
624
+ "version_minor": 0,
625
+ "model_id": "b3323116f1b44cd9aacdfe59272a1310"
626
+ }
627
+ },
628
+ "metadata": {},
629
+ "output_type": "display_data"
630
+ },
631
+ {
632
+ "data": {
633
+ "text/plain": [
634
+ "Validation: | | 0/? [00:00<?, ?it/s]"
635
+ ],
636
+ "application/vnd.jupyter.widget-view+json": {
637
+ "version_major": 2,
638
+ "version_minor": 0,
639
+ "model_id": "e1f76950fb134014a65f59bd58f85541"
640
+ }
641
+ },
642
+ "metadata": {},
643
+ "output_type": "display_data"
644
+ },
645
+ {
646
+ "data": {
647
+ "text/plain": [
648
+ "Validation: | | 0/? [00:00<?, ?it/s]"
649
+ ],
650
+ "application/vnd.jupyter.widget-view+json": {
651
+ "version_major": 2,
652
+ "version_minor": 0,
653
+ "model_id": "52f93b059f014695ac2984b1ab8ab6e0"
654
+ }
655
+ },
656
+ "metadata": {},
657
+ "output_type": "display_data"
658
+ },
659
+ {
660
+ "data": {
661
+ "text/plain": [
662
+ "Validation: | | 0/? [00:00<?, ?it/s]"
663
+ ],
664
+ "application/vnd.jupyter.widget-view+json": {
665
+ "version_major": 2,
666
+ "version_minor": 0,
667
+ "model_id": "b738ac5554f9427ea2fe44f82b159593"
668
+ }
669
+ },
670
+ "metadata": {},
671
+ "output_type": "display_data"
672
+ },
673
+ {
674
+ "data": {
675
+ "text/plain": [
676
+ "Validation: | | 0/? [00:00<?, ?it/s]"
677
+ ],
678
+ "application/vnd.jupyter.widget-view+json": {
679
+ "version_major": 2,
680
+ "version_minor": 0,
681
+ "model_id": "85d57ade665a4808a138b0364dc40d0d"
682
+ }
683
+ },
684
+ "metadata": {},
685
+ "output_type": "display_data"
686
+ }
687
+ ],
688
+ "execution_count": null
689
+ },
690
+ {
691
+ "metadata": {
692
+ "id": "652b3f73247ae77c",
693
+ "ExecuteTime": {
694
+ "end_time": "2025-09-12T11:07:00.608309Z",
695
+ "start_time": "2025-09-12T11:06:38.554751Z"
696
+ }
697
+ },
698
+ "cell_type": "code",
699
+ "source": [
700
+ "trainer.test(model, data_module)"
701
+ ],
702
+ "id": "652b3f73247ae77c",
703
+ "outputs": [
704
+ {
705
+ "name": "stderr",
706
+ "output_type": "stream",
707
+ "text": [
708
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
709
+ "C:\\Users\\User\\PyCharmMiscProject\\.venv\\Lib\\site-packages\\pytorch_lightning\\trainer\\connectors\\data_connector.py:433: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
710
+ ]
711
+ },
712
+ {
713
+ "data": {
714
+ "text/plain": [
715
+ "Testing: | | 0/? [00:00<?, ?it/s]"
716
+ ],
717
+ "application/vnd.jupyter.widget-view+json": {
718
+ "version_major": 2,
719
+ "version_minor": 0,
720
+ "model_id": "37db78bcf4f646bb995c4a725d9126eb"
721
+ }
722
+ },
723
+ "metadata": {},
724
+ "output_type": "display_data"
725
+ },
726
+ {
727
+ "data": {
728
+ "text/plain": [
729
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
730
+ "┃\u001B[1m \u001B[0m\u001B[1m Test metric \u001B[0m\u001B[1m \u001B[0m┃\u001B[1m \u001B[0m\u001B[1m DataLoader 0 \u001B[0m\u001B[1m \u001B[0m┃\n",
731
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
732
+ "│\u001B[36m \u001B[0m\u001B[36m test_bleu \u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 0.13064175844192505 \u001B[0m\u001B[35m \u001B[0m│\n",
733
+ "│\u001B[36m \u001B[0m\u001B[36m test_loss \u001B[0m\u001B[36m \u001B[0m│\u001B[35m \u001B[0m\u001B[35m 0.5454719662666321 \u001B[0m\u001B[35m \u001B[0m│\n",
734
+ "└───────────────────────────┴───────────────────────────┘\n"
735
+ ],
736
+ "text/html": [
737
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
738
+ "┃<span style=\"font-weight: bold\"> Test metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n",
739
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
740
+ "│<span style=\"color: #008080; text-decoration-color: #008080\"> test_bleu </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.13064175844192505 </span>│\n",
741
+ "│<span style=\"color: #008080; text-decoration-color: #008080\"> test_loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.5454719662666321 </span>│\n",
742
+ "└───────────────────────────┴───────────────────────────┘\n",
743
+ "</pre>\n"
744
+ ]
745
+ },
746
+ "metadata": {},
747
+ "output_type": "display_data"
748
+ },
749
+ {
750
+ "data": {
751
+ "text/plain": [
752
+ "[{'test_bleu': 0.13064175844192505, 'test_loss': 0.5454719662666321}]"
753
+ ]
754
+ },
755
+ "execution_count": 12,
756
+ "metadata": {},
757
+ "output_type": "execute_result"
758
+ }
759
+ ],
760
+ "execution_count": 12
761
+ },
762
+ {
763
+ "metadata": {
764
+ "id": "9f115c718388b7f9",
765
+ "outputId": "2e3186ce-b52e-4344-cfd5-8fea0ee7b28a",
766
+ "colab": {
767
+ "base_uri": "https://localhost:8080/",
768
+ "height": 141
769
+ },
770
+ "executionInfo": {
771
+ "status": "error",
772
+ "timestamp": 1757187730430,
773
+ "user_tz": -360,
774
+ "elapsed": 104,
775
+ "user": {
776
+ "displayName": "Chironjit Banerjee",
777
+ "userId": "04428016465669976257"
778
+ }
779
+ },
780
+ "ExecuteTime": {
781
+ "end_time": "2025-09-12T11:07:00.671344Z",
782
+ "start_time": "2025-09-12T11:07:00.665343Z"
783
+ }
784
+ },
785
+ "cell_type": "code",
786
+ "source": [
787
+ "model.model.config"
788
+ ],
789
+ "id": "9f115c718388b7f9",
790
+ "outputs": [
791
+ {
792
+ "data": {
793
+ "text/plain": [
794
+ "MarianConfig {\n",
795
+ " \"activation_dropout\": 0.0,\n",
796
+ " \"activation_function\": \"swish\",\n",
797
+ " \"add_bias_logits\": false,\n",
798
+ " \"add_final_layer_norm\": false,\n",
799
+ " \"architectures\": [\n",
800
+ " \"MarianMTModel\"\n",
801
+ " ],\n",
802
+ " \"attention_dropout\": 0.0,\n",
803
+ " \"bad_words_ids\": [\n",
804
+ " [\n",
805
+ " 61759\n",
806
+ " ]\n",
807
+ " ],\n",
808
+ " \"bos_token_id\": 0,\n",
809
+ " \"classif_dropout\": 0.0,\n",
810
+ " \"classifier_dropout\": 0.0,\n",
811
+ " \"d_model\": 512,\n",
812
+ " \"decoder_attention_heads\": 8,\n",
813
+ " \"decoder_ffn_dim\": 2048,\n",
814
+ " \"decoder_layerdrop\": 0.0,\n",
815
+ " \"decoder_layers\": 6,\n",
816
+ " \"decoder_start_token_id\": 61759,\n",
817
+ " \"decoder_vocab_size\": 61760,\n",
818
+ " \"dropout\": 0.1,\n",
819
+ " \"dtype\": \"float32\",\n",
820
+ " \"encoder_attention_heads\": 8,\n",
821
+ " \"encoder_ffn_dim\": 2048,\n",
822
+ " \"encoder_layerdrop\": 0.0,\n",
823
+ " \"encoder_layers\": 6,\n",
824
+ " \"eos_token_id\": 0,\n",
825
+ " \"extra_pos_embeddings\": 61760,\n",
826
+ " \"forced_eos_token_id\": 0,\n",
827
+ " \"id2label\": {\n",
828
+ " \"0\": \"LABEL_0\",\n",
829
+ " \"1\": \"LABEL_1\",\n",
830
+ " \"2\": \"LABEL_2\"\n",
831
+ " },\n",
832
+ " \"init_std\": 0.02,\n",
833
+ " \"is_encoder_decoder\": true,\n",
834
+ " \"label2id\": {\n",
835
+ " \"LABEL_0\": 0,\n",
836
+ " \"LABEL_1\": 1,\n",
837
+ " \"LABEL_2\": 2\n",
838
+ " },\n",
839
+ " \"max_length\": 512,\n",
840
+ " \"max_position_embeddings\": 512,\n",
841
+ " \"model_type\": \"marian\",\n",
842
+ " \"normalize_before\": false,\n",
843
+ " \"normalize_embedding\": false,\n",
844
+ " \"num_beams\": 4,\n",
845
+ " \"num_hidden_layers\": 6,\n",
846
+ " \"pad_token_id\": 61759,\n",
847
+ " \"scale_embedding\": true,\n",
848
+ " \"share_encoder_decoder_embeddings\": true,\n",
849
+ " \"static_position_embeddings\": true,\n",
850
+ " \"transformers_version\": \"4.56.1\",\n",
851
+ " \"use_cache\": true,\n",
852
+ " \"vocab_size\": 61760\n",
853
+ "}"
854
+ ]
855
+ },
856
+ "execution_count": 13,
857
+ "metadata": {},
858
+ "output_type": "execute_result"
859
+ }
860
+ ],
861
+ "execution_count": 13
862
+ },
863
+ {
864
+ "metadata": {
865
+ "id": "4db952c7f44ec3b2",
866
+ "outputId": "1c8af67a-9c64-4017-c39b-5bf773a0e8ad",
867
+ "colab": {
868
+ "base_uri": "https://localhost:8080/",
869
+ "height": 159
870
+ },
871
+ "executionInfo": {
872
+ "status": "error",
873
+ "timestamp": 1757187730464,
874
+ "user_tz": -360,
875
+ "elapsed": 20,
876
+ "user": {
877
+ "displayName": "Chironjit Banerjee",
878
+ "userId": "04428016465669976257"
879
+ }
880
+ },
881
+ "ExecuteTime": {
882
+ "end_time": "2025-09-12T11:07:00.694215Z",
883
+ "start_time": "2025-09-12T11:07:00.687372Z"
884
+ }
885
+ },
886
+ "cell_type": "code",
887
+ "source": [
888
+ "for name, module in model.model.named_modules():\n",
889
+ " print(name)"
890
+ ],
891
+ "id": "4db952c7f44ec3b2",
892
+ "outputs": [
893
+ {
894
+ "name": "stdout",
895
+ "output_type": "stream",
896
+ "text": [
897
+ "\n",
898
+ "model\n",
899
+ "model.shared\n",
900
+ "model.encoder\n",
901
+ "model.encoder.embed_positions\n",
902
+ "model.encoder.layers\n",
903
+ "model.encoder.layers.0\n",
904
+ "model.encoder.layers.0.self_attn\n",
905
+ "model.encoder.layers.0.self_attn.k_proj\n",
906
+ "model.encoder.layers.0.self_attn.v_proj\n",
907
+ "model.encoder.layers.0.self_attn.q_proj\n",
908
+ "model.encoder.layers.0.self_attn.out_proj\n",
909
+ "model.encoder.layers.0.self_attn_layer_norm\n",
910
+ "model.encoder.layers.0.activation_fn\n",
911
+ "model.encoder.layers.0.fc1\n",
912
+ "model.encoder.layers.0.fc2\n",
913
+ "model.encoder.layers.0.final_layer_norm\n",
914
+ "model.encoder.layers.1\n",
915
+ "model.encoder.layers.1.self_attn\n",
916
+ "model.encoder.layers.1.self_attn.k_proj\n",
917
+ "model.encoder.layers.1.self_attn.v_proj\n",
918
+ "model.encoder.layers.1.self_attn.q_proj\n",
919
+ "model.encoder.layers.1.self_attn.out_proj\n",
920
+ "model.encoder.layers.1.self_attn_layer_norm\n",
921
+ "model.encoder.layers.1.activation_fn\n",
922
+ "model.encoder.layers.1.fc1\n",
923
+ "model.encoder.layers.1.fc2\n",
924
+ "model.encoder.layers.1.final_layer_norm\n",
925
+ "model.encoder.layers.2\n",
926
+ "model.encoder.layers.2.self_attn\n",
927
+ "model.encoder.layers.2.self_attn.k_proj\n",
928
+ "model.encoder.layers.2.self_attn.v_proj\n",
929
+ "model.encoder.layers.2.self_attn.q_proj\n",
930
+ "model.encoder.layers.2.self_attn.out_proj\n",
931
+ "model.encoder.layers.2.self_attn_layer_norm\n",
932
+ "model.encoder.layers.2.activation_fn\n",
933
+ "model.encoder.layers.2.fc1\n",
934
+ "model.encoder.layers.2.fc2\n",
935
+ "model.encoder.layers.2.final_layer_norm\n",
936
+ "model.encoder.layers.3\n",
937
+ "model.encoder.layers.3.self_attn\n",
938
+ "model.encoder.layers.3.self_attn.k_proj\n",
939
+ "model.encoder.layers.3.self_attn.v_proj\n",
940
+ "model.encoder.layers.3.self_attn.q_proj\n",
941
+ "model.encoder.layers.3.self_attn.out_proj\n",
942
+ "model.encoder.layers.3.self_attn_layer_norm\n",
943
+ "model.encoder.layers.3.activation_fn\n",
944
+ "model.encoder.layers.3.fc1\n",
945
+ "model.encoder.layers.3.fc2\n",
946
+ "model.encoder.layers.3.final_layer_norm\n",
947
+ "model.encoder.layers.4\n",
948
+ "model.encoder.layers.4.self_attn\n",
949
+ "model.encoder.layers.4.self_attn.k_proj\n",
950
+ "model.encoder.layers.4.self_attn.v_proj\n",
951
+ "model.encoder.layers.4.self_attn.q_proj\n",
952
+ "model.encoder.layers.4.self_attn.out_proj\n",
953
+ "model.encoder.layers.4.self_attn_layer_norm\n",
954
+ "model.encoder.layers.4.activation_fn\n",
955
+ "model.encoder.layers.4.fc1\n",
956
+ "model.encoder.layers.4.fc2\n",
957
+ "model.encoder.layers.4.final_layer_norm\n",
958
+ "model.encoder.layers.5\n",
959
+ "model.encoder.layers.5.self_attn\n",
960
+ "model.encoder.layers.5.self_attn.k_proj\n",
961
+ "model.encoder.layers.5.self_attn.v_proj\n",
962
+ "model.encoder.layers.5.self_attn.q_proj\n",
963
+ "model.encoder.layers.5.self_attn.out_proj\n",
964
+ "model.encoder.layers.5.self_attn_layer_norm\n",
965
+ "model.encoder.layers.5.activation_fn\n",
966
+ "model.encoder.layers.5.fc1\n",
967
+ "model.encoder.layers.5.fc2\n",
968
+ "model.encoder.layers.5.final_layer_norm\n",
969
+ "model.decoder\n",
970
+ "model.decoder.embed_positions\n",
971
+ "model.decoder.layers\n",
972
+ "model.decoder.layers.0\n",
973
+ "model.decoder.layers.0.self_attn\n",
974
+ "model.decoder.layers.0.self_attn.k_proj\n",
975
+ "model.decoder.layers.0.self_attn.v_proj\n",
976
+ "model.decoder.layers.0.self_attn.q_proj\n",
977
+ "model.decoder.layers.0.self_attn.out_proj\n",
978
+ "model.decoder.layers.0.activation_fn\n",
979
+ "model.decoder.layers.0.self_attn_layer_norm\n",
980
+ "model.decoder.layers.0.encoder_attn\n",
981
+ "model.decoder.layers.0.encoder_attn.k_proj\n",
982
+ "model.decoder.layers.0.encoder_attn.v_proj\n",
983
+ "model.decoder.layers.0.encoder_attn.q_proj\n",
984
+ "model.decoder.layers.0.encoder_attn.out_proj\n",
985
+ "model.decoder.layers.0.encoder_attn_layer_norm\n",
986
+ "model.decoder.layers.0.fc1\n",
987
+ "model.decoder.layers.0.fc2\n",
988
+ "model.decoder.layers.0.final_layer_norm\n",
989
+ "model.decoder.layers.1\n",
990
+ "model.decoder.layers.1.self_attn\n",
991
+ "model.decoder.layers.1.self_attn.k_proj\n",
992
+ "model.decoder.layers.1.self_attn.v_proj\n",
993
+ "model.decoder.layers.1.self_attn.q_proj\n",
994
+ "model.decoder.layers.1.self_attn.out_proj\n",
995
+ "model.decoder.layers.1.activation_fn\n",
996
+ "model.decoder.layers.1.self_attn_layer_norm\n",
997
+ "model.decoder.layers.1.encoder_attn\n",
998
+ "model.decoder.layers.1.encoder_attn.k_proj\n",
999
+ "model.decoder.layers.1.encoder_attn.v_proj\n",
1000
+ "model.decoder.layers.1.encoder_attn.q_proj\n",
1001
+ "model.decoder.layers.1.encoder_attn.out_proj\n",
1002
+ "model.decoder.layers.1.encoder_attn_layer_norm\n",
1003
+ "model.decoder.layers.1.fc1\n",
1004
+ "model.decoder.layers.1.fc2\n",
1005
+ "model.decoder.layers.1.final_layer_norm\n",
1006
+ "model.decoder.layers.2\n",
1007
+ "model.decoder.layers.2.self_attn\n",
1008
+ "model.decoder.layers.2.self_attn.k_proj\n",
1009
+ "model.decoder.layers.2.self_attn.v_proj\n",
1010
+ "model.decoder.layers.2.self_attn.q_proj\n",
1011
+ "model.decoder.layers.2.self_attn.out_proj\n",
1012
+ "model.decoder.layers.2.activation_fn\n",
1013
+ "model.decoder.layers.2.self_attn_layer_norm\n",
1014
+ "model.decoder.layers.2.encoder_attn\n",
1015
+ "model.decoder.layers.2.encoder_attn.k_proj\n",
1016
+ "model.decoder.layers.2.encoder_attn.v_proj\n",
1017
+ "model.decoder.layers.2.encoder_attn.q_proj\n",
1018
+ "model.decoder.layers.2.encoder_attn.out_proj\n",
1019
+ "model.decoder.layers.2.encoder_attn_layer_norm\n",
1020
+ "model.decoder.layers.2.fc1\n",
1021
+ "model.decoder.layers.2.fc2\n",
1022
+ "model.decoder.layers.2.final_layer_norm\n",
1023
+ "model.decoder.layers.3\n",
1024
+ "model.decoder.layers.3.self_attn\n",
1025
+ "model.decoder.layers.3.self_attn.k_proj\n",
1026
+ "model.decoder.layers.3.self_attn.v_proj\n",
1027
+ "model.decoder.layers.3.self_attn.q_proj\n",
1028
+ "model.decoder.layers.3.self_attn.out_proj\n",
1029
+ "model.decoder.layers.3.activation_fn\n",
1030
+ "model.decoder.layers.3.self_attn_layer_norm\n",
1031
+ "model.decoder.layers.3.encoder_attn\n",
1032
+ "model.decoder.layers.3.encoder_attn.k_proj\n",
1033
+ "model.decoder.layers.3.encoder_attn.v_proj\n",
1034
+ "model.decoder.layers.3.encoder_attn.q_proj\n",
1035
+ "model.decoder.layers.3.encoder_attn.out_proj\n",
1036
+ "model.decoder.layers.3.encoder_attn_layer_norm\n",
1037
+ "model.decoder.layers.3.fc1\n",
1038
+ "model.decoder.layers.3.fc2\n",
1039
+ "model.decoder.layers.3.final_layer_norm\n",
1040
+ "model.decoder.layers.4\n",
1041
+ "model.decoder.layers.4.self_attn\n",
1042
+ "model.decoder.layers.4.self_attn.k_proj\n",
1043
+ "model.decoder.layers.4.self_attn.v_proj\n",
1044
+ "model.decoder.layers.4.self_attn.q_proj\n",
1045
+ "model.decoder.layers.4.self_attn.out_proj\n",
1046
+ "model.decoder.layers.4.activation_fn\n",
1047
+ "model.decoder.layers.4.self_attn_layer_norm\n",
1048
+ "model.decoder.layers.4.encoder_attn\n",
1049
+ "model.decoder.layers.4.encoder_attn.k_proj\n",
1050
+ "model.decoder.layers.4.encoder_attn.v_proj\n",
1051
+ "model.decoder.layers.4.encoder_attn.q_proj\n",
1052
+ "model.decoder.layers.4.encoder_attn.out_proj\n",
1053
+ "model.decoder.layers.4.encoder_attn_layer_norm\n",
1054
+ "model.decoder.layers.4.fc1\n",
1055
+ "model.decoder.layers.4.fc2\n",
1056
+ "model.decoder.layers.4.final_layer_norm\n",
1057
+ "model.decoder.layers.5\n",
1058
+ "model.decoder.layers.5.self_attn\n",
1059
+ "model.decoder.layers.5.self_attn.k_proj\n",
1060
+ "model.decoder.layers.5.self_attn.v_proj\n",
1061
+ "model.decoder.layers.5.self_attn.q_proj\n",
1062
+ "model.decoder.layers.5.self_attn.out_proj\n",
1063
+ "model.decoder.layers.5.activation_fn\n",
1064
+ "model.decoder.layers.5.self_attn_layer_norm\n",
1065
+ "model.decoder.layers.5.encoder_attn\n",
1066
+ "model.decoder.layers.5.encoder_attn.k_proj\n",
1067
+ "model.decoder.layers.5.encoder_attn.v_proj\n",
1068
+ "model.decoder.layers.5.encoder_attn.q_proj\n",
1069
+ "model.decoder.layers.5.encoder_attn.out_proj\n",
1070
+ "model.decoder.layers.5.encoder_attn_layer_norm\n",
1071
+ "model.decoder.layers.5.fc1\n",
1072
+ "model.decoder.layers.5.fc2\n",
1073
+ "model.decoder.layers.5.final_layer_norm\n",
1074
+ "lm_head\n"
1075
+ ]
1076
+ }
1077
+ ],
1078
+ "execution_count": 14
1079
+ },
1080
+ {
1081
+ "metadata": {
1082
+ "ExecuteTime": {
1083
+ "end_time": "2025-09-14T07:00:16.535340Z",
1084
+ "start_time": "2025-09-14T07:00:14.981559Z"
1085
+ }
1086
+ },
1087
+ "cell_type": "code",
1088
+ "source": [
1089
+ "import mlflow\n",
1090
+ "mlflow.set_experiment(experiment_name= \"BongoDev Projects\")\n"
1091
+ ],
1092
+ "id": "1c8e5f9092cfe5c6",
1093
+ "outputs": [
1094
+ {
1095
+ "name": "stderr",
1096
+ "output_type": "stream",
1097
+ "text": [
1098
+ "2025/09/14 13:00:16 INFO mlflow.tracking.fluent: Experiment with name 'BongoDev Projects' does not exist. Creating a new experiment.\n"
1099
+ ]
1100
+ },
1101
+ {
1102
+ "data": {
1103
+ "text/plain": [
1104
+ "<Experiment: artifact_location='file:///C:/Users/User/PyCharmMiscProject/mlruns/793621701339965882', creation_time=1757833216529, experiment_id='793621701339965882', last_update_time=1757833216529, lifecycle_stage='active', name='BongoDev Projects', tags={}>"
1105
+ ]
1106
+ },
1107
+ "execution_count": 1,
1108
+ "metadata": {},
1109
+ "output_type": "execute_result"
1110
+ }
1111
+ ],
1112
+ "execution_count": 1
1113
+ },
1114
+ {
1115
+ "metadata": {},
1116
+ "cell_type": "markdown",
1117
+ "source": "### Experiment Tracking using MLFlow",
1118
+ "id": "1f356c18c4d0a1f9"
1119
+ },
1120
+ {
1121
+ "metadata": {},
1122
+ "cell_type": "code",
1123
+ "outputs": [],
1124
+ "execution_count": null,
1125
+ "source": [
1126
+ "with mlflow.start_run():\n",
1127
+ " # Log Hyperparameters\n",
1128
+ " mlflow.log_param(\"learning_rate\", LEARNING_RATE)\n",
1129
+ " mlflow.log_param(\"batch_size\", BATCH_SIZE)\n",
1130
+ " mlflow.log_param(\"epochs\", EPOCHS)\n",
1131
+ "\n",
1132
+ "\n",
1133
+ "\n",
1134
+ " trainer.fit(\n",
1135
+ " model=model,\n",
1136
+ " datamodule=data_module\n",
1137
+ " )\n",
1138
+ "\n",
1139
+ " # Get the best model\n",
1140
+ " best_model_path = checkpoint_callback.best_model_path\n",
1141
+ " best_model = DigitClassifier.load_from_checkpoint(best_model_path)\n",
1142
+ "\n",
1143
+ " # Evaluate the model on the test set\n",
1144
+ " evaluation_score = trainer.test(\n",
1145
+ " best_model,\n",
1146
+ " datamodule= data_module,\n",
1147
+ " )\n",
1148
+ "\n",
1149
+ "\n",
1150
+ " mlflow.log_metric(\"test_accuracy\", evaluation_score[0][\"test_accuracy\"])\n",
1151
+ " mlflow.log_metric(\"test_loss\", evaluation_score[0][\"test_loss\"])\n",
1152
+ "\n",
1153
+ "\n",
1154
+ " # Save the model\n",
1155
+ " # Prepare a small input_example from the test loader\n",
1156
+ " test_loader = data_module.test_dataloader()\n",
1157
+ " first_batch = next(iter(test_loader))\n",
1158
+ " src_input_ids_example = first_batch[\"src_input_ids\"].cpu().numpy()\n",
1159
+ "\n",
1160
+ " signature = infer_signature(src_input_ids_example, src_input_ids_example)\n",
1161
+ "\n",
1162
+ " # Log the underlying HF Seq2Seq model (nn.Module) to keep it simple\n",
1163
+ " import mlflow.pytorch\n",
1164
+ " mlflow.pytorch.log_model(\n",
1165
+ " pytorch_model=best_model.model,\n",
1166
+ " artifact_path=ARTIFACT_FOLDER_NAME,\n",
1167
+ " input_example=src_input_ids_example,\n",
1168
+ " signature=signature\n",
1169
+ " )\n",
1170
+ "\n",
1171
+ " import shutil\n",
1172
+ " shutil.copyfile(SOURCE_CODE_PATH, SOURCE_CODE_ARTIFACT)\n",
1173
+ " mlflow.log_artifact(SOURCE_CODE_ARTIFACT)\n",
1174
+ "\n",
1175
+ "\n",
1176
+ "\n",
1177
+ "\n",
1178
+ "\n"
1179
+ ],
1180
+ "id": "23776af915ea20ae"
1181
+ }
1182
+ ],
1183
+ "metadata": {
1184
+ "kernelspec": {
1185
+ "display_name": "Python 3",
1186
+ "language": "python",
1187
+ "name": "python3"
1188
+ },
1189
+ "language_info": {
1190
+ "codemirror_mode": {
1191
+ "name": "ipython",
1192
+ "version": 2
1193
+ },
1194
+ "file_extension": ".py",
1195
+ "mimetype": "text/x-python",
1196
+ "name": "python",
1197
+ "nbconvert_exporter": "python",
1198
+ "pygments_lexer": "ipython2",
1199
+ "version": "2.7.6"
1200
+ },
1201
+ "colab": {
1202
+ "provenance": []
1203
+ }
1204
+ },
1205
+ "nbformat": 4,
1206
+ "nbformat_minor": 5
1207
+ }