# app.py import os import time import traceback from torch.nn import functional as F import math from torch import nn as nn from typing import List, Tuple from transformers import AutoTokenizer import gradio as gr from huggingface_hub import hf_hub_download import torch from typing import Tuple, Optional class LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps # 初始化可学习的缩放和平移参数 self.gamma = nn.Parameter(torch.ones(hidden_size)) self.beta = nn.Parameter(torch.zeros(hidden_size)) def forward(self, x): mean = x.mean(dim=-1, keepdim=True) variance = x.var(dim=-1, keepdim=True, unbiased=False) x_normalized = (x - mean) / torch.sqrt(variance + self.eps) output = ( self.gamma * x_normalized + self.beta ) # 形状: (batch_size, seq_len, hidden_size) return output class RopePositionEmbedding(nn.Module): def __init__(self, dim: int, base=10000): super().__init__() inv_freq = 1 / base ** (torch.arange(0, dim, 2).float() / dim) inv_freq = inv_freq.unsqueeze(0) self.register_buffer("inv_freq", inv_freq) def rotate_half(self, x: torch.Tensor): odd = x[..., 1::2] even = x[..., 0::2] return torch.stack((-odd, even), dim=-1).flatten(-2) def apply_rope(self, x: torch.Tensor): x_len = x.shape[2] t = torch.arange(0, x_len, device=x.device, dtype=torch.float32).unsqueeze(1) freq = t * self.inv_freq freq = torch.repeat_interleave(freq, repeats=2, dim=-1)[None, None, :, :] xf = x.float() y = xf * freq.cos() + self.rotate_half(xf) * freq.sin() return y.to(x.dtype) def forward(self, q: torch.Tensor, k: torch.Tensor): return self.apply_rope(q), self.apply_rope(k) class MultiHeadAttention(nn.Module): def __init__( self, embed_dim: int, num_heads: int, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, ): super().__init__() assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除" self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads assert self.head_dim % 2 == 0 self.q_proj = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) self.proj_drop = nn.Dropout(proj_drop) self.position_embed = RopePositionEmbedding(self.head_dim) def _shape(self, x: torch.Tensor, B: int, T: int): # [B, T, C] -> [B, H, T, Dh] return x.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor | None = None, ): B, T_q, C = q.shape T_k = k.shape[1] # 线性映射到 Q/K/V Q = self.q_proj(q) # [B, T_q, C] K = self.k_proj(k) # [B, T_k, C] V = self.v_proj(v) # [B, T_k, C] # 拆头 Q = self._shape(Q, B, T_q) # [B, H, T_q, Dh] K = self._shape(K, B, T_k) # [B, H, T_k, Dh] V = self._shape(V, B, T_k) # [B, H, T_k, Dh] # 加位置编码 Q, K = self.position_embed(Q, K) # logits: [B, H, T_q, T_k] logits = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.head_dim) if attn_mask is not None: if attn_mask.dim() == 2: attn_mask = attn_mask[None, None, :, :] logits = logits + attn_mask attn = F.softmax(logits.float(), dim=-1).to(Q.dtype) attn = self.attn_drop(attn) # 加权求和 out = torch.matmul(attn, V) # [B, H, T_q, Dh] out = out.transpose(1, 2).contiguous().view(B, T_q, C) # [B, T_q, C] out = self.out_proj(out) out = self.proj_drop(out) return out class FFN(nn.Module): def __init__(self, dim, drop_rate): super().__init__() self.ffn = nn.Sequential( nn.Linear(dim, dim * 4), nn.SiLU(), nn.Linear(dim * 4, dim), nn.Dropout(drop_rate), ) def forward(self, x): return self.ffn(x) class EncoderLayer(nn.Module): def __init__(self, embed_dim, num_heads, qkv_bias, attn_drop, proj_drop, ffn_drop): super().__init__() self.norm_1 = LayerNorm(embed_dim) self.attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop, ) self.norm_2 = LayerNorm(embed_dim) self.ffn = FFN(embed_dim, drop_rate=ffn_drop) def forward(self, x, encoder_mask): residual = x x = self.norm_1(x) x = residual + self.attn(x, x, x, attn_mask=encoder_mask) x = x + self.ffn(self.norm_2(x)) return x class DecoderLayer(nn.Module): def __init__(self, embed_dim, num_heads, qkv_bias, attn_drop, proj_drop, ffn_drop): super().__init__() self.norm_1 = LayerNorm(embed_dim) self.self_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop, ) self.norm_2 = LayerNorm(embed_dim) self.cross_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop, ) self.norm_3 = LayerNorm(embed_dim) self.ffn = FFN(embed_dim, drop_rate=ffn_drop) def forward(self, x, e_output, encoder_mask, decoder_mask): residual = x x = self.norm_1(x) x = residual + self.self_attn(x, x, x, attn_mask=decoder_mask) x = x + self.cross_attn( self.norm_2(x), e_output, e_output, attn_mask=encoder_mask ) x = x + self.ffn(self.norm_3(x)) return x class TranslateModel(nn.Module): def __init__( self, vocab_size, encoder_layers, decoder_layers, embed_dim, num_heads, qkv_bias, attn_drop, proj_drop, ffn_drop, ): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.encoder = nn.ModuleList( [ EncoderLayer( embed_dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop, ffn_drop=ffn_drop, ) for _ in range(encoder_layers) ] ) self.decoder = nn.ModuleList( [ DecoderLayer( embed_dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop, ffn_drop=ffn_drop, ) for _ in range(decoder_layers) ] ) self.head = nn.Linear(embed_dim, vocab_size, bias=False) self.head.weight = self.embedding.weight def forward(self, src, tgt_input, src_mask, tgt_mask): src_hidden = self.embedding(src) # [B, Ts, C] tgt_hidden = self.embedding(tgt_input) # [B, Tt, C] for layer in self.encoder: src_hidden = layer(src_hidden, src_mask) src_encoder = src_hidden for layer in self.decoder: tgt_hidden = layer(tgt_hidden, src_encoder, src_mask, tgt_mask) logits = self.head(tgt_hidden) return logits DEVICE = "cuda" if torch.cuda.is_available() else "cpu" HF_REPO_ID = "caixiaoshun/translate-zh2en" HF_FILENAME = "weights/default.pt" # 全局缓存 MODEL: Optional[nn.Module] = None TOKENIZER = None WEIGHT_LOCAL_PATH: Optional[str] = None def build_tokenizer(): return AutoTokenizer.from_pretrained("google/mt5-small", use_fast=False) def _load_checkpoint_local_path() -> str: """使用 huggingface_hub 下载权重到本地缓存并返回路径。""" path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME) return path def build_module(weight_local_path: str) -> nn.Module: model: nn.Module = TranslateModel( vocab_size=250100, encoder_layers=6, decoder_layers=6, embed_dim=512, num_heads=8, qkv_bias=True, attn_drop=0.02, proj_drop=0.02, ffn_drop=0.02, ) # 兼容不同checkpoint结构 ckpt = torch.load(weight_local_path, map_location="cpu", weights_only=False) model.load_state_dict(ckpt, strict=True) model.eval().to(DEVICE) return model # ---------------------------- # 掩码 & Greedy 解码 # ---------------------------- def build_encoder_mask(src_ids: torch.Tensor, pad_id: int) -> torch.Tensor: device = src_ids.device neg_inf = -1e8 visible = (src_ids != pad_id)[:, None, None, :] return torch.where(visible, torch.tensor(0.0, device=device), torch.tensor(neg_inf, device=device)) def build_decoder_mask(tgt_in_ids: torch.Tensor, pad_id: int) -> torch.Tensor: device = tgt_in_ids.device B, T = tgt_in_ids.shape neg_inf = -1e8 causal = torch.zeros((T, T), device=device) causal = causal.masked_fill(torch.triu(torch.ones((T, T), dtype=torch.bool, device=device), diagonal=1), neg_inf) causal = causal[None, None, :, :] # [1,1,T,T] key_visible = (tgt_in_ids != pad_id)[:, None, None, :] # [B,1,1,T] key_mask = torch.where(key_visible, torch.tensor(0.0, device=device), torch.tensor(neg_inf, device=device)) query_visible = (tgt_in_ids != pad_id)[:, None, :, None] # [B,1,T,1] query_mask = torch.where(query_visible, torch.tensor(0.0, device=device), torch.tensor(neg_inf, device=device)) return causal + key_mask + query_mask @torch.no_grad() def greedy_decode(model, src_ids: torch.Tensor, pad_id: int, eos_id: int, max_len: int = 128) -> torch.Tensor: model.eval() device = src_ids.device enc_mask = build_encoder_mask(src_ids, pad_id=pad_id) # [B,1,1,S] src_hidden = model.embedding(src_ids) # [B,S,C] for layer in model.encoder: src_hidden = layer(src_hidden, enc_mask) memory = src_hidden out = torch.full((src_ids.size(0), 1), pad_id, dtype=torch.long, device=device) for _ in range(max_len - 1): dec_mask = build_decoder_mask(out, pad_id=pad_id) # [B,1,T,T] tgt_hidden = model.embedding(out) # [B,T,C] for layer in model.decoder: tgt_hidden = layer(tgt_hidden, memory, enc_mask, dec_mask) logits = model.head(tgt_hidden) # [B,T,V] next_token = logits[:, -1].argmax(-1, keepdim=True) out = torch.cat([out, next_token], dim=1) if (next_token == eos_id).all(): break return out def translate_one(model, tokenizer, zh_text: str, max_src_len: int = 128, max_tgt_len: int = 128) -> str: pad_id = tokenizer.pad_token_id eos_id = tokenizer.eos_token_id assert pad_id is not None and eos_id is not None, "tokenizer 缺少 pad/eos" src_ids = tokenizer.encode(zh_text, add_special_tokens=False) src_ids = src_ids[: max_src_len - 1] + [eos_id] src = torch.tensor(src_ids, dtype=torch.long, device=DEVICE)[None, :] # [1,S] out = greedy_decode(model, src, pad_id=pad_id, eos_id=eos_id, max_len=max_tgt_len) pred = out[0, 1:].tolist() if eos_id in pred: pred = pred[: pred.index(eos_id) + 1] return tokenizer.decode(pred, skip_special_tokens=True) # ---------------------------- # 懒加载(首次翻译时自动下载并加载) # ---------------------------- def ensure_loaded() -> Tuple[str, str]: global MODEL, TOKENIZER, WEIGHT_LOCAL_PATH if MODEL is not None and TOKENIZER is not None: return "已就绪", f"设备:{DEVICE}" t0 = time.perf_counter() try: if WEIGHT_LOCAL_PATH is None: WEIGHT_LOCAL_PATH = _load_checkpoint_local_path() TOKENIZER = build_tokenizer() MODEL = build_module(WEIGHT_LOCAL_PATH) dt = time.perf_counter() - t0 return f"✅ 模型已加载({dt:.2f}s)", f"设备:{DEVICE}" except Exception as e: traceback.print_exc() return f"❌ 加载失败:{repr(e)}\n\n{traceback.format_exc()}", "" # ---------------------------- # 推理(单句) # ---------------------------- @torch.no_grad() def translate_single(zh_text: str) -> Tuple[str, str]: if not zh_text or not zh_text.strip(): return "", "请输入中文句子。" status, device_info = ensure_loaded() if status.startswith("❌"): return "", status t0 = time.perf_counter() try: out = translate_one(MODEL, TOKENIZER, zh_text.strip(), max_src_len=128, max_tgt_len=128) dt = time.perf_counter() - t0 info = f"{device_info} | 用时:{dt:.2f}s" return out, info except Exception as e: traceback.print_exc() return f"[翻译失败] {repr(e)}", "" # ---------------------------- # 极简 Gradio UI(单句输入) # ---------------------------- with gr.Blocks(title="Translate-zh2en", theme=gr.themes.Soft()) as demo: zh_in = gr.Textbox(lines=1, label="中文输入", placeholder="祝你生日快乐!希望你喜欢这份礼物。") run_btn = gr.Button("翻译 ✅", variant="primary") en_out = gr.Textbox(lines=3, label="英文输出") info_out = gr.Markdown() examples = gr.Examples( examples=[ ["祝你生日快乐!希望你喜欢这份礼物。"], ["我们将于明天上午九点开会,请准时参加。"], ["根据《中国疼痛医学发展报告(2020)》的数据显示,我国目前有超过3亿人正在经受慢性疼痛的困扰,慢性疼痛已经成为仅次于心脑血管疾病和肿瘤的第三大健康问题。"] ], inputs=[zh_in], label="示例", ) run_btn.click(fn=translate_single, inputs=[zh_in], outputs=[en_out, info_out]) zh_in.submit(fn=translate_single, inputs=[zh_in], outputs=[en_out, info_out]) # 回车即翻译 demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=7860, show_error=True)