ybbwcwaps
commited on
Commit
·
9011f2d
1
Parent(s):
2fc658c
ui
Browse files
app.py
CHANGED
|
@@ -1,14 +1,27 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from run import detect_video
|
|
|
|
|
|
|
| 3 |
|
| 4 |
def greet(video):
|
| 5 |
print(video, type(video))
|
| 6 |
-
pred = detect_video(video_path=video)
|
| 7 |
if pred > 0.5:
|
| 8 |
string = f"Fake: {pred*100:.2f}%"
|
| 9 |
else:
|
| 10 |
string = f"Real: {(1-pred)*100:.2f}%"
|
|
|
|
| 11 |
return string
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from run import get_model, detect_video
|
| 3 |
+
|
| 4 |
+
model = get_model()
|
| 5 |
|
| 6 |
def greet(video):
|
| 7 |
print(video, type(video))
|
| 8 |
+
pred = detect_video(video_path=video, model=model)
|
| 9 |
if pred > 0.5:
|
| 10 |
string = f"Fake: {pred*100:.2f}%"
|
| 11 |
else:
|
| 12 |
string = f"Real: {(1-pred)*100:.2f}%"
|
| 13 |
+
print(string)
|
| 14 |
return string
|
| 15 |
|
| 16 |
+
with gr.Blocks() as demo:
|
| 17 |
+
gr.Markdown("# Fake Video Detector")
|
| 18 |
+
with gr.Tabs():
|
| 19 |
+
with gr.TabItem("Video Detect"):
|
| 20 |
+
with gr.Row():
|
| 21 |
+
video_input = gr.Video()
|
| 22 |
+
detect_output = gr.Textbox("Output")
|
| 23 |
+
video_button = gr.Button("detect")
|
| 24 |
+
|
| 25 |
+
video_button.click(greet, inputs=video_input, outputs=detect_output)
|
| 26 |
+
|
| 27 |
+
demo.launch()
|
run.py
CHANGED
|
@@ -19,24 +19,25 @@ import options
|
|
| 19 |
from networks.validator import Validator
|
| 20 |
|
| 21 |
|
| 22 |
-
def
|
| 23 |
val_opt = options.TestOptions().parse(print_options=False)
|
| 24 |
-
|
| 25 |
output_dir=os.path.join(val_opt.output, val_opt.name)
|
| 26 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
| 27 |
# logger = create_logger(output_dir=output_dir, name="FakeVideoDetector")
|
| 28 |
print(f"working...")
|
| 29 |
|
| 30 |
model = Validator(val_opt)
|
| 31 |
model.load_state_dict(val_opt.ckpt)
|
| 32 |
print("ckpt loaded!")
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
|
| 35 |
frames, _, _ = read_video(str(video_path), pts_unit='sec')
|
| 36 |
frames = frames[:16]
|
| 37 |
frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W)
|
| 38 |
|
| 39 |
-
|
| 40 |
video_frames = torch.cat([model.clip_model.preprocess(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames])
|
| 41 |
|
| 42 |
with torch.no_grad():
|
|
@@ -73,7 +74,9 @@ if __name__ == '__main__':
|
|
| 73 |
|
| 74 |
# pred = model.model(model.input).view(-1).unsqueeze(1).sigmoid()
|
| 75 |
|
| 76 |
-
|
|
|
|
|
|
|
| 77 |
if pred > 0.5:
|
| 78 |
print(f"Fake: {pred*100:.2f}%")
|
| 79 |
else:
|
|
|
|
| 19 |
from networks.validator import Validator
|
| 20 |
|
| 21 |
|
| 22 |
+
def get_model():
|
| 23 |
val_opt = options.TestOptions().parse(print_options=False)
|
|
|
|
| 24 |
output_dir=os.path.join(val_opt.output, val_opt.name)
|
| 25 |
os.makedirs(output_dir, exist_ok=True)
|
| 26 |
+
|
| 27 |
# logger = create_logger(output_dir=output_dir, name="FakeVideoDetector")
|
| 28 |
print(f"working...")
|
| 29 |
|
| 30 |
model = Validator(val_opt)
|
| 31 |
model.load_state_dict(val_opt.ckpt)
|
| 32 |
print("ckpt loaded!")
|
| 33 |
+
return model
|
| 34 |
+
|
| 35 |
|
| 36 |
+
def detect_video(video_path, model):
|
| 37 |
frames, _, _ = read_video(str(video_path), pts_unit='sec')
|
| 38 |
frames = frames[:16]
|
| 39 |
frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W)
|
| 40 |
|
|
|
|
| 41 |
video_frames = torch.cat([model.clip_model.preprocess(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames])
|
| 42 |
|
| 43 |
with torch.no_grad():
|
|
|
|
| 74 |
|
| 75 |
# pred = model.model(model.input).view(-1).unsqueeze(1).sigmoid()
|
| 76 |
|
| 77 |
+
model = get_model()
|
| 78 |
+
|
| 79 |
+
pred = detect_video(video_path, model)
|
| 80 |
if pred > 0.5:
|
| 81 |
print(f"Fake: {pred*100:.2f}%")
|
| 82 |
else:
|