<a href="https://colab.research.google.com/github/your-username/mt5-finetune-en-de/blob/main/mt5_finetune_en_de.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-Tuning mT5 for English → German Translation

This notebook fine-tunes **`google/mt5-small`** on the **WMT16 En→De** dataset using Hugging Face Transformers.

- Model: Multilingual T5 (mT5) – pre-trained on 101 languages **without supervised translation**
- Task: Teach it **high-quality English to German translation** via fine-tuning
- Dataset: WMT16 (via `datasets` library)
- Framework: `transformers` + `seq2seq` Trainer

---

## 1. Install Dependencies

In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [2]:
!pip install -q transformers datasets sentencepiece sacrebleu accelerate evaluate
!pip install -q torch --index-url https://download.pytorch.org/whl/cu118

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [5]:
!pip install pandas



## 2. Load Dataset (WMT16 En→De)

In [3]:
import gdown
import json

data_path = "/content/drive/MyDrive/llm-translator/parallel_corpus.json"  # Adjust path
try:
    with open(data_path, "r", encoding="utf-8") as file:
        data = json.load(file)
        print(f"Loaded {len(data)} entries from parallel_corpus.json")
except FileNotFoundError:
    print(f"Error: The file '{data_path}' was not found.")
    exit(1)
except json.JSONDecodeError:
    print("Error: Failed to decode JSON from the file.")
    exit(1)

Loaded 31192 entries from parallel_corpus.json


In [10]:
import string
import pandas as pd
from datasets import Dataset, DatasetDict


# Build the three language pairs exactly as you did
def is_valid(t): return bool(t and t.strip() and t.strip() not in string.punctuation)
df = pd.DataFrame(data)
breton_df = df[df.apply(lambda r: is_valid(r["niv_text"]) and is_valid(r["koad21_text"]), axis=1)][["niv_text","koad21_text"]].rename(columns={"niv_text":"en","koad21_text":"target"})
breton_df["language"] = "br"
cornish_df = df[df.apply(lambda r: is_valid(r["niv_text"]) and is_valid(r["abk_text"]), axis=1)][["niv_text","abk_text"]].rename(columns={"niv_text":"en","abk_text":"target"})
cornish_df["language"] = "abk"
welsh_df = df[df.apply(lambda r: is_valid(r["niv_text"]) and is_valid(r["bcnda_text"]), axis=1)][["niv_text","bcnda_text"]].rename(columns={"niv_text":"en","bcnda_text":"target"})
welsh_df["language"] = "cy"


combined_df = pd.concat([breton_df, cornish_df, welsh_df], ignore_index=True)
print(combined_df.head(5))
dataset = Dataset.from_pandas(combined_df).train_test_split(test_size=0.2, seed=42)
print(f"Combined dataset size: {len(combined_df)} pairs (Breton: {len(breton_df)}, Cornish: {len(cornish_df)}, Welsh: {len(welsh_df)})")

raw_datasets = DatasetDict({
    "train": dataset["train"],
    "test" : dataset["test"]
})
print(f"Train: {len(raw_datasets['train'])}, Test: {len(raw_datasets['test'])}")
print(raw_datasets['train'])

                                                  en  \
0  The Lord called to Moses and spoke to him from...   
1  “Speak to the Israelites and say to them: ‘Whe...   
2  “ ‘If the offering is a burnt offering from th...   
3  You are to lay your hand on the head of the bu...   
4  You are to slaughter the young bull before the...   

                                              target language  
0  An AOTROU a c’halvas Moizez hag a gomzas dezha...       br  
1  Komz da vibien Israel ha lavar: Pa raio unan b...       br  
2  Mar d-eo e brof ul loskaberzh a loened bras, e...       br  
3  Lakaat a raio e zorn war benn al loskaberzh, a...       br  
4  Lazhañ a raio ar c’hole dirak an AOTROU ; an a...       br  
Combined dataset size: 93233 pairs (Breton: 31077, Cornish: 31086, Welsh: 31070)
Train: 74586, Test: 18647
Dataset({
    features: ['en', 'target', 'language'],
    num_rows: 74586
})


## 3. Load Model & Tokenizer

In [11]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_name = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# mT5 uses SentencePiece – no fast tokenizer available

## 4. Preprocess Function

In [13]:
max_input_length = 128
max_target_length = 128

def preprocess(examples):
    inputs = [f"translate English to {lang}: {en}"
              for lang, en in zip(examples["language"], examples["en"])]
    targets = examples["target"]
    model_inputs = tokenizer(inputs, max_length=max_input_length,
                             truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=max_target_length,
                       truncation=True, padding="max_length").input_ids
    model_inputs["labels"] = labels
    return model_inputs



In [14]:
# Apply preprocessing
print("Tokenising …")
print(raw_datasets)
tokenized_datasets = raw_datasets.map(preprocess,
                               batched=True,
                               remove_columns=raw_datasets["train"].column_names)

print(tokenized_datasets)

Tokenising …
DatasetDict({
    train: Dataset({
        features: ['en', 'target', 'language'],
        num_rows: 74586
    })
    test: Dataset({
        features: ['en', 'target', 'language'],
        num_rows: 18647
    })
})


Map:   0%|          | 0/74586 [00:00<?, ? examples/s]

Map:   0%|          | 0/18647 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 74586
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 18647
    })
})


## 5. Data Collator

In [15]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

## 6. Evaluation Metric (BLEU)

In [16]:
import evaluate
import numpy as np

metric = evaluate.load("sacrebleu")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    return result

Downloading builder script: 0.00B [00:00, ?B/s]

## 7. Training Setup

In [17]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
final_dir = "/content/drive/MyDrive/llm-low-resource-translator/t5_multilingual/final"
training_args = Seq2SeqTrainingArguments(
    output_dir=final_dir,
    eval_strategy="epoch",
    learning_rate=3e-4,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=False,
    logging_steps=100,
    report_to="none"
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

  trainer = Seq2SeqTrainer(


## 8. Train the Model

In [18]:
trainer.train()



Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

## 9. Inference Example

In [None]:
text = "Hello, how are you today? I hope you're doing well."

inputs = tokenizer(text, return_tensors="pt").to(model.device)
generated_ids = model.generate(
    inputs["input_ids"],
    max_length=128,
    num_beams=5,
    early_stopping=True
)
print(generated_ids)
translation = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print("English:", text)
print("German:", translation)

## 10. Save Model (Optional)

In [34]:
import os

os.makedirs(final_dir, exist_ok=True)

trainer.save_model(final_dir)
tokenizer.save_pretrained(final_dir)
print(f"SAVED TO DRIVE: {final_dir}")


('./mt5-en-de-finetuned/tokenizer_config.json',
 './mt5-en-de-finetuned/special_tokens_map.json',
 './mt5-en-de-finetuned/spiece.model',
 './mt5-en-de-finetuned/added_tokens.json')

---

**Done!** You now have a fine-tuned mT5 model for **English → German** translation.

To adapt to **any other language**, just change:
- `wmt16` → another dataset (e.g., `opus100`, `flores200`)
- `source_lang`, `target_lang` keys
- Dataset name in `load_dataset()`

Let me know if you want a version for **low-resource languages** (e.g., Swahili, Quechua)!