import os import os import json from typing import Any, Dict, Optional, List import joblib from transformers import PreTrainedModel from .configuration_knn import KNNConfig class KNNModel(PreTrainedModel): """ A tiny wrapper so an sklearn KNN (joblib) can be saved/loaded with the transformers save_pretrained / from_pretrained pattern. Notes: - We persist the sklearn object as `model.joblib` inside the folder. - Loading from the Hub via `transformers` will require `trust_remote_code=True` or using this module locally. """ config_class = KNNConfig base_model_prefix = "knn" def __init__(self, config: KNNConfig, model: Optional[Any] = None, models: Optional[List] = None): super().__init__(config) # self.knn is the actual sklearn KNN object (e.g., sklearn.neighbors.KNeighborsClassifier) # for single models self.knn = model # self.models is a list of sklearn KNN objects for ensemble models self.models = models or [] self.is_ensemble = config.is_ensemble or len(self.models) > 1 def forward(self, X, **kwargs): """Return predictions for an input array-like X. For ensemble models, uses the first model's predictions. (You can implement voting/averaging logic here if desired) This is intentionally simple; you can adapt to return ModelOutput structured objects if desired. """ if self.is_ensemble and self.models: # Use first model for now; could implement ensemble voting return self.models[0].predict(X) elif self.knn is not None: return self.knn.predict(X) else: raise ValueError("Model not loaded. Call from_pretrained or load a joblib model first.") def save_pretrained(self, save_directory: str, **kwargs) -> None: """ Save only the config and the sklearn object(s). We intentionally avoid calling the parent `save_pretrained` because the transformers implementation expects a PyTorch model (and tries to infer a `dtype` from model tensors), which fails for non-torch objects and raises the IndexError seen in CI/when running locally. Instead we use the config's `save_pretrained` method and persist the sklearn object as `model.joblib` (or multiple files for ensembles). """ os.makedirs(save_directory, exist_ok=True) # save config.json (PretrainedConfig handles this) self.config.save_pretrained(save_directory) # persist sklearn object(s) as joblib if self.is_ensemble and self.models: # Save each ensemble member with its original filename for i, (member_name, model_obj) in enumerate(zip(self.config.ensemble_members, self.models)): out_path = os.path.join(save_directory, member_name) os.makedirs(os.path.dirname(out_path), exist_ok=True) joblib.dump(model_obj, out_path) elif self.knn is not None: out_path = os.path.join(save_directory, "model.joblib") joblib.dump(self.knn, out_path) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): """ Load a KNN model with optional variant selection. Supports two modes: 1. Direct loading: loads model.joblib from the specified path/repo 2. Variant selection: specify parameters to auto-select a model variant For ensemble models (7T-21T, Synthetic), automatically loads all sub-models. Args: pretrained_model_name_or_path: Local path or HF Hub repo ID data_source: Optional. One of: "7T", "21T", "7T-21T", "Synthetic" k_neighbors: Optional. 1 or 3 metric: Optional. "euclidean" or "manhattan" training_version: Optional. For single models only, ignored for ensembles variant: Optional. Direct variant name (e.g., "knn_21T_k1_euclidean") Examples: # Load default best model model = KNNModel.from_pretrained("SaeedLab/dom-formula-assignment-using-knn") # Load specific variant by parameters model = KNNModel.from_pretrained( "SaeedLab/dom-formula-assignment-using-knn", data_source="7T-21T", # This is an ensemble! k_neighbors=1, metric="euclidean" ) # Load by variant name model = KNNModel.from_pretrained( "SaeedLab/dom-formula-assignment-using-knn", variant="knn_21T_k3_manhattan" ) """ # Extract variant selection parameters data_source = kwargs.pop("data_source", None) k_neighbors = kwargs.pop("k_neighbors", None) metric = kwargs.pop("metric", None) training_version = kwargs.pop("training_version", None) variant = kwargs.pop("variant", None) # Determine if this is an ensemble model is_ensemble = data_source in ["7T-21T", "Synthetic"] if data_source else False # load config using parent machinery (handles repo id or local path) config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) # Update config with variant info if provided if k_neighbors is not None: config.n_neighbors = k_neighbors if metric is not None: config.metric = metric if data_source is not None: config.data_source = data_source if training_version is not None: config.training_version = training_version if is_ensemble: # Load ensemble model (multiple joblib files) model_filenames = cls._resolve_ensemble_filenames( pretrained_model_name_or_path, variant=variant, data_source=data_source, k_neighbors=k_neighbors, metric=metric, ) config.is_ensemble = True config.ensemble_members = model_filenames models = [] for model_filename in model_filenames: model_file = os.path.join(pretrained_model_name_or_path, model_filename) if os.path.exists(model_file): knn = joblib.load(model_file) else: # try to download from hub try: from huggingface_hub import hf_hub_download repo_id = pretrained_model_name_or_path model_path = hf_hub_download( repo_id=repo_id, filename=model_filename, **kwargs.get("hub_kwargs", {}) ) knn = joblib.load(model_path) except Exception as exc: raise RuntimeError( f"Could not locate or download {model_filename} for {pretrained_model_name_or_path}: {exc}" ) models.append(knn) inst = cls(config=config, models=models) return inst else: # Load single model model_filename = cls._resolve_model_filename( pretrained_model_name_or_path, variant=variant, data_source=data_source, k_neighbors=k_neighbors, metric=metric, training_version=training_version, ) config.is_ensemble = False # Attempt to resolve model file model_file = os.path.join(pretrained_model_name_or_path, model_filename) if os.path.exists(model_file): knn = joblib.load(model_file) else: # try to download from hub try: from huggingface_hub import hf_hub_download repo_id = pretrained_model_name_or_path model_path = hf_hub_download( repo_id=repo_id, filename=model_filename, **kwargs.get("hub_kwargs", {}) ) knn = joblib.load(model_path) except Exception as exc: raise RuntimeError( f"Could not locate or download {model_filename} for {pretrained_model_name_or_path}: {exc}" ) inst = cls(config=config, model=knn) return inst @staticmethod def _resolve_model_filename( pretrained_model_name_or_path: str, variant: Optional[str] = None, data_source: Optional[str] = None, k_neighbors: Optional[int] = None, metric: Optional[str] = None, training_version: Optional[str] = None, ) -> str: """ Resolve the model filename based on variant parameters. Returns: Filename of the .joblib model to load (e.g., "models/knn_21T_k1_euclidean.joblib") """ # If direct variant name provided, use it if variant: # Ensure .joblib extension if not variant.endswith(".joblib"): variant = f"{variant}.joblib" # Check if it needs models/ prefix if not variant.startswith("models/"): return f"models/{variant}" return variant # If no parameters provided, use default (best performing model) if not any([data_source, k_neighbors, metric, training_version]): return "models/knn_21T_k1_euclidean.joblib" # Try to load model index to find matching variant try: index_path = os.path.join(pretrained_model_name_or_path, "model_index.json") if os.path.exists(index_path): with open(index_path, "r") as f: index = json.load(f) else: # Try to download from hub from huggingface_hub import hf_hub_download index_path = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="model_index.json" ) with open(index_path, "r") as f: index = json.load(f) # Search for matching variant for variant_name, variant_info in index.get("variants", {}).items(): matches = True if data_source and variant_info.get("data_source") != data_source: matches = False if k_neighbors and variant_info.get("k_neighbors") != k_neighbors: matches = False if metric and variant_info.get("metric").lower() != metric.lower(): matches = False if training_version and variant_info.get("training_version") != training_version: matches = False if matches: return variant_info["filename"] # No match found raise ValueError( f"No model variant found matching: data_source={data_source}, " f"k_neighbors={k_neighbors}, metric={metric}, training_version={training_version}" ) except Exception as e: # Fallback: construct filename from parameters if not data_source or not k_neighbors or not metric: raise ValueError( "Could not load model_index.json and insufficient parameters provided. " "Please specify: data_source, k_neighbors, and metric" ) from e # Construct filename ds = data_source.replace("-", "") # "7T-21T" -> "7T21T" version_suffix = f"_{training_version}" if training_version else "" filename = f"models/knn_{ds}_k{k_neighbors}_{metric.lower()}{version_suffix}.joblib" return filename @staticmethod def _resolve_ensemble_filenames( pretrained_model_name_or_path: str, variant: Optional[str] = None, data_source: Optional[str] = None, k_neighbors: Optional[int] = None, metric: Optional[str] = None, ) -> List[str]: """ Resolve ensemble model filenames (multiple .joblib files for one logical model). For 7T-21T: returns 2 filenames (ver2 and ver3) For Synthetic: returns 3 filenames (ver2, ver3, synthetic_data) Returns: List of filenames to load """ if not data_source: raise ValueError("data_source is required for ensemble models") if data_source not in ["7T-21T", "Synthetic"]: raise ValueError(f"data_source '{data_source}' is not an ensemble model") if not k_neighbors or not metric: raise ValueError("k_neighbors and metric are required for ensemble models") # Define ensemble members for each type if data_source == "7T-21T": training_versions = ["DOM_training_set_ver2", "DOM_training_set_ver3"] elif data_source == "Synthetic": training_versions = ["DOM_training_set_ver2", "DOM_training_set_ver3", "synthetic_data"] else: raise ValueError(f"Unknown ensemble type: {data_source}") # Construct filenames based on original naming pattern # Pattern: knn_model_Model-{data_source}_K{k}_{Metric}_{training_version}.joblib metric_name = metric.capitalize() # "euclidean" -> "Euclidean" filenames = [] for version in training_versions: filename = f"models/knn_model_Model-{data_source}_K{k_neighbors}_{metric_name}_{version}.joblib" filenames.append(filename) return filenames