Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from PIL import Image | |
| import network | |
| import os | |
| import math | |
| import render_utils | |
| import paddle | |
| import paddle.nn as nn | |
| import paddle.nn.functional as F | |
| import cv2 | |
| import render_parallel | |
| import render_serial | |
| def main(input_path, model_path, output_dir, need_animation=False, resize_h=None, resize_w=None, serial=False): | |
| if not os.path.exists(output_dir): | |
| os.mkdir(output_dir) | |
| input_name = os.path.basename(input_path) | |
| output_path = os.path.join(output_dir, input_name) | |
| frame_dir = None | |
| if need_animation: | |
| if not serial: | |
| print('It must be under serial mode if animation results are required, so serial flag is set to True!') | |
| serial = True | |
| frame_dir = os.path.join(output_dir, input_name[:input_name.find('.')]) | |
| if not os.path.exists(frame_dir): | |
| os.mkdir(frame_dir) | |
| stroke_num = 8 | |
| #* ----- load model ----- *# | |
| # paddle.set_device('gpu') | |
| paddle.set_device('cpu') # 2021-12-21 jkang edited to "cpu" | |
| net_g = network.Painter(5, stroke_num, 256, 8, 3, 3) | |
| net_g.set_state_dict(paddle.load(model_path)) | |
| net_g.eval() | |
| for param in net_g.parameters(): | |
| param.stop_gradient = True | |
| #* ----- load brush ----- *# | |
| brush_large_vertical = render_utils.read_img('brush/brush_large_vertical.png', 'L') | |
| brush_large_horizontal = render_utils.read_img('brush/brush_large_horizontal.png', 'L') | |
| meta_brushes = paddle.concat([brush_large_vertical, brush_large_horizontal], axis=0) | |
| import time | |
| t0 = time.time() | |
| original_img = render_utils.read_img(input_path, 'RGB', resize_h, resize_w) | |
| if serial: | |
| final_result_list = render_serial.render_serial(original_img, net_g, meta_brushes) | |
| if need_animation: | |
| print("total frame:", len(final_result_list)) | |
| for idx, frame in enumerate(final_result_list): | |
| cv2.imwrite(os.path.join(frame_dir, '%03d.png' %idx), frame) | |
| else: | |
| cv2.imwrite(output_path, final_result_list[-1]) | |
| else: | |
| final_result = render_parallel.render_parallel(original_img, net_g, meta_brushes) | |
| cv2.imwrite(output_path, final_result) | |
| print("total infer time:", time.time() - t0) | |
| if __name__ == '__main__': | |
| main(input_path='input/chicago.jpg', | |
| model_path='paint_best.pdparams', | |
| output_dir='output/', | |
| need_animation=True, # whether need intermediate results for animation. | |
| resize_h=512, # resize original input to this size. None means do not resize. | |
| resize_w=512, # resize original input to this size. None means do not resize. | |
| serial=True) # if need animation, serial must be True. |