| import logging | |
| import time | |
| from asyncio import Queue as AioQueue | |
| from dataclasses import asdict | |
| from multiprocessing import shared_memory | |
| from queue import Queue | |
| from threading import Thread | |
| from typing import Dict, List, Tuple | |
| import numpy as np | |
| import orjson | |
| from redis import ConnectionPool, Redis | |
| from inference.core.entities.requests.inference import ( | |
| InferenceRequest, | |
| request_from_type, | |
| ) | |
| from inference.core.env import MAX_ACTIVE_MODELS, MAX_BATCH_SIZE, REDIS_HOST, REDIS_PORT | |
| from inference.core.managers.base import ModelManager | |
| from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache | |
| from inference.core.models.roboflow import RoboflowInferenceModel | |
| from inference.core.registries.roboflow import RoboflowModelRegistry | |
| from inference.enterprise.parallel.tasks import postprocess | |
| from inference.enterprise.parallel.utils import ( | |
| SharedMemoryMetadata, | |
| failure_handler, | |
| shm_manager, | |
| ) | |
| logging.basicConfig(level=logging.WARNING) | |
| logger = logging.getLogger() | |
| from inference.models.utils import ROBOFLOW_MODEL_TYPES | |
| BATCH_SIZE = MAX_BATCH_SIZE | |
| if BATCH_SIZE == float("inf"): | |
| BATCH_SIZE = 32 | |
| AGE_TRADEOFF_SECONDS_FACTOR = 30 | |
| class InferServer: | |
| def __init__(self, redis: Redis) -> None: | |
| self.redis = redis | |
| model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES) | |
| model_manager = ModelManager(model_registry) | |
| self.model_manager = WithFixedSizeCache( | |
| model_manager, max_size=MAX_ACTIVE_MODELS | |
| ) | |
| self.running = True | |
| self.response_queue = Queue() | |
| self.write_thread = Thread(target=self.write_responses) | |
| self.write_thread.start() | |
| self.batch_queue = Queue(maxsize=1) | |
| self.infer_thread = Thread(target=self.infer) | |
| self.infer_thread.start() | |
| def write_responses(self): | |
| while True: | |
| try: | |
| response = self.response_queue.get() | |
| write_infer_arrays_and_launch_postprocess(*response) | |
| except Exception as error: | |
| logger.warning( | |
| f"Encountered error while writiing response:\n" + str(error) | |
| ) | |
| def infer_loop(self): | |
| while self.running: | |
| try: | |
| model_names = get_requested_model_names(self.redis) | |
| if not model_names: | |
| time.sleep(0.001) | |
| continue | |
| self.get_batch(model_names) | |
| except Exception as error: | |
| logger.warning("Encountered error in infer loop:\n" + str(error)) | |
| continue | |
| def infer(self): | |
| while True: | |
| model_id, images, batch, preproc_return_metadatas = self.batch_queue.get() | |
| outputs = self.model_manager.predict(model_id, images) | |
| for output, b, metadata in zip( | |
| zip(*outputs), batch, preproc_return_metadatas | |
| ): | |
| self.response_queue.put_nowait((output, b["request"], metadata)) | |
| def get_batch(self, model_names): | |
| start = time.perf_counter() | |
| batch, model_id = get_batch(self.redis, model_names) | |
| logger.info(f"Inferring: model<{model_id}> batch_size<{len(batch)}>") | |
| with failure_handler(self.redis, *[b["request"]["id"] for b in batch]): | |
| self.model_manager.add_model(model_id, batch[0]["request"]["api_key"]) | |
| model_type = self.model_manager.get_task_type(model_id) | |
| for b in batch: | |
| request = request_from_type(model_type, b["request"]) | |
| b["request"] = request | |
| b["shm_metadata"] = SharedMemoryMetadata(**b["shm_metadata"]) | |
| metadata_processed = time.perf_counter() | |
| logger.info( | |
| f"Took {(metadata_processed - start):3f} seconds to process metadata" | |
| ) | |
| with shm_manager( | |
| *[b["shm_metadata"].shm_name for b in batch], unlink_on_success=True | |
| ) as shms: | |
| images, preproc_return_metadatas = load_batch(batch, shms) | |
| loaded = time.perf_counter() | |
| logger.info( | |
| f"Took {(loaded - metadata_processed):3f} seconds to load batch" | |
| ) | |
| self.batch_queue.put( | |
| (model_id, images, batch, preproc_return_metadatas) | |
| ) | |
| def get_requested_model_names(redis: Redis) -> List[str]: | |
| request_counts = redis.hgetall("requests") | |
| model_names = [ | |
| model_name for model_name, count in request_counts.items() if int(count) > 0 | |
| ] | |
| return model_names | |
| def get_batch(redis: Redis, model_names: List[str]) -> Tuple[List[Dict], str]: | |
| """ | |
| Run a heuristic to select the best batch to infer on | |
| redis[Redis]: redis client | |
| model_names[List[str]]: list of models with nonzero number of requests | |
| returns: | |
| Tuple[List[Dict], str] | |
| List[Dict] represents a batch of request dicts | |
| str is the model id | |
| """ | |
| batch_sizes = [ | |
| RoboflowInferenceModel.model_metadata_from_memcache_endpoint(m)["batch_size"] | |
| for m in model_names | |
| ] | |
| batch_sizes = [b if not isinstance(b, str) else BATCH_SIZE for b in batch_sizes] | |
| batches = [ | |
| redis.zrange(f"infer:{m}", 0, b - 1, withscores=True) | |
| for m, b in zip(model_names, batch_sizes) | |
| ] | |
| model_index = select_best_inference_batch(batches, batch_sizes) | |
| batch = batches[model_index] | |
| selected_model = model_names[model_index] | |
| redis.zrem(f"infer:{selected_model}", *[b[0] for b in batch]) | |
| redis.hincrby(f"requests", selected_model, -len(batch)) | |
| batch = [orjson.loads(b[0]) for b in batch] | |
| return batch, selected_model | |
| def select_best_inference_batch(batches, batch_sizes): | |
| now = time.time() | |
| average_ages = [np.mean([float(b[1]) - now for b in batch]) for batch in batches] | |
| lengths = [ | |
| len(batch) / batch_size for batch, batch_size in zip(batches, batch_sizes) | |
| ] | |
| fitnesses = [ | |
| age / AGE_TRADEOFF_SECONDS_FACTOR + length | |
| for age, length in zip(average_ages, lengths) | |
| ] | |
| model_index = fitnesses.index(max(fitnesses)) | |
| return model_index | |
| def load_batch( | |
| batch: List[Dict[str, str]], shms: List[shared_memory.SharedMemory] | |
| ) -> Tuple[List[np.ndarray], List[Dict]]: | |
| images = [] | |
| preproc_return_metadatas = [] | |
| for b, shm in zip(batch, shms): | |
| shm_metadata: SharedMemoryMetadata = b["shm_metadata"] | |
| image = np.ndarray( | |
| shm_metadata.array_shape, dtype=shm_metadata.array_dtype, buffer=shm.buf | |
| ).copy() | |
| images.append(image) | |
| preproc_return_metadatas.append(b["preprocess_metadata"]) | |
| return images, preproc_return_metadatas | |
| def write_infer_arrays_and_launch_postprocess( | |
| arrs: Tuple[np.ndarray, ...], | |
| request: InferenceRequest, | |
| preproc_return_metadata: Dict, | |
| ): | |
| """Write inference results to shared memory and launch the postprocessing task""" | |
| shms = [shared_memory.SharedMemory(create=True, size=arr.nbytes) for arr in arrs] | |
| with shm_manager(*shms): | |
| shm_metadatas = [] | |
| for arr, shm in zip(arrs, shms): | |
| shared = np.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf) | |
| shared[:] = arr[:] | |
| shm_metadata = SharedMemoryMetadata( | |
| shm_name=shm.name, array_shape=arr.shape, array_dtype=arr.dtype.name | |
| ) | |
| shm_metadatas.append(asdict(shm_metadata)) | |
| postprocess.s( | |
| tuple(shm_metadatas), request.dict(), preproc_return_metadata | |
| ).delay() | |
| if __name__ == "__main__": | |
| pool = ConnectionPool(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True) | |
| redis = Redis(connection_pool=pool) | |
| InferServer(redis).infer_loop() | |