Spaces:
Running
Running
| ## All Generation Gradio Interface | |
| import uuid | |
| import time | |
| from .utils import * | |
| from .vote_utils import t2s_logger, t2s_multi_logger, i2s_logger, i2s_multi_logger | |
| from constants import RGBA_SERVER, LOG_SERVER, TEXT_PROMPT_PATH, PROMPT_NUM | |
| with open(TEXT_PROMPT_PATH, 'r') as f: | |
| prompt_list = json.load(f) | |
| assert len(prompt_list) == PROMPT_NUM, f"Load {len(prompt_list)} text prompts, but expected {PROMPT_NUM}." | |
| class State: | |
| def __init__(self, | |
| model_name, i2s_mode=False, offline=False, | |
| prompt=None, image=None, offline_idx=None, | |
| normal_video=None , rgb_video=None, geo_video=None, | |
| evaluted_dims=0): | |
| self.conv_id = uuid.uuid4().hex | |
| self.model_name = model_name | |
| self.i2s_mode = i2s_mode | |
| self.offline = offline | |
| self.prompt = prompt | |
| self.image = image | |
| self.offline_idx = offline_idx | |
| # self.output = None | |
| self.normal_video = normal_video | |
| self.rgb_video = rgb_video | |
| self.geo_video = geo_video | |
| self.evaluted_dims = evaluted_dims | |
| def dict(self): | |
| base = { | |
| "conv_id": self.conv_id, | |
| "model_name": self.model_name, | |
| "i2s_mode": self.i2s_mode, | |
| "offline": self.offline, | |
| "prompt": self.prompt, | |
| "evaluted_dims": self.evaluted_dims, | |
| } | |
| if self.offline: | |
| base['offline_idx'] = self.offline_idx | |
| return base | |
| # class StateI2S: | |
| # def __init__(self, model_name): | |
| # self.conv_id = uuid.uuid4().hex | |
| # self.model_name = model_name | |
| # self.image = None | |
| # self.output = None | |
| # def dict(self): | |
| # base = { | |
| # "conv_id": self.conv_id, | |
| # "model_name": self.model_name, | |
| # } | |
| # return base | |
| def sample_t2s_model(state_0, state_1, model_list): | |
| model_name_0, model_name_1 = random.sample(eval(model_list), 2) | |
| if state_0 is None: | |
| state_0 = State(model_name_0, i2s_mode=False) | |
| if state_1 is None: | |
| state_1 = State(model_name_1, i2s_mode=False) | |
| state_0.model_name = model_name_0 | |
| state_0.i2s_mode = False | |
| state_1.model_name = model_name_1 | |
| state_1.i2s_mode = False | |
| return state_0, state_1, model_name_0, model_name_1 | |
| def sample_i2s_model(state_0, state_1, model_list): | |
| model_name_0, model_name_1 = random.sample(eval(model_list), 2) | |
| if state_0 is None: | |
| state_0 = State(model_name_0, i2s_mode=True) | |
| if state_1 is None: | |
| state_1 = State(model_name_1, i2s_mode=True) | |
| state_0.model_name = model_name_0 | |
| state_0.i2s_mode = True | |
| state_1.model_name = model_name_1 | |
| state_1.i2s_mode = True | |
| return state_0, state_1, model_name_0, model_name_1 | |
| def sample_prompt(state, model_name): | |
| if state is None: | |
| state = State(model_name) | |
| idx = random.randint(0, PROMPT_NUM-1) | |
| prompt = prompt_list[idx] | |
| state.model_name = model_name | |
| state.prompt = prompt | |
| state.i2s_mode = False | |
| state.offline = True, | |
| state.offline_idx = idx | |
| return state, prompt | |
| def sample_prompt_side_by_side(state_0, state_1, model_name_0, model_name_1): | |
| if state_0 is None: | |
| state_0 = State(model_name_0) | |
| if state_1 is None: | |
| state_1 = State(model_name_1) | |
| idx = random.randint(0, PROMPT_NUM-1) | |
| prompt = prompt_list[idx] | |
| state_0.i2s_mode, state_1.i2s_mode = False, False | |
| state_0.offline, state_1.offline = True, True | |
| state_0.offline_idx, state_1.offline_idx = idx, idx | |
| state_0.prompt, state_1.prompt = prompt, prompt | |
| return state_0, state_1, prompt | |
| def sample_image(state, model_name): | |
| if state is None: | |
| state = State(model_name) | |
| idx = random.randint(0, PROMPT_NUM-1) | |
| img_url = f"{RGBA_SERVER}/{idx}.png" | |
| state.model_name = model_name | |
| state.image = img_url | |
| state.i2s_mode = True | |
| state.offline = True, | |
| state.offline_idx = idx | |
| return state, img_url | |
| def sample_image_side_by_side(state_0, state_1, model_name_0, model_name_1): | |
| if state_0 is None: | |
| state_0 = State(model_name_0) | |
| if state_1 is None: | |
| state_1 = State(model_name_1) | |
| idx = random.randint(0, PROMPT_NUM-1) | |
| img_url = f"{RGBA_SERVER}/{idx}.png" | |
| state_0.i2s_mode, state_1.i2s_mode = True, True | |
| state_0.offline, state_1.offline = True, True | |
| state_0.offline_idx, state_1.offline_idx = idx, idx | |
| state_0.image, state_1.image = img_url, img_url | |
| return state_0, state_1, img_url | |
| def generate_t2s(gen_func, render_func, | |
| state, | |
| text, | |
| model_name, | |
| request: gr.Request): | |
| if not text or text.strip()=="": | |
| raise gr.Warning("Prompt cannot be empty.") | |
| if not model_name: | |
| raise gr.Warning("Model name cannot be empty.") | |
| if state is None: | |
| state = State(model_name, i2s_mode=False, offline=False) | |
| text = text.strip() | |
| ip = get_ip(request) | |
| t2s_logger.info(f"generate. ip: {ip}") | |
| state.model_name = model_name | |
| state.prompt = text | |
| state.evaluted_dims = 0 | |
| try: | |
| idx = prompt_list.index(text) | |
| state.offline = True | |
| state.offline_idx = idx | |
| except: | |
| state.offline = False | |
| state.offline_idx = None | |
| if state.offline and state.offline_idx: | |
| start_time = time.time() | |
| videos = gen_func(text, model_name, offline=state.offline, offline_idx=state.offline_idx) | |
| # normal_video = os.path.join(OFFLINE_DIR, "text2shape", model_name, "normal", f"{state.offline_idx}.mp4") | |
| # rgb_video = os.path.join(OFFLINE_DIR, "text2shape", model_name, "rgb", f"{state.offline_idx}.mp4") | |
| state.normal_video = videos['normal'] | |
| state.rgb_video = videos['rgb'] | |
| state.geo_video = videos['geo'] | |
| yield state, videos['geo'], videos['normal'], videos['rgb'] | |
| finish_tstamp = time.time() | |
| # logger.info(f"===output===: {output}") | |
| data = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "ip": get_ip(request), | |
| } | |
| else: | |
| start_time = time.time() | |
| shape = gen_func(text, model_name) | |
| generate_time = time.time() - start_time | |
| videos = render_func(shape, model_name) | |
| finish_time = time.time() | |
| render_time = finish_time - start_time - generate_time | |
| state.normal_video = videos['normal'] | |
| state.rgb_video = videos['rgb'] | |
| yield state, videos['normal'], videos['rgb'] | |
| finish_tstamp = time.time() | |
| # logger.info(f"===output===: {output}") | |
| data = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| "ip": get_ip(request), | |
| } | |
| with open(get_conv_log_filename(), "a") as fout: | |
| fout.write(json.dumps(data) + "\n") | |
| append_json_item_on_log_server(data, get_conv_log_filename()) | |
| # output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png' | |
| # os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
| # with open(output_file, 'w') as f: | |
| # state.output.save(f, 'PNG') | |
| # save_image_file_on_log_server(output_file) | |
| def generate_t2s_multi(gen_func, render_func, | |
| state_0, state_1, | |
| text, | |
| model_name_0, model_name_1, | |
| request: gr.Request): | |
| if not text or text.strip()=="": | |
| raise gr.Warning("Prompt cannot be empty.") | |
| if not model_name_0: | |
| raise gr.Warning("Model name A cannot be empty.") | |
| if not model_name_1: | |
| raise gr.Warning("Model name B cannot be empty.") | |
| if state_0 is None: | |
| state_0 = State(model_name_0, i2s_mode=False, offline=False) | |
| if state_1 is None: | |
| state_1 = State(model_name_1, i2s_mode=False, offline=False) | |
| text = text.strip() | |
| ip = get_ip(request) | |
| t2s_multi_logger.info(f"generate. ip: {ip}") | |
| state_0.model_name, state_1.model_name = model_name_0, model_name_1 | |
| state_0.prompt, state_1.prompt = text, text | |
| state_0.evaluted_dims, state_1.evaluted_dims = 0, 0 | |
| try: | |
| idx = prompt_list.index(text) | |
| state_0.offline, state_1.offline = True, True | |
| state_0.offline_idx, state_1.offline_idx = idx, idx | |
| except: | |
| state_0.offline, state_1.offline = False, False | |
| state_0.offline_idx, state_1.offline_idx = None, None | |
| if state_0.offline and state_0.offline_idx: | |
| start_time = time.time() | |
| videos_0, videos_1 = gen_func(text, model_name_0, model_name_1, offline=state_0.offline, offline_idx=state_0.offline_idx) | |
| # normal_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4") | |
| # rgb_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4") | |
| # normal_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4") | |
| # rgb_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4") | |
| state_0.normal_video, state_0.rgb_video, state_0.geo_video = videos_0['normal'], videos_0['rgb'], videos_0['geo'] | |
| state_1.normal_video, state_1.rgb_video, state_1.geo_video = videos_1['normal'], videos_1['rgb'], videos_1['geo'] | |
| yield state_0, state_1, \ | |
| videos_0['geo'], videos_0['normal'], videos_0['rgb'], \ | |
| videos_1['geo'], videos_1['normal'], videos_1['rgb'] | |
| finish_tstamp = time.time() | |
| # logger.info(f"===output===: {output}") | |
| data_0 = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name_0, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state_0.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "ip": get_ip(request), | |
| } | |
| data_1 = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name_1, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state_1.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "ip": get_ip(request), | |
| } | |
| else: | |
| start_time = time.time() | |
| shape_0, shape_1 = gen_func(text, model_name_0, model_name_1) | |
| generate_time = time.time() - start_time | |
| videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1) | |
| finish_time = time.time() | |
| render_time = finish_time - start_time - generate_time | |
| state_0.normal_video, state_0.rgb_video, state_0.geo_video = videos_0['normal'], videos_0['rgb'], videos_0['geo'] | |
| state_1.normal_video, state_1.rgb_video, state_1.geo_video = videos_1['normal'], videos_1['rgb'], videos_1['geo'] | |
| yield state_0, state_1, \ | |
| videos_0['geo'], videos_0['normal'], videos_0['rgb'], \ | |
| videos_1['geo'], videos_1['normal'], videos_1['rgb'] | |
| finish_tstamp = time.time() | |
| # logger.info(f"===output===: {output}") | |
| data_0 = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name_0, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state_0.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| "ip": get_ip(request), | |
| } | |
| data_1 = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name_1, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state_1.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| "ip": get_ip(request), | |
| } | |
| with open(get_conv_log_filename(), "a") as fout: | |
| fout.write(json.dumps(data_0) + "\n") | |
| fout.write(json.dumps(data_1) + "\n") | |
| append_json_item_on_log_server(data_0, get_conv_log_filename()) | |
| append_json_item_on_log_server(data_1, get_conv_log_filename()) | |
| # for i, state in enumerate([state_0, state_1]): | |
| # output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png' | |
| # os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
| # with open(output_file, 'w') as f: | |
| # state.output.save(f, 'PNG') | |
| # save_image_file_on_log_server(output_file) | |
| def generate_t2s_multi_annoy(gen_func, render_func, | |
| state_0, state_1, | |
| text, | |
| model_name_0, model_name_1, | |
| request: gr.Request): | |
| if not text or text.strip()=="": | |
| raise gr.Warning("Prompt cannot be empty.") | |
| if state_0 is None: | |
| state_0 = State(model_name_0, i2s_mode=False, offline=False) | |
| if state_1 is None: | |
| state_1 = State(model_name_1, i2s_mode=False, offline=False) | |
| text = text.strip() | |
| ip = get_ip(request) | |
| t2s_multi_logger.info(f"generate. ip: {ip}") | |
| state_0.model_name, state_1.model_name = model_name_0, model_name_1 | |
| state_0.prompt, state_1.prompt = text, text | |
| state_0.evaluted_dims, state_1.evaluted_dims = 0, 0 | |
| try: | |
| idx = prompt_list.index(text) | |
| state_0.offline, state_1.offline = True, True | |
| state_0.offline_idx, state_1.offline_idx = idx, idx | |
| except: | |
| state_0.offline, state_1.offline = False, False | |
| state_0.offline_idx, state_1.offline_idx = None, None | |
| if state_0.offline and state_0.offline_idx: | |
| start_time = time.time() | |
| videos_0, videos_1, model_name_0, model_name_1 = gen_func(text, model_name_0, model_name_1, | |
| i2s_model=False, offline=state_0.offline, offline_idx=state_0.offline_idx) | |
| # normal_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4") | |
| # rgb_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4") | |
| # normal_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4") | |
| # rgb_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4") | |
| state_0.model_name, state_1.model_name = model_name_0, model_name_1 | |
| state_0.normal_video, state_0.rgb_video, state_0.geo_video = videos_0['normal'], videos_0['rgb'], videos_0['geo'] | |
| state_1.normal_video, state_1.rgb_video, state_1.geo_video = videos_1['normal'], videos_1['rgb'], videos_1['geo'] | |
| yield state_0, state_1, \ | |
| videos_0['geo'], videos_0['normal'], videos_0['rgb'], \ | |
| videos_1['geo'], videos_1['normal'], videos_1['rgb'], \ | |
| gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") | |
| finish_tstamp = time.time() | |
| # logger.info(f"===output===: {output}") | |
| data_0 = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name_0, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state_0.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "ip": get_ip(request), | |
| } | |
| data_1 = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name_1, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state_1.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "ip": get_ip(request), | |
| } | |
| else: | |
| start_time = time.time() | |
| shape_0, shape_1, model_name_0, model_name_1 = gen_func(text, model_name_0, model_name_1, i2s_model=False) | |
| generate_time = time.time() - start_time | |
| videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1) | |
| finish_time = time.time() | |
| render_time = finish_time - start_time - generate_time | |
| state_0.model_name, state_1.model_name = model_name_0, model_name_1 | |
| state_0.normal_video, state_0.rgb_video, state_0.geo_video = videos_0['normal'], videos_0['rgb'], videos_0['geo'] | |
| state_1.normal_video, state_1.rgb_video, state_1.geo_video = videos_1['normal'], videos_1['rgb'], videos_1['geo'] | |
| yield state_0, state_1, \ | |
| videos_0['geo'], videos_0['normal'], videos_0['rgb'], \ | |
| videos_1['geo'], videos_1['normal'], videos_1['rgb'], \ | |
| gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") | |
| finish_tstamp = time.time() | |
| # logger.info(f"===output===: {output}") | |
| data_0 = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name_0, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state_0.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| "ip": get_ip(request), | |
| } | |
| data_1 = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name_1, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state_1.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| "ip": get_ip(request), | |
| } | |
| with open(get_conv_log_filename(), "a") as fout: | |
| fout.write(json.dumps(data_0) + "\n") | |
| fout.write(json.dumps(data_1) + "\n") | |
| append_json_item_on_log_server(data_0, get_conv_log_filename()) | |
| append_json_item_on_log_server(data_1, get_conv_log_filename()) | |
| # for i, state in enumerate([state_0, state_1]): | |
| # output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png' | |
| # os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
| # with open(output_file, 'w') as f: | |
| # state.output.save(f, 'PNG') | |
| # save_image_file_on_log_server(output_file) | |
| def generate_i2s(gen_func, render_func, state, image, model_name, request: gr.Request): | |
| if image is None: | |
| raise gr.Warning("Image cannot be empty.") | |
| if not model_name: | |
| raise gr.Warning("Model name cannot be empty.") | |
| if state is None: | |
| state = State(model_name, i2s_mode=True, offline=False) | |
| ip = get_ip(request) | |
| i2s_logger.info(f"generate. ip: {ip}") | |
| state.model_name = model_name | |
| state.image = image | |
| state.evaluted_dims = 0 | |
| if state.offline and state.offline_idx: | |
| start_time = time.time() | |
| videos = gen_func(image, model_name, offline=state.offline, offline_idx=state.offline_idx) | |
| # normal_video = os.path.join(OFFLINE_DIR, "image2shape", model_name, "normal", f"{state.offline_idx}.mp4") | |
| # rgb_video = os.path.join(OFFLINE_DIR, "image2shape", model_name, "rgb", f"{state.offline_idx}.mp4") | |
| state.normal_video = videos['normal'] | |
| state.rgb_video = videos['rgb'] | |
| state.geo_video = videos['geo'] | |
| yield state, videos['geo'], videos['normal'], videos['rgb'] | |
| finish_tstamp = time.time() | |
| # logger.info(f"===output===: {output}") | |
| data = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "ip": get_ip(request), | |
| } | |
| else: | |
| start_time = time.time() | |
| shape = gen_func(image, model_name) | |
| generate_time = time.time() - start_time | |
| videos = render_func(shape, model_name) | |
| finish_time = time.time() | |
| render_time = finish_time - start_time - generate_time | |
| state.normal_video = videos['normal'] | |
| state.rgb_video = videos['rgb'] | |
| yield state, videos['normal'], videos['rgb'] | |
| finish_tstamp = time.time() | |
| # logger.info(f"===output===: {output}") | |
| data = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| "ip": get_ip(request), | |
| } | |
| with open(get_conv_log_filename(), "a") as fout: | |
| fout.write(json.dumps(data) + "\n") | |
| append_json_item_on_log_server(data, get_conv_log_filename()) | |
| # src_img_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_src.png' | |
| # os.makedirs(os.path.dirname(src_img_file), exist_ok=True) | |
| # with open(src_img_file, 'w') as f: | |
| # state.source_image.save(f, 'PNG') | |
| # output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_out.png' | |
| # with open(output_file, 'w') as f: | |
| # state.output.save(f, 'PNG') | |
| # save_image_file_on_log_server(src_img_file) | |
| # save_image_file_on_log_server(output_file) | |
| def generate_i2s_multi(gen_func, render_func, | |
| state_0, state_1, | |
| image, | |
| model_name_0, model_name_1, | |
| request: gr.Request): | |
| if image is None: | |
| raise gr.Warning("Image cannot be empty.") | |
| if not model_name_0: | |
| raise gr.Warning("Model name A cannot be empty.") | |
| if not model_name_1: | |
| raise gr.Warning("Model name B cannot be empty.") | |
| if state_0 is None: | |
| state_0 = State(model_name_0, i2s_mode=True, offline=False) | |
| if state_1 is None: | |
| state_1 = State(model_name_1, i2s_mode=True, offline=False) | |
| ip = get_ip(request) | |
| i2s_multi_logger.info(f"generate. ip: {ip}") | |
| state_0.model_name, state_1.model_name = model_name_0, model_name_1 | |
| state_0.image, state_1.image = image, image | |
| state_0.evaluted_dims, state_1.evaluted_dims = 0, 0 | |
| if state_0.offline and state_0.offline_idx: | |
| start_time = time.time() | |
| videos_0, videos_1 = gen_func(image, model_name_0, model_name_1, offline=state_0.offline, offline_idx=state_0.offline_idx) | |
| # normal_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4") | |
| # rgb_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4") | |
| # normal_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4") | |
| # rgb_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4") | |
| state_0.normal_video, state_0.rgb_video, state_0.geo_video = videos_0['normal'], videos_0['rgb'], videos_0['geo'] | |
| state_1.normal_video, state_1.rgb_video, state_1.geo_video = videos_1['normal'], videos_1['rgb'], videos_1['geo'] | |
| yield state_0, state_1, \ | |
| videos_0['geo'], videos_0['normal'], videos_0['rgb'], \ | |
| videos_1['geo'], videos_1['normal'], videos_1['rgb'], \ | |
| gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") | |
| finish_tstamp = time.time() | |
| # logger.info(f"===output===: {output}") | |
| data_0 = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name_0, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state_0.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "ip": get_ip(request), | |
| } | |
| data_1 = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name_1, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state_1.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "ip": get_ip(request), | |
| } | |
| else: | |
| start_time = time.time() | |
| shape_0, shape_1 = gen_func(image, model_name_0, model_name_1) | |
| generate_time = time.time() - start_time | |
| videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1) | |
| finish_time = time.time() | |
| render_time = finish_time - start_time - generate_time | |
| state_0.normal_video, state_0.rgb_video, state_0.geo_video = videos_0['normal'], videos_0['rgb'], videos_0['geo'] | |
| state_1.normal_video, state_1.rgb_video, state_1.geo_video = videos_1['normal'], videos_1['rgb'], videos_1['geo'] | |
| yield state_0, state_1, \ | |
| videos_0['geo'], videos_0['normal'], videos_0['rgb'], \ | |
| videos_1['geo'], videos_1['normal'], videos_1['rgb'] | |
| finish_tstamp = time.time() | |
| # logger.info(f"===output===: {output}") | |
| data_0 = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name_0, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state_0.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| "ip": get_ip(request), | |
| } | |
| data_1 = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name_1, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state_1.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| "ip": get_ip(request), | |
| } | |
| with open(get_conv_log_filename(), "a") as fout: | |
| fout.write(json.dumps(data_0) + "\n") | |
| fout.write(json.dumps(data_1) + "\n") | |
| append_json_item_on_log_server(data_0, get_conv_log_filename()) | |
| append_json_item_on_log_server(data_1, get_conv_log_filename()) | |
| # for i, state in enumerate([state_0, state_1]): | |
| # src_img_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_src.png' | |
| # os.makedirs(os.path.dirname(src_img_file), exist_ok=True) | |
| # with open(src_img_file, 'w') as f: | |
| # state.source_image.save(f, 'PNG') | |
| # output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_out.png' | |
| # with open(output_file, 'w') as f: | |
| # state.output.save(f, 'PNG') | |
| # save_image_file_on_log_server(src_img_file) | |
| # save_image_file_on_log_server(output_file) | |
| def generate_i2s_multi_annoy(gen_func, render_func, | |
| state_0, state_1, | |
| image, | |
| model_name_0, model_name_1, | |
| request: gr.Request): | |
| if image is None: | |
| raise gr.Warning("Image cannot be empty.") | |
| if state_0 is None: | |
| state_0 = State(model_name_0, i2s_mode=True, offline=False) | |
| if state_1 is None: | |
| state_1 = State(model_name_1, i2s_mode=True, offline=False) | |
| ip = get_ip(request) | |
| i2s_multi_logger.info(f"generate. ip: {ip}") | |
| state_0.model_name, state_1.model_name = model_name_0, model_name_1 | |
| state_0.image, state_1.image = image, image | |
| state_0.evaluted_dims, state_1.evaluted_dims = 0, 0 | |
| if state_0.offline and state_0.offline_idx and state_1.offline and state_1.offline_idx: | |
| start_time = time.time() | |
| videos_0, videos_1, model_name_0, model_name_1 = gen_func(image, model_name_0, model_name_1, | |
| i2s_model=True, offline=state_0.offline, offline_idx=state_0.offline_idx) | |
| # normal_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4") | |
| # rgb_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4") | |
| # normal_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4") | |
| # rgb_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4") | |
| state_0.model_name, state_1.model_name = model_name_0, model_name_1 | |
| state_0.normal_video, state_0.rgb_video, state_0.geo_video = videos_0['normal'], videos_0['rgb'], videos_0['geo'] | |
| state_1.normal_video, state_1.rgb_video, state_1.geo_video = videos_1['normal'], videos_1['rgb'], videos_1['geo'] | |
| yield state_0, state_1, \ | |
| videos_0['geo'], videos_0['normal'], videos_0['rgb'], \ | |
| videos_1['geo'], videos_1['normal'], videos_1['rgb'], \ | |
| gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") | |
| finish_tstamp = time.time() | |
| # logger.info(f"===output===: {output}") | |
| data_0 = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name_0, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state_0.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "ip": get_ip(request), | |
| } | |
| data_1 = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name_1, | |
| "type": "offline", | |
| "gen_params": {}, | |
| "state": state_1.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "ip": get_ip(request), | |
| } | |
| else: | |
| start_time = time.time() | |
| shape_0, shape_1 = gen_func(image, model_name_0, model_name_1, i2s_model=True) | |
| generate_time = time.time() - start_time | |
| videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1) | |
| finish_time = time.time() | |
| render_time = finish_time - start_time - generate_time | |
| state_0.normal_video, state_0.rgb_video, state_0.geo_video = videos_0['normal'], videos_0['rgb'], videos_0['geo'] | |
| state_1.normal_video, state_1.rgb_video, state_1.geo_video = videos_1['normal'], videos_1['rgb'], videos_1['geo'] | |
| yield state_0, state_1, \ | |
| videos_0['geo'], videos_0['normal'], videos_0['rgb'], \ | |
| videos_1['geo'], videos_1['normal'], videos_1['rgb'], \ | |
| gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}") | |
| finish_tstamp = time.time() | |
| # logger.info(f"===output===: {output}") | |
| data_0 = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name_0, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state_0.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| "ip": get_ip(request), | |
| } | |
| data_1 = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "model": model_name_1, | |
| "type": "online", | |
| "gen_params": {}, | |
| "state": state_1.dict(), | |
| "start": round(start_time, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "time": round(finish_time - start_time, 4), | |
| "generate_time": round(generate_time, 4), | |
| "render_time": round(render_time, 4), | |
| "ip": get_ip(request), | |
| } | |
| with open(get_conv_log_filename(), "a") as fout: | |
| fout.write(json.dumps(data_0) + "\n") | |
| fout.write(json.dumps(data_1) + "\n") | |
| append_json_item_on_log_server(data_0, get_conv_log_filename()) | |
| append_json_item_on_log_server(data_1, get_conv_log_filename()) | |
| # for i, state in enumerate([state_0, state_1]): | |
| # src_img_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_src.png' | |
| # os.makedirs(os.path.dirname(src_img_file), exist_ok=True) | |
| # with open(src_img_file, 'w') as f: | |
| # state.source_image.save(f, 'PNG') | |
| # output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_out.png' | |
| # with open(output_file, 'w') as f: | |
| # state.output.save(f, 'PNG') | |
| # save_image_file_on_log_server(src_img_file) | |
| # save_image_file_on_log_server(output_file) |