Spaces:
Runtime error
Runtime error
| import sys | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| warnings.simplefilter('ignore') | |
| import argparse | |
| import multiprocessing as mp | |
| import os | |
| import subprocess as sp | |
| from shutil import copyfile | |
| import numpy as np | |
| import torch | |
| from IPython.display import Image as Image_colab | |
| from IPython.display import display, SVG, clear_output | |
| from ipywidgets import IntSlider, Output, IntProgress, Button | |
| import time | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--target_file", type=str, | |
| help="target image file, located in <target_images>") | |
| parser.add_argument("--num_strokes", type=int, default=16, | |
| help="number of strokes used to generate the sketch, this defines the level of abstraction.") | |
| parser.add_argument("--num_iter", type=int, default=2001, | |
| help="number of iterations") | |
| parser.add_argument("--fix_scale", type=int, default=0, | |
| help="if the target image is not squared, it is recommended to fix the scale") | |
| parser.add_argument("--mask_object", type=int, default=0, | |
| help="if the target image contains background, it's better to mask it out") | |
| parser.add_argument("--num_sketches", type=int, default=3, | |
| help="it is recommended to draw 3 sketches and automatically chose the best one") | |
| parser.add_argument("--multiprocess", type=int, default=0, | |
| help="recommended to use multiprocess if your computer has enough memory") | |
| parser.add_argument('-colab', action='store_true') | |
| parser.add_argument('-cpu', action='store_true') | |
| parser.add_argument('-display', action='store_true') | |
| parser.add_argument('--gpunum', type=int, default=0) | |
| args = parser.parse_args() | |
| multiprocess = not args.colab and args.num_sketches > 1 and args.multiprocess | |
| abs_path = os.path.abspath(os.getcwd()) | |
| target = f"{abs_path}/target_images/{args.target_file}" | |
| assert os.path.isfile(target), f"{target} does not exists!" | |
| if not os.path.isfile(f"{abs_path}/U2Net_/saved_models/u2net.pth"): | |
| sp.run(["gdown", "https://drive.google.com/uc?id=1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ", | |
| "-O", "U2Net_/saved_models/"]) | |
| test_name = os.path.splitext(args.target_file)[0] | |
| output_dir = f"{abs_path}/output_sketches/{test_name}/" | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| num_iter = args.num_iter | |
| save_interval = 10 | |
| use_gpu = not args.cpu | |
| if not torch.cuda.is_available(): | |
| use_gpu = False | |
| print("CUDA is not configured with GPU, running with CPU instead.") | |
| print("Note that this will be very slow, it is recommended to use colab.") | |
| if args.colab: | |
| print("=" * 50) | |
| print(f"Processing [{args.target_file}] ...") | |
| if args.colab or args.display: | |
| img_ = Image_colab(target) | |
| display(img_) | |
| print(f"GPU: {use_gpu}, {torch.cuda.current_device()}") | |
| print(f"Results will be saved to \n[{output_dir}] ...") | |
| print("=" * 50) | |
| seeds = list(range(0, args.num_sketches * 1000, 1000)) | |
| exit_codes = [] | |
| manager = mp.Manager() | |
| losses_all = manager.dict() | |
| def run(seed, wandb_name): | |
| exit_code = sp.run(["python", "painterly_rendering.py", target, | |
| "--num_paths", str(args.num_strokes), | |
| "--output_dir", output_dir, | |
| "--wandb_name", wandb_name, | |
| "--num_iter", str(num_iter), | |
| "--save_interval", str(save_interval), | |
| "--seed", str(seed), | |
| "--use_gpu", str(int(use_gpu)), | |
| "--fix_scale", str(args.fix_scale), | |
| "--mask_object", str(args.mask_object), | |
| "--mask_object_attention", str( | |
| args.mask_object), | |
| "--display_logs", str(int(args.colab)), | |
| "--display", str(int(args.display))]) | |
| if exit_code.returncode: | |
| sys.exit(1) | |
| config = np.load(f"{output_dir}/{wandb_name}/config.npy", | |
| allow_pickle=True)[()] | |
| loss_eval = np.array(config['loss_eval']) | |
| inds = np.argsort(loss_eval) | |
| losses_all[wandb_name] = loss_eval[inds][0] | |
| def display_(seed, wandb_name): | |
| path_to_svg = f"{output_dir}/{wandb_name}/svg_logs/" | |
| intervals_ = list(range(0, num_iter, save_interval)) | |
| filename = f"svg_iter0.svg" | |
| display(IntSlider()) | |
| out = Output() | |
| display(out) | |
| for i in intervals_: | |
| filename = f"svg_iter{i}.svg" | |
| not_exist = True | |
| while not_exist: | |
| not_exist = not os.path.isfile(f"{path_to_svg}/{filename}") | |
| continue | |
| with out: | |
| clear_output() | |
| print("") | |
| display(IntProgress( | |
| value=i, | |
| min=0, | |
| max=num_iter, | |
| description='Processing:', | |
| bar_style='info', # 'success', 'info', 'warning', 'danger' or '' | |
| style={'bar_color': 'maroon'}, | |
| orientation='horizontal' | |
| )) | |
| display(SVG(f"{path_to_svg}/svg_iter{i}.svg")) | |
| if multiprocess: | |
| ncpus = 10 | |
| P = mp.Pool(ncpus) # Generate pool of workers | |
| for seed in seeds: | |
| wandb_name = f"{test_name}_{args.num_strokes}strokes_seed{seed}" | |
| if multiprocess: | |
| P.apply_async(run, (seed, wandb_name)) | |
| else: | |
| run(seed, wandb_name) | |
| if args.display: | |
| time.sleep(10) | |
| P.apply_async(display_, (0, f"{test_name}_{args.num_strokes}strokes_seed0")) | |
| if multiprocess: | |
| P.close() | |
| P.join() # start processes | |
| sorted_final = dict(sorted(losses_all.items(), key=lambda item: item[1])) | |
| copyfile(f"{output_dir}/{list(sorted_final.keys())[0]}/best_iter.svg", | |
| f"{output_dir}/{list(sorted_final.keys())[0]}_best.svg") | |