arjunanand13 commited on
Commit
30142b4
·
verified ·
1 Parent(s): 7af2d0e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -0
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ import os
6
+ import shutil
7
+ import gradio as gr
8
+ import mediapipe as mp
9
+ from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration, BitsAndBytesConfig
10
+
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf"
13
+
14
+ quantization_config = BitsAndBytesConfig(
15
+ load_in_4bit=True,
16
+ bnb_4bit_compute_dtype=torch.float16,
17
+ bnb_4bit_use_double_quant=True,
18
+ bnb_4bit_quant_type="nf4"
19
+ )
20
+
21
+ model = LlavaNextVideoForConditionalGeneration.from_pretrained(
22
+ model_id,
23
+ quantization_config=quantization_config,
24
+ low_cpu_mem_usage=True,
25
+ device_map="auto"
26
+ )
27
+
28
+ processor = LlavaNextVideoProcessor.from_pretrained(model_id)
29
+
30
+ mpHands = mp.solutions.hands
31
+ hands = mpHands.Hands(static_image_mode=True, max_num_hands=2)
32
+ mpDraw = mp.solutions.drawing_utils
33
+
34
+ def track_hand_position(frame):
35
+ height, width = frame.shape[:2]
36
+ mid_width = width // 2
37
+
38
+ imgRGB = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
39
+ results = hands.process(imgRGB)
40
+
41
+ hand_positions = []
42
+
43
+ if results.multi_hand_landmarks:
44
+ for handLms in results.multi_hand_landmarks:
45
+ cx_values = []
46
+ for lm in handLms.landmark:
47
+ cx = int(lm.x * width)
48
+ cx_values.append(cx)
49
+
50
+ avg_cx = sum(cx_values) / len(cx_values)
51
+
52
+ if avg_cx < mid_width:
53
+ hand_positions.append("Region A")
54
+ else:
55
+ hand_positions.append("Region B")
56
+
57
+ mpDraw.draw_landmarks(frame, handLms, mpHands.HAND_CONNECTIONS)
58
+
59
+ return frame, hand_positions
60
+
61
+ def add_regions_to_frame(frame, frame_idx, output_dir):
62
+ height, width = frame.shape[:2]
63
+ mid_width = width // 2
64
+
65
+ overlay = frame.copy()
66
+ cv2.rectangle(overlay, (0, 0), (mid_width, height), (255, 0, 0), -1)
67
+ cv2.rectangle(overlay, (mid_width, 0), (width, height), (0, 255, 0), -1)
68
+
69
+ frame = cv2.addWeighted(frame, 0.7, overlay, 0.3, 0)
70
+
71
+ cv2.line(frame, (mid_width, 0), (mid_width, height), (255, 255, 255), 3)
72
+
73
+ cv2.putText(frame, "Region A", (mid_width//4, height//2), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 3)
74
+ cv2.putText(frame, "Region B", (mid_width + mid_width//4, height//2), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 3)
75
+
76
+ tracked_frame, hand_pos = track_hand_position(frame.copy())
77
+
78
+ cv2.imwrite(f"{output_dir}/frame_{frame_idx:03d}.jpg", tracked_frame)
79
+
80
+ return tracked_frame, hand_pos
81
+
82
+ def sample_frames(video_path, num_frames):
83
+ output_dir = "/tmp/processed_frames"
84
+
85
+ if os.path.exists(output_dir):
86
+ shutil.rmtree(output_dir)
87
+ os.makedirs(output_dir)
88
+
89
+ video = cv2.VideoCapture(video_path)
90
+
91
+ if not video.isOpened():
92
+ raise ValueError(f"Could not open video file: {video_path}")
93
+
94
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
95
+ interval = max(1, total_frames // num_frames)
96
+ frames = []
97
+ frame_count = 0
98
+ hand_tracking_log = []
99
+
100
+ for i in range(total_frames):
101
+ ret, frame = video.read()
102
+ if not ret:
103
+ continue
104
+ if i % interval == 0 and len(frames) < num_frames:
105
+ processed_frame, hand_positions = add_regions_to_frame(frame, frame_count, output_dir)
106
+ pil_img = Image.fromarray(cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB))
107
+ frames.append(pil_img)
108
+ hand_tracking_log.append(f"Frame {frame_count}: {hand_positions}")
109
+ frame_count += 1
110
+
111
+ video.release()
112
+
113
+ frame_paths = [f"{output_dir}/frame_{i:03d}.jpg" for i in range(frame_count)]
114
+
115
+ return frames, frame_paths, hand_tracking_log
116
+
117
+ def analyze_video(video_path):
118
+ conversation = [
119
+ {
120
+ "role": "user",
121
+ "content": [
122
+ {"type": "text", "text": "Analyze this gas pipe quality control video and classify into one category: 1) PASSED - pipe taken from Region A, dipped in water, no bubbles, moved to Region B. Example: Person picks pipe from left side, tests in water, no bubbles seen, places in right side. 2) FAILED - pipe tested in water, bubbles visible. Example: Person dips pipe in water, bubbles appear indicating leak, pipe rejected. 3) CHEATING - pipe moved from A to B without testing. Example: Person takes pipe from left and directly places in right without water test. Give classification and brief reason."},
123
+ {"type": "video"},
124
+ ],
125
+ },
126
+ ]
127
+
128
+ prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
129
+
130
+ video_frames, frame_paths, hand_log = sample_frames(video_path, 8)
131
+
132
+ inputs = processor(text=prompt, videos=video_frames, padding=True)
133
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
134
+
135
+ output = model.generate(
136
+ **inputs,
137
+ max_new_tokens=150,
138
+ do_sample=True,
139
+ temperature=0.7,
140
+ top_p=0.9,
141
+ top_k=50,
142
+ repetition_penalty=1.1,
143
+ pad_token_id=processor.tokenizer.eos_token_id
144
+ )
145
+
146
+ result = processor.decode(output[0][2:], skip_special_tokens=True)
147
+
148
+ hand_tracking_summary = "\n".join(hand_log)
149
+
150
+ return frame_paths, result, hand_tracking_summary
151
+
152
+ examples = [
153
+ ["/front view/07.mp4"],
154
+ ["//front view/09.mp4"],
155
+ ["/front view/29.mp4"]
156
+ ]
157
+
158
+ iface = gr.Interface(
159
+ fn=analyze_video,
160
+ inputs=gr.Video(),
161
+ outputs=[
162
+ gr.Gallery(label="Processed Frames"),
163
+ gr.Textbox(label="LLM Analysis", lines=10),
164
+ gr.Textbox(label="Hand Tracking Log", lines=15)
165
+ ],
166
+ title="Gas Pipe Quality Control Analyzer",
167
+ examples=examples,
168
+ cache_examples=False
169
+ )
170
+
171
+ iface.launch(share=True)