Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import os | |
| import cv2 | |
| import argparse | |
| import numpy as np | |
| import gradio as gr | |
| import numpy as np | |
| from tqdm import tqdm | |
| from PIL import Image, ImageEnhance | |
| import torch | |
| from torch.amp import autocast | |
| import torch.nn.functional as F | |
| from network.line_extractor import LineExtractor | |
| def resize(image, max_size=3840): | |
| h, w = image.shape[:2] | |
| if h > w: | |
| h, w = (max_size, int(w * max_size / h)) | |
| else: | |
| h, w = (int(h * max_size / w), max_size) | |
| return cv2.resize(image, (w, h)) | |
| def increase_sharpness(img, factor=6.0): | |
| image = Image.fromarray(img) | |
| enhancer = ImageEnhance.Sharpness(image) | |
| return np.array(enhancer.enhance(factor)) | |
| def load_model(mode): | |
| if mode == 'basic': | |
| model = LineExtractor(3, 1, True) | |
| elif mode == 'detail': | |
| model = LineExtractor(2, 1, True) | |
| path_model = os.path.join('weights', f'{mode}.pth') | |
| model.load_state_dict(torch.load(path_model, weights_only=True)) | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| model.eval() | |
| return model | |
| def process_image(image, mode, binarize, threshold, fp16=True): | |
| if image is None: | |
| return None | |
| binarize_value = threshold if binarize else -1 | |
| args = argparse.Namespace(mode=mode, binarize=binarize_value, fp16=fp16, device="cuda:0") | |
| image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| if image.shape[0] > 1920 or image.shape[1] > 1920: | |
| image = resize(image) | |
| return inference(image, args) | |
| def process_video(path_in, path_out, fourcc='mp4v', **kwargs): | |
| video = cv2.VideoCapture(path_in) | |
| fps = video.get(cv2.CAP_PROP_FPS) | |
| width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fourcc = cv2.VideoWriter_fourcc(*fourcc) | |
| video_out = cv2.VideoWriter(path_out, fourcc, fps, (width, height)) | |
| for _ in tqdm(range(total_frames), desc='Processing Video'): | |
| ret, frame = video.read() | |
| if not ret: | |
| break | |
| img = inference(frame, **kwargs) | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
| video_out.write(img) | |
| video.release() | |
| video_out.release() | |
| def inference(img: np.ndarray, args): | |
| if args.mode == 'basic': | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| img = increase_sharpness(img) | |
| img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float().to(args.device) / 255. | |
| x_in = img | |
| else: | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| sobelx = cv2.Sobel(img, cv2.CV_64F, 1, 0, ksize=3) | |
| sobely = cv2.Sobel(img, cv2.CV_64F, 0, 1, ksize=3) | |
| sobel = cv2.magnitude(sobelx, sobely) | |
| sobel = 255 - cv2.normalize(sobel, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8UC1) | |
| img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float().to(args.device) / 255. | |
| sobel = torch.from_numpy(sobel).unsqueeze(0).unsqueeze(0).float().to(args.device) / 255. | |
| x_in = torch.cat([img, sobel], dim=1) | |
| B, C, H, W = x_in.shape | |
| pad_h = 8 - (H % 8) | |
| pad_w = 8 - (W % 8) | |
| x_in = F.pad(x_in, (0, pad_w, 0, pad_h), mode='reflect') | |
| with torch.no_grad(), autocast(enabled=args.fp16, device_type='cuda:0'): | |
| if args.mode == 'basic': | |
| pred = model_basic(x_in) | |
| elif args.mode == 'detail': | |
| pred = model_detail(x_in) | |
| pred = pred[:, :, :H, :W] | |
| if args.binarize != -1: | |
| pred = (pred > args.binarize).float() | |
| return np.clip((pred[0, 0].cpu().numpy() * 255) + 0.5, 0, 255).astype(np.uint8) | |
| model_basic = load_model("basic").to("cuda:0") | |
| model_detail = load_model("detail").to("cuda:0") | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# AniLines - Anime Line Extractor Demo") | |
| gr.Markdown("For video and batch processing, please refer to the [project page](https://github.com/zhenglinpan/AniLines-Anime-Line-Extractor)") | |
| with gr.Tabs(): | |
| with gr.Tab("Image Processing"): | |
| gr.Markdown("## Process Images") | |
| gr.Markdown("*Online demo resizes image to a max of 4K if larger.") | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| image_output = gr.Image(label="Processed Output") | |
| mode_dropdown = gr.Radio(["basic", "detail"], value="detail", label="Processing Mode") | |
| binarize_checkbox = gr.Checkbox(label="Binarize", value=False) | |
| binarize_slider = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.75, label="Binarization Threshold (-1 for auto)", visible=False) | |
| binarize_checkbox.change(lambda binarize: gr.update(visible=binarize), inputs=binarize_checkbox, outputs=binarize_slider) | |
| process_button = gr.Button("Process") | |
| gr.Examples( | |
| examples=["example.png", "example2.jpg"], | |
| inputs=image_input, | |
| outputs=image_input | |
| ) | |
| process_button.click(process_image, | |
| inputs=[image_input, mode_dropdown, binarize_checkbox, binarize_slider], | |
| outputs=image_output) | |
| demo.queue().launch() | |