# app.py – Streamlit app para predição de GO:MF # ProtBERT / ProtBERT-BFD fine-tuned (melvinalves/FineTune) # ESM-2 base (facebook/esm2_t33_650M_UR50D) import os, re, numpy as np, torch, joblib, streamlit as st from huggingface_hub import login from transformers import AutoTokenizer, AutoModel from keras.models import load_model from goatools.obo_parser import GODag # AUTENTICAÇÃO # login(os.environ["HF_TOKEN"]) # CONFIG # SPACE_ID = "melvinalves/protein_function_prediction" TOP_N = 20 THRESH = 0.37 CHUNK_PB = 512 CHUNK_ESM = 1024 # REPOSITÓRIOS HF FINETUNED_PB = ("melvinalves/FineTune", "fineTunedProtbert") FINETUNED_BFD = ("melvinalves/FineTune", "fineTunedProtbertbfd") BASE_ESM = "facebook/esm2_t33_650M_UR50D" # HELPERS # @st.cache_resource def download_file(path): """Ficheiros pequenos (≤1 GB) guardados no Space.""" from huggingface_hub import hf_hub_download return hf_hub_download(repo_id=SPACE_ID, repo_type="space", filename=path) @st.cache_resource def load_keras(name): """Carrega modelos Keras (MLPs e stacking).""" return load_model(download_file(f"models/{name}"), compile=False) @st.cache_resource def load_hf_encoder(repo_id, subfolder=None, base_tok=None): """Carrega tokenizer + encoder; converte TF-weights → PyTorch on-the-fly.""" if base_tok is None: base_tok = repo_id tok = AutoTokenizer.from_pretrained(base_tok, do_lower_case=False) kwargs = dict(from_tf=True) if subfolder: kwargs["subfolder"] = subfolder mdl = AutoModel.from_pretrained(repo_id, **kwargs) mdl.eval() return tok, mdl def embed_seq(model_ref, seq, chunk): """Devolve embedding CLS médio; corta seq. longa em chunks se preciso.""" if isinstance(model_ref, tuple): # ProtBERT fine-tuned repo_id, subf = model_ref tok, mdl = load_hf_encoder(repo_id, subfolder=subf, base_tok="Rostlab/prot_bert") else: # modelo base ESM-2 tok, mdl = load_hf_encoder(model_ref) parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)] vecs = [] for p in parts: toks = tok(" ".join(p), return_tensors="pt", truncation=False) with torch.no_grad(): out = mdl(**{k: v.to(mdl.device) for k, v in toks.items()}) vecs.append(out.last_hidden_state[:, 0, :].cpu().numpy()) return np.mean(vecs, axis=0) @st.cache_resource def load_go_info(): """Lê GO.obo e devolve {id: (name, definition bruta)}.""" dag = GODag(download_file("data/go.obo"), optional_attrs=["defn"]) return {tid: (term.name, term.defn) for tid, term in dag.items()} GO_INFO = load_go_info() # MODELOS # mlp_pb = load_keras("mlp_protbert.h5") mlp_bfd = load_keras("mlp_protbertbfd.h5") mlp_esm = load_keras("mlp_esm2.h5") stacking = load_keras("ensemble_stack.h5") mlb = joblib.load(download_file("data/mlb_597.pkl")) GO = mlb.classes_ # UI # st.set_page_config(page_title="Predição de Funções Moleculares de Proteínas", page_icon="🧬", layout="centered") st.markdown( """ """, unsafe_allow_html=True ) if os.path.exists("logo.png"): st.image("logo.png", width=180) st.title("Predição de Funções Moleculares de Proteínas (GO:MF)") fasta_input = st.text_area("Insere uma ou mais sequências FASTA:", height=300) predict_clicked = st.button("Prever GO terms") # UTILITÁRIOS # def parse_fasta_multiple(text): """Extrai [(header, seq)] de texto FASTA (bloco inicial sem '>' suportado).""" out = [] for i, blk in enumerate(text.strip().split(">")): if not blk.strip(): continue lines = blk.strip().splitlines() header = lines[0].strip() if i else f"Seq_{i+1}" seq = "".join(lines[1:] if i else lines).replace(" ", "").upper() if seq: out.append((header, seq)) return out def clean_definition(defin: str) -> str: """ Retorna apenas o texto dentro das primeiras aspas. Se não houver aspas, devolve texto antes do primeiro '['. """ if not defin: return "" m = re.search(r'"([^"]+)"', defin) if m: return m.group(1).strip() return defin.split("[", 1)[0].strip() def go_link(go_id, name=""): url = f"https://www.ebi.ac.uk/QuickGO/term/{go_id}" return f"[{go_id} - {name}]({url})" if name else f"[{go_id}]({url})" # MOSTRAR RESULTADOS # def mostrar(header, y_pred): pid = header.split()[0] uniprot = f"https://www.uniprot.org/uniprotkb/{pid}" with st.expander(header, expanded=True): st.markdown( f"""
""", unsafe_allow_html=True ) col1, col2 = st.columns(2) # coluna 1 : ≥ threshold with col1: st.markdown(f"**GO terms com prob ≥ {THRESH}**") hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0] if hits: for go_id in hits: name, defin_raw = GO_INFO.get(go_id, ("- sem nome -", "")) defin = clean_definition(defin_raw) st.markdown(f"- {go_link(go_id, name)}") if defin: st.caption(defin) else: st.code("- nenhum -") # coluna 2 : Top-20 with col2: st.markdown(f"**Top {TOP_N} GO terms mais prováveis**") for rank, idx in enumerate(np.argsort(-y_pred[0])[:TOP_N], 1): go_id = GO[idx] name, _ = GO_INFO.get(go_id, ("", "")) st.markdown(f"{rank}. {go_link(go_id, name)} : {y_pred[0][idx]:.4f}") # INFERÊNCIA # if predict_clicked: for header, seq in parse_fasta_multiple(fasta_input): with st.spinner(f"A processar {header}… (pode demorar alguns minutos)"): emb_pb = embed_seq(FINETUNED_PB, seq, CHUNK_PB) emb_bfd = embed_seq(FINETUNED_BFD, seq, CHUNK_PB) emb_esm = embed_seq(BASE_ESM, seq, CHUNK_ESM) y_pb = mlp_pb.predict(emb_pb) y_bfd = mlp_bfd.predict(emb_bfd) y_esm = mlp_esm.predict(emb_esm)[:, :597] y_ens = stacking.predict(np.concatenate([y_pb, y_bfd, y_esm], axis=1)) mostrar(header, y_ens) # LISTA COMPLETA COM BARRA DE PESQUISA # with st.expander("Mostrar lista completa dos 597 GO terms possíveis", expanded=False): search_term = st.text_input("Filtra GO term ou nome:") # aplicar filtro filtered_go_terms = [] for go_id in GO: name, _ = GO_INFO.get(go_id, ("", "")) if search_term.strip().lower() in go_id.lower() or search_term.strip().lower() in name.lower(): filtered_go_terms.append((go_id, name)) # mostrar por colunas if filtered_go_terms: cols = st.columns(3) for i, (go_id, name) in enumerate(filtered_go_terms): cols[i % 3].markdown(f"- {go_link(go_id, name)}") else: st.info("Nenhum GO term corresponde ao filtro inserido.")