Spaces:
Runtime error
Runtime error
Commit
·
b05da2d
1
Parent(s):
d04cd0a
add pose inference
Browse files- app.py +10 -3
- inference/pose.py +6 -2
- load_and_test.ipynb +23 -21
app.py
CHANGED
|
@@ -5,19 +5,26 @@ from PIL import Image
|
|
| 5 |
import cv2
|
| 6 |
import spaces
|
| 7 |
|
| 8 |
-
from inference.seg import process_image_or_video
|
|
|
|
| 9 |
from config import SAPIENS_LITE_MODELS_PATH
|
| 10 |
|
| 11 |
def update_model_choices(task):
|
| 12 |
model_choices = list(SAPIENS_LITE_MODELS_PATH[task.lower()].keys())
|
| 13 |
return gr.Dropdown(choices=model_choices, value=model_choices[0] if model_choices else None)
|
| 14 |
|
| 15 |
-
@spaces.GPU(
|
| 16 |
def process_image(input_image, task, version):
|
| 17 |
if isinstance(input_image, np.ndarray):
|
| 18 |
input_image = Image.fromarray(input_image)
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
return result
|
| 23 |
|
|
|
|
| 5 |
import cv2
|
| 6 |
import spaces
|
| 7 |
|
| 8 |
+
from inference.seg import process_image_or_video as process_seg
|
| 9 |
+
from inference.pose import process_image_or_video as process_pose
|
| 10 |
from config import SAPIENS_LITE_MODELS_PATH
|
| 11 |
|
| 12 |
def update_model_choices(task):
|
| 13 |
model_choices = list(SAPIENS_LITE_MODELS_PATH[task.lower()].keys())
|
| 14 |
return gr.Dropdown(choices=model_choices, value=model_choices[0] if model_choices else None)
|
| 15 |
|
| 16 |
+
@spaces.GPU()
|
| 17 |
def process_image(input_image, task, version):
|
| 18 |
if isinstance(input_image, np.ndarray):
|
| 19 |
input_image = Image.fromarray(input_image)
|
| 20 |
|
| 21 |
+
if task.lower() == 'seg':
|
| 22 |
+
result = process_seg(input_image, task=task.lower(), version=version)
|
| 23 |
+
elif task.lower() == 'pose':
|
| 24 |
+
result = process_pose(input_image, task=task.lower(), version=version)
|
| 25 |
+
else:
|
| 26 |
+
result = None
|
| 27 |
+
print(f"Tarea no soportada: {task}")
|
| 28 |
|
| 29 |
return result
|
| 30 |
|
inference/pose.py
CHANGED
|
@@ -90,6 +90,9 @@ def load_model(task, version):
|
|
| 90 |
try:
|
| 91 |
model_path = SAPIENS_LITE_MODELS_PATH[task][version]
|
| 92 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
| 93 |
model = torch.jit.load(model_path)
|
| 94 |
model.eval()
|
| 95 |
model.to(device)
|
|
@@ -109,6 +112,7 @@ def preprocess_image(image, input_shape):
|
|
| 109 |
return img.unsqueeze(0)
|
| 110 |
|
| 111 |
def udp_decode(heatmap, img_size, heatmap_size):
|
|
|
|
| 112 |
h, w = heatmap_size
|
| 113 |
keypoints = np.zeros((heatmap.shape[0], 2))
|
| 114 |
keypoint_scores = np.zeros(heatmap.shape[0])
|
|
@@ -133,8 +137,6 @@ def process_image_or_video(input_data, task='pose', version='sapiens_1b'):
|
|
| 133 |
if model is None or device is None:
|
| 134 |
return None
|
| 135 |
|
| 136 |
-
input_shape = (3, 1024, 768)
|
| 137 |
-
|
| 138 |
def process_frame(frame):
|
| 139 |
if isinstance(frame, np.ndarray):
|
| 140 |
frame = Image.fromarray(frame)
|
|
@@ -142,6 +144,8 @@ def process_image_or_video(input_data, task='pose', version='sapiens_1b'):
|
|
| 142 |
if frame.mode == 'RGBA':
|
| 143 |
frame = frame.convert('RGB')
|
| 144 |
|
|
|
|
|
|
|
| 145 |
img = preprocess_image(frame, input_shape)
|
| 146 |
|
| 147 |
with torch.no_grad():
|
|
|
|
| 90 |
try:
|
| 91 |
model_path = SAPIENS_LITE_MODELS_PATH[task][version]
|
| 92 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 93 |
+
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
|
| 94 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 95 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 96 |
model = torch.jit.load(model_path)
|
| 97 |
model.eval()
|
| 98 |
model.to(device)
|
|
|
|
| 112 |
return img.unsqueeze(0)
|
| 113 |
|
| 114 |
def udp_decode(heatmap, img_size, heatmap_size):
|
| 115 |
+
# This is a simplified version. You might need to implement the full UDP decode logic
|
| 116 |
h, w = heatmap_size
|
| 117 |
keypoints = np.zeros((heatmap.shape[0], 2))
|
| 118 |
keypoint_scores = np.zeros(heatmap.shape[0])
|
|
|
|
| 137 |
if model is None or device is None:
|
| 138 |
return None
|
| 139 |
|
|
|
|
|
|
|
| 140 |
def process_frame(frame):
|
| 141 |
if isinstance(frame, np.ndarray):
|
| 142 |
frame = Image.fromarray(frame)
|
|
|
|
| 144 |
if frame.mode == 'RGBA':
|
| 145 |
frame = frame.convert('RGB')
|
| 146 |
|
| 147 |
+
input_shape = (3, frame.height, frame.width)
|
| 148 |
+
|
| 149 |
img = preprocess_image(frame, input_shape)
|
| 150 |
|
| 151 |
with torch.no_grad():
|
load_and_test.ipynb
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
-
"execution_count":
|
| 6 |
"metadata": {},
|
| 7 |
"outputs": [
|
| 8 |
{
|
|
@@ -32,7 +32,7 @@
|
|
| 32 |
},
|
| 33 |
{
|
| 34 |
"cell_type": "code",
|
| 35 |
-
"execution_count":
|
| 36 |
"metadata": {},
|
| 37 |
"outputs": [],
|
| 38 |
"source": [
|
|
@@ -77,7 +77,7 @@
|
|
| 77 |
},
|
| 78 |
{
|
| 79 |
"cell_type": "code",
|
| 80 |
-
"execution_count":
|
| 81 |
"metadata": {},
|
| 82 |
"outputs": [
|
| 83 |
{
|
|
@@ -829,7 +829,7 @@
|
|
| 829 |
")"
|
| 830 |
]
|
| 831 |
},
|
| 832 |
-
"execution_count":
|
| 833 |
"metadata": {},
|
| 834 |
"output_type": "execute_result"
|
| 835 |
}
|
|
@@ -842,7 +842,7 @@
|
|
| 842 |
},
|
| 843 |
{
|
| 844 |
"cell_type": "code",
|
| 845 |
-
"execution_count":
|
| 846 |
"metadata": {},
|
| 847 |
"outputs": [],
|
| 848 |
"source": [
|
|
@@ -856,7 +856,7 @@
|
|
| 856 |
},
|
| 857 |
{
|
| 858 |
"cell_type": "code",
|
| 859 |
-
"execution_count":
|
| 860 |
"metadata": {},
|
| 861 |
"outputs": [],
|
| 862 |
"source": [
|
|
@@ -871,7 +871,7 @@
|
|
| 871 |
},
|
| 872 |
{
|
| 873 |
"cell_type": "code",
|
| 874 |
-
"execution_count":
|
| 875 |
"metadata": {},
|
| 876 |
"outputs": [],
|
| 877 |
"source": [
|
|
@@ -888,7 +888,7 @@
|
|
| 888 |
},
|
| 889 |
{
|
| 890 |
"cell_type": "code",
|
| 891 |
-
"execution_count":
|
| 892 |
"metadata": {},
|
| 893 |
"outputs": [
|
| 894 |
{
|
|
@@ -899,7 +899,7 @@
|
|
| 899 |
"<PIL.Image.Image image mode=RGB size=640x480>"
|
| 900 |
]
|
| 901 |
},
|
| 902 |
-
"execution_count":
|
| 903 |
"metadata": {},
|
| 904 |
"output_type": "execute_result"
|
| 905 |
}
|
|
@@ -917,7 +917,7 @@
|
|
| 917 |
},
|
| 918 |
{
|
| 919 |
"cell_type": "code",
|
| 920 |
-
"execution_count":
|
| 921 |
"metadata": {},
|
| 922 |
"outputs": [],
|
| 923 |
"source": [
|
|
@@ -926,7 +926,7 @@
|
|
| 926 |
},
|
| 927 |
{
|
| 928 |
"cell_type": "code",
|
| 929 |
-
"execution_count":
|
| 930 |
"metadata": {},
|
| 931 |
"outputs": [
|
| 932 |
{
|
|
@@ -937,7 +937,7 @@
|
|
| 937 |
"<PIL.Image.Image image mode=RGB size=1024x1024>"
|
| 938 |
]
|
| 939 |
},
|
| 940 |
-
"execution_count":
|
| 941 |
"metadata": {},
|
| 942 |
"output_type": "execute_result"
|
| 943 |
}
|
|
@@ -955,7 +955,7 @@
|
|
| 955 |
},
|
| 956 |
{
|
| 957 |
"cell_type": "code",
|
| 958 |
-
"execution_count":
|
| 959 |
"metadata": {},
|
| 960 |
"outputs": [
|
| 961 |
{
|
|
@@ -977,7 +977,7 @@
|
|
| 977 |
},
|
| 978 |
{
|
| 979 |
"cell_type": "code",
|
| 980 |
-
"execution_count":
|
| 981 |
"metadata": {},
|
| 982 |
"outputs": [
|
| 983 |
{
|
|
@@ -2188,7 +2188,7 @@
|
|
| 2188 |
")"
|
| 2189 |
]
|
| 2190 |
},
|
| 2191 |
-
"execution_count":
|
| 2192 |
"metadata": {},
|
| 2193 |
"output_type": "execute_result"
|
| 2194 |
}
|
|
@@ -2230,7 +2230,7 @@
|
|
| 2230 |
},
|
| 2231 |
{
|
| 2232 |
"cell_type": "code",
|
| 2233 |
-
"execution_count":
|
| 2234 |
"metadata": {},
|
| 2235 |
"outputs": [],
|
| 2236 |
"source": [
|
|
@@ -2292,7 +2292,7 @@
|
|
| 2292 |
},
|
| 2293 |
{
|
| 2294 |
"cell_type": "code",
|
| 2295 |
-
"execution_count":
|
| 2296 |
"metadata": {},
|
| 2297 |
"outputs": [
|
| 2298 |
{
|
|
@@ -2303,7 +2303,7 @@
|
|
| 2303 |
"<PIL.Image.Image image mode=RGB size=640x480>"
|
| 2304 |
]
|
| 2305 |
},
|
| 2306 |
-
"execution_count":
|
| 2307 |
"metadata": {},
|
| 2308 |
"output_type": "execute_result"
|
| 2309 |
}
|
|
@@ -2321,16 +2321,18 @@
|
|
| 2321 |
},
|
| 2322 |
{
|
| 2323 |
"cell_type": "code",
|
| 2324 |
-
"execution_count":
|
| 2325 |
"metadata": {},
|
| 2326 |
"outputs": [],
|
| 2327 |
"source": [
|
|
|
|
|
|
|
| 2328 |
"output_pose = get_pose(resized_pil_image, model)"
|
| 2329 |
]
|
| 2330 |
},
|
| 2331 |
{
|
| 2332 |
"cell_type": "code",
|
| 2333 |
-
"execution_count":
|
| 2334 |
"metadata": {},
|
| 2335 |
"outputs": [
|
| 2336 |
{
|
|
@@ -2341,7 +2343,7 @@
|
|
| 2341 |
"<PIL.Image.Image image mode=RGB size=640x480>"
|
| 2342 |
]
|
| 2343 |
},
|
| 2344 |
-
"execution_count":
|
| 2345 |
"metadata": {},
|
| 2346 |
"output_type": "execute_result"
|
| 2347 |
}
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
"metadata": {},
|
| 7 |
"outputs": [
|
| 8 |
{
|
|
|
|
| 32 |
},
|
| 33 |
{
|
| 34 |
"cell_type": "code",
|
| 35 |
+
"execution_count": 2,
|
| 36 |
"metadata": {},
|
| 37 |
"outputs": [],
|
| 38 |
"source": [
|
|
|
|
| 77 |
},
|
| 78 |
{
|
| 79 |
"cell_type": "code",
|
| 80 |
+
"execution_count": 5,
|
| 81 |
"metadata": {},
|
| 82 |
"outputs": [
|
| 83 |
{
|
|
|
|
| 829 |
")"
|
| 830 |
]
|
| 831 |
},
|
| 832 |
+
"execution_count": 5,
|
| 833 |
"metadata": {},
|
| 834 |
"output_type": "execute_result"
|
| 835 |
}
|
|
|
|
| 842 |
},
|
| 843 |
{
|
| 844 |
"cell_type": "code",
|
| 845 |
+
"execution_count": 6,
|
| 846 |
"metadata": {},
|
| 847 |
"outputs": [],
|
| 848 |
"source": [
|
|
|
|
| 856 |
},
|
| 857 |
{
|
| 858 |
"cell_type": "code",
|
| 859 |
+
"execution_count": 7,
|
| 860 |
"metadata": {},
|
| 861 |
"outputs": [],
|
| 862 |
"source": [
|
|
|
|
| 871 |
},
|
| 872 |
{
|
| 873 |
"cell_type": "code",
|
| 874 |
+
"execution_count": 8,
|
| 875 |
"metadata": {},
|
| 876 |
"outputs": [],
|
| 877 |
"source": [
|
|
|
|
| 888 |
},
|
| 889 |
{
|
| 890 |
"cell_type": "code",
|
| 891 |
+
"execution_count": 9,
|
| 892 |
"metadata": {},
|
| 893 |
"outputs": [
|
| 894 |
{
|
|
|
|
| 899 |
"<PIL.Image.Image image mode=RGB size=640x480>"
|
| 900 |
]
|
| 901 |
},
|
| 902 |
+
"execution_count": 9,
|
| 903 |
"metadata": {},
|
| 904 |
"output_type": "execute_result"
|
| 905 |
}
|
|
|
|
| 917 |
},
|
| 918 |
{
|
| 919 |
"cell_type": "code",
|
| 920 |
+
"execution_count": 10,
|
| 921 |
"metadata": {},
|
| 922 |
"outputs": [],
|
| 923 |
"source": [
|
|
|
|
| 926 |
},
|
| 927 |
{
|
| 928 |
"cell_type": "code",
|
| 929 |
+
"execution_count": 11,
|
| 930 |
"metadata": {},
|
| 931 |
"outputs": [
|
| 932 |
{
|
|
|
|
| 937 |
"<PIL.Image.Image image mode=RGB size=1024x1024>"
|
| 938 |
]
|
| 939 |
},
|
| 940 |
+
"execution_count": 11,
|
| 941 |
"metadata": {},
|
| 942 |
"output_type": "execute_result"
|
| 943 |
}
|
|
|
|
| 955 |
},
|
| 956 |
{
|
| 957 |
"cell_type": "code",
|
| 958 |
+
"execution_count": 13,
|
| 959 |
"metadata": {},
|
| 960 |
"outputs": [
|
| 961 |
{
|
|
|
|
| 977 |
},
|
| 978 |
{
|
| 979 |
"cell_type": "code",
|
| 980 |
+
"execution_count": 14,
|
| 981 |
"metadata": {},
|
| 982 |
"outputs": [
|
| 983 |
{
|
|
|
|
| 2188 |
")"
|
| 2189 |
]
|
| 2190 |
},
|
| 2191 |
+
"execution_count": 14,
|
| 2192 |
"metadata": {},
|
| 2193 |
"output_type": "execute_result"
|
| 2194 |
}
|
|
|
|
| 2230 |
},
|
| 2231 |
{
|
| 2232 |
"cell_type": "code",
|
| 2233 |
+
"execution_count": 15,
|
| 2234 |
"metadata": {},
|
| 2235 |
"outputs": [],
|
| 2236 |
"source": [
|
|
|
|
| 2292 |
},
|
| 2293 |
{
|
| 2294 |
"cell_type": "code",
|
| 2295 |
+
"execution_count": 16,
|
| 2296 |
"metadata": {},
|
| 2297 |
"outputs": [
|
| 2298 |
{
|
|
|
|
| 2303 |
"<PIL.Image.Image image mode=RGB size=640x480>"
|
| 2304 |
]
|
| 2305 |
},
|
| 2306 |
+
"execution_count": 16,
|
| 2307 |
"metadata": {},
|
| 2308 |
"output_type": "execute_result"
|
| 2309 |
}
|
|
|
|
| 2321 |
},
|
| 2322 |
{
|
| 2323 |
"cell_type": "code",
|
| 2324 |
+
"execution_count": 18,
|
| 2325 |
"metadata": {},
|
| 2326 |
"outputs": [],
|
| 2327 |
"source": [
|
| 2328 |
+
"from PIL import Image, ImageDraw\n",
|
| 2329 |
+
"\n",
|
| 2330 |
"output_pose = get_pose(resized_pil_image, model)"
|
| 2331 |
]
|
| 2332 |
},
|
| 2333 |
{
|
| 2334 |
"cell_type": "code",
|
| 2335 |
+
"execution_count": 20,
|
| 2336 |
"metadata": {},
|
| 2337 |
"outputs": [
|
| 2338 |
{
|
|
|
|
| 2343 |
"<PIL.Image.Image image mode=RGB size=640x480>"
|
| 2344 |
]
|
| 2345 |
},
|
| 2346 |
+
"execution_count": 20,
|
| 2347 |
"metadata": {},
|
| 2348 |
"output_type": "execute_result"
|
| 2349 |
}
|