Spaces:
Configuration error
Configuration error
| import os | |
| from typing import Optional, Tuple, Union | |
| from inference.core.cache import cache | |
| from inference.core.devices.utils import GLOBAL_DEVICE_ID | |
| from inference.core.entities.types import DatasetID, ModelType, TaskType, VersionID | |
| from inference.core.env import LAMBDA, MODEL_CACHE_DIR | |
| from inference.core.exceptions import ( | |
| MissingApiKeyError, | |
| ModelArtefactError, | |
| ModelNotRecognisedError, | |
| ) | |
| from inference.core.logger import logger | |
| from inference.core.models.base import Model | |
| from inference.core.registries.base import ModelRegistry | |
| from inference.core.roboflow_api import ( | |
| MODEL_TYPE_DEFAULTS, | |
| MODEL_TYPE_KEY, | |
| PROJECT_TASK_TYPE_KEY, | |
| ModelEndpointType, | |
| get_roboflow_dataset_type, | |
| get_roboflow_model_data, | |
| get_roboflow_workspace, | |
| ) | |
| from inference.core.utils.file_system import dump_json, read_json | |
| from inference.core.utils.roboflow import get_model_id_chunks | |
| from inference.models.aliases import resolve_roboflow_model_alias | |
| GENERIC_MODELS = { | |
| "clip": ("embed", "clip"), | |
| "sam": ("embed", "sam"), | |
| "gaze": ("gaze", "l2cs"), | |
| "doctr": ("ocr", "doctr"), | |
| "grounding_dino": ("object-detection", "grounding-dino"), | |
| "cogvlm": ("llm", "cogvlm"), | |
| "yolo_world": ("object-detection", "yolo-world"), | |
| } | |
| STUB_VERSION_ID = "0" | |
| CACHE_METADATA_LOCK_TIMEOUT = 1.0 | |
| class RoboflowModelRegistry(ModelRegistry): | |
| """A Roboflow-specific model registry which gets the model type using the model id, | |
| then returns a model class based on the model type. | |
| """ | |
| def get_model(self, model_id: str, api_key: str) -> Model: | |
| """Returns the model class based on the given model id and API key. | |
| Args: | |
| model_id (str): The ID of the model to be retrieved. | |
| api_key (str): The API key used to authenticate. | |
| Returns: | |
| Model: The model class corresponding to the given model ID and type. | |
| Raises: | |
| ModelNotRecognisedError: If the model type is not supported or found. | |
| """ | |
| model_type = get_model_type(model_id, api_key) | |
| if model_type not in self.registry_dict: | |
| raise ModelNotRecognisedError(f"Model type not supported: {model_type}") | |
| return self.registry_dict[model_type] | |
| def get_model_type( | |
| model_id: str, | |
| api_key: Optional[str] = None, | |
| ) -> Tuple[TaskType, ModelType]: | |
| """Retrieves the model type based on the given model ID and API key. | |
| Args: | |
| model_id (str): The ID of the model. | |
| api_key (str): The API key used to authenticate. | |
| Returns: | |
| tuple: The project task type and the model type. | |
| Raises: | |
| WorkspaceLoadError: If the workspace could not be loaded or if the API key is invalid. | |
| DatasetLoadError: If the dataset could not be loaded due to invalid ID, workspace ID or version ID. | |
| MissingDefaultModelError: If default model is not configured and API does not provide this info | |
| MalformedRoboflowAPIResponseError: Roboflow API responds in invalid format. | |
| """ | |
| model_id = resolve_roboflow_model_alias(model_id=model_id) | |
| dataset_id, version_id = get_model_id_chunks(model_id=model_id) | |
| if dataset_id in GENERIC_MODELS: | |
| logger.debug(f"Loading generic model: {dataset_id}.") | |
| return GENERIC_MODELS[dataset_id] | |
| cached_metadata = get_model_metadata_from_cache( | |
| dataset_id=dataset_id, version_id=version_id | |
| ) | |
| if cached_metadata is not None: | |
| return cached_metadata[0], cached_metadata[1] | |
| if version_id == STUB_VERSION_ID: | |
| if api_key is None: | |
| raise MissingApiKeyError( | |
| "Stub model version provided but no API key was provided. API key is required to load stub models." | |
| ) | |
| workspace_id = get_roboflow_workspace(api_key=api_key) | |
| project_task_type = get_roboflow_dataset_type( | |
| api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id | |
| ) | |
| model_type = "stub" | |
| save_model_metadata_in_cache( | |
| dataset_id=dataset_id, | |
| version_id=version_id, | |
| project_task_type=project_task_type, | |
| model_type=model_type, | |
| ) | |
| return project_task_type, model_type | |
| api_data = get_roboflow_model_data( | |
| api_key=api_key, | |
| model_id=model_id, | |
| endpoint_type=ModelEndpointType.ORT, | |
| device_id=GLOBAL_DEVICE_ID, | |
| ).get("ort") | |
| if api_data is None: | |
| raise ModelArtefactError("Error loading model artifacts from Roboflow API.") | |
| # some older projects do not have type field - hence defaulting | |
| project_task_type = api_data.get("type", "object-detection") | |
| model_type = api_data.get("modelType") | |
| if model_type is None or model_type == "ort": | |
| # some very old model versions do not have modelType reported - and API respond in a generic way - | |
| # then we shall attempt using default model for given task type | |
| model_type = MODEL_TYPE_DEFAULTS.get(project_task_type) | |
| if model_type is None or project_task_type is None: | |
| raise ModelArtefactError("Error loading model artifacts from Roboflow API.") | |
| save_model_metadata_in_cache( | |
| dataset_id=dataset_id, | |
| version_id=version_id, | |
| project_task_type=project_task_type, | |
| model_type=model_type, | |
| ) | |
| return project_task_type, model_type | |
| def get_model_metadata_from_cache( | |
| dataset_id: str, version_id: str | |
| ) -> Optional[Tuple[TaskType, ModelType]]: | |
| if LAMBDA: | |
| return _get_model_metadata_from_cache( | |
| dataset_id=dataset_id, version_id=version_id | |
| ) | |
| with cache.lock( | |
| f"lock:metadata:{dataset_id}:{version_id}", expire=CACHE_METADATA_LOCK_TIMEOUT | |
| ): | |
| return _get_model_metadata_from_cache( | |
| dataset_id=dataset_id, version_id=version_id | |
| ) | |
| def _get_model_metadata_from_cache( | |
| dataset_id: str, version_id: str | |
| ) -> Optional[Tuple[TaskType, ModelType]]: | |
| model_type_cache_path = construct_model_type_cache_path( | |
| dataset_id=dataset_id, version_id=version_id | |
| ) | |
| if not os.path.isfile(model_type_cache_path): | |
| return None | |
| try: | |
| model_metadata = read_json(path=model_type_cache_path) | |
| if model_metadata_content_is_invalid(content=model_metadata): | |
| return None | |
| return model_metadata[PROJECT_TASK_TYPE_KEY], model_metadata[MODEL_TYPE_KEY] | |
| except ValueError as e: | |
| logger.warning( | |
| f"Could not load model description from cache under path: {model_type_cache_path} - decoding issue: {e}." | |
| ) | |
| return None | |
| def model_metadata_content_is_invalid(content: Optional[Union[list, dict]]) -> bool: | |
| if content is None: | |
| logger.warning("Empty model metadata file encountered in cache.") | |
| return True | |
| if not issubclass(type(content), dict): | |
| logger.warning("Malformed file encountered in cache.") | |
| return True | |
| if PROJECT_TASK_TYPE_KEY not in content or MODEL_TYPE_KEY not in content: | |
| logger.warning( | |
| f"Could not find one of required keys {PROJECT_TASK_TYPE_KEY} or {MODEL_TYPE_KEY} in cache." | |
| ) | |
| return True | |
| return False | |
| def save_model_metadata_in_cache( | |
| dataset_id: DatasetID, | |
| version_id: VersionID, | |
| project_task_type: TaskType, | |
| model_type: ModelType, | |
| ) -> None: | |
| if LAMBDA: | |
| _save_model_metadata_in_cache( | |
| dataset_id=dataset_id, | |
| version_id=version_id, | |
| project_task_type=project_task_type, | |
| model_type=model_type, | |
| ) | |
| return None | |
| with cache.lock( | |
| f"lock:metadata:{dataset_id}:{version_id}", expire=CACHE_METADATA_LOCK_TIMEOUT | |
| ): | |
| _save_model_metadata_in_cache( | |
| dataset_id=dataset_id, | |
| version_id=version_id, | |
| project_task_type=project_task_type, | |
| model_type=model_type, | |
| ) | |
| return None | |
| def _save_model_metadata_in_cache( | |
| dataset_id: DatasetID, | |
| version_id: VersionID, | |
| project_task_type: TaskType, | |
| model_type: ModelType, | |
| ) -> None: | |
| model_type_cache_path = construct_model_type_cache_path( | |
| dataset_id=dataset_id, version_id=version_id | |
| ) | |
| metadata = { | |
| PROJECT_TASK_TYPE_KEY: project_task_type, | |
| MODEL_TYPE_KEY: model_type, | |
| } | |
| dump_json( | |
| path=model_type_cache_path, content=metadata, allow_override=True, indent=4 | |
| ) | |
| def construct_model_type_cache_path(dataset_id: str, version_id: str) -> str: | |
| cache_dir = os.path.join(MODEL_CACHE_DIR, dataset_id, version_id) | |
| return os.path.join(cache_dir, "model_type.json") | |