Spaces:
Runtime error
Runtime error
| import argparse | |
| from openai import OpenAI | |
| import ast | |
| import cairosvg | |
| import json | |
| import os | |
| import utils | |
| import traceback | |
| from dotenv import load_dotenv | |
| from PIL import Image | |
| from prompts import sketch_first_prompt, system_prompt, gt_example | |
| def call_argparse(): | |
| parser = argparse.ArgumentParser(description='Process Arguments') | |
| # General | |
| parser.add_argument('--concept_to_draw', type=str, default="cat") | |
| parser.add_argument('--path2save', type=str, default=f"results/test") | |
| parser.add_argument('--temperature', type=float, default=0.3) | |
| parser.add_argument('--model', type=str, default='o3') | |
| parser.add_argument('--gen_mode', type=str, default='generation', choices=['generation', 'completion']) | |
| # Grid params | |
| parser.add_argument('--res', type=int, default=50, help="the resolution of the grid is set to 50x50") | |
| parser.add_argument('--cell_size', type=int, default=12, help="size of each cell in the grid") | |
| parser.add_argument('--stroke_width', type=float, default=7.0) | |
| args = parser.parse_args() | |
| args.grid_size = (args.res + 1) * args.cell_size | |
| args.save_name = args.concept_to_draw.replace(" ", "_") | |
| args.path2save = f"{args.path2save}/{args.save_name}" | |
| if not os.path.exists(args.path2save): | |
| os.makedirs(args.path2save) | |
| with open(f"{args.path2save}/experiment_log.json", 'w') as json_file: | |
| json.dump([], json_file, indent=4) | |
| return args | |
| class SketchApp: | |
| def __init__(self, args): | |
| # General | |
| self.path2save = args.path2save | |
| self.target_concept = args.concept_to_draw | |
| # Grid related | |
| self.res = args.res | |
| self.num_cells = args.res | |
| self.cell_size = args.cell_size | |
| self.grid_size = (args.grid_size, args.grid_size) | |
| self.init_canvas, self.positions = utils.create_grid_image(res=args.res, cell_size=args.cell_size, header_size=args.cell_size) | |
| self.init_canvas_str = utils.image_to_str(self.init_canvas) | |
| self.cells_to_pixels_map = utils.cells_to_pixels(args.res, args.cell_size, header_size=args.cell_size) | |
| # SVG related | |
| self.stroke_width = args.stroke_width | |
| # LLM Setup (you need to provide your OPENAI_API_KEY in your .env file) | |
| # self.cache = False | |
| self.max_tokens = 3000 | |
| openai_key = os.getenv("OPENAI_API_KEY") | |
| self.client = OpenAI(api_key=openai_key) | |
| self.model = "gpt-4o" | |
| self.input_prompt = sketch_first_prompt.format(concept=args.concept_to_draw, gt_sketches_str=gt_example) | |
| self.gen_mode = args.gen_mode | |
| self.temperature = args.temperature | |
| def call_llm(self, system_message, other_msg, additional_args): | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[{"role": "system", "content": system_message}] + other_msg, | |
| max_tokens=self.max_tokens, | |
| temperature=self.temperature, | |
| stop=additional_args.get("stop", None) | |
| ) | |
| return response.choices[0].message.content | |
| def define_input_to_llm(self, msg_history, init_canvas_str, msg): | |
| content = [] | |
| if init_canvas_str is not None: | |
| content.append({ | |
| "type": "image_url", | |
| "image_url": "data:image/jpeg;base64," + init_canvas_str | |
| }) | |
| content.append({"type": "text", "text": msg}) | |
| other_msg = msg_history + [{"role": "user", "content": content}] | |
| return other_msg | |
| def get_response_from_llm( | |
| self, | |
| msg, | |
| system_message, | |
| msg_history=[], | |
| init_canvas_str=None, | |
| prefill_msg=None, | |
| # seed_mode="stochastic", | |
| stop=None, | |
| gen_mode="generation" | |
| ): | |
| additional_args = {} | |
| # if seed_mode == "deterministic": | |
| #additional_args["temperature"] = 0.0 | |
| #additional_args["top_k"] = 1 | |
| # if self.cache: | |
| # system_message = [{ | |
| # "type": "text", | |
| # "text": system_message, | |
| # "cache_control": {"type": "ephemeral"} | |
| # }] | |
| # other_msg should contain all messgae without the system prompt | |
| other_msg = self.define_input_to_llm(msg_history, init_canvas_str, msg) | |
| if gen_mode == "completion": | |
| if prefill_msg: | |
| other_msg = other_msg + [{"role": "assistant", "content": f"{prefill_msg}"}] | |
| # In case of stroke by stroke generation | |
| if stop: | |
| additional_args["stop"]= stop | |
| else: | |
| additional_args["stop"]= ["</answer>"] | |
| response = self.call_llm(system_message, other_msg, additional_args) | |
| content = response | |
| if gen_mode == "completion": | |
| other_msg = other_msg[:-1] # remove initial assistant prompt | |
| content = f"{prefill_msg}{content}" | |
| # saves to json | |
| if self.path2save is not None: | |
| system_message_json = [{"role": "system", "content": system_message}] | |
| new_msg_history = other_msg + [ | |
| { | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": content, | |
| } | |
| ], | |
| } | |
| ] | |
| with open(f"{self.path2save}/experiment_log.json", 'w') as json_file: | |
| json.dump(system_message_json + new_msg_history, json_file, indent=4) | |
| print(f"Data has been saved to [{self.path2save}/experiment_log.json]") | |
| return content | |
| def call_model_for_sketch_generation(self): | |
| print("Calling LLM...") | |
| add_args = {} | |
| add_args["stop"] = "</answer>" | |
| msg_history = [] | |
| init_canvas_str = None # self.init_canvas_str | |
| all_llm_output = self.get_response_from_llm( | |
| msg=self.input_prompt, | |
| system_message=system_prompt.format(res=self.res), | |
| msg_history=msg_history, | |
| init_canvas_str=init_canvas_str, | |
| #seed_mode=self.seed_mode, | |
| gen_mode=self.gen_mode, | |
| **add_args | |
| ) | |
| all_llm_output += f"</answer>" | |
| return all_llm_output | |
| def parse_model_to_svg(self, model_rep_sketch): | |
| # Parse model_rep with xml | |
| strokes_list_str, t_values_str = utils.parse_xml_string(model_rep_sketch, self.res) | |
| strokes_list, t_values = ast.literal_eval(strokes_list_str), ast.literal_eval(t_values_str) | |
| # extract control points from sampled lists | |
| all_control_points = utils.get_control_points(strokes_list, t_values, self.cells_to_pixels_map) | |
| # define SVG based on control point | |
| sketch_text_svg = utils.format_svg(all_control_points, dim=self.grid_size, stroke_width=self.stroke_width) | |
| return sketch_text_svg | |
| def generate_sketch(self): | |
| sketching_commands = self.call_model_for_sketch_generation() | |
| model_strokes_svg = self.parse_model_to_svg(sketching_commands) | |
| # saved the SVG sketch | |
| with open(f"{self.path2save}/{self.target_concept}.svg", "w") as svg_file: | |
| svg_file.write(model_strokes_svg) | |
| # vector->pixel | |
| # save the sketch to png with blank backgournd | |
| cairosvg.svg2png(url=f"{self.path2save}/{self.target_concept}.svg", write_to=f"{self.path2save}/{self.target_concept}.png", background_color="white") | |
| # save the sketch to png on the canvas | |
| output_png_path = f"{self.path2save}/{self.target_concept}_canvas.png" | |
| cairosvg.svg2png(url=f"{self.path2save}/{self.target_concept}.svg", write_to=output_png_path) | |
| foreground = Image.open(output_png_path) | |
| self.init_canvas.paste(Image.open(output_png_path), (0, 0), foreground) | |
| self.init_canvas.save(output_png_path) | |
| # Initialize and run the SketchApp | |
| if __name__ == '__main__': | |
| args = call_argparse() | |
| sketch_app = SketchApp(args) | |
| for attempts in range(3): | |
| try: | |
| sketch_app.generate_sketch() | |
| exit(0) | |
| except Exception as e: | |
| print(f"An error has occurred: {e}") | |
| traceback.print_exc() |