# ---- BOOTSTRAP: stable cache to /data, minimal downloads ---- import os, subprocess from huggingface_hub import snapshot_download os.makedirs("/data/.cache/huggingface/hub", exist_ok=True) os.makedirs("/data/snapshots", exist_ok=True) os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache") os.environ.setdefault("HF_HOME", "/data/.cache/huggingface") os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub") # Optional: keep pip cache small try: subprocess.run(["pip", "cache", "purge"], check=False) except Exception: pass # ---- END BOOTSTRAP ---- import gradio as gr import sys import pandas as pd import torch from transformers import AutoTokenizer, AutoModel, AutoConfig # Pin via Space → Settings → Variables if you want (helps avoid repeated downloads) MODEL_ID = "ChatterjeeLab/MetaLATTE" TOKENIZER_ID = "facebook/esm2_t33_650M_UR50D" MODEL_REV = os.getenv("MODEL_REV", "ad1716045c768b30ce87eb6b3963d58578fa5401") # from your screenshot TOKENIZER_REV = os.getenv("TOKENIZER_REV", "") def snapshot_to(local_name, repo_id, revision, allow_patterns): local_dir = f"/data/snapshots/{local_name}" os.makedirs(local_dir, exist_ok=True) return snapshot_download( repo_id=repo_id, revision=revision if revision else None, allow_patterns=allow_patterns, local_dir=local_dir, # new hub ignores symlink flag; this is enough ) # Download tokenizer (unchanged) esm_local = snapshot_to( "esm2_tokenizer", "facebook/esm2_t33_650M_UR50D", os.getenv("TOKENIZER_REV",""), allow_patterns=[ "tokenizer.json","tokenizer_config.json","vocab.*","merges.*", "special_tokens_map.json","*.model","tokenizer*.txt","spiece.*","*.tiktoken","config.json" ], ) # Download MetaLATTE: include both main and stage1 in case your loader uses them metalatte_local = snapshot_to( "metalatte_model", "ChatterjeeLab/MetaLATTE", os.getenv("MODEL_REV", "ad1716045c768b30ce87eb6b3963d58578fa5401"), allow_patterns=[ "config.json", "pytorch_model.bin", "model/pytorch_model.bin", "model.safetensors", "model/model.safetensors", "stage1_model.bin", "model/stage1_model.bin", ], ) import os, sys, torch, pandas as pd, gradio as gr from transformers import AutoTokenizer, AutoModel, AutoConfig # --- your local package --- sys.path.insert(0, ".") from configuration import MetaLATTEConfig from modeling_metalatte import MultitaskProteinModel # Register types BEFORE loading AutoConfig.register("metalatte", MetaLATTEConfig) AutoModel.register(MetaLATTEConfig, MultitaskProteinModel) # ---- Monkey-patch: make your from_pretrained support local dirs ---- def _local_aware_from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): import os from transformers import AutoConfig from safetensors.torch import load_file as load_safetensors # If a local directory is passed, load directly from disk if os.path.isdir(pretrained_model_name_or_path): config = kwargs.get("config", None) if config is None: try: # works because we registered the type above config = AutoConfig.from_pretrained(pretrained_model_name_or_path, local_files_only=True) except Exception: # fallback in case AutoConfig isn't enough config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path, local_files_only=True) model = cls(config) # Look for weights in common locations; prefer .safetensors > pytorch .bin > stage1 candidates = [ "model/model.safetensors", "model.safetensors", "model/pytorch_model.bin", "pytorch_model.bin", "model/stage1_model.bin", "stage1_model.bin", ] weight_path = next((os.path.join(pretrained_model_name_or_path, c) for c in candidates if os.path.exists(os.path.join(pretrained_model_name_or_path, c))), None) if weight_path is None: raise FileNotFoundError(f"No weights found in {pretrained_model_name_or_path}; tried {candidates}") # Load state dict (STRICT to catch any mismatch instead of silently skipping) if weight_path.endswith(".safetensors"): state = load_safetensors(weight_path, device="cpu") else: state = torch.load(weight_path, map_location="cpu") missing, unexpected = model.load_state_dict(state, strict=True) if missing or unexpected: raise RuntimeError(f"State dict mismatch. missing={missing[:5]}... unexpected={unexpected[:5]}...") model.eval() return model # Otherwise, fall back to the original remote/HF logic (your class already had) # NOTE: We call the original classmethod via the unbound function on the class return _orig_from_pretrained(pretrained_model_name_or_path, *args, **kwargs) # Swap in the monkey patch (but keep a handle to the original) _orig_from_pretrained = MultitaskProteinModel.from_pretrained.__func__ MultitaskProteinModel.from_pretrained = classmethod(_local_aware_from_pretrained) # -------------------------------------------------------------------- # Load config and model exactly like before (now it will use the local-aware loader) config = AutoConfig.from_pretrained(metalatte_local, local_files_only=True) tokenizer = AutoTokenizer.from_pretrained(esm_local, local_files_only=True) model = AutoModel.from_pretrained(metalatte_local, config=config, local_files_only=True) model.eval() @torch.inference_mode() def predict(sequence): inputs = tokenizer(sequence, return_tensors="pt") raw_probs, predictions = model.predict(**inputs) id2label = config.id2label row = {id2label[i]: ('✓' if int(pred) == 1 else '') for i, pred in enumerate(predictions[0])} return pd.DataFrame([row]) iface = gr.Interface( fn=predict, inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence here..."), outputs=gr.Dataframe(headers=list(config.id2label.values())), title="MetaLATTE: Metal Binding Prediction", description="Enter a protein sequence to predict its metal binding properties." ) iface.launch()