Spaces:
Configuration error
Configuration error
| import hashlib | |
| from dataclasses import asdict | |
| from typing import Any, Dict, List, Optional | |
| from inference.core import logger | |
| from inference.core.active_learning.entities import ( | |
| ActiveLearningConfiguration, | |
| RoboflowProjectMetadata, | |
| SamplingMethod, | |
| ) | |
| from inference.core.active_learning.samplers.close_to_threshold import ( | |
| initialize_close_to_threshold_sampling, | |
| ) | |
| from inference.core.active_learning.samplers.contains_classes import ( | |
| initialize_classes_based_sampling, | |
| ) | |
| from inference.core.active_learning.samplers.number_of_detections import ( | |
| initialize_detections_number_based_sampling, | |
| ) | |
| from inference.core.active_learning.samplers.random import initialize_random_sampling | |
| from inference.core.cache.base import BaseCache | |
| from inference.core.exceptions import ( | |
| ActiveLearningConfigurationDecodingError, | |
| ActiveLearningConfigurationError, | |
| RoboflowAPINotAuthorizedError, | |
| RoboflowAPINotNotFoundError, | |
| ) | |
| from inference.core.roboflow_api import ( | |
| get_roboflow_active_learning_configuration, | |
| get_roboflow_dataset_type, | |
| get_roboflow_workspace, | |
| ) | |
| from inference.core.utils.roboflow import get_model_id_chunks | |
| TYPE2SAMPLING_INITIALIZERS = { | |
| "random": initialize_random_sampling, | |
| "close_to_threshold": initialize_close_to_threshold_sampling, | |
| "classes_based": initialize_classes_based_sampling, | |
| "detections_number_based": initialize_detections_number_based_sampling, | |
| } | |
| ACTIVE_LEARNING_CONFIG_CACHE_EXPIRE = 900 # 15 min | |
| def prepare_active_learning_configuration( | |
| api_key: str, | |
| model_id: str, | |
| cache: BaseCache, | |
| ) -> Optional[ActiveLearningConfiguration]: | |
| project_metadata = get_roboflow_project_metadata( | |
| api_key=api_key, | |
| model_id=model_id, | |
| cache=cache, | |
| ) | |
| if not project_metadata.active_learning_configuration.get("enabled", False): | |
| return None | |
| logger.info( | |
| f"Configuring active learning for workspace: {project_metadata.workspace_id}, " | |
| f"project: {project_metadata.dataset_id} of type: {project_metadata.dataset_type}. " | |
| f"AL configuration: {project_metadata.active_learning_configuration}" | |
| ) | |
| return initialise_active_learning_configuration( | |
| project_metadata=project_metadata, | |
| ) | |
| def prepare_active_learning_configuration_inplace( | |
| api_key: str, | |
| model_id: str, | |
| active_learning_configuration: Optional[dict], | |
| ) -> Optional[ActiveLearningConfiguration]: | |
| if ( | |
| active_learning_configuration is None | |
| or active_learning_configuration.get("enabled", False) is False | |
| ): | |
| return None | |
| dataset_id, version_id = get_model_id_chunks(model_id=model_id) | |
| workspace_id = get_roboflow_workspace(api_key=api_key) | |
| dataset_type = get_roboflow_dataset_type( | |
| api_key=api_key, | |
| workspace_id=workspace_id, | |
| dataset_id=dataset_id, | |
| ) | |
| project_metadata = RoboflowProjectMetadata( | |
| dataset_id=dataset_id, | |
| version_id=version_id, | |
| workspace_id=workspace_id, | |
| dataset_type=dataset_type, | |
| active_learning_configuration=active_learning_configuration, | |
| ) | |
| return initialise_active_learning_configuration( | |
| project_metadata=project_metadata, | |
| ) | |
| def get_roboflow_project_metadata( | |
| api_key: str, | |
| model_id: str, | |
| cache: BaseCache, | |
| ) -> RoboflowProjectMetadata: | |
| logger.info(f"Fetching active learning configuration.") | |
| config_cache_key = construct_cache_key_for_active_learning_config( | |
| api_key=api_key, model_id=model_id | |
| ) | |
| cached_config = cache.get(config_cache_key) | |
| if cached_config is not None: | |
| logger.info("Found Active Learning configuration in cache.") | |
| return parse_cached_roboflow_project_metadata(cached_config=cached_config) | |
| dataset_id, version_id = get_model_id_chunks(model_id=model_id) | |
| workspace_id = get_roboflow_workspace(api_key=api_key) | |
| dataset_type = get_roboflow_dataset_type( | |
| api_key=api_key, | |
| workspace_id=workspace_id, | |
| dataset_id=dataset_id, | |
| ) | |
| try: | |
| roboflow_api_configuration = get_roboflow_active_learning_configuration( | |
| api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id | |
| ) | |
| except (RoboflowAPINotAuthorizedError, RoboflowAPINotNotFoundError): | |
| # currently backend returns HTTP 404 if dataset does not exist | |
| # or workspace_id from api_key indicate that the owner is different, | |
| # so in the situation when we query for Universe dataset. | |
| # We want the owner of public dataset to be able to set AL configs | |
| # and use them, but not other people. At this point it's known | |
| # that HTTP 404 means not authorised (which will probably change | |
| # in future iteration of backend) - so on both NotAuth and NotFound | |
| # errors we assume that we simply cannot use AL with this model and | |
| # this api_key. | |
| roboflow_api_configuration = {"enabled": False} | |
| configuration = RoboflowProjectMetadata( | |
| dataset_id=dataset_id, | |
| version_id=version_id, | |
| workspace_id=workspace_id, | |
| dataset_type=dataset_type, | |
| active_learning_configuration=roboflow_api_configuration, | |
| ) | |
| cache.set( | |
| key=config_cache_key, | |
| value=asdict(configuration), | |
| expire=ACTIVE_LEARNING_CONFIG_CACHE_EXPIRE, | |
| ) | |
| return configuration | |
| def construct_cache_key_for_active_learning_config(api_key: str, model_id: str) -> str: | |
| dataset_id = model_id.split("/")[0] | |
| api_key_hash = hashlib.md5(api_key.encode("utf-8")).hexdigest() | |
| return f"active_learning:configurations:{api_key_hash}:{dataset_id}" | |
| def parse_cached_roboflow_project_metadata( | |
| cached_config: dict, | |
| ) -> RoboflowProjectMetadata: | |
| try: | |
| return RoboflowProjectMetadata(**cached_config) | |
| except Exception as error: | |
| raise ActiveLearningConfigurationDecodingError( | |
| f"Failed to initialise Active Learning configuration. Cause: {str(error)}" | |
| ) from error | |
| def initialise_active_learning_configuration( | |
| project_metadata: RoboflowProjectMetadata, | |
| ) -> ActiveLearningConfiguration: | |
| sampling_methods = initialize_sampling_methods( | |
| sampling_strategies_configs=project_metadata.active_learning_configuration[ | |
| "sampling_strategies" | |
| ], | |
| ) | |
| target_workspace_id = project_metadata.active_learning_configuration.get( | |
| "target_workspace", project_metadata.workspace_id | |
| ) | |
| target_dataset_id = project_metadata.active_learning_configuration.get( | |
| "target_project", project_metadata.dataset_id | |
| ) | |
| return ActiveLearningConfiguration.init( | |
| roboflow_api_configuration=project_metadata.active_learning_configuration, | |
| sampling_methods=sampling_methods, | |
| workspace_id=target_workspace_id, | |
| dataset_id=target_dataset_id, | |
| model_id=f"{project_metadata.dataset_id}/{project_metadata.version_id}", | |
| ) | |
| def initialize_sampling_methods( | |
| sampling_strategies_configs: List[Dict[str, Any]] | |
| ) -> List[SamplingMethod]: | |
| result = [] | |
| for sampling_strategy_config in sampling_strategies_configs: | |
| sampling_type = sampling_strategy_config["type"] | |
| if sampling_type not in TYPE2SAMPLING_INITIALIZERS: | |
| logger.warn( | |
| f"Could not identify sampling method `{sampling_type}` - skipping initialisation." | |
| ) | |
| continue | |
| initializer = TYPE2SAMPLING_INITIALIZERS[sampling_type] | |
| result.append(initializer(sampling_strategy_config)) | |
| names = set(m.name for m in result) | |
| if len(names) != len(result): | |
| raise ActiveLearningConfigurationError( | |
| "Detected duplication of Active Learning strategies names." | |
| ) | |
| return result | |