nishanth-saka commited on
Commit
0ce4a01
ยท
verified ยท
1 Parent(s): 9a79af1

speed + confidence labels on the output video

Browse files
Files changed (1) hide show
  1. app.py +161 -73
app.py CHANGED
@@ -1,96 +1,184 @@
1
  # ============================================================
2
- # ๐Ÿš— Stage 4 โ€” Speed Calculation (Gradio UI)
3
  # ============================================================
4
 
5
  import gradio as gr
6
- import numpy as np, json, cv2, tempfile, os, math, time
7
- from collections import defaultdict, deque
 
 
 
8
 
9
  # ------------------------------------------------------------
10
- # ๐Ÿง  Configuration defaults
11
  # ------------------------------------------------------------
12
- DEFAULT_PIXEL_TO_METER = 0.05 # meters per pixel
13
- DEFAULT_SPEED_LIMIT = 60.0 # km/h
14
- WINDOW_SIZE = 5 # frames for moving average
15
 
16
  # ------------------------------------------------------------
17
- # ๐Ÿงฎ Compute speeds from trajectories JSON
18
  # ------------------------------------------------------------
19
- def compute_speeds(json_file, pixel_to_meter, speed_limit):
20
- try:
21
- data = json.load(open(json_file))
22
- except Exception as e:
23
- return None, {"error": f"Invalid JSON file: {e}"}
24
-
25
- # Input format: {track_id: [[x,y], [x,y], ...]}
26
- results = {}
27
- overlay = np.ones((600, 900, 3), dtype=np.uint8) * 40
28
-
29
- for tid, pts in data.items():
30
- pts = np.array(pts, dtype=float)
31
- if len(pts) < 2:
32
- continue
33
-
34
- # compute displacement per frame
35
- diffs = np.diff(pts, axis=0)
36
- dists_pix = np.linalg.norm(diffs, axis=1)
37
- # assume 30 FPS default โ†’ ฮ”t = 1/30 s
38
- speeds_m_s = (dists_pix * pixel_to_meter) * 30.0
39
- speeds_kmph = speeds_m_s * 3.6
40
- avg_speed = np.mean(speeds_kmph)
41
- status = "SPEEDING" if avg_speed > speed_limit else "OK"
42
-
43
- results[str(tid)] = {
44
- "avg_speed_kmph": round(float(avg_speed), 2),
45
- "max_speed_kmph": round(float(np.max(speeds_kmph)), 2),
46
- "status": status
47
- }
48
-
49
-
50
- # draw on overlay (simple visualization)
51
- color = (0, 0, 255) if status == "SPEEDING" else (0, 255, 0)
52
- start = tuple(np.int32(pts[0]))
53
- end = tuple(np.int32(pts[-1]))
54
- cv2.arrowedLine(overlay, start, end, color, 2, tipLength=0.2)
55
- cv2.putText(
56
- overlay,
57
- f"ID:{tid} {avg_speed:.1f}km/h",
58
- end,
59
- cv2.FONT_HERSHEY_SIMPLEX,
60
- 0.6,
61
- color,
62
- 2,
63
- )
64
-
65
- # Save overlay
66
- out_path = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False).name
67
- cv2.imwrite(out_path, overlay)
68
-
69
- return out_path, results
70
 
 
 
 
71
 
72
- # ------------------------------------------------------------
73
- # ๐Ÿ–ฅ๏ธ Gradio Interface
74
- # ------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  description = """
76
- ### ๐Ÿงฎ Stage 4 โ€” Speed Calculation
77
- Uploads the **trajectories JSON** (same format as Stage 1 output).
78
- Calculates per-vehicle speed (km/h) using a pixel-to-meter calibration and 30 FPS default.
79
- Color-codes SPEEDING vehicles (red).
 
 
 
 
 
80
  """
81
 
82
  demo = gr.Interface(
83
- fn=compute_speeds,
84
  inputs=[
85
- gr.File(label="Upload trajectories JSON"),
86
- gr.Slider(0.01, 0.2, value=DEFAULT_PIXEL_TO_METER, step=0.01, label="Pixel โ†’ Meter Conversion"),
87
- gr.Slider(10, 120, value=DEFAULT_SPEED_LIMIT, step=5, label="Speed Limit (km/h)"),
 
 
88
  ],
89
  outputs=[
90
- gr.Image(label="Speed Overlay"),
91
- gr.JSON(label="Speed Stats (Stage 4 Output)")
92
  ],
93
- title="๐Ÿš— Stage 4 โ€“ Speed Calculation",
94
  description=description,
95
  )
96
 
 
1
  # ============================================================
2
+ # ๐Ÿš— Stage 4 โ€” Speed Calculation (Video + Confidence + Filter)
3
  # ============================================================
4
 
5
  import gradio as gr
6
+ import cv2, os, json, tempfile, math, time
7
+ import numpy as np
8
+ from ultralytics import YOLO
9
+ from filterpy.kalman import KalmanFilter
10
+ from scipy.optimize import linear_sum_assignment
11
 
12
  # ------------------------------------------------------------
13
+ # ๐Ÿง  Safe-load fix for PyTorch 2.6
14
  # ------------------------------------------------------------
15
+ import torch, ultralytics.nn.tasks as ultralytics_tasks
16
+ torch.serialization.add_safe_globals([ulralytics_tasks.DetectionModel])
 
17
 
18
  # ------------------------------------------------------------
19
+ # โš™๏ธ Model + Config
20
  # ------------------------------------------------------------
21
+ MODEL_PATH = "yolov8n.pt"
22
+ model = YOLO(MODEL_PATH)
23
+ VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorcycle, bus, truck
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # Speed calculation constants
26
+ PIXEL_TO_METER = 0.05
27
+ FPS_DEFAULT = 30.0
28
 
29
+ # ============================================================
30
+ # ๐Ÿงฉ Kalman-based Tracker
31
+ # ============================================================
32
+ class Track:
33
+ def __init__(self, bbox, tid):
34
+ self.id = tid
35
+ self.kf = KalmanFilter(dim_x=4, dim_z=2)
36
+ self.kf.F = np.array([[1,0,1,0],[0,1,0,1],[0,0,1,0],[0,0,0,1]])
37
+ self.kf.H = np.array([[1,0,0,0],[0,1,0,0]])
38
+ self.kf.P *= 10
39
+ self.kf.R *= 1
40
+ self.kf.x[:2] = np.array(bbox[:2]).reshape(2,1)
41
+ self.history = []
42
+ self.frames_seen = 0
43
+ self.avg_speed = 0
44
+ self.confidence = 1.0
45
+ self.status = "OK"
46
+
47
+ def update(self, bbox):
48
+ self.kf.predict()
49
+ self.kf.update(np.array(bbox[:2]))
50
+ x, y = self.kf.x[:2].reshape(-1)
51
+ self.history.append([x, y])
52
+ if len(self.history) > 30:
53
+ self.history.pop(0)
54
+ self.frames_seen += 1
55
+ return [x, y]
56
+
57
+
58
+ # ============================================================
59
+ # ๐Ÿงฎ Utility Functions
60
+ # ============================================================
61
+ def compute_speed(track, fps, pixel_to_meter):
62
+ if len(track.history) < 2:
63
+ return 0.0
64
+ pts = np.array(track.history)
65
+ diffs = np.diff(pts, axis=0)
66
+ dists = np.linalg.norm(diffs, axis=1)
67
+ mean_pix_per_frame = np.mean(dists)
68
+ speed_m_s = mean_pix_per_frame * pixel_to_meter * fps
69
+ return speed_m_s * 3.6 # km/h
70
+
71
+
72
+ # ============================================================
73
+ # ๐Ÿš€ Main Processing Function
74
+ # ============================================================
75
+ def process_video(video_file, speed_limit, pixel_to_meter, confidence_filter, show_only_speeding):
76
+
77
+ cap = cv2.VideoCapture(video_file)
78
+ fps = cap.get(cv2.CAP_PROP_FPS) or FPS_DEFAULT
79
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
80
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
81
+ out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
82
+ out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
83
+
84
+ tracks, next_id = {}, 0
85
+ frame_no = 0
86
+ DELAY_FRAMES = 5
87
+ SPEED_SMOOTH_ALPHA = 0.5 # exponential moving average
88
+
89
+ while True:
90
+ ret, frame = cap.read()
91
+ if not ret:
92
+ break
93
+ frame_no += 1
94
+ results = model(frame)[0]
95
+ dets = []
96
+ for box in results.boxes:
97
+ cls = int(box.cls[0])
98
+ conf = float(box.conf[0])
99
+ if cls in VEHICLE_CLASSES:
100
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
101
+ cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
102
+ dets.append([cx, cy, conf])
103
+ dets = np.array(dets)
104
+
105
+ # --- Tracker update ---
106
+ assigned = set()
107
+ if len(dets) > 0 and len(tracks) > 0:
108
+ existing = np.array([t.kf.x[:2].reshape(-1) for t in tracks.values()])
109
+ dists = np.linalg.norm(existing[:, None, :] - dets[None, :, :2], axis=2)
110
+ row_idx, col_idx = linear_sum_assignment(dists)
111
+ for r, c in zip(row_idx, col_idx):
112
+ if dists[r, c] < 60:
113
+ tid = list(tracks.keys())[r]
114
+ tracks[tid].update(dets[c])
115
+ tracks[tid].confidence = float(dets[c][2])
116
+ assigned.add(c)
117
+ for i, d in enumerate(dets):
118
+ if i not in assigned:
119
+ tracks[next_id] = Track(d, next_id)
120
+ tracks[next_id].confidence = float(d[2])
121
+ next_id += 1
122
+
123
+ # --- Speed & Draw ---
124
+ for tid, trk in list(tracks.items()):
125
+ pos = trk.update(trk.kf.x[:2].reshape(-1))
126
+ if trk.frames_seen < DELAY_FRAMES:
127
+ continue
128
+
129
+ # compute speed
130
+ speed = compute_speed(trk, fps, pixel_to_meter)
131
+ # smooth speed
132
+ trk.avg_speed = SPEED_SMOOTH_ALPHA * speed + (1 - SPEED_SMOOTH_ALPHA) * trk.avg_speed
133
+ status = "SPEEDING" if trk.avg_speed > speed_limit else "OK"
134
+ trk.status = status
135
+
136
+ # skip by confidence filter
137
+ if trk.confidence < confidence_filter:
138
+ continue
139
+ if show_only_speeding and trk.status != "SPEEDING":
140
+ continue
141
+
142
+ color = (0, 0, 255) if status == "SPEEDING" else (0, 255, 0)
143
+ label = f"ID:{tid} {trk.avg_speed:.1f}km/h ({trk.confidence:.2f})"
144
+ cv2.putText(frame, label, tuple(np.int32(pos)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
145
+ cv2.circle(frame, tuple(np.int32(pos)), 4, color, -1)
146
+
147
+ out.write(frame)
148
+
149
+ cap.release()
150
+ out.release()
151
+ return out_path
152
+
153
+
154
+ # ============================================================
155
+ # ๐ŸŽ›๏ธ Gradio UI
156
+ # ============================================================
157
  description = """
158
+ ### ๐Ÿš— Stage 4 โ€” Speed Calculation (Video + Confidence + Filtering)
159
+ Uploads a traffic video, detects and tracks vehicles,
160
+ computes their **approximate speed**, and overlays **speed + confidence labels**.
161
+
162
+ **Controls:**
163
+ - ๐ŸŽš๏ธ Pixel โ†’ Meter conversion for calibration
164
+ - ๐Ÿšง Speed limit for violation tagging
165
+ - ๐Ÿง  Confidence threshold (hide low-confidence detections)
166
+ - ๐Ÿšจ Option to show only SPEEDING vehicles
167
  """
168
 
169
  demo = gr.Interface(
170
+ fn=process_video,
171
  inputs=[
172
+ gr.File(label="Upload Traffic Video (.mp4)"),
173
+ gr.Slider(10, 120, value=60, step=5, label="Speed Limit (km/h)"),
174
+ gr.Slider(0.01, 0.2, value=0.05, step=0.01, label="Pixel โ†’ Meter Conversion"),
175
+ gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="Confidence Filter (Show โ‰ฅ this value)"),
176
+ gr.Checkbox(label="Show ONLY Speeding Vehicles", value=False)
177
  ],
178
  outputs=[
179
+ gr.Video(label="Output Video (Speed Overlay)")
 
180
  ],
181
+ title="๐Ÿš— Stage 4 โ€“ Speed Calculation with Confidence & Filter",
182
  description=description,
183
  )
184