Spaces:
Running
Running
| import concurrent.futures | |
| import random | |
| import gradio as gr | |
| import requests | |
| import io, base64, json | |
| # import spaces | |
| from PIL import Image | |
| from .model_config import model_config | |
| from .model_worker import BaseModelWorker | |
| class ModelManager: | |
| def __init__(self): | |
| self.models_config = model_config | |
| self.models_worker: list[BaseModelWorker] = {} | |
| self.build_model_workers() | |
| def build_model_workers(self): | |
| for cfg in self.models_config.values(): | |
| worker = BaseModelWorker(cfg.model_name, cfg.i2s_model, cfg.online_model, cfg.model_path) | |
| self.models_worker[cfg.model_name] = worker | |
| def get_all_models(self): | |
| models = [] | |
| for model_name in self.models_config.keys(): | |
| models.append(model_name) | |
| return models | |
| def get_t2s_models(self): | |
| models = [] | |
| for cfg in self.models_config.values(): | |
| if not cfg.i2s_model: | |
| models.append(cfg.model_name) | |
| return models | |
| def get_i2s_models(self): | |
| models = [] | |
| for cfg in self.models_config.values(): | |
| if cfg.i2s_model: | |
| models.append(cfg.model_name) | |
| return models | |
| def get_online_models(self): | |
| models = [] | |
| for cfg in self.models_config.values(): | |
| if cfg.online_model: | |
| models.append(cfg.model_name) | |
| return models | |
| def get_models(self, i2s_model:bool, online_model:bool): | |
| models = [] | |
| for cfg in self.models_config.values(): | |
| if cfg.i2s_model==i2s_model and cfg.online_model==online_model: | |
| models.append(cfg.model_name) | |
| return models | |
| def check_online(self, name): | |
| worker = self.models_worker[name] | |
| if not worker.online_model: | |
| return | |
| # @spaces.GPU(duration=120) | |
| def inference(self, | |
| prompt, model_name, | |
| offline=False, offline_idx=None): | |
| result = None | |
| worker = self.models_worker[model_name] | |
| if offline: | |
| result = worker.load_offline(offline_idx) | |
| if not offline or result == None: | |
| if worker.check_online(): | |
| result = worker.inference(prompt) | |
| return result | |
| def render(self, shape, model_name): | |
| worker = self.models_worker[model_name] | |
| result = worker.render(shape) | |
| return result | |
| def inference_parallel(self, | |
| prompt, model_A, model_B, | |
| offline=False, offline_idx=None): | |
| results = [] | |
| model_names = [model_A, model_B] | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| future_to_result = {executor.submit(self.inference, prompt, model, offline, offline_idx): model | |
| for model in model_names} | |
| for future in concurrent.futures.as_completed(future_to_result): | |
| result = future.result() | |
| results.append(result) | |
| return results[0], results[1] | |
| def inference_parallel_anony(self, | |
| prompt, model_A, model_B, | |
| i2s_model: bool, offline: bool =False, offline_idx: int =None): | |
| if model_A == model_B == "": | |
| if offline and i2s_model: | |
| model_A, model_B = random.sample(self.get_i2s_models(), 2) | |
| elif offline and not i2s_model: | |
| model_A, model_B = random.sample(self.get_t2s_models(), 2) | |
| else: | |
| model_A, model_B = random.sample(self.get_models(i2s_model=i2s_model, online_model=True), 2) | |
| model_names = [model_A, model_B] | |
| results = [] | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| future_to_result = {executor.submit(self.inference, prompt, model, offline, offline_idx): model | |
| for model in model_names} | |
| for future in concurrent.futures.as_completed(future_to_result): | |
| result = future.result() | |
| results.append(result) | |
| return results[0], results[1], model_A, model_B | |
| def render_parallel(self, shape_A, model_A, shape_B, model_B): | |
| results = [] | |
| model_names = [model_A, model_B] | |
| shapes = [shape_A, shape_B] | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| future_to_result = {executor.submit(self.render, shape, model): model | |
| for model, shape in zip(model_names, shapes)} | |
| for future in concurrent.futures.as_completed(future_to_result): | |
| result = future.result() | |
| results.append(result) | |
| return results[0], results[1] | |
| # def i2s_inference_parallel(self, image, model_A, model_B): | |
| # results = [] | |
| # model_names = [model_A, model_B] | |
| # with concurrent.futures.ThreadPoolExecutor() as executor: | |
| # future_to_result = {executor.submit(self.inference, image, model): model | |
| # for model in model_names} | |
| # for future in concurrent.futures.as_completed(future_to_result): | |
| # result = future.result() | |
| # results.append(result) | |
| # return results[0], results[1] | |