Update app.py
Browse files
app.py
CHANGED
|
@@ -3,6 +3,7 @@ import cv2
|
|
| 3 |
import torch
|
| 4 |
import os
|
| 5 |
import numpy as np
|
|
|
|
| 6 |
from torchvision.models.detection import FasterRCNN
|
| 7 |
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
| 8 |
|
|
@@ -22,20 +23,6 @@ def load_model(model_path, backbone_name, num_classes):
|
|
| 22 |
model.to(device)
|
| 23 |
model.eval()
|
| 24 |
return model
|
| 25 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 26 |
-
if backbone_name == "resnet50":
|
| 27 |
-
model = torch.load(model_path, map_location=device)
|
| 28 |
-
elif backbone_name == "mobilenet":
|
| 29 |
-
model = torch.load(model_path, map_location=device)
|
| 30 |
-
model.to(device)
|
| 31 |
-
model.eval()
|
| 32 |
-
return model
|
| 33 |
-
if backbone_name == "resnet50":
|
| 34 |
-
model = torch.load(model_path)
|
| 35 |
-
elif backbone_name == "mobilenet":
|
| 36 |
-
model = torch.load(model_path)
|
| 37 |
-
model.eval()
|
| 38 |
-
return model
|
| 39 |
|
| 40 |
resnet_model = load_model('fasterrcnnResnet.pth', 'resnet50', num_classes=6)
|
| 41 |
mobilenet_model = load_model('fasterrcnnMobilenet.pth', 'mobilenet', num_classes=6)
|
|
@@ -101,4 +88,4 @@ video_interface = gr.Interface(
|
|
| 101 |
title="Video Inference"
|
| 102 |
)
|
| 103 |
|
| 104 |
-
gr.TabbedInterface([image_interface, video_interface], ["Image", "Video"]).launch()
|
|
|
|
| 3 |
import torch
|
| 4 |
import os
|
| 5 |
import numpy as np
|
| 6 |
+
import torchvision
|
| 7 |
from torchvision.models.detection import FasterRCNN
|
| 8 |
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
| 9 |
|
|
|
|
| 23 |
model.to(device)
|
| 24 |
model.eval()
|
| 25 |
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
resnet_model = load_model('fasterrcnnResnet.pth', 'resnet50', num_classes=6)
|
| 28 |
mobilenet_model = load_model('fasterrcnnMobilenet.pth', 'mobilenet', num_classes=6)
|
|
|
|
| 88 |
title="Video Inference"
|
| 89 |
)
|
| 90 |
|
| 91 |
+
gr.TabbedInterface([image_interface, video_interface], ["Image", "Video"]).launch()
|