Spaces:
Running
Running
| import os | |
| import json | |
| import time | |
| import kiui | |
| from typing import List | |
| import replicate | |
| import subprocess | |
| from constants import OFFLINE_GIF_DIR | |
| # os.environ("REPLICATE_API_TOKEN", "r8_0BaoQW0G8nWFXY8YWBCCUDurANxCtY72rarv9") | |
| class BaseModelWorker: | |
| def __init__(self, | |
| model_name: str, | |
| i2s_model: bool, | |
| online_model: bool, | |
| model_api: str = None | |
| ): | |
| self.model_name = model_name | |
| self.i2s_model = i2s_model | |
| self.online_model = online_model | |
| self.model_api = model_api | |
| self.urls_json = None | |
| urls_json_path = os.path.join(OFFLINE_GIF_DIR, f"{model_name}.json") | |
| if os.path.exists(urls_json_path): | |
| with open(urls_json_path, 'r') as f: | |
| self.urls_json = json.load(f) | |
| def check_online(self) -> bool: | |
| if self.online_model and not self.model: | |
| return True | |
| else: | |
| return False | |
| def load_offline(self, offline: bool, offline_idx): | |
| ## offline | |
| if offline and str(offline_idx) in self.urls_json.keys(): | |
| return self.urls_json[str(offline_idx)] | |
| else: | |
| return None | |
| def inference(self, prompt): | |
| pass | |
| def render(self, shape, rgb_on=True, normal_on=True): | |
| pass | |
| class HuggingfaceApiWorker(BaseModelWorker): | |
| def __init__( | |
| self, | |
| model_name: str, | |
| i2s_model: bool, | |
| online_model: bool, | |
| model_api: str, | |
| ): | |
| super().__init__( | |
| model_name, | |
| i2s_model, | |
| online_model, | |
| model_api, | |
| ) | |
| class PointE_Worker(BaseModelWorker): | |
| def __init__(self, | |
| model_name: str, | |
| i2s_model: bool, | |
| online_model: bool, | |
| model_api: str): | |
| super().__init__(model_name, i2s_model, online_model, model_api) | |
| class TriplaneGaussian(BaseModelWorker): | |
| def __init__(self, model_name: str, i2s_model: bool, online_model: bool, model_api: str = None): | |
| super().__init__(model_name, i2s_model, online_model, model_api) | |
| class LGM_Worker(BaseModelWorker): | |
| def __init__(self, | |
| model_name: str, | |
| i2s_model: bool, | |
| online_model: bool, | |
| model_api: str = "camenduru/lgm:d2870893aa115773465a823fe70fd446673604189843f39a99642dd9171e05e2", | |
| ): | |
| super().__init__(model_name, i2s_model, online_model, model_api) | |
| self.model_client = replicate.Client(api_token=REPLICATE_API_TOKEN) | |
| def inference(self, image): | |
| output = self.model_client.run( | |
| self.model_api, | |
| input={"input_image": image} | |
| ) | |
| #=> .mp4 .ply | |
| return output[1] | |
| def render(self, shape): | |
| mesh = Gau2Mesh_client.run(shape) | |
| path_normal = "" | |
| cmd_normal = f"python -m ..kiuikit.kiui.render {mesh} --save {path_normal} \ | |
| --wogui --H 512 --W 512 --radius 3 --elevation 0 --num_azimuth 40 --front_dir='+z' --mode normal" | |
| subprocess.run(cmd_normal, shell=True, check=True) | |
| path_rgb = "" | |
| cmd_rgb = f"python -m ..kiuikit.kiui.render {mesh} --save {path_rgb} \ | |
| --wogui --H 512 --W 512 --radius 3 --elevation 0 --num_azimuth 40 --front_dir='+z' --mode rgb" | |
| subprocess.run(cmd_rgb, shell=True, check=True) | |
| return path_normal, path_rgb | |
| class V3D_Worker(BaseModelWorker): | |
| def __init__(self, | |
| model_name: str, | |
| i2s_model: bool, | |
| online_model: bool, | |
| model_api: str = None): | |
| super().__init__(model_name, i2s_model, online_model, model_api) | |
| # model = 'LGM' | |
| # # model = 'TriplaneGaussian' | |
| # folder = 'glbs_full' | |
| # form = 'glb' | |
| # pose = '+z' | |
| # pair = ('OpenLRM', 'meshes', 'obj', '-y') | |
| # pair = ('TriplaneGaussian', 'glbs_full', 'glb', '-y') | |
| # pair = ('LGM', 'glbs_full', 'glb', '+z') | |
| if __name__=="__main__": | |
| # input = { | |
| # "input_image": "https://replicate.delivery/pbxt/KN0hQI9pYB3NOpHLqktkkQIblwpXt0IG7qI90n5hEnmV9kvo/bird_rgba.png", | |
| # } | |
| # print("Start...") | |
| # model_client = replicate.Client(api_token=REPLICATE_API_TOKEN) | |
| # output = model_client.run( | |
| # "camenduru/lgm:d2870893aa115773465a823fe70fd446673604189843f39a99642dd9171e05e2", | |
| # input=input | |
| # ) | |
| # print("output: ", output) | |
| #=> ['https://replicate.delivery/pbxt/toffawxRE3h6AUofI9sPtiAsoYI0v73zuGDZjZWBWAPzHKSlA/gradio_output.mp4', 'https://replicate.delivery/pbxt/oSn1XPfoJuw2UKOUIAue2iXeT7aXncVjC4QwHKU5W5x0HKSlA/gradio_output.ply'] | |
| output = ['https://replicate.delivery/pbxt/RPSTEes37lzAJav3jy1lPuzizm76WGU4IqDcFcAMxhQocjUJA/gradio_output.mp4', 'https://replicate.delivery/pbxt/2Vy8yrPO3PYiI1YJBxPXAzryR0SC0oyqW3XKPnXiuWHUuRqE/gradio_output.ply'] | |
| to_mesh_client = Client("https://dylanebert-splat-to-mesh.hf.space/", upload_files=True, download_files=True) | |
| mesh = to_mesh_client.predict(output[1], api_name="/run") | |
| print(mesh) |