Jendersen commited on
Commit
c09fd61
·
verified ·
1 Parent(s): 8a37fa0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -0
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+ #!/usr/bin/env python
3
+ import os
4
+ import json
5
+ import string
6
+ import pandas as pd
7
+ import evaluate
8
+ import numpy as np
9
+ from datasets import load_dataset, DatasetDict
10
+ from transformers import (
11
+ AutoTokenizer, AutoModelForSeq2SeqLM,
12
+ Seq2SeqTrainingArguments, Seq2SeqTrainer,
13
+ DataCollatorForSeq2Seq
14
+ )
15
+ from huggingface_hub import login
16
+
17
+ # -------------------------------------------------
18
+ # 0. HF login (set HF_TOKEN in Secrets)
19
+ # -------------------------------------------------
20
+ login() # reads HF_TOKEN from environment
21
+
22
+ # -------------------------------------------------
23
+ # 1. Load dataset from Hub
24
+ # -------------------------------------------------
25
+ dataset = load_dataset("your-username/celtic-parallel")
26
+ data = json.loads(dataset["train"][0]["parallel_corpus.json"]) # dummy – we load the file directly
27
+ # Actually we load the JSON file that was uploaded:
28
+ raw = load_dataset("your-username/celtic-parallel", data_files="parallel_corpus.json")["train"]
29
+ df = pd.DataFrame(json.loads(raw[0]["parallel_corpus.json"]))
30
+
31
+ # -------------------------------------------------
32
+ # 2. Build English → {br, abk, cy}
33
+ # -------------------------------------------------
34
+ def is_valid(t):
35
+ return bool(t and t.strip() and t.strip() not in string.punctuation)
36
+
37
+ br = df[df.apply(lambda r: is_valid(r["niv_text"]) and is_valid(r["koad21_text"]), axis=1)][["niv_text","koad21_text"]].rename(columns={"niv_text":"en","koad21_text":"target"})
38
+ br["language"] = "br"
39
+
40
+ abk = df[df.apply(lambda r: is_valid(r["niv_text"]) and is_valid(r["abk_text"]), axis=1)][["niv_text","abk_text"]].rename(columns={"niv_text":"en","abk_text":"target"})
41
+ abk["language"] = "abk"
42
+
43
+ cy = df[df.apply(lambda r: is_valid(r["niv_text"]) and is_valid(r["bcnda_text"]), axis=1)][["niv_text","bcnda_text"]].rename(columns={"niv_text":"en","bcnda_text":"target"})
44
+ cy["language"] = "cy"
45
+
46
+ combined = pd.concat([br, abk, cy], ignore_index=True)
47
+ print(f"Total examples: {len(combined)} (br:{len(br)}, abk:{len(abk)}, cy:{len(cy)})")
48
+
49
+ # -------------------------------------------------
50
+ # 3. Train / test split
51
+ # -------------------------------------------------
52
+ from datasets import Dataset
53
+ ds = Dataset.from_pandas(combined).train_test_split(test_size=0.2, seed=42)
54
+ raw_datasets = DatasetDict({"train": ds["train"], "test": ds["test"]})
55
+
56
+ # -------------------------------------------------
57
+ # 4. Tokenizer & Model
58
+ # -------------------------------------------------
59
+ model_name = "google/mt5-small"
60
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
61
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
62
+
63
+ # -------------------------------------------------
64
+ # 5. Pre-process
65
+ # -------------------------------------------------
66
+ MAX_LEN = 96
67
+
68
+ def preprocess(examples):
69
+ inputs = [f"translate English to {lang}: {en}"
70
+ for lang, en in zip(examples["language"], examples["en"])]
71
+ targets = examples["target"]
72
+ model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True, padding="max_length")
73
+ labels = tokenizer(targets, max_length=MAX_LEN, truncation=True, padding="max_length").input_ids
74
+ model_inputs["labels"] = labels
75
+ return model_inputs
76
+
77
+ tokenized = raw_datasets.map(preprocess, batched=True, remove_columns=raw_datasets["train"].column_names)
78
+
79
+ # -------------------------------------------------
80
+ # 6. Data collator & metric
81
+ # -------------------------------------------------
82
+ data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
83
+
84
+ metric = evaluate.load("sacrebleu")
85
+
86
+ def compute_metrics(eval_preds):
87
+ preds, labels = eval_preds
88
+ if isinstance(preds, tuple):
89
+ preds = preds[0]
90
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
91
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
92
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
93
+ decoded_preds = [p.strip() for p in decoded_preds]
94
+ decoded_labels = [[l.strip()] for l in decoded_labels]
95
+ result = metric.compute(predictions=decoded_preds, references=decoded_labels)
96
+ return {"bleu": result["score"]}
97
+
98
+ # -------------------------------------------------
99
+ # 7. Training args
100
+ # -------------------------------------------------
101
+ training_args = Seq2SeqTrainingArguments(
102
+ output_dir="mt5-celtic-finetuned",
103
+ eval_strategy="epoch",
104
+ save_strategy="epoch",
105
+ learning_rate=3e-4,
106
+ per_device_train_batch_size=16,
107
+ per_device_eval_batch_size=16,
108
+ weight_decay=0.01,
109
+ num_train_epochs=3,
110
+ predict_with_generate=True,
111
+ fp16=True, # GPU
112
+ bf16=True, # TPU (auto-enabled if on TPU)
113
+ logging_steps=100,
114
+ report_to="wandb", # optional
115
+ push_to_hub=True,
116
+ hub_model_id="your-username/mt5-celtic-en-br-abk-cy",
117
+ hub_strategy="end",
118
+ load_best_model_at_end=True,
119
+ metric_for_best_model="bleu",
120
+ )
121
+
122
+ # -------------------------------------------------
123
+ # 8. Trainer
124
+ # -------------------------------------------------
125
+ trainer = Seq2SeqTrainer(
126
+ model=model,
127
+ args=training_args,
128
+ train_dataset=tokenized["train"],
129
+ eval_dataset=tokenized["test"],
130
+ tokenizer=tokenizer,
131
+ data_collator=data_collator,
132
+ compute_metrics=compute_metrics,
133
+ )
134
+
135
+ trainer.train()
136
+
137
+ # -------------------------------------------------
138
+ # 9. Final push
139
+ # -------------------------------------------------
140
+ trainer.push_to_hub("mt5-celtic-en-br-abk-cy")
141
+ print("Model pushed to Hub!")