# ------------------------------------------------------------------------------------------------- # app.py – Streamlit app para predição de GO:MF # • ProtBERT / ProtBERT-BFD fine-tuned (melvinalves/FineTune) # • ESM-2 base (facebook/esm2_t33_650M_UR50D) # ------------------------------------------------------------------------------------------------- import os, re, numpy as np, torch, joblib, streamlit as st from huggingface_hub import login from transformers import AutoTokenizer, AutoModel from keras.models import load_model from goatools.obo_parser import GODag # ——————————————————— AUTENTICAÇÃO ——————————————————— # login(os.environ["HF_TOKEN"]) # ——————————————————— CONFIG ——————————————————— # SPACE_ID = "melvinalves/protein_function_prediction" TOP_N = 20 # mostra agora top-20 THRESH = 0.37 CHUNK_PB = 512 # janela ProtBERT / ProtBERT-BFD CHUNK_ESM = 1024 # janela ESM-2 # repositórios HF FINETUNED_PB = ("melvinalves/FineTune", "fineTunedProtbert") FINETUNED_BFD = ("melvinalves/FineTune", "fineTunedProtbertbfd") BASE_ESM = "facebook/esm2_t33_650M_UR50D" # ——————————————————— HELPERS ——————————————————— # @st.cache_resource def download_file(path): """Ficheiros pequenos (≤1 GB) guardados no Space.""" from huggingface_hub import hf_hub_download return hf_hub_download(repo_id=SPACE_ID, repo_type="space", filename=path) @st.cache_resource def load_keras(name): """Carrega modelos Keras (MLPs e stacking).""" return load_model(download_file(f"models/{name}"), compile=False) # ---------- carregar tokenizer + encoder ---------- @st.cache_resource def load_hf_encoder(repo_id, subfolder=None, base_tok=None): """ • repo_id : repositório HF ou caminho local • subfolder : subpasta onde vivem pesos/config (None se não houver) • base_tok : repo para o tokenizer (None => usa repo_id) Converte tf_model.h5 → PyTorch on-the-fly (from_tf=True). """ if base_tok is None: base_tok = repo_id tok = AutoTokenizer.from_pretrained(base_tok, do_lower_case=False) kwargs = dict(from_tf=True) if subfolder: kwargs["subfolder"] = subfolder mdl = AutoModel.from_pretrained(repo_id, **kwargs) mdl.eval() return tok, mdl # ---------- extrair embedding ---------- def embed_seq(model_ref, seq, chunk): """ • model_ref = string (modelo base) OU tuple(repo_id, subfolder) (modelo fine-tuned) Retorna embedding CLS médio (caso a sequência seja dividida em chunks). """ if isinstance(model_ref, tuple): # ProtBERT / ProtBERT-BFD fine-tuned repo_id, subf = model_ref tok, mdl = load_hf_encoder(repo_id, subfolder=subf, base_tok="Rostlab/prot_bert") else: # modelo base (ESM-2) tok, mdl = load_hf_encoder(model_ref) parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)] vecs = [] for p in parts: toks = tok(" ".join(p), return_tensors="pt", truncation=False) with torch.no_grad(): out = mdl(**{k: v.to(mdl.device) for k, v in toks.items()}) vecs.append(out.last_hidden_state[:, 0, :].cpu().numpy()) return np.mean(vecs, axis=0) @st.cache_resource def load_go_info(): """Lê GO.obo e devolve dicionário id → (name, definition).""" obo_path = download_file("data/go.obo") dag = GODag(obo_path, optional_attrs=["defn"]) return {tid: (term.name, term.defn) for tid, term in dag.items()} GO_INFO = load_go_info() # ——————————————————— CARGA MODELOS ——————————————————— # mlp_pb = load_keras("mlp_protbert.h5") mlp_bfd = load_keras("mlp_protbertbfd.h5") mlp_esm = load_keras("mlp_esm2.h5") stacking = load_keras("ensemble_stack.h5") mlb = joblib.load(download_file("data/mlb_597.pkl")) GO = mlb.classes_ # ——————————————————— UI ——————————————————— # # --- aspecto geral st.set_page_config(page_title="Predição de Funções Moleculares de Proteínas", page_icon="🧬", layout="centered") # fundo branco + pequenos ajustes de margem/padding st.markdown(""" """, unsafe_allow_html=True) # logo (coloca um ficheiro logo.png na pasta raiz do Space) LOGO_PATH = "logo.png" if os.path.exists(LOGO_PATH): st.image(LOGO_PATH, width=180) st.title("Predição de Funções Moleculares de Proteínas (GO:MF)") fasta_input = st.text_area("Insere uma ou mais sequências FASTA:", height=300) predict_clicked = st.button("Prever GO terms") # ——————————————————— PARSE DE MÚLTIPLAS SEQUÊNCIAS ——————————————————— # def parse_fasta_multiple(fasta_str): """ Devolve lista de (header, seq) a partir de texto FASTA possivelmente múltiplo. Suporta bloco inicial sem '>'. """ entries, parsed = fasta_str.strip().split(">"), [] for i, entry in enumerate(entries): if not entry.strip(): continue lines = entry.strip().splitlines() if i > 0: # bloco típico FASTA header = lines[0].strip() seq = "".join(lines[1:]).replace(" ", "").upper() else: # sequência sem '>' header = f"Seq_{i+1}" seq = "".join(lines).replace(" ", "").upper() if seq: parsed.append((header, seq)) return parsed # ——————————————————— FUNÇÕES AUXILIARES DE LAYOUT ——————————————————— # def go_link(go_id, name=""): """Cria link para página do GO term (QuickGO).""" url = f"https://www.ebi.ac.uk/QuickGO/term/{go_id}" label = f"{go_id} — {name}" if name else go_id return f"[{label}]({url})" def prot_link(header): """Tenta gerar link para UniProt usando o primeiro token do header.""" pid = header.split()[0] url = f"https://www.uniprot.org/uniprotkb/{pid}" return f"[{header}]({url})" # ——————————————————— FUNÇÃO PRINCIPAL DE RESULTADOS ——————————————————— # def mostrar(tag, y_pred): """Mostra resultados em duas colunas dentro de um expander.""" with st.expander(tag, expanded=True): col1, col2 = st.columns(2) # ——— coluna 1 : termos acima do threshold with col1: st.markdown(f"**GO terms com prob ≥ {THRESH}**") hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0] if hits: for go_id in hits: name, defin = GO_INFO.get(go_id, ("— sem nome —", "")) defin = re.sub(r'^\\s*"?(.+?)"?\\s*(\\[[^\\]]*\\])?\\s*$', r'\\1', defin or "") st.markdown(f"- {go_link(go_id, name)} ") if defin: st.caption(defin) else: st.code("— nenhum —") # ——— coluna 2 : top-N mais prováveis with col2: st.markdown(f"**Top {TOP_N} GO terms mais prováveis**") for rank, idx in enumerate(np.argsort(-y_pred[0])[:TOP_N], start=1): go_id = GO[idx] name, _ = GO_INFO.get(go_id, ("", "")) st.markdown(f"{rank}. {go_link(go_id, name)} : {y_pred[0][idx]:.4f}") # ——————————————————— INFERÊNCIA ——————————————————— # if predict_clicked: parsed_seqs = parse_fasta_multiple(fasta_input) if not parsed_seqs: st.warning("Não foi possível encontrar nenhuma sequência válida.") st.stop() for header, seq in parsed_seqs: with st.spinner(f"A processar {header}… (pode demorar alguns minutos)"): # ———————————— EMBEDDINGS ———————————— # emb_pb = embed_seq(FINETUNED_PB, seq, CHUNK_PB) emb_bfd = embed_seq(FINETUNED_BFD, seq, CHUNK_PB) emb_esm = embed_seq(BASE_ESM, seq, CHUNK_ESM) # ———————————— PREDIÇÕES MLPs ———————————— # y_pb = mlp_pb.predict(emb_pb) y_bfd = mlp_bfd.predict(emb_bfd) y_esm = mlp_esm.predict(emb_esm)[:, :597] # alinhar nº de termos # ———————————— STACKING ———————————— # X = np.concatenate([y_pb, y_bfd, y_esm], axis=1) y_ens = stacking.predict(X) # header como link para UniProt mostrar(prot_link(header), y_ens) # ——————————————————— LISTA COMPLETA DE TERMOS SUPORTADOS ——————————————————— # with st.expander("Mostrar lista completa dos 597 GO terms possíveis", expanded=False): cols = st.columns(3) for i, go_id in enumerate(GO): name, _ = GO_INFO.get(go_id, ("", "")) cols[i % 3].markdown(f"- {go_link(go_id, name)}")