Jiaqi-hkust commited on
Commit
ef0f225
·
0 Parent(s):

Initial commit on new branch

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +36 -0
  2. README.md +199 -0
  3. app.py +234 -0
  4. configs/eval_configs/eval.yaml +39 -0
  5. configs/prompts/alignment_image.txt +4 -0
  6. configs/train_configs/stage1_pretrain.yaml +77 -0
  7. configs/train_configs/stage2_finetune.yaml +84 -0
  8. environment.yml +225 -0
  9. figs/demo.png +0 -0
  10. figs/examples/car.mp4 +3 -0
  11. figs/examples/explosion2.mp4 +0 -0
  12. figs/icon.png +0 -0
  13. figs/motivation1.png +0 -0
  14. hawk/__init__.py +31 -0
  15. hawk/common/__init__.py +0 -0
  16. hawk/common/config.py +468 -0
  17. hawk/common/dist_utils.py +137 -0
  18. hawk/common/gradcam.py +24 -0
  19. hawk/common/logger.py +195 -0
  20. hawk/common/optims.py +119 -0
  21. hawk/common/registry.py +329 -0
  22. hawk/common/utils.py +424 -0
  23. hawk/configs/datasets/instruct/llava_instruct.yaml +6 -0
  24. hawk/configs/datasets/instruct/webvid_instruct.yaml +6 -0
  25. hawk/configs/datasets/webvid/defaults.yaml +6 -0
  26. hawk/configs/default.yaml +5 -0
  27. hawk/configs/models/minigpt4.yaml +33 -0
  28. hawk/configs/models/video_llama.yaml +36 -0
  29. hawk/conversation/__init__.py +0 -0
  30. hawk/conversation/conversation_video.py +362 -0
  31. hawk/datasets/__init__.py +0 -0
  32. hawk/datasets/builders/__init__.py +77 -0
  33. hawk/datasets/builders/base_dataset_builder.py +236 -0
  34. hawk/datasets/builders/image_text_pair_builder.py +106 -0
  35. hawk/datasets/builders/instruct_builder.py +79 -0
  36. hawk/datasets/builders/video_caption_builder.py +34 -0
  37. hawk/datasets/data_utils.py +196 -0
  38. hawk/datasets/datasets/__init__.py +0 -0
  39. hawk/datasets/datasets/base_dataset.py +68 -0
  40. hawk/datasets/datasets/caption_datasets.py +85 -0
  41. hawk/datasets/datasets/dataloader_utils.py +162 -0
  42. hawk/datasets/datasets/llava_instruct_dataset.py +312 -0
  43. hawk/datasets/datasets/video_instruct_dataset.py +426 -0
  44. hawk/datasets/datasets/webvid_datasets.py +173 -0
  45. hawk/models/ImageBind/.assets/bird_audio.wav +0 -0
  46. hawk/models/ImageBind/.assets/bird_image.jpg +0 -0
  47. hawk/models/ImageBind/.assets/car_audio.wav +0 -0
  48. hawk/models/ImageBind/.assets/car_image.jpg +0 -0
  49. hawk/models/ImageBind/.assets/dog_audio.wav +0 -0
  50. hawk/models/ImageBind/.assets/dog_image.jpg +0 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # Hawk: Learning to Understand Open-World Video Anomalies
4
+
5
+ <div align="center">
6
+
7
+ ### This is the official repository for [Hawk](https://arxiv.org/pdf/2405.16886).
8
+
9
+ [Jiaqi Tang^](https://jqt.me/), [Hao Lu^](https://scholar.google.com/citations?user=OOagpAcAAAAJ&hl=en), [Ruizheng Wu](https://scholar.google.com/citations?user=OOagpAcAAAAJ&hl=en), [Xiaogang Xu](https://xuxiaogang.com/), [Ke Ma](https://scholar.google.com.hk/citations?user=yXGNGS8AAAAJ&hl=en), [Cheng Fang](),
10
+ \
11
+ [Bin Guo](http://www.guob.org/), [Jiangbo Lu](https://sites.google.com/site/jiangbolu), [Qifeng Chen](https://cqf.io/) and [Ying-Cong Chen*](https://www.yingcong.me/)
12
+
13
+ ^: Equal contribution.
14
+ *: Corresponding Author.
15
+
16
+ [![made-for-VSCode](https://img.shields.io/badge/Made%20for-VSCode-1f425f.svg)](https://code.visualstudio.com/) [![Visits Badge](https://badges.strrl.dev/visits/jqtangust/hawk)](https://badges.strrl.dev)
17
+
18
+
19
+
20
+ <img src="figs/icon.png" alt="Have eyes like a HAWK!" width="80">
21
+ </div>
22
+ </div>
23
+
24
+ ## 🔍 **Motivation** - Have eyes like a Hawk!
25
+ - 🚩 Current VAD systems are often limited by their superficial semantic understanding of scenes and minimal user interaction.
26
+ - 🚩 Additionally, the prevalent data scarcity in existing datasets restricts their applicability in open-world scenarios.
27
+
28
+ <div align="center">
29
+ <img src="figs/motivation1.png" alt="Hawk">
30
+ </div>
31
+
32
+
33
+ ## 📢 **Updates**
34
+
35
+ - ✅ Feb 24, 2025 - We release the **training and demo code** of **Hawk**.
36
+ - ✅ Feb 24, 2025 - We release the **dataset (video + annotation)** of **Hawk**. Check this Huggingface link for [DOWNLOAD](https://huggingface.co/datasets/Jiaqi-hkust/hawk).
37
+ - ✅ Step 26, 2024 - **Hawk** is accepted by NeurIPS 2024.
38
+ - ✅ June 29, 2024 - We release the **dataset (annotation)** of Hawk. Check this Google Cloud link for [DOWNLOAD](https://drive.google.com/file/d/1WCnizldWZvtS4Yg5SX7ay5C3kUQfz-Eg/view?usp=sharing).
39
+
40
+
41
+ ## ▶️ **Getting Started**
42
+
43
+ ### 🪒 *Installation*
44
+ - Create environment by following steps:
45
+ ```
46
+ apt install ffmpeg
47
+ conda env create -f environment.yml
48
+ conda activate hawk
49
+ ```
50
+
51
+ ### 🏰 *Pretrained and Fine-tuned Model*
52
+
53
+
54
+ - The following checkpoints are utilized to run Hawk:
55
+
56
+ | Checkpoint | Link | Note |
57
+ |:------------------|-------------|-------------|
58
+ | Video-LLaMA-2-7B-Finetuned | [link](https://huggingface.co/DAMO-NLP-SG/Video-LLaMA-2-7B-Finetuned/tree/main) | Used as initial weights for training.|
59
+ | **Hawk_Pretrained** | [link](https://huggingface.co/Jiaqi-hkust/hawk) | Pretrained on the [WebViD](https://github.com/m-bain/webvid)|
60
+ | **Hawk_Finetuned** | [link](https://huggingface.co/Jiaqi-hkust/hawk) | Fine-tuned on [Hawk dataset](https://huggingface.co/datasets/Jiaqi-hkust/hawk)|
61
+
62
+ - If you want to use the pretrained model, please use the **Hawk_Pretrained** checkpoint.
63
+ - If you wish to leverage the model for our anomaly understanding, please opt for the **Hawk_Finetuned** checkpoint.
64
+
65
+
66
+ ## ⏳ **Domo**
67
+
68
+ - The configuration files for [`demo`](/configs/eval_configs/eval.yaml).
69
+
70
+ - Replace the following part as your own path:
71
+ ```
72
+ # Use LLaMA-2-chat as base modal
73
+
74
+ # Some ckpts could be download from Video_LLaMA-2-7B-Finetuned
75
+ # https://huggingface.co/DAMO-NLP-SG/Video-LLaMA-2-7B-Finetuned
76
+ llama_model: ".../Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf"
77
+
78
+ # Hawk Weight (Pretrained or Finetuned)
79
+ ckpt: '.../checkpoint.pth'
80
+ ```
81
+
82
+ - Then, run the script:
83
+ ```
84
+ python app.py \
85
+ --cfg-path configs/eval_configs/eval.yaml \
86
+ --model_type llama_v2 \
87
+ --gpu-id 0
88
+ ```
89
+
90
+ - GUI
91
+ <div align="center">
92
+ <img src="figs/demo.png" alt="Hawk">
93
+ </div>
94
+
95
+ ## 🖥️ **Training**
96
+
97
+ ### 💾 *Dataset Preparation*
98
+
99
+ - **For your convenience, we now provide the video and annotations for the Hawk dataset. You can download them using the Hugglingface: [DOWNLOAD](https://huggingface.co/datasets/Jiaqi-hkust/hawk).**
100
+
101
+ - Traditional Data Acquisition Method:
102
+
103
+ - DOWNLOAD all video datasets for their original dources.
104
+ 1. [CUHK_Avenue](https://www.cse.cuhk.edu.hk/leojia/projects/detectabnormal/dataset.html)
105
+ 2. [DoTA](https://github.com/MoonBlvd/Detection-of-Traffic-Anomaly)
106
+ 3. [Ped1](http://www.svcl.ucsd.edu/projects/anomaly/dataset.htm)
107
+ 4. [Ped2](http://www.svcl.ucsd.edu/projects/anomaly/dataset.htm)
108
+ 5. [ShanghaiTech](https://svip-lab.github.io/dataset/campus_dataset.html)
109
+ 6. [UBNormal](https://github.com/lilygeorgescu/UBnormal/)
110
+ 7. [UCF_Crime](https://www.crcv.ucf.edu/projects/real-world/)
111
+
112
+ - Google Drive Link to [DOWNLOAD](https://drive.google.com/file/d/1WCnizldWZvtS4Yg5SX7ay5C3kUQfz-Eg/view?usp=sharing) our annotations.
113
+
114
+ - Data Structure: each forder contains one annotation file (e.g. CUHK Avenue, DoTA, etc.). The `All_Mix` directory contains all of datasets in training and testing.
115
+
116
+ - The dataset is organized as follows:
117
+
118
+ ```
119
+ (Hawk_data)
120
+
121
+ Annotation
122
+ ├── All_Mix
123
+ │ ├── all_videos_all.json
124
+ │ ���── all_videos_test.json
125
+ │ └── all_videos_train.json
126
+
127
+ ├── CUHK_Avenue
128
+ │ └── Avenue.json
129
+ ├── DoTA
130
+ │ └── DoTA.json
131
+ ├── Ped1
132
+ │ ├── ...
133
+ ├── ...
134
+ └── UCF_Crime
135
+ │ └── ...
136
+
137
+ Videos
138
+ ├── CUHK_Avenue
139
+ │ └── Avenue.json
140
+ ├── DoTA
141
+ │ └── DoTA.json
142
+ ├── Ped1
143
+ │ ├── ...
144
+ ├── ...
145
+
146
+ readme
147
+
148
+ ```
149
+ Note:the data path should be redefined.
150
+
151
+
152
+ ### 🔨 *Configuration*
153
+
154
+ - The configuration files for [`training`](/configs/train_configs) including two stages.
155
+
156
+ - Replace the following part as your own path:
157
+
158
+ ```
159
+ llama_model: ".../Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf"
160
+
161
+ # The ckpt of vision branch after stage1 pretrained, (only for stage 2)
162
+ ckpt: ".../checkpoint.pth"
163
+ ```
164
+
165
+ ### 🖥️ *To Train*
166
+
167
+ - Then, run the script:
168
+ ```
169
+ # for pretraining
170
+ NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port='10000' train.py --cfg-path ./configs/train_configs/stage1_pretrain.yaml
171
+
172
+ # for fine-tuning
173
+ NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port='12001' train.py --cfg-path ./configs/train_configs/stage2_finetune.yaml
174
+ ```
175
+
176
+ *Resource Usage: Training (stage 1 and stage 2): 4 * RTX A6000 48G*
177
+
178
+ ## 🌐 **Citations**
179
+
180
+ **The following is a BibTeX reference:**
181
+
182
+ ``` latex
183
+ @inproceedings{atang2024hawk,
184
+ title = {Hawk: Learning to Understand Open-World Video Anomalies},
185
+ author = {Tang, Jiaqi and Lu, Hao and Wu, Ruizheng and Xu, Xiaogang and Ma, Ke and Fang, Cheng and Guo, Bin and Lu, Jiangbo and Chen, Qifeng and Chen, Ying-Cong},
186
+ year = {2024},
187
+ booktitle = {Neural Information Processing Systems (NeurIPS)}
188
+ }
189
+ ```
190
+
191
+ ## 📧 **Connecting with Us?**
192
+
193
+ If you have any questions, please feel free to send email to `[email protected]`.
194
+
195
+
196
+ ## 📜 **Acknowledgment**
197
+ This work is supported by the National Natural Science Foundation of China (No. 62206068) and the Natural Science Foundation of Zhejiang Province, China under No. LD24F020002.
198
+
199
+ Also, this project is inspired by [Video-LLaMA](https://github.com/DAMO-NLP-SG/Video-LLaMA).
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Run the following command to start the demo:
3
+
4
+ python demo_video.py \
5
+ --cfg-path /remote-home/share/jiaqitang/Hawk_Ours/configs/eval_configs/eval.yaml \
6
+ --model_type llama_v2 \
7
+ --gpu-id 0
8
+ """
9
+
10
+ import argparse
11
+ import os
12
+ import random
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.backends.cudnn as cudnn
17
+ import gradio as gr
18
+
19
+ from hawk.common.config import Config
20
+ from hawk.common.dist_utils import get_rank
21
+ from hawk.common.registry import registry
22
+ from hawk.conversation.conversation_video import Chat, Conversation, default_conversation, SeparatorStyle,conv_llava_llama_2
23
+ import decord
24
+ decord.bridge.set_bridge('torch')
25
+
26
+ #%%
27
+ # imports modules for registration
28
+ from hawk.datasets.builders import *
29
+ from hawk.models import *
30
+ from hawk.processors import *
31
+ from hawk.runners import *
32
+ from hawk.tasks import *
33
+ import time
34
+
35
+
36
+ def parse_args():
37
+ parser = argparse.ArgumentParser(description="Demo")
38
+ parser.add_argument("--cfg-path", required=False, default='./configs/eval_configs/eval.yaml', help="path to configuration file.")
39
+ parser.add_argument("--gpu-id", type=int, default=6, help="specify the gpu to load the model.")
40
+ parser.add_argument("--model_type", type=str, default='llama_v2', help="The type of LLM")
41
+ parser.add_argument(
42
+ "--options",
43
+ nargs="+",
44
+ help="override some settings in the used config, the key-value pair "
45
+ "in xxx=yyy format will be merged into config file (deprecate), "
46
+ "change to --cfg-options instead.",
47
+ )
48
+ args = parser.parse_args()
49
+ return args
50
+
51
+
52
+ def setup_seeds(config):
53
+ seed = config.run_cfg.seed + get_rank()
54
+
55
+ random.seed(seed)
56
+ np.random.seed(seed)
57
+ torch.manual_seed(seed)
58
+
59
+ cudnn.benchmark = False
60
+ cudnn.deterministic = True
61
+
62
+
63
+ # ========================================
64
+ # Model Initialization
65
+ # ========================================
66
+
67
+ print('Initializing Chat')
68
+ args = parse_args()
69
+ cfg = Config(args)
70
+
71
+ model_config = cfg.model_cfg
72
+ model_config.device_8bit = args.gpu_id
73
+ model_cls = registry.get_model_class(model_config.arch)
74
+ model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
75
+ model.eval()
76
+ vis_processor_cfg = cfg.datasets_cfg.webvid.vis_processor.train
77
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
78
+ chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
79
+ print('Initialization Finished')
80
+
81
+ # ========================================
82
+ # Gradio Setting
83
+ # ========================================
84
+
85
+ def gradio_reset(chat_state, img_list):
86
+ if chat_state is not None:
87
+ chat_state.messages = []
88
+ if img_list is not None:
89
+ img_list = []
90
+ return None, gr.update(value=None, interactive=True), gr.update(interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
91
+
92
+ def upload_imgorvideo(gr_video, text_input, chat_state, chatbot):
93
+ # if args.model_type == 'vicuna':
94
+ # chat_state = default_conversation.copy()
95
+ # else:
96
+ chat_state = conv_llava_llama_2.copy()
97
+ if gr_video is None:
98
+ return None, None, None, gr.update(interactive=True), chat_state, None
99
+ # elif gr_img is not None and gr_video is None:
100
+ # print(gr_img)
101
+ # chatbot = chatbot + [((gr_img,), None)]
102
+ # chat_state.system = "You are able to understand the visual content that the user provides. Follow the instructions carefully and explain your answers in detail."
103
+ # img_list = []
104
+ # llm_message = chat.upload_img(gr_img, chat_state, img_list)
105
+ # return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list,chatbot
106
+ elif gr_video is not None:
107
+ print(gr_video)
108
+ chatbot = chatbot + [((gr_video,), None)]
109
+ chat_state.system = "You are able to understand the visual content that the user provides. Follow the instructions carefully and explain your answers in detail."
110
+ img_list = []
111
+ llm_message = chat.upload_video_without_audio(gr_video, chat_state, img_list)
112
+ return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list,chatbot
113
+ # else:
114
+ # # img_list = []
115
+ # return gr.update(interactive=False), gr.update(interactive=False, placeholder='Currently, only one input is supported'), gr.update(value="Currently, only one input is supported", interactive=False), chat_state, None,chatbot
116
+
117
+ def gradio_ask(user_message, chatbot, chat_state):
118
+ if len(user_message) == 0:
119
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
120
+ chat.ask(user_message, chat_state)
121
+ chatbot = chatbot + [[user_message, None]]
122
+ return '', chatbot, chat_state
123
+
124
+
125
+ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
126
+ llm_message = chat.answer(conv=chat_state,
127
+ img_list=img_list,
128
+ num_beams=num_beams,
129
+ temperature=temperature,
130
+ max_new_tokens=300,
131
+ max_length=2000)[0]
132
+ chatbot[-1][1] = llm_message
133
+ print(chat_state.get_prompt())
134
+ print(chat_state)
135
+ return chatbot, chat_state, img_list
136
+
137
+ title = """
138
+ <div align="center">
139
+ <h1>Hawk: Learning to Understand Open-World Video Anomalies</h1>
140
+ </div>
141
+
142
+ <h5 align="center"> "Have eyes like a Hawk!" </h5>
143
+
144
+ <div style="display: flex; justify-content: center; gap: 0.25rem;">
145
+ <a href='https://github.com/jqtangust/hawk'>
146
+ <img src='https://img.shields.io/badge/Github-Code-success' alt="GitHub Code">
147
+ </a>
148
+ <a href='https://huggingface.co/spaces/Jiaqi-hkust/hawk'>
149
+ <img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue' alt="Hugging Face Spaces">
150
+ </a>
151
+ <a href='https://huggingface.co/spaces/Jiaqi-hkust/hawk'>
152
+ <img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue' alt="Hugging Face Model">
153
+ </a>
154
+ <a href='https://arxiv.org/pdf/2405.16886'>
155
+ <img src='https://img.shields.io/badge/Paper-PDF-red' alt="Download Paper">
156
+ </a>
157
+ </div>
158
+
159
+ """
160
+
161
+ cite_markdown = ("""
162
+ ## Citation
163
+ The following is a BibTeX reference:
164
+ ```
165
+ @inproceedings{atang2024hawk,
166
+ title = {Hawk: Learning to Understand Open-World Video Anomalies},
167
+ author = {Tang, Jiaqi and Lu, Hao and Wu, Ruizheng and Xu, Xiaogang and Ma, Ke and Fang, Cheng and Guo, Bin and Lu, Jiangbo and Chen, Qifeng and Chen, Ying-Cong},
168
+ year = {2024},
169
+ booktitle = {Neural Information Processing Systems (NeurIPS)}
170
+ }
171
+ """)
172
+
173
+ # case_note_upload = ("""
174
+ # ### We provide some examples at the bottom of the page. Simply click on them to try them out directly.
175
+ # """)
176
+
177
+ #TODO show examples below
178
+
179
+ with gr.Blocks() as demo:
180
+ gr.Markdown(title)
181
+
182
+ with gr.Row():
183
+ with gr.Column(scale=0.5):
184
+ video = gr.Video()
185
+ # image = gr.Image(type="filepath")
186
+ # gr.Markdown(case_note_upload)
187
+
188
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
189
+ clear = gr.Button("Restart")
190
+
191
+ num_beams = gr.Slider(
192
+ minimum=1,
193
+ maximum=10,
194
+ value=1,
195
+ step=1,
196
+ interactive=True,
197
+ label="beam search numbers)",
198
+ )
199
+
200
+ temperature = gr.Slider(
201
+ minimum=0.1,
202
+ maximum=2.0,
203
+ value=1.0,
204
+ step=0.1,
205
+ interactive=True,
206
+ label="Temperature",
207
+ )
208
+ # audio = gr.Checkbox(interactive=True, value=False, label="Audio")
209
+ with gr.Column():
210
+ chat_state = gr.State()
211
+ img_list = gr.State()
212
+ chatbot = gr.Chatbot(label='Hawk')
213
+ text_input = gr.Textbox(label='User', placeholder='Upload your video first and start to chat.', interactive=False)
214
+
215
+
216
+ with gr.Column():
217
+ gr.Examples(examples=[
218
+ [f"figs/examples/explosion2.mp4", "What happened in this video? "],
219
+ [f"figs/examples/car.mp4", "What is the anomaly for the car in this video? "],
220
+ ], inputs=[video, text_input])
221
+
222
+ gr.Markdown(cite_markdown)
223
+ upload_button.click(upload_imgorvideo, [video, text_input, chat_state, chatbot], [video, text_input, upload_button, chat_state, img_list, chatbot])
224
+
225
+ start_time = time.time()
226
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
227
+ gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
228
+ )
229
+ end_time = time.time()
230
+ print('Time:', end_time - start_time)
231
+
232
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, video, text_input, upload_button, chat_state, img_list], queue=False)
233
+
234
+ demo.launch(share=False)
configs/eval_configs/eval.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: hawk
3
+ model_type: pretrain_llama_v2
4
+ freeze_vit: True
5
+ freeze_qformer: True
6
+ max_txt_len: 512
7
+ end_sym: "</s>"
8
+ low_resource: False
9
+
10
+ frozen_llama_proj: False
11
+
12
+ # Use LLaMA-2-chat as base modal
13
+
14
+ # some ckpts could be download from Video_LLaMA-2-7B-Finetuned
15
+ # https://huggingface.co/DAMO-NLP-SG/Video-LLaMA-2-7B-Finetuned
16
+ llama_model: "/remote-home/share/jiaqitang/Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf"
17
+
18
+ # Hawk Weight
19
+ ckpt: '/remote-home/share/jiaqitang/Hawk_Ours/hawk/output/hawk_finetune/20250221045/checkpoint_5.pth'
20
+
21
+ equip_audio_branch: False
22
+
23
+ fusion_head_layers: 2
24
+ max_frame_pos: 32
25
+ fusion_header_type: "seqTransf"
26
+
27
+ datasets:
28
+ webvid:
29
+ vis_processor:
30
+ train:
31
+ name: "alpro_video_eval"
32
+ n_frms: 32
33
+ image_size: 224
34
+ text_processor:
35
+ train:
36
+ name: "blip_caption"
37
+
38
+ run:
39
+ task: video_text_pretrain
configs/prompts/alignment_image.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <Image><ImageHere></Image> Describe this video in detail.
2
+ <Image><ImageHere></Image> Take a look at this video and describe what you notice.
3
+ <Image><ImageHere></Image> Please provide a detailed description of the video.
4
+ <Image><ImageHere></Image> Could you describe the contents of this video for me?
configs/train_configs/stage1_pretrain.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: hawk
3
+ model_type: pretrain_llama_v2
4
+ freeze_vit: True
5
+ freeze_qformer: True
6
+
7
+
8
+ # Q-Former
9
+ num_query_token: 32
10
+
11
+ # If you want train models based on LLaMA-2-chat,
12
+ # some ckpts could be download from our provided huggingface repo
13
+ # i.e. https://huggingface.co/DAMO-NLP-SG/Video-LLaMA-2-13B-Finetuned
14
+
15
+ llama_model: "/remote-home/share/jiaqitang/Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf"
16
+ # imagebind_ckpt_path: "/remote-home/share/jiaqitang/ImageBind/weight"
17
+
18
+ # llama_proj_model: ''
19
+
20
+
21
+ # only train vision branch
22
+ equip_audio_branch: False
23
+ frozen_llama_proj: False
24
+ frozen_video_Qformer: False
25
+ frozen_audio_Qformer: True
26
+
27
+ fusion_head_layers: 2
28
+ max_frame_pos: 32
29
+ fusion_header_type: "seqTransf"
30
+ num_video_query_token: 32
31
+
32
+ datasets:
33
+ webvid:
34
+ data_type: video
35
+ build_info:
36
+ anno_dir: /remote-home/share/jiaqitang/WebVid-2M/train_data/filter_annotations/
37
+ videos_dir: /remote-home/share/jiaqitang/WebVid-2M/train_data/videos/
38
+
39
+ vis_processor:
40
+ train:
41
+ name: "alpro_video_train"
42
+ n_frms: 32
43
+ image_size: 224
44
+ text_processor:
45
+ train:
46
+ name: "blip_caption"
47
+ sample_ratio: 100
48
+
49
+ run:
50
+ task: video_text_pretrain
51
+ # optimizer
52
+ lr_sched: "linear_warmup_cosine_lr"
53
+ init_lr: 1e-5
54
+ min_lr: 1e-6
55
+ warmup_lr: 1e-6
56
+
57
+ weight_decay: 0.05
58
+ max_epoch: 160
59
+ batch_size_train: 1
60
+ batch_size_eval: 1
61
+ num_workers: 16
62
+ warmup_steps: 1000
63
+ iters_per_epoch: 2500
64
+
65
+ seed: 42
66
+ output_dir: "output/hawk_pretrain"
67
+
68
+ amp: True
69
+ resume_ckpt_path: null
70
+
71
+ evaluate: False
72
+ train_splits: ["train"]
73
+
74
+ device: "cuda"
75
+ world_size: 1
76
+ dist_url: "env://"
77
+ distributed: True
configs/train_configs/stage2_finetune.yaml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: hawk
3
+ model_type: pretrain_llama_v2
4
+ freeze_vit: True
5
+ freeze_qformer: True
6
+
7
+
8
+ # Q-Former
9
+ num_query_token: 32
10
+
11
+ # If you want train models based on LLaMA-2-chat,
12
+ # some ckpts could be download from our provided huggingface repo
13
+ # i.e. https://huggingface.co/DAMO-NLP-SG/Video-LLaMA-2-13B-Finetuned
14
+ llama_model: "/remote-home/share/jiaqitang/Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf"
15
+ # imagebind_ckpt_path: "/remote-home/share/jiaqitang/Video-LLaMA-2-7B-Finetuned"
16
+
17
+ # The ckpt of vision branch after stage1 pretrained,
18
+ ckpt: "/remote-home/share/jiaqitang/Hawk_Ours/hawk/output/hawk_pretrain/20250217073/checkpoint_127.pth"
19
+
20
+
21
+ # only train vision branch
22
+ equip_audio_branch: False
23
+ frozen_llama_proj: False
24
+ frozen_video_Qformer: False
25
+ frozen_audio_Qformer: True
26
+
27
+ fusion_head_layers: 2
28
+ max_frame_pos: 32
29
+ fusion_header_type: "seqTransf"
30
+
31
+ max_txt_len: 320
32
+
33
+ for llama_2_chat:
34
+ end_sym: "</s>"
35
+ prompt_path: "/remote-home/share/jiaqitang/Hawk_Ours/configs/prompts/alignment_image.txt"
36
+ prompt_template: '[INST] <<SYS>>\n \n<</SYS>>\n\n{} [/INST] '
37
+
38
+ datasets:
39
+ webvid_instruct:
40
+ data_type: video
41
+ build_info:
42
+ anno_dir: /remote-home/share/jiaqitang/Data_Annotation/A_Overall/all_videos_train.json
43
+ videos_dir: /remote-home/share/jiaqitang/Data/
44
+ vis_processor:
45
+ train:
46
+ name: "alpro_video_train"
47
+ n_frms: 32
48
+ image_size: 224
49
+ text_processor:
50
+ train:
51
+ name: "blip_caption"
52
+ num_video_query_token: 32
53
+ tokenizer_name: "/remote-home/share/jiaqitang/Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf"
54
+ model_type: "llama_v2"
55
+
56
+ run:
57
+ task: video_text_pretrain
58
+ # optimizer
59
+ lr_sched: "linear_warmup_cosine_lr"
60
+ init_lr: 1e-5
61
+ min_lr: 1e-6
62
+ warmup_lr: 1e-6
63
+
64
+ weight_decay: 0.05
65
+ max_epoch: 160
66
+ batch_size_train: 1
67
+ batch_size_eval: 1
68
+ num_workers: 16
69
+ warmup_steps: 1000
70
+ iters_per_epoch: 2500
71
+
72
+ seed: 42
73
+ output_dir: "output/hawk_finetune"
74
+
75
+ amp: True
76
+ resume_ckpt_path: null
77
+
78
+ evaluate: False
79
+ train_splits: ["train"]
80
+
81
+ device: "cuda"
82
+ world_size: 1
83
+ dist_url: "env://"
84
+ distributed: True
environment.yml ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: hawk
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - _libgcc_mutex=0.1=main
6
+ - _openmp_mutex=5.1=1_gnu
7
+ - bzip2=1.0.8=h7b6447c_0
8
+ - ca-certificates=2023.08.22=h06a4308_0
9
+ - ld_impl_linux-64=2.38=h1181459_1
10
+ - libffi=3.4.4=h6a678d5_0
11
+ - libgcc-ng=11.2.0=h1234567_1
12
+ - libgfortran-ng=7.5.0=ha8ba4b0_17
13
+ - libgfortran4=7.5.0=ha8ba4b0_17
14
+ - libgomp=11.2.0=h1234567_1
15
+ - libstdcxx-ng=11.2.0=h1234567_1
16
+ - libuuid=1.41.5=h5eee18b_0
17
+ - mpi=1.0=mpich
18
+ - mpi4py=3.1.4=py310hfc96bbd_0
19
+ - mpich=3.3.2=hc856adb_0
20
+ - ncurses=6.4=h6a678d5_0
21
+ - openssl=3.0.11=h7f8727e_2
22
+ - pip=23.2.1=py310h06a4308_0
23
+ - python=3.10.13=h955ad1f_0
24
+ - readline=8.2=h5eee18b_0
25
+ - setuptools=68.0.0=py310h06a4308_0
26
+ - sqlite=3.41.2=h5eee18b_0
27
+ - tk=8.6.12=h1ccaba5_0
28
+ - wheel=0.41.2=py310h06a4308_0
29
+ - xz=5.4.2=h5eee18b_0
30
+ - zlib=1.2.13=h5eee18b_0
31
+ - pip:
32
+ - absl-py==2.1.0
33
+ - accelerate==0.23.0
34
+ - aiofiles==23.2.1
35
+ - aiohttp==3.8.6
36
+ - aiosignal==1.3.1
37
+ - altair==5.1.2
38
+ - annotated-types==0.6.0
39
+ - antlr4-python3-runtime==4.9.3
40
+ - anyio==3.7.1
41
+ - appdirs==1.4.4
42
+ - asttokens==2.4.0
43
+ - async-timeout==4.0.3
44
+ - attrs==23.1.0
45
+ - av==10.0.0
46
+ - backcall==0.2.0
47
+ - bitsandbytes==0.41.1
48
+ - black==23.9.1
49
+ - blessed==1.20.0
50
+ - blis==1.2.0
51
+ - braceexpand==0.1.7
52
+ - brotli==1.1.0
53
+ - cachetools==5.3.1
54
+ - catalogue==2.0.10
55
+ - certifi==2023.7.22
56
+ - charset-normalizer==3.3.0
57
+ - click==8.1.7
58
+ - cloudpathlib==0.20.0
59
+ - cmake==3.27.6
60
+ - coloredlogs==15.0.1
61
+ - confection==0.1.5
62
+ - contourpy==1.1.1
63
+ - cycler==0.12.1
64
+ - cymem==2.0.11
65
+ - datasets==2.14.5
66
+ - debugpy==1.8.12
67
+ - decorator==5.1.1
68
+ - decord==0.6.0
69
+ - dill==0.3.7
70
+ - einops==0.7.0
71
+ - en-core-web-sm==3.8.0
72
+ - exceptiongroup==1.1.3
73
+ - executing==2.0.0
74
+ - fairscale==0.4.13
75
+ - fastapi==0.115.8
76
+ - ffmpy==0.3.1
77
+ - filelock==3.12.4
78
+ - fire==0.5.0
79
+ - fonttools==4.43.1
80
+ - frozenlist==1.4.0
81
+ - fsspec==2023.6.0
82
+ - ftfy==6.1.1
83
+ - fvcore==0.1.5.post20221221
84
+ - gpustat==1.1.1
85
+ - gradio==5.16.0
86
+ - gradio-client==1.7.0
87
+ - grpcio==1.70.0
88
+ - h11==0.14.0
89
+ - hiq-python==1.1.12
90
+ - httpcore==0.18.0
91
+ - httpx==0.25.0
92
+ - huggingface-hub==0.28.1
93
+ - humanfriendly==10.0
94
+ - idna==3.4
95
+ - importlib-resources==6.1.0
96
+ - inflate64==0.3.1
97
+ - iopath==0.1.10
98
+ - ipython==8.16.1
99
+ - jedi==0.19.1
100
+ - jinja2==3.1.2
101
+ - jsonschema==4.19.1
102
+ - jsonschema-specifications==2023.7.1
103
+ - kiwisolver==1.4.5
104
+ - langcodes==3.5.0
105
+ - language-data==1.3.0
106
+ - linkify-it-py==2.0.2
107
+ - lit==17.0.2
108
+ - loralib==0.1.2
109
+ - marisa-trie==1.2.1
110
+ - markdown==3.7
111
+ - markdown-it-py==2.2.0
112
+ - markupsafe==2.1.3
113
+ - matplotlib==3.8.0
114
+ - matplotlib-inline==0.1.6
115
+ - mdit-py-plugins==0.3.3
116
+ - mdurl==0.1.2
117
+ - mpmath==1.3.0
118
+ - multidict==6.0.4
119
+ - multiprocess==0.70.15
120
+ - multivolumefile==0.2.3
121
+ - murmurhash==1.0.12
122
+ - mypy-extensions==1.0.0
123
+ - networkx==3.1
124
+ - numpy==1.26.0
125
+ - nvidia-ml-py==12.535.108
126
+ - nvitop==1.3.1
127
+ - omegaconf==2.3.0
128
+ - opencv-python==4.8.1.78
129
+ - optimum==1.13.2
130
+ - orjson==3.9.9
131
+ - packaging==23.2
132
+ - pandas==2.1.1
133
+ - parameterized==0.9.0
134
+ - parso==0.8.3
135
+ - pathspec==0.11.2
136
+ - peft==0.5.0
137
+ - pexpect==4.8.0
138
+ - pickleshare==0.7.5
139
+ - pillow==10.0.1
140
+ - platformdirs==3.11.0
141
+ - portalocker==2.8.2
142
+ - preshed==3.0.9
143
+ - prompt-toolkit==3.0.39
144
+ - protobuf==4.24.4
145
+ - psutil==5.9.5
146
+ - ptyprocess==0.7.0
147
+ - pure-eval==0.2.2
148
+ - py-itree==0.0.19
149
+ - py3nvml==0.2.7
150
+ - py7zr==0.20.6
151
+ - pyarrow==13.0.0
152
+ - pybcj==1.0.1
153
+ - pycryptodomex==3.19.0
154
+ - pydantic==2.4.2
155
+ - pydantic-core==2.10.1
156
+ - pydub==0.25.1
157
+ - pygments==2.16.1
158
+ - pyllama==0.0.9
159
+ - pyparsing==3.1.1
160
+ - pyppmd==1.0.0
161
+ - python-dateutil==2.8.2
162
+ - python-multipart==0.0.20
163
+ - pytorchvideo==0.1.5
164
+ - pytz==2023.3.post1
165
+ - pyyaml==6.0.1
166
+ - pyzstd==0.15.9
167
+ - referencing==0.30.2
168
+ - regex==2023.10.3
169
+ - requests==2.31.0
170
+ - rich==13.6.0
171
+ - rpds-py==0.10.6
172
+ - ruff==0.9.6
173
+ - safehttpx==0.1.6
174
+ - safetensors==0.4.0
175
+ - scipy==1.11.3
176
+ - semantic-version==2.10.0
177
+ - sentencepiece==0.1.97
178
+ - shellingham==1.5.4
179
+ - six==1.16.0
180
+ - smart-open==7.1.0
181
+ - sniffio==1.3.0
182
+ - spacy==3.8.4
183
+ - spacy-legacy==3.0.12
184
+ - spacy-loggers==1.0.5
185
+ - srsly==2.5.1
186
+ - stack-data==0.6.3
187
+ - starlette==0.45.3
188
+ - sympy==1.12
189
+ - tabulate==0.9.0
190
+ - tensorboard==2.18.0
191
+ - tensorboard-data-server==0.7.2
192
+ - termcolor==2.3.0
193
+ - texttable==1.7.0
194
+ - thinc==8.3.4
195
+ - timm==0.9.7
196
+ - tokenize-rt==5.2.0
197
+ - tokenizers==0.13.3
198
+ - tomli==2.0.1
199
+ - tomlkit==0.13.2
200
+ - toolz==0.12.0
201
+ - torch==2.0.1+cu117
202
+ - torchaudio==2.0.2+cu117
203
+ - torchvision==0.15.2+cu117
204
+ - tqdm==4.66.1
205
+ - traitlets==5.11.2
206
+ - transformers==4.28.0
207
+ - triton==2.0.0
208
+ - typer==0.15.1
209
+ - typing-extensions==4.8.0
210
+ - tzdata==2023.3
211
+ - uc-micro-py==1.0.2
212
+ - urllib3==2.0.6
213
+ - uvicorn==0.23.2
214
+ - wasabi==1.1.3
215
+ - wcwidth==0.2.8
216
+ - weasel==0.4.1
217
+ - webdataset==0.2.57
218
+ - websockets==11.0.3
219
+ - werkzeug==3.1.3
220
+ - wrapt==1.17.2
221
+ - xmltodict==0.13.0
222
+ - xxhash==3.4.1
223
+ - yacs==0.1.8
224
+ - yarl==1.9.2
225
+ prefix: /root/anaconda3/envs/hawk
figs/demo.png ADDED
figs/examples/car.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1aca21a877d9569fc1f0c102644012a8232d06caef7a883ab3cc0750640e209d
3
+ size 1443171
figs/examples/explosion2.mp4 ADDED
Binary file (834 kB). View file
 
figs/icon.png ADDED
figs/motivation1.png ADDED
hawk/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ import sys
10
+
11
+ from omegaconf import OmegaConf
12
+
13
+ from hawk.common.registry import registry
14
+
15
+ from hawk.datasets.builders import *
16
+ from hawk.models import *
17
+ from hawk.processors import *
18
+ from hawk.tasks import *
19
+
20
+
21
+ root_dir = os.path.dirname(os.path.abspath(__file__))
22
+ default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
23
+
24
+ registry.register_path("library_root", root_dir)
25
+ repo_root = os.path.join(root_dir, "..")
26
+ registry.register_path("repo_root", repo_root)
27
+ cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
28
+ registry.register_path("cache_root", cache_root)
29
+
30
+ registry.register("MAX_INT", sys.maxsize)
31
+ registry.register("SPLIT_NAMES", ["train", "val", "test"])
hawk/common/__init__.py ADDED
File without changes
hawk/common/config.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import json
10
+ from typing import Dict
11
+
12
+ from omegaconf import OmegaConf
13
+ from hawk.common.registry import registry
14
+
15
+
16
+ class Config:
17
+ def __init__(self, args):
18
+ self.config = {}
19
+
20
+ self.args = args
21
+
22
+ # Register the config and configuration for setup
23
+ registry.register("configuration", self)
24
+
25
+ user_config = self._build_opt_list(self.args.options)
26
+
27
+ config = OmegaConf.load(self.args.cfg_path)
28
+
29
+ runner_config = self.build_runner_config(config)
30
+ model_config = self.build_model_config(config, **user_config)
31
+ dataset_config = self.build_dataset_config(config)
32
+
33
+ # Validate the user-provided runner configuration
34
+ # model and dataset configuration are supposed to be validated by the respective classes
35
+ # [TODO] validate the model/dataset configuration
36
+ # self._validate_runner_config(runner_config)
37
+
38
+ # Override the default configuration with user options.
39
+ self.config = OmegaConf.merge(
40
+ runner_config, model_config, dataset_config, user_config
41
+ )
42
+
43
+ def _validate_runner_config(self, runner_config):
44
+ """
45
+ This method validates the configuration, such that
46
+ 1) all the user specified options are valid;
47
+ 2) no type mismatches between the user specified options and the config.
48
+ """
49
+ runner_config_validator = create_runner_config_validator()
50
+ runner_config_validator.validate(runner_config)
51
+
52
+ def _build_opt_list(self, opts):
53
+ opts_dot_list = self._convert_to_dot_list(opts)
54
+ return OmegaConf.from_dotlist(opts_dot_list)
55
+
56
+ @staticmethod
57
+ def build_model_config(config, **kwargs):
58
+ model = config.get("model", None)
59
+ assert model is not None, "Missing model configuration file."
60
+
61
+ model_cls = registry.get_model_class(model.arch)
62
+ assert model_cls is not None, f"Model '{model.arch}' has not been registered."
63
+
64
+ model_type = kwargs.get("model.model_type", None)
65
+ if not model_type:
66
+ model_type = model.get("model_type", None)
67
+ # else use the model type selected by user.
68
+
69
+ assert model_type is not None, "Missing model_type."
70
+
71
+ model_config_path = model_cls.default_config_path(model_type=model_type)
72
+
73
+ model_config = OmegaConf.create()
74
+ # hierarchy override, customized config > default config
75
+ model_config = OmegaConf.merge(
76
+ model_config,
77
+ OmegaConf.load(model_config_path),
78
+ {"model": config["model"]},
79
+ )
80
+
81
+ return model_config
82
+
83
+ @staticmethod
84
+ def build_runner_config(config):
85
+ return {"run": config.run}
86
+
87
+ @staticmethod
88
+ def build_dataset_config(config):
89
+ datasets = config.get("datasets", None)
90
+ if datasets is None:
91
+ raise KeyError(
92
+ "Expecting 'datasets' as the root key for dataset configuration."
93
+ )
94
+
95
+ dataset_config = OmegaConf.create()
96
+
97
+ for dataset_name in datasets:
98
+ builder_cls = registry.get_builder_class(dataset_name)
99
+
100
+ dataset_config_type = datasets[dataset_name].get("type", "default")
101
+ dataset_config_path = builder_cls.default_config_path(
102
+ type=dataset_config_type
103
+ )
104
+
105
+ # hierarchy override, customized config > default config
106
+ dataset_config = OmegaConf.merge(
107
+ dataset_config,
108
+ OmegaConf.load(dataset_config_path),
109
+ {"datasets": {dataset_name: config["datasets"][dataset_name]}},
110
+ )
111
+
112
+ return dataset_config
113
+
114
+ def _convert_to_dot_list(self, opts):
115
+ if opts is None:
116
+ opts = []
117
+
118
+ if len(opts) == 0:
119
+ return opts
120
+
121
+ has_equal = opts[0].find("=") != -1
122
+
123
+ if has_equal:
124
+ return opts
125
+
126
+ return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
127
+
128
+ def get_config(self):
129
+ return self.config
130
+
131
+ @property
132
+ def run_cfg(self):
133
+ return self.config.run
134
+
135
+ @property
136
+ def datasets_cfg(self):
137
+ return self.config.datasets
138
+
139
+ @property
140
+ def model_cfg(self):
141
+ return self.config.model
142
+
143
+ def pretty_print(self):
144
+ logging.info("\n===== Running Parameters =====")
145
+ logging.info(self._convert_node_to_json(self.config.run))
146
+
147
+ logging.info("\n====== Dataset Attributes ======")
148
+ datasets = self.config.datasets
149
+
150
+ for dataset in datasets:
151
+ if dataset in self.config.datasets:
152
+ logging.info(f"\n======== {dataset} =======")
153
+ dataset_config = self.config.datasets[dataset]
154
+ logging.info(self._convert_node_to_json(dataset_config))
155
+ else:
156
+ logging.warning(f"No dataset named '{dataset}' in config. Skipping")
157
+
158
+ logging.info(f"\n====== Model Attributes ======")
159
+ logging.info(self._convert_node_to_json(self.config.model))
160
+
161
+ def _convert_node_to_json(self, node):
162
+ container = OmegaConf.to_container(node, resolve=True)
163
+ return json.dumps(container, indent=4, sort_keys=True)
164
+
165
+ def to_dict(self):
166
+ return OmegaConf.to_container(self.config)
167
+
168
+
169
+ def node_to_dict(node):
170
+ return OmegaConf.to_container(node)
171
+
172
+
173
+ class ConfigValidator:
174
+ """
175
+ This is a preliminary implementation to centralize and validate the configuration.
176
+ May be altered in the future.
177
+
178
+ A helper class to validate configurations from yaml file.
179
+
180
+ This serves the following purposes:
181
+ 1. Ensure all the options in the yaml are defined, raise error if not.
182
+ 2. when type mismatches are found, the validator will raise an error.
183
+ 3. a central place to store and display helpful messages for supported configurations.
184
+
185
+ """
186
+
187
+ class _Argument:
188
+ def __init__(self, name, choices=None, type=None, help=None):
189
+ self.name = name
190
+ self.val = None
191
+ self.choices = choices
192
+ self.type = type
193
+ self.help = help
194
+
195
+ def __str__(self):
196
+ s = f"{self.name}={self.val}"
197
+ if self.type is not None:
198
+ s += f", ({self.type})"
199
+ if self.choices is not None:
200
+ s += f", choices: {self.choices}"
201
+ if self.help is not None:
202
+ s += f", ({self.help})"
203
+ return s
204
+
205
+ def __init__(self, description):
206
+ self.description = description
207
+
208
+ self.arguments = dict()
209
+
210
+ self.parsed_args = None
211
+
212
+ def __getitem__(self, key):
213
+ assert self.parsed_args is not None, "No arguments parsed yet."
214
+
215
+ return self.parsed_args[key]
216
+
217
+ def __str__(self) -> str:
218
+ return self.format_help()
219
+
220
+ def add_argument(self, *args, **kwargs):
221
+ """
222
+ Assume the first argument is the name of the argument.
223
+ """
224
+ self.arguments[args[0]] = self._Argument(*args, **kwargs)
225
+
226
+ def validate(self, config=None):
227
+ """
228
+ Convert yaml config (dict-like) to list, required by argparse.
229
+ """
230
+ for k, v in config.items():
231
+ assert (
232
+ k in self.arguments
233
+ ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
234
+
235
+ if self.arguments[k].type is not None:
236
+ try:
237
+ self.arguments[k].val = self.arguments[k].type(v)
238
+ except ValueError:
239
+ raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
240
+
241
+ if self.arguments[k].choices is not None:
242
+ assert (
243
+ v in self.arguments[k].choices
244
+ ), f"""{k} must be one of {self.arguments[k].choices}."""
245
+
246
+ return config
247
+
248
+ def format_arguments(self):
249
+ return str([f"{k}" for k in sorted(self.arguments.keys())])
250
+
251
+ def format_help(self):
252
+ # description + key-value pair string for each argument
253
+ help_msg = str(self.description)
254
+ return help_msg + ", available arguments: " + self.format_arguments()
255
+
256
+ def print_help(self):
257
+ # display help message
258
+ print(self.format_help())
259
+
260
+
261
+ def create_runner_config_validator():
262
+ validator = ConfigValidator(description="Runner configurations")
263
+
264
+ validator.add_argument(
265
+ "runner",
266
+ type=str,
267
+ choices=["runner_base", "runner_iter"],
268
+ help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
269
+ runner runs based on iters. Default: runner_base""",
270
+ )
271
+ # add argumetns for training dataset ratios
272
+ validator.add_argument(
273
+ "train_dataset_ratios",
274
+ type=Dict[str, float],
275
+ help="""Ratios of training dataset. This is used in iteration-based runner.
276
+ Do not support for epoch-based runner because how to define an epoch becomes tricky.
277
+ Default: None""",
278
+ )
279
+ validator.add_argument(
280
+ "max_iters",
281
+ type=float,
282
+ help="Maximum number of iterations to run.",
283
+ )
284
+ validator.add_argument(
285
+ "max_epoch",
286
+ type=int,
287
+ help="Maximum number of epochs to run.",
288
+ )
289
+ # add arguments for iters_per_inner_epoch
290
+ validator.add_argument(
291
+ "iters_per_inner_epoch",
292
+ type=float,
293
+ help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
294
+ )
295
+ lr_scheds_choices = registry.list_lr_schedulers()
296
+ validator.add_argument(
297
+ "lr_sched",
298
+ type=str,
299
+ choices=lr_scheds_choices,
300
+ help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
301
+ )
302
+ task_choices = registry.list_tasks()
303
+ validator.add_argument(
304
+ "task",
305
+ type=str,
306
+ choices=task_choices,
307
+ help="Task to use, from {}".format(task_choices),
308
+ )
309
+ # add arguments for init_lr
310
+ validator.add_argument(
311
+ "init_lr",
312
+ type=float,
313
+ help="Initial learning rate. This will be the learning rate after warmup and before decay.",
314
+ )
315
+ # add arguments for min_lr
316
+ validator.add_argument(
317
+ "min_lr",
318
+ type=float,
319
+ help="Minimum learning rate (after decay).",
320
+ )
321
+ # add arguments for warmup_lr
322
+ validator.add_argument(
323
+ "warmup_lr",
324
+ type=float,
325
+ help="Starting learning rate for warmup.",
326
+ )
327
+ # add arguments for learning rate decay rate
328
+ validator.add_argument(
329
+ "lr_decay_rate",
330
+ type=float,
331
+ help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
332
+ )
333
+ # add arguments for weight decay
334
+ validator.add_argument(
335
+ "weight_decay",
336
+ type=float,
337
+ help="Weight decay rate.",
338
+ )
339
+ # add arguments for training batch size
340
+ validator.add_argument(
341
+ "batch_size_train",
342
+ type=int,
343
+ help="Training batch size.",
344
+ )
345
+ # add arguments for evaluation batch size
346
+ validator.add_argument(
347
+ "batch_size_eval",
348
+ type=int,
349
+ help="Evaluation batch size, including validation and testing.",
350
+ )
351
+ # add arguments for number of workers for data loading
352
+ validator.add_argument(
353
+ "num_workers",
354
+ help="Number of workers for data loading.",
355
+ )
356
+ # add arguments for warm up steps
357
+ validator.add_argument(
358
+ "warmup_steps",
359
+ type=int,
360
+ help="Number of warmup steps. Required if a warmup schedule is used.",
361
+ )
362
+ # add arguments for random seed
363
+ validator.add_argument(
364
+ "seed",
365
+ type=int,
366
+ help="Random seed.",
367
+ )
368
+ # add arguments for output directory
369
+ validator.add_argument(
370
+ "output_dir",
371
+ type=str,
372
+ help="Output directory to save checkpoints and logs.",
373
+ )
374
+ # add arguments for whether only use evaluation
375
+ validator.add_argument(
376
+ "evaluate",
377
+ help="Whether to only evaluate the model. If true, training will not be performed.",
378
+ )
379
+ # add arguments for splits used for training, e.g. ["train", "val"]
380
+ validator.add_argument(
381
+ "train_splits",
382
+ type=list,
383
+ help="Splits to use for training.",
384
+ )
385
+ # add arguments for splits used for validation, e.g. ["val"]
386
+ validator.add_argument(
387
+ "valid_splits",
388
+ type=list,
389
+ help="Splits to use for validation. If not provided, will skip the validation.",
390
+ )
391
+ # add arguments for splits used for testing, e.g. ["test"]
392
+ validator.add_argument(
393
+ "test_splits",
394
+ type=list,
395
+ help="Splits to use for testing. If not provided, will skip the testing.",
396
+ )
397
+ # add arguments for accumulating gradient for iterations
398
+ validator.add_argument(
399
+ "accum_grad_iters",
400
+ type=int,
401
+ help="Number of iterations to accumulate gradient for.",
402
+ )
403
+
404
+ # ====== distributed training ======
405
+ validator.add_argument(
406
+ "device",
407
+ type=str,
408
+ choices=["cpu", "cuda"],
409
+ help="Device to use. Support 'cuda' or 'cpu' as for now.",
410
+ )
411
+ validator.add_argument(
412
+ "world_size",
413
+ type=int,
414
+ help="Number of processes participating in the job.",
415
+ )
416
+ validator.add_argument("dist_url", type=str)
417
+ validator.add_argument("distributed", type=bool)
418
+ # add arguments to opt using distributed sampler during evaluation or not
419
+ validator.add_argument(
420
+ "use_dist_eval_sampler",
421
+ type=bool,
422
+ help="Whether to use distributed sampler during evaluation or not.",
423
+ )
424
+
425
+ # ====== task specific ======
426
+ # generation task specific arguments
427
+ # add arguments for maximal length of text output
428
+ validator.add_argument(
429
+ "max_len",
430
+ type=int,
431
+ help="Maximal length of text output.",
432
+ )
433
+ # add arguments for minimal length of text output
434
+ validator.add_argument(
435
+ "min_len",
436
+ type=int,
437
+ help="Minimal length of text output.",
438
+ )
439
+ # add arguments number of beams
440
+ validator.add_argument(
441
+ "num_beams",
442
+ type=int,
443
+ help="Number of beams used for beam search.",
444
+ )
445
+
446
+ # vqa task specific arguments
447
+ # add arguments for number of answer candidates
448
+ validator.add_argument(
449
+ "num_ans_candidates",
450
+ type=int,
451
+ help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
452
+ )
453
+ # add arguments for inference method
454
+ validator.add_argument(
455
+ "inference_method",
456
+ type=str,
457
+ choices=["genearte", "rank"],
458
+ help="""Inference method to use for question answering. If rank, requires a answer list.""",
459
+ )
460
+
461
+ # ====== model specific ======
462
+ validator.add_argument(
463
+ "k_test",
464
+ type=int,
465
+ help="Number of top k most similar samples from ITC/VTC selection to be tested.",
466
+ )
467
+
468
+ return validator
hawk/common/dist_utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import functools
10
+ import os
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import timm.models.hub as timm_hub
15
+
16
+
17
+ def setup_for_distributed(is_master):
18
+ """
19
+ This function disables printing when not in master process
20
+ """
21
+ import builtins as __builtin__
22
+
23
+ builtin_print = __builtin__.print
24
+
25
+ def print(*args, **kwargs):
26
+ force = kwargs.pop("force", False)
27
+ if is_master or force:
28
+ builtin_print(*args, **kwargs)
29
+
30
+ __builtin__.print = print
31
+
32
+
33
+ def is_dist_avail_and_initialized():
34
+ if not dist.is_available():
35
+ return False
36
+ if not dist.is_initialized():
37
+ return False
38
+ return True
39
+
40
+
41
+ def get_world_size():
42
+ if not is_dist_avail_and_initialized():
43
+ return 1
44
+ return dist.get_world_size()
45
+
46
+
47
+ def get_rank():
48
+ if not is_dist_avail_and_initialized():
49
+ return 0
50
+ return dist.get_rank()
51
+
52
+
53
+ def is_main_process():
54
+ return get_rank() == 0
55
+
56
+
57
+ def init_distributed_mode(args):
58
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59
+ args.rank = int(os.environ["RANK"])
60
+ args.world_size = int(os.environ["WORLD_SIZE"])
61
+ args.gpu = int(os.environ["LOCAL_RANK"])
62
+ elif "SLURM_PROCID" in os.environ:
63
+ args.rank = int(os.environ["SLURM_PROCID"])
64
+ args.gpu = args.rank % torch.cuda.device_count()
65
+ else:
66
+ print("Not using distributed mode")
67
+ args.distributed = False
68
+ return
69
+
70
+ args.distributed = True
71
+
72
+ torch.cuda.set_device(args.gpu)
73
+ args.dist_backend = "nccl"
74
+ print(
75
+ "| distributed init (rank {}, world {}): {}".format(
76
+ args.rank, args.world_size, args.dist_url
77
+ ),
78
+ flush=True,
79
+ )
80
+ torch.distributed.init_process_group(
81
+ backend=args.dist_backend,
82
+ init_method=args.dist_url,
83
+ world_size=args.world_size,
84
+ rank=args.rank,
85
+ timeout=datetime.timedelta(
86
+ days=365
87
+ ), # allow auto-downloading and de-compressing
88
+ )
89
+ torch.distributed.barrier()
90
+ setup_for_distributed(args.rank == 0)
91
+
92
+
93
+ def get_dist_info():
94
+ if torch.__version__ < "1.0":
95
+ initialized = dist._initialized
96
+ else:
97
+ initialized = dist.is_initialized()
98
+ if initialized:
99
+ rank = dist.get_rank()
100
+ world_size = dist.get_world_size()
101
+ else: # non-distributed training
102
+ rank = 0
103
+ world_size = 1
104
+ return rank, world_size
105
+
106
+
107
+ def main_process(func):
108
+ @functools.wraps(func)
109
+ def wrapper(*args, **kwargs):
110
+ rank, _ = get_dist_info()
111
+ if rank == 0:
112
+ return func(*args, **kwargs)
113
+
114
+ return wrapper
115
+
116
+
117
+ def download_cached_file(url, check_hash=True, progress=False):
118
+ """
119
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
120
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
121
+ """
122
+
123
+ def get_cached_file_path():
124
+ # a hack to sync the file path across processes
125
+ parts = torch.hub.urlparse(url)
126
+ filename = os.path.basename(parts.path)
127
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
128
+
129
+ return cached_file
130
+
131
+ if is_main_process():
132
+ timm_hub.download_cached_file(url, check_hash, progress)
133
+
134
+ if is_dist_avail_and_initialized():
135
+ dist.barrier()
136
+
137
+ return get_cached_file_path()
hawk/common/gradcam.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from matplotlib import pyplot as plt
3
+ from scipy.ndimage import filters
4
+ from skimage import transform as skimage_transform
5
+
6
+
7
+ def getAttMap(img, attMap, blur=True, overlap=True):
8
+ attMap -= attMap.min()
9
+ if attMap.max() > 0:
10
+ attMap /= attMap.max()
11
+ attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12
+ if blur:
13
+ attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14
+ attMap -= attMap.min()
15
+ attMap /= attMap.max()
16
+ cmap = plt.get_cmap("jet")
17
+ attMapV = cmap(attMap)
18
+ attMapV = np.delete(attMapV, 3, 2)
19
+ if overlap:
20
+ attMap = (
21
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22
+ + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23
+ )
24
+ return attMap
hawk/common/logger.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import logging
10
+ import time
11
+ from collections import defaultdict, deque
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+
16
+ from hawk.common import dist_utils
17
+
18
+
19
+ class SmoothedValue(object):
20
+ """Track a series of values and provide access to smoothed values over a
21
+ window or the global series average.
22
+ """
23
+
24
+ def __init__(self, window_size=20, fmt=None):
25
+ if fmt is None:
26
+ fmt = "{median:.4f} ({global_avg:.4f})"
27
+ self.deque = deque(maxlen=window_size)
28
+ self.total = 0.0
29
+ self.count = 0
30
+ self.fmt = fmt
31
+
32
+ def update(self, value, n=1):
33
+ self.deque.append(value)
34
+ self.count += n
35
+ self.total += value * n
36
+
37
+ def synchronize_between_processes(self):
38
+ """
39
+ Warning: does not synchronize the deque!
40
+ """
41
+ if not dist_utils.is_dist_avail_and_initialized():
42
+ return
43
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44
+ dist.barrier()
45
+ dist.all_reduce(t)
46
+ t = t.tolist()
47
+ self.count = int(t[0])
48
+ self.total = t[1]
49
+
50
+ @property
51
+ def median(self):
52
+ d = torch.tensor(list(self.deque))
53
+ return d.median().item()
54
+
55
+ @property
56
+ def avg(self):
57
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
58
+ return d.mean().item()
59
+
60
+ @property
61
+ def global_avg(self):
62
+ return self.total / self.count
63
+
64
+ @property
65
+ def max(self):
66
+ return max(self.deque)
67
+
68
+ @property
69
+ def value(self):
70
+ return self.deque[-1]
71
+
72
+ def __str__(self):
73
+ return self.fmt.format(
74
+ median=self.median,
75
+ avg=self.avg,
76
+ global_avg=self.global_avg,
77
+ max=self.max,
78
+ value=self.value,
79
+ )
80
+
81
+
82
+ class MetricLogger(object):
83
+ def __init__(self, delimiter="\t"):
84
+ self.meters = defaultdict(SmoothedValue)
85
+ self.delimiter = delimiter
86
+
87
+ def update(self, **kwargs):
88
+ for k, v in kwargs.items():
89
+ if isinstance(v, torch.Tensor):
90
+ v = v.item()
91
+ assert isinstance(v, (float, int))
92
+ self.meters[k].update(v)
93
+
94
+ def __getattr__(self, attr):
95
+ if attr in self.meters:
96
+ return self.meters[attr]
97
+ if attr in self.__dict__:
98
+ return self.__dict__[attr]
99
+ raise AttributeError(
100
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
101
+ )
102
+
103
+ def __str__(self):
104
+ loss_str = []
105
+ for name, meter in self.meters.items():
106
+ loss_str.append("{}: {}".format(name, str(meter)))
107
+ return self.delimiter.join(loss_str)
108
+
109
+ def global_avg(self):
110
+ loss_str = []
111
+ for name, meter in self.meters.items():
112
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
113
+ return self.delimiter.join(loss_str)
114
+
115
+ def synchronize_between_processes(self):
116
+ for meter in self.meters.values():
117
+ meter.synchronize_between_processes()
118
+
119
+ def add_meter(self, name, meter):
120
+ self.meters[name] = meter
121
+
122
+ def log_every(self, iterable, print_freq, header=None):
123
+ i = 0
124
+ if not header:
125
+ header = ""
126
+ start_time = time.time()
127
+ end = time.time()
128
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
129
+ data_time = SmoothedValue(fmt="{avg:.4f}")
130
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
131
+ log_msg = [
132
+ header,
133
+ "[{0" + space_fmt + "}/{1}]",
134
+ "eta: {eta}",
135
+ "{meters}",
136
+ "time: {time}",
137
+ "data: {data}",
138
+ ]
139
+ if torch.cuda.is_available():
140
+ log_msg.append("max mem: {memory:.0f}")
141
+ log_msg = self.delimiter.join(log_msg)
142
+ MB = 1024.0 * 1024.0
143
+ for obj in iterable:
144
+ data_time.update(time.time() - end)
145
+ yield obj
146
+ iter_time.update(time.time() - end)
147
+ if i % print_freq == 0 or i == len(iterable) - 1:
148
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
149
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
150
+ if torch.cuda.is_available():
151
+ print(
152
+ log_msg.format(
153
+ i,
154
+ len(iterable),
155
+ eta=eta_string,
156
+ meters=str(self),
157
+ time=str(iter_time),
158
+ data=str(data_time),
159
+ memory=torch.cuda.max_memory_allocated() / MB,
160
+ )
161
+ )
162
+ else:
163
+ print(
164
+ log_msg.format(
165
+ i,
166
+ len(iterable),
167
+ eta=eta_string,
168
+ meters=str(self),
169
+ time=str(iter_time),
170
+ data=str(data_time),
171
+ )
172
+ )
173
+ i += 1
174
+ end = time.time()
175
+ total_time = time.time() - start_time
176
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177
+ print(
178
+ "{} Total time: {} ({:.4f} s / it)".format(
179
+ header, total_time_str, total_time / len(iterable)
180
+ )
181
+ )
182
+
183
+
184
+ class AttrDict(dict):
185
+ def __init__(self, *args, **kwargs):
186
+ super(AttrDict, self).__init__(*args, **kwargs)
187
+ self.__dict__ = self
188
+
189
+
190
+ def setup_logger():
191
+ logging.basicConfig(
192
+ level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
193
+ format="%(asctime)s [%(levelname)s] %(message)s",
194
+ handlers=[logging.StreamHandler()],
195
+ )
hawk/common/optims.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import math
9
+
10
+ from hawk.common.registry import registry
11
+
12
+
13
+ @registry.register_lr_scheduler("linear_warmup_step_lr")
14
+ class LinearWarmupStepLRScheduler:
15
+ def __init__(
16
+ self,
17
+ optimizer,
18
+ max_epoch,
19
+ min_lr,
20
+ init_lr,
21
+ decay_rate=1,
22
+ warmup_start_lr=-1,
23
+ warmup_steps=0,
24
+ **kwargs
25
+ ):
26
+ self.optimizer = optimizer
27
+
28
+ self.max_epoch = max_epoch
29
+ self.min_lr = min_lr
30
+
31
+ self.decay_rate = decay_rate
32
+
33
+ self.init_lr = init_lr
34
+ self.warmup_steps = warmup_steps
35
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36
+
37
+ def step(self, cur_epoch, cur_step):
38
+ if cur_epoch == 0:
39
+ warmup_lr_schedule(
40
+ step=cur_step,
41
+ optimizer=self.optimizer,
42
+ max_step=self.warmup_steps,
43
+ init_lr=self.warmup_start_lr,
44
+ max_lr=self.init_lr,
45
+ )
46
+ else:
47
+ step_lr_schedule(
48
+ epoch=cur_epoch,
49
+ optimizer=self.optimizer,
50
+ init_lr=self.init_lr,
51
+ min_lr=self.min_lr,
52
+ decay_rate=self.decay_rate,
53
+ )
54
+
55
+
56
+ @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57
+ class LinearWarmupCosineLRScheduler:
58
+ def __init__(
59
+ self,
60
+ optimizer,
61
+ max_epoch,
62
+ iters_per_epoch,
63
+ min_lr,
64
+ init_lr,
65
+ warmup_steps=0,
66
+ warmup_start_lr=-1,
67
+ **kwargs
68
+ ):
69
+ self.optimizer = optimizer
70
+
71
+ self.max_epoch = max_epoch
72
+ self.iters_per_epoch = iters_per_epoch
73
+ self.min_lr = min_lr
74
+
75
+ self.init_lr = init_lr
76
+ self.warmup_steps = warmup_steps
77
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
78
+
79
+ def step(self, cur_epoch, cur_step):
80
+ total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
81
+ if total_cur_step < self.warmup_steps:
82
+ warmup_lr_schedule(
83
+ step=cur_step,
84
+ optimizer=self.optimizer,
85
+ max_step=self.warmup_steps,
86
+ init_lr=self.warmup_start_lr,
87
+ max_lr=self.init_lr,
88
+ )
89
+ else:
90
+ cosine_lr_schedule(
91
+ epoch=total_cur_step,
92
+ optimizer=self.optimizer,
93
+ max_epoch=self.max_epoch * self.iters_per_epoch,
94
+ init_lr=self.init_lr,
95
+ min_lr=self.min_lr,
96
+ )
97
+
98
+
99
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
100
+ """Decay the learning rate"""
101
+ lr = (init_lr - min_lr) * 0.5 * (
102
+ 1.0 + math.cos(math.pi * epoch / max_epoch)
103
+ ) + min_lr
104
+ for param_group in optimizer.param_groups:
105
+ param_group["lr"] = lr
106
+
107
+
108
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
109
+ """Warmup the learning rate"""
110
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
111
+ for param_group in optimizer.param_groups:
112
+ param_group["lr"] = lr
113
+
114
+
115
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
116
+ """Decay the learning rate"""
117
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
118
+ for param_group in optimizer.param_groups:
119
+ param_group["lr"] = lr
hawk/common/registry.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+
9
+ class Registry:
10
+ mapping = {
11
+ "builder_name_mapping": {},
12
+ "task_name_mapping": {},
13
+ "processor_name_mapping": {},
14
+ "model_name_mapping": {},
15
+ "lr_scheduler_name_mapping": {},
16
+ "runner_name_mapping": {},
17
+ "state": {},
18
+ "paths": {},
19
+ }
20
+
21
+ @classmethod
22
+ def register_builder(cls, name):
23
+ r"""Register a dataset builder to registry with key 'name'
24
+
25
+ Args:
26
+ name: Key with which the builder will be registered.
27
+
28
+ Usage:
29
+
30
+ from video_llama.common.registry import registry
31
+ from video_llama.datasets.base_dataset_builder import BaseDatasetBuilder
32
+ """
33
+
34
+ def wrap(builder_cls):
35
+ from hawk.datasets.builders.base_dataset_builder import BaseDatasetBuilder
36
+
37
+ assert issubclass(
38
+ builder_cls, BaseDatasetBuilder
39
+ ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
40
+ builder_cls
41
+ )
42
+ if name in cls.mapping["builder_name_mapping"]:
43
+ raise KeyError(
44
+ "Name '{}' already registered for {}.".format(
45
+ name, cls.mapping["builder_name_mapping"][name]
46
+ )
47
+ )
48
+ cls.mapping["builder_name_mapping"][name] = builder_cls
49
+ return builder_cls
50
+
51
+ return wrap
52
+
53
+ @classmethod
54
+ def register_task(cls, name):
55
+ r"""Register a task to registry with key 'name'
56
+
57
+ Args:
58
+ name: Key with which the task will be registered.
59
+
60
+ Usage:
61
+
62
+ from video_llama.common.registry import registry
63
+ """
64
+
65
+ def wrap(task_cls):
66
+ from hawk.tasks.base_task import BaseTask
67
+
68
+ assert issubclass(
69
+ task_cls, BaseTask
70
+ ), "All tasks must inherit BaseTask class"
71
+ if name in cls.mapping["task_name_mapping"]:
72
+ raise KeyError(
73
+ "Name '{}' already registered for {}.".format(
74
+ name, cls.mapping["task_name_mapping"][name]
75
+ )
76
+ )
77
+ cls.mapping["task_name_mapping"][name] = task_cls
78
+ return task_cls
79
+
80
+ return wrap
81
+
82
+ @classmethod
83
+ def register_model(cls, name):
84
+ r"""Register a task to registry with key 'name'
85
+
86
+ Args:
87
+ name: Key with which the task will be registered.
88
+
89
+ Usage:
90
+
91
+ from video_llama.common.registry import registry
92
+ """
93
+
94
+ def wrap(model_cls):
95
+ from hawk.models import BaseModel
96
+
97
+ assert issubclass(
98
+ model_cls, BaseModel
99
+ ), "All models must inherit BaseModel class"
100
+ if name in cls.mapping["model_name_mapping"]:
101
+ raise KeyError(
102
+ "Name '{}' already registered for {}.".format(
103
+ name, cls.mapping["model_name_mapping"][name]
104
+ )
105
+ )
106
+ cls.mapping["model_name_mapping"][name] = model_cls
107
+ return model_cls
108
+
109
+ return wrap
110
+
111
+ @classmethod
112
+ def register_processor(cls, name):
113
+ r"""Register a processor to registry with key 'name'
114
+
115
+ Args:
116
+ name: Key with which the task will be registered.
117
+
118
+ Usage:
119
+
120
+ from video_llama.common.registry import registry
121
+ """
122
+
123
+ def wrap(processor_cls):
124
+ from hawk.processors import BaseProcessor
125
+
126
+ assert issubclass(
127
+ processor_cls, BaseProcessor
128
+ ), "All processors must inherit BaseProcessor class"
129
+ if name in cls.mapping["processor_name_mapping"]:
130
+ raise KeyError(
131
+ "Name '{}' already registered for {}.".format(
132
+ name, cls.mapping["processor_name_mapping"][name]
133
+ )
134
+ )
135
+ cls.mapping["processor_name_mapping"][name] = processor_cls
136
+ return processor_cls
137
+
138
+ return wrap
139
+
140
+ @classmethod
141
+ def register_lr_scheduler(cls, name):
142
+ r"""Register a model to registry with key 'name'
143
+
144
+ Args:
145
+ name: Key with which the task will be registered.
146
+
147
+ Usage:
148
+
149
+ from video_llama.common.registry import registry
150
+ """
151
+
152
+ def wrap(lr_sched_cls):
153
+ if name in cls.mapping["lr_scheduler_name_mapping"]:
154
+ raise KeyError(
155
+ "Name '{}' already registered for {}.".format(
156
+ name, cls.mapping["lr_scheduler_name_mapping"][name]
157
+ )
158
+ )
159
+ cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
160
+ return lr_sched_cls
161
+
162
+ return wrap
163
+
164
+ @classmethod
165
+ def register_runner(cls, name):
166
+ r"""Register a model to registry with key 'name'
167
+
168
+ Args:
169
+ name: Key with which the task will be registered.
170
+
171
+ Usage:
172
+
173
+ from video_llama.common.registry import registry
174
+ """
175
+
176
+ def wrap(runner_cls):
177
+ if name in cls.mapping["runner_name_mapping"]:
178
+ raise KeyError(
179
+ "Name '{}' already registered for {}.".format(
180
+ name, cls.mapping["runner_name_mapping"][name]
181
+ )
182
+ )
183
+ cls.mapping["runner_name_mapping"][name] = runner_cls
184
+ return runner_cls
185
+
186
+ return wrap
187
+
188
+ @classmethod
189
+ def register_path(cls, name, path):
190
+ r"""Register a path to registry with key 'name'
191
+
192
+ Args:
193
+ name: Key with which the path will be registered.
194
+
195
+ Usage:
196
+
197
+ from video_llama.common.registry import registry
198
+ """
199
+ assert isinstance(path, str), "All path must be str."
200
+ if name in cls.mapping["paths"]:
201
+ raise KeyError("Name '{}' already registered.".format(name))
202
+ cls.mapping["paths"][name] = path
203
+
204
+ @classmethod
205
+ def register(cls, name, obj):
206
+ r"""Register an item to registry with key 'name'
207
+
208
+ Args:
209
+ name: Key with which the item will be registered.
210
+
211
+ Usage::
212
+
213
+ from video_llama.common.registry import registry
214
+
215
+ registry.register("config", {})
216
+ """
217
+ path = name.split(".")
218
+ current = cls.mapping["state"]
219
+
220
+ for part in path[:-1]:
221
+ if part not in current:
222
+ current[part] = {}
223
+ current = current[part]
224
+
225
+ current[path[-1]] = obj
226
+
227
+ # @classmethod
228
+ # def get_trainer_class(cls, name):
229
+ # return cls.mapping["trainer_name_mapping"].get(name, None)
230
+
231
+ @classmethod
232
+ def get_builder_class(cls, name):
233
+ return cls.mapping["builder_name_mapping"].get(name, None)
234
+
235
+ @classmethod
236
+ def get_model_class(cls, name):
237
+ return cls.mapping["model_name_mapping"].get(name, None)
238
+
239
+ @classmethod
240
+ def get_task_class(cls, name):
241
+ return cls.mapping["task_name_mapping"].get(name, None)
242
+
243
+ @classmethod
244
+ def get_processor_class(cls, name):
245
+ return cls.mapping["processor_name_mapping"].get(name, None)
246
+
247
+ @classmethod
248
+ def get_lr_scheduler_class(cls, name):
249
+ return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
250
+
251
+ @classmethod
252
+ def get_runner_class(cls, name):
253
+ return cls.mapping["runner_name_mapping"].get(name, None)
254
+
255
+ @classmethod
256
+ def list_runners(cls):
257
+ return sorted(cls.mapping["runner_name_mapping"].keys())
258
+
259
+ @classmethod
260
+ def list_models(cls):
261
+ return sorted(cls.mapping["model_name_mapping"].keys())
262
+
263
+ @classmethod
264
+ def list_tasks(cls):
265
+ return sorted(cls.mapping["task_name_mapping"].keys())
266
+
267
+ @classmethod
268
+ def list_processors(cls):
269
+ return sorted(cls.mapping["processor_name_mapping"].keys())
270
+
271
+ @classmethod
272
+ def list_lr_schedulers(cls):
273
+ return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
274
+
275
+ @classmethod
276
+ def list_datasets(cls):
277
+ return sorted(cls.mapping["builder_name_mapping"].keys())
278
+
279
+ @classmethod
280
+ def get_path(cls, name):
281
+ return cls.mapping["paths"].get(name, None)
282
+
283
+ @classmethod
284
+ def get(cls, name, default=None, no_warning=False):
285
+ r"""Get an item from registry with key 'name'
286
+
287
+ Args:
288
+ name (string): Key whose value needs to be retrieved.
289
+ default: If passed and key is not in registry, default value will
290
+ be returned with a warning. Default: None
291
+ no_warning (bool): If passed as True, warning when key doesn't exist
292
+ will not be generated. Useful for MMF's
293
+ internal operations. Default: False
294
+ """
295
+ original_name = name
296
+ name = name.split(".")
297
+ value = cls.mapping["state"]
298
+ for subname in name:
299
+ value = value.get(subname, default)
300
+ if value is default:
301
+ break
302
+
303
+ if (
304
+ "writer" in cls.mapping["state"]
305
+ and value == default
306
+ and no_warning is False
307
+ ):
308
+ cls.mapping["state"]["writer"].warning(
309
+ "Key {} is not present in registry, returning default value "
310
+ "of {}".format(original_name, default)
311
+ )
312
+ return value
313
+
314
+ @classmethod
315
+ def unregister(cls, name):
316
+ r"""Remove an item from registry with key 'name'
317
+
318
+ Args:
319
+ name: Key which needs to be removed.
320
+ Usage::
321
+
322
+ from mmf.common.registry import registry
323
+
324
+ config = registry.unregister("config")
325
+ """
326
+ return cls.mapping["state"].pop(name, None)
327
+
328
+
329
+ registry = Registry()
hawk/common/utils.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import io
9
+ import json
10
+ import logging
11
+ import os
12
+ import pickle
13
+ import re
14
+ import shutil
15
+ import urllib
16
+ import urllib.error
17
+ import urllib.request
18
+ from typing import Optional
19
+ from urllib.parse import urlparse
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+ import yaml
24
+ from iopath.common.download import download
25
+ from iopath.common.file_io import file_lock, g_pathmgr
26
+ from hawk.common.registry import registry
27
+ from torch.utils.model_zoo import tqdm
28
+ from torchvision.datasets.utils import (
29
+ check_integrity,
30
+ download_file_from_google_drive,
31
+ extract_archive,
32
+ )
33
+
34
+
35
+ def now():
36
+ from datetime import datetime
37
+
38
+ return datetime.now().strftime("%Y%m%d%H%M")[:-1]
39
+
40
+
41
+ def is_url(url_or_filename):
42
+ parsed = urlparse(url_or_filename)
43
+ return parsed.scheme in ("http", "https")
44
+
45
+
46
+ def get_cache_path(rel_path):
47
+ return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
48
+
49
+
50
+ def get_abs_path(rel_path):
51
+ return os.path.join(registry.get_path("library_root"), rel_path)
52
+
53
+
54
+ def load_json(filename):
55
+ with open(filename, "r") as f:
56
+ return json.load(f)
57
+
58
+
59
+ # The following are adapted from torchvision and vissl
60
+ # torchvision: https://github.com/pytorch/vision
61
+ # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
62
+
63
+
64
+ def makedir(dir_path):
65
+ """
66
+ Create the directory if it does not exist.
67
+ """
68
+ is_success = False
69
+ try:
70
+ if not g_pathmgr.exists(dir_path):
71
+ g_pathmgr.mkdirs(dir_path)
72
+ is_success = True
73
+ except BaseException:
74
+ print(f"Error creating directory: {dir_path}")
75
+ return is_success
76
+
77
+
78
+ def get_redirected_url(url: str):
79
+ """
80
+ Given a URL, returns the URL it redirects to or the
81
+ original URL in case of no indirection
82
+ """
83
+ import requests
84
+
85
+ with requests.Session() as session:
86
+ with session.get(url, stream=True, allow_redirects=True) as response:
87
+ if response.history:
88
+ return response.url
89
+ else:
90
+ return url
91
+
92
+
93
+ def to_google_drive_download_url(view_url: str) -> str:
94
+ """
95
+ Utility function to transform a view URL of google drive
96
+ to a download URL for google drive
97
+ Example input:
98
+ https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
99
+ Example output:
100
+ https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
101
+ """
102
+ splits = view_url.split("/")
103
+ assert splits[-1] == "view"
104
+ file_id = splits[-2]
105
+ return f"https://drive.google.com/uc?export=download&id={file_id}"
106
+
107
+
108
+ def download_google_drive_url(url: str, output_path: str, output_file_name: str):
109
+ """
110
+ Download a file from google drive
111
+ Downloading an URL from google drive requires confirmation when
112
+ the file of the size is too big (google drive notifies that
113
+ anti-viral checks cannot be performed on such files)
114
+ """
115
+ import requests
116
+
117
+ with requests.Session() as session:
118
+
119
+ # First get the confirmation token and append it to the URL
120
+ with session.get(url, stream=True, allow_redirects=True) as response:
121
+ for k, v in response.cookies.items():
122
+ if k.startswith("download_warning"):
123
+ url = url + "&confirm=" + v
124
+
125
+ # Then download the content of the file
126
+ with session.get(url, stream=True, verify=True) as response:
127
+ makedir(output_path)
128
+ path = os.path.join(output_path, output_file_name)
129
+ total_size = int(response.headers.get("Content-length", 0))
130
+ with open(path, "wb") as file:
131
+ from tqdm import tqdm
132
+
133
+ with tqdm(total=total_size) as progress_bar:
134
+ for block in response.iter_content(
135
+ chunk_size=io.DEFAULT_BUFFER_SIZE
136
+ ):
137
+ file.write(block)
138
+ progress_bar.update(len(block))
139
+
140
+
141
+ def _get_google_drive_file_id(url: str) -> Optional[str]:
142
+ parts = urlparse(url)
143
+
144
+ if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
145
+ return None
146
+
147
+ match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
148
+ if match is None:
149
+ return None
150
+
151
+ return match.group("id")
152
+
153
+
154
+ def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
155
+ with open(filename, "wb") as fh:
156
+ with urllib.request.urlopen(
157
+ urllib.request.Request(url, headers={"User-Agent": "vissl"})
158
+ ) as response:
159
+ with tqdm(total=response.length) as pbar:
160
+ for chunk in iter(lambda: response.read(chunk_size), ""):
161
+ if not chunk:
162
+ break
163
+ pbar.update(chunk_size)
164
+ fh.write(chunk)
165
+
166
+
167
+ def download_url(
168
+ url: str,
169
+ root: str,
170
+ filename: Optional[str] = None,
171
+ md5: Optional[str] = None,
172
+ ) -> None:
173
+ """Download a file from a url and place it in root.
174
+ Args:
175
+ url (str): URL to download file from
176
+ root (str): Directory to place downloaded file in
177
+ filename (str, optional): Name to save the file under.
178
+ If None, use the basename of the URL.
179
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
180
+ """
181
+ root = os.path.expanduser(root)
182
+ if not filename:
183
+ filename = os.path.basename(url)
184
+ fpath = os.path.join(root, filename)
185
+
186
+ makedir(root)
187
+
188
+ # check if file is already present locally
189
+ if check_integrity(fpath, md5):
190
+ print("Using downloaded and verified file: " + fpath)
191
+ return
192
+
193
+ # expand redirect chain if needed
194
+ url = get_redirected_url(url)
195
+
196
+ # check if file is located on Google Drive
197
+ file_id = _get_google_drive_file_id(url)
198
+ if file_id is not None:
199
+ return download_file_from_google_drive(file_id, root, filename, md5)
200
+
201
+ # download the file
202
+ try:
203
+ print("Downloading " + url + " to " + fpath)
204
+ _urlretrieve(url, fpath)
205
+ except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
206
+ if url[:5] == "https":
207
+ url = url.replace("https:", "http:")
208
+ print(
209
+ "Failed download. Trying https -> http instead."
210
+ " Downloading " + url + " to " + fpath
211
+ )
212
+ _urlretrieve(url, fpath)
213
+ else:
214
+ raise e
215
+
216
+ # check integrity of downloaded file
217
+ if not check_integrity(fpath, md5):
218
+ raise RuntimeError("File not found or corrupted.")
219
+
220
+
221
+ def download_and_extract_archive(
222
+ url: str,
223
+ download_root: str,
224
+ extract_root: Optional[str] = None,
225
+ filename: Optional[str] = None,
226
+ md5: Optional[str] = None,
227
+ remove_finished: bool = False,
228
+ ) -> None:
229
+ download_root = os.path.expanduser(download_root)
230
+ if extract_root is None:
231
+ extract_root = download_root
232
+ if not filename:
233
+ filename = os.path.basename(url)
234
+
235
+ download_url(url, download_root, filename, md5)
236
+
237
+ archive = os.path.join(download_root, filename)
238
+ print("Extracting {} to {}".format(archive, extract_root))
239
+ extract_archive(archive, extract_root, remove_finished)
240
+
241
+
242
+ def cache_url(url: str, cache_dir: str) -> str:
243
+ """
244
+ This implementation downloads the remote resource and caches it locally.
245
+ The resource will only be downloaded if not previously requested.
246
+ """
247
+ parsed_url = urlparse(url)
248
+ dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
249
+ makedir(dirname)
250
+ filename = url.split("/")[-1]
251
+ cached = os.path.join(dirname, filename)
252
+ with file_lock(cached):
253
+ if not os.path.isfile(cached):
254
+ logging.info(f"Downloading {url} to {cached} ...")
255
+ cached = download(url, dirname, filename=filename)
256
+ logging.info(f"URL {url} cached in {cached}")
257
+ return cached
258
+
259
+
260
+ # TODO (prigoyal): convert this into RAII-style API
261
+ def create_file_symlink(file1, file2):
262
+ """
263
+ Simply create the symlinks for a given file1 to file2.
264
+ Useful during model checkpointing to symlinks to the
265
+ latest successful checkpoint.
266
+ """
267
+ try:
268
+ if g_pathmgr.exists(file2):
269
+ g_pathmgr.rm(file2)
270
+ g_pathmgr.symlink(file1, file2)
271
+ except Exception as e:
272
+ logging.info(f"Could NOT create symlink. Error: {e}")
273
+
274
+
275
+ def save_file(data, filename, append_to_json=True, verbose=True):
276
+ """
277
+ Common i/o utility to handle saving data to various file formats.
278
+ Supported:
279
+ .pkl, .pickle, .npy, .json
280
+ Specifically for .json, users have the option to either append (default)
281
+ or rewrite by passing in Boolean value to append_to_json.
282
+ """
283
+ if verbose:
284
+ logging.info(f"Saving data to file: {filename}")
285
+ file_ext = os.path.splitext(filename)[1]
286
+ if file_ext in [".pkl", ".pickle"]:
287
+ with g_pathmgr.open(filename, "wb") as fopen:
288
+ pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
289
+ elif file_ext == ".npy":
290
+ with g_pathmgr.open(filename, "wb") as fopen:
291
+ np.save(fopen, data)
292
+ elif file_ext == ".json":
293
+ if append_to_json:
294
+ with g_pathmgr.open(filename, "a") as fopen:
295
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
296
+ fopen.flush()
297
+ else:
298
+ with g_pathmgr.open(filename, "w") as fopen:
299
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
300
+ fopen.flush()
301
+ elif file_ext == ".yaml":
302
+ with g_pathmgr.open(filename, "w") as fopen:
303
+ dump = yaml.dump(data)
304
+ fopen.write(dump)
305
+ fopen.flush()
306
+ else:
307
+ raise Exception(f"Saving {file_ext} is not supported yet")
308
+
309
+ if verbose:
310
+ logging.info(f"Saved data to file: {filename}")
311
+
312
+
313
+ def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
314
+ """
315
+ Common i/o utility to handle loading data from various file formats.
316
+ Supported:
317
+ .pkl, .pickle, .npy, .json
318
+ For the npy files, we support reading the files in mmap_mode.
319
+ If the mmap_mode of reading is not successful, we load data without the
320
+ mmap_mode.
321
+ """
322
+ if verbose:
323
+ logging.info(f"Loading data from file: {filename}")
324
+
325
+ file_ext = os.path.splitext(filename)[1]
326
+ if file_ext == ".txt":
327
+ with g_pathmgr.open(filename, "r") as fopen:
328
+ data = fopen.readlines()
329
+ elif file_ext in [".pkl", ".pickle"]:
330
+ with g_pathmgr.open(filename, "rb") as fopen:
331
+ data = pickle.load(fopen, encoding="latin1")
332
+ elif file_ext == ".npy":
333
+ if mmap_mode:
334
+ try:
335
+ with g_pathmgr.open(filename, "rb") as fopen:
336
+ data = np.load(
337
+ fopen,
338
+ allow_pickle=allow_pickle,
339
+ encoding="latin1",
340
+ mmap_mode=mmap_mode,
341
+ )
342
+ except ValueError as e:
343
+ logging.info(
344
+ f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
345
+ )
346
+ data = np.load(
347
+ filename,
348
+ allow_pickle=allow_pickle,
349
+ encoding="latin1",
350
+ mmap_mode=mmap_mode,
351
+ )
352
+ logging.info("Successfully loaded without g_pathmgr")
353
+ except Exception:
354
+ logging.info("Could not mmap without g_pathmgr. Trying without mmap")
355
+ with g_pathmgr.open(filename, "rb") as fopen:
356
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
357
+ else:
358
+ with g_pathmgr.open(filename, "rb") as fopen:
359
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
360
+ elif file_ext == ".json":
361
+ with g_pathmgr.open(filename, "r") as fopen:
362
+ data = json.load(fopen)
363
+ elif file_ext == ".yaml":
364
+ with g_pathmgr.open(filename, "r") as fopen:
365
+ data = yaml.load(fopen, Loader=yaml.FullLoader)
366
+ elif file_ext == ".csv":
367
+ with g_pathmgr.open(filename, "r") as fopen:
368
+ data = pd.read_csv(fopen)
369
+ else:
370
+ raise Exception(f"Reading from {file_ext} is not supported yet")
371
+ return data
372
+
373
+
374
+ def abspath(resource_path: str):
375
+ """
376
+ Make a path absolute, but take into account prefixes like
377
+ "http://" or "manifold://"
378
+ """
379
+ regex = re.compile(r"^\w+://")
380
+ if regex.match(resource_path) is None:
381
+ return os.path.abspath(resource_path)
382
+ else:
383
+ return resource_path
384
+
385
+
386
+ def makedir(dir_path):
387
+ """
388
+ Create the directory if it does not exist.
389
+ """
390
+ is_success = False
391
+ try:
392
+ if not g_pathmgr.exists(dir_path):
393
+ g_pathmgr.mkdirs(dir_path)
394
+ is_success = True
395
+ except BaseException:
396
+ logging.info(f"Error creating directory: {dir_path}")
397
+ return is_success
398
+
399
+
400
+ def is_url(input_url):
401
+ """
402
+ Check if an input string is a url. look for http(s):// and ignoring the case
403
+ """
404
+ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
405
+ return is_url
406
+
407
+
408
+ def cleanup_dir(dir):
409
+ """
410
+ Utility for deleting a directory. Useful for cleaning the storage space
411
+ that contains various training artifacts like checkpoints, data etc.
412
+ """
413
+ if os.path.exists(dir):
414
+ logging.info(f"Deleting directory: {dir}")
415
+ shutil.rmtree(dir)
416
+ logging.info(f"Deleted contents of directory: {dir}")
417
+
418
+
419
+ def get_file_size(filename):
420
+ """
421
+ Given a file, get the size of file in MB
422
+ """
423
+ size_in_mb = os.path.getsize(filename) / float(1024**2)
424
+ return size_in_mb
hawk/configs/datasets/instruct/llava_instruct.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ datasets:
2
+ llava_instruct:
3
+ data_type: image
4
+ build_info:
5
+ anno_dir: /path/llava_instruct_150k.json
6
+ videos_dir: /path/train2014/train2014/
hawk/configs/datasets/instruct/webvid_instruct.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ datasets:
2
+ webvid_instruct:
3
+ data_type: image
4
+ build_info:
5
+ anno_dir: /path/webvid_align/videochat_instruct_11k.json
6
+ videos_dir: /path/webvid_align/videos/
hawk/configs/datasets/webvid/defaults.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ datasets:
2
+ webvid:
3
+ data_type: video
4
+ build_info:
5
+ anno_dir: path/webvid/webvid_tain_data/annotations/
6
+ videos_dir: path//webvid/webvid_tain_data/videos/
hawk/configs/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ env:
2
+ # For default users
3
+ # cache_root: "cache"
4
+ # For internal use with persistent storage
5
+ cache_root: "/export/home/.cache/minigpt4"
hawk/configs/models/minigpt4.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: mini_gpt4
3
+
4
+ # vit encoder
5
+ image_size: 224
6
+ drop_path_rate: 0
7
+ use_grad_checkpoint: False
8
+ vit_precision: "fp16"
9
+ freeze_vit: True
10
+ freeze_qformer: True
11
+
12
+ # Q-Former
13
+ num_query_token: 32
14
+
15
+ # Vicuna
16
+ llama_model: "ckpt/vicuna-13b/"
17
+
18
+ # generation configs
19
+ prompt: ""
20
+
21
+ preprocess:
22
+ vis_processor:
23
+ train:
24
+ name: "blip2_image_train"
25
+ image_size: 224
26
+ eval:
27
+ name: "blip2_image_eval"
28
+ image_size: 224
29
+ text_processor:
30
+ train:
31
+ name: "blip_caption"
32
+ eval:
33
+ name: "blip_caption"
hawk/configs/models/video_llama.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: video_llama
3
+
4
+ # vit encoder
5
+ image_size: 224
6
+ drop_path_rate: 0
7
+ use_grad_checkpoint: False
8
+ vit_precision: "fp16"
9
+ freeze_vit: True
10
+ freeze_qformer: True
11
+
12
+ # Q-Former
13
+ num_query_token: 32
14
+
15
+ # Vicuna
16
+ llama_model: "ckpt/vicuna-7b/"
17
+
18
+ # generation configs
19
+ prompt: ""
20
+
21
+ preprocess:
22
+ vis_processor:
23
+ train:
24
+ name: "alpro_video_train"
25
+ image_size: 224
26
+ n_frms: 8
27
+ eval:
28
+ name: "alpro_video_eval"
29
+ image_size: 224
30
+ n_frms: 8
31
+ text_processor:
32
+ train:
33
+ name: "blip_caption"
34
+ eval:
35
+ name: "blip_caption"
36
+
hawk/conversation/__init__.py ADDED
File without changes
hawk/conversation/conversation_video.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Conversation prompt template of Video-LLaMA.
3
+ Adapted from: https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/conversation/conversation.py
4
+ """
5
+ import argparse
6
+ import time
7
+ from PIL import Image
8
+ import sys
9
+ import os
10
+ import torch
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
12
+ from transformers import StoppingCriteria, StoppingCriteriaList
13
+
14
+ import dataclasses
15
+ from enum import auto, Enum
16
+ from typing import List, Tuple, Any
17
+ import os
18
+ from hawk.common.registry import registry
19
+ from hawk.processors.video_processor import ToTHWC,ToUint8,load_video,load_video_motion
20
+ from hawk.processors import Blip2ImageEvalProcessor
21
+
22
+ from hawk.models.ImageBind.data import load_and_transform_audio_data
23
+ class SeparatorStyle(Enum):
24
+ """Different separator style."""
25
+ SINGLE = auto()
26
+ TWO = auto()
27
+ LLAMA_2 = auto()
28
+
29
+
30
+ @dataclasses.dataclass
31
+ class Conversation:
32
+ """A class that keeps all conversation history."""
33
+ system: str
34
+ roles: List[str]
35
+ messages: List[List[str]]
36
+ offset: int
37
+ # system_img: List[Image.Image] = []
38
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
39
+ sep: str = "###"
40
+ sep2: str = None
41
+
42
+ skip_next: bool = False
43
+ conv_id: Any = None
44
+
45
+ def get_prompt(self):
46
+ if self.sep_style == SeparatorStyle.SINGLE:
47
+ ret = self.system + self.sep
48
+ for role, message in self.messages:
49
+ if message:
50
+ ret += role + ": " + message + self.sep
51
+ else:
52
+ ret += role + ":"
53
+ return ret
54
+ elif self.sep_style == SeparatorStyle.TWO:
55
+ seps = [self.sep, self.sep2]
56
+ ret = self.system + seps[0]
57
+ for i, (role, message) in enumerate(self.messages):
58
+ if message:
59
+ ret += role + ": " + message + seps[i % 2]
60
+ else:
61
+ ret += role + ":"
62
+ return ret
63
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
64
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
65
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
66
+ ret = ""
67
+
68
+ for i, (role, message) in enumerate(self.messages):
69
+ if i == 0:
70
+ assert message, "first message should not be none"
71
+ assert role == self.roles[0], "first message should come from user"
72
+ if message:
73
+ if type(message) is tuple:
74
+ message, _, _ = message
75
+ if i == 0: message = wrap_sys(self.system) + message
76
+ if i % 2 == 0:
77
+ message = wrap_inst(message)
78
+ ret += self.sep + message
79
+ else:
80
+ ret += " " + message + " " + self.sep2
81
+ else:
82
+ ret += ""
83
+ ret = ret.lstrip(self.sep)
84
+ return ret
85
+ else:
86
+ raise ValueError(f"Invalid style: {self.sep_style}")
87
+
88
+ def append_message(self, role, message):
89
+ self.messages.append([role, message])
90
+
91
+ def to_gradio_chatbot(self):
92
+ ret = []
93
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
94
+ if i % 2 == 0:
95
+ ret.append([msg, None])
96
+ else:
97
+ ret[-1][-1] = msg
98
+ return ret
99
+
100
+ def copy(self):
101
+ return Conversation(
102
+ system=self.system,
103
+ # system_img=self.system_img,
104
+ roles=self.roles,
105
+ messages=[[x, y] for x, y in self.messages],
106
+ offset=self.offset,
107
+ sep_style=self.sep_style,
108
+ sep=self.sep,
109
+ sep2=self.sep2,
110
+ conv_id=self.conv_id)
111
+
112
+ def dict(self):
113
+ return {
114
+ "system": self.system,
115
+ # "system_img": self.system_img,
116
+ "roles": self.roles,
117
+ "messages": self.messages,
118
+ "offset": self.offset,
119
+ "sep": self.sep,
120
+ "sep2": self.sep2,
121
+ "conv_id": self.conv_id,
122
+ }
123
+
124
+
125
+ class StoppingCriteriaSub(StoppingCriteria):
126
+
127
+ def __init__(self, stops=[], encounters=1):
128
+ super().__init__()
129
+ self.stops = stops
130
+
131
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
132
+ for stop in self.stops:
133
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
134
+ return True
135
+
136
+ return False
137
+
138
+
139
+ CONV_VISION = Conversation(
140
+ system="Give the following image: <Img>ImageContent</Img>. "
141
+ "You will be able to see the image once I provide it to you. Please answer my questions.",
142
+ roles=("Human", "Assistant"),
143
+ messages=[],
144
+ offset=0,
145
+ sep_style=SeparatorStyle.SINGLE,
146
+ sep="###",
147
+ )
148
+
149
+ default_conversation = Conversation(
150
+ system="",
151
+ roles=("Human", "Assistant"),
152
+ messages=[],
153
+ offset=0,
154
+ sep_style=SeparatorStyle.SINGLE,
155
+ sep="###",
156
+ )
157
+ conv_llava_llama_2 = Conversation(
158
+ system="You are a helpful language and vision assistant. "
159
+ "You are able to understand the visual content that the user provides, "
160
+ "and assist the user with a variety of tasks using natural language.",
161
+ roles=("USER", "ASSISTANT"),
162
+ messages=(),
163
+ offset=0,
164
+ sep_style=SeparatorStyle.LLAMA_2,
165
+ sep="<s>",
166
+ sep2="</s>",
167
+ )
168
+ class Chat:
169
+ def __init__(self, model, vis_processor, device='cuda:0'):
170
+ self.device = device
171
+ self.model = model
172
+ self.vis_processor = vis_processor
173
+ self.image_vis_processor = Blip2ImageEvalProcessor()
174
+ # stop_words_ids = [torch.tensor([835]).to(self.device),
175
+ # torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
176
+ # self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
177
+
178
+ def ask(self, text, conv):
179
+ if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
180
+ and ('</Video>' in conv.messages[-1][1] or '</Image>' in conv.messages[-1][1]): # last message is image.
181
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
182
+ else:
183
+ conv.append_message(conv.roles[0], text)
184
+
185
+ def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
186
+ repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
187
+ conv.append_message(conv.roles[1], None)
188
+ embs = self.get_context_emb(conv, img_list) #torch.Size([1, 312, 4096])
189
+
190
+ current_max_len = embs.shape[1] + max_new_tokens
191
+ if current_max_len - max_length > 0:
192
+ print('Warning: The number of tokens in current conversation exceeds the max length. '
193
+ 'The model will not see the contexts outside the range.')
194
+ begin_idx = max(0, current_max_len - max_length)
195
+
196
+ embs = embs[:, begin_idx:]
197
+ if conv.sep =="###":
198
+ stop_words_ids = [torch.tensor([835]).to(self.device),
199
+ torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
200
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
201
+ else:
202
+ stop_words_ids = [torch.tensor([2]).to(self.device)]
203
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
204
+
205
+ # stopping_criteria
206
+ outputs = self.model.llama_model.generate(
207
+ inputs_embeds=embs, #torch.Size([1, 312, 4096])
208
+ max_new_tokens=max_new_tokens,
209
+ stopping_criteria=stopping_criteria,
210
+ num_beams=num_beams,
211
+ do_sample=True,
212
+ min_length=min_length,
213
+ top_p=top_p,
214
+ repetition_penalty=repetition_penalty,
215
+ length_penalty=length_penalty,
216
+ temperature=temperature,
217
+ )
218
+ output_token = outputs[0]
219
+ if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
220
+ output_token = output_token[1:]
221
+ if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
222
+ output_token = output_token[1:]
223
+ output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
224
+ # TODO: add saving file
225
+
226
+ if conv.sep =="###":
227
+ output_text = output_text.split('###')[0] # remove the stop sign '###'
228
+ output_text = output_text.split('Assistant:')[-1].strip()
229
+ else:
230
+ output_text = output_text.split(conv.sep2)[0] # remove the stop sign '###'
231
+ output_text = output_text.split(conv.roles[1]+':')[-1].strip()
232
+ conv.messages[-1][1] = output_text
233
+ return output_text, output_token.cpu().numpy()
234
+
235
+ def upload_video(self, video_path, conv, img_list):
236
+
237
+ msg = ""
238
+ if isinstance(video_path, str): # is a video path
239
+ ext = os.path.splitext(video_path)[-1].lower()
240
+ print(video_path)
241
+ # image = self.vis_processor(image).unsqueeze(0).to(self.device)
242
+ video, msg = load_video(
243
+ video_path=video_path,
244
+ n_frms=32,
245
+ height=224,
246
+ width=224,
247
+ sampling ="uniform", return_msg = True
248
+ )
249
+ video = self.vis_processor.transform(video)
250
+ video = video.unsqueeze(0).to(self.device)
251
+ # print(image)
252
+ else:
253
+ raise NotImplementedError
254
+
255
+ try:
256
+ audio_flag = 1
257
+ audio = load_and_transform_audio_data([video_path],"cpu", clips_per_video=8)
258
+ audio = audio.to(self.device)
259
+ except :
260
+ print('no audio is found')
261
+ audio_flag = 0
262
+ finally:
263
+ if audio_flag == 1:
264
+ # image_emb, _ = self.model.encode_videoQformer_audiovideo(video,audio)
265
+ image_emb, _ = self.model.encode_videoQformer_visual(video)
266
+ audio_emb,_ = self.model.encode_audioQformer(audio)
267
+ img_list.append(audio_emb)
268
+ img_list.append(image_emb)
269
+ conv.system = ""
270
+ # conv.append_message(conv.roles[0], "The audio of this video is <Video><ImageHere></Video> ")
271
+ conv.append_message(conv.roles[0], "Close your eyes, open your ears and you imagine only based on the sound that: <ImageHere>. \
272
+ Close your ears, open your eyes and you see that <Video><ImageHere></Video>. \
273
+ Now answer my question based on what you have just seen and heard.")
274
+
275
+ else: # only vison no audio
276
+ # conv.system = "You can understand the video that the user provides. Follow the instructions carefully and explain your answers in detail."
277
+ image_emb, _ = self.model.encode_videoQformer_visual(video)
278
+ img_list.append(image_emb)
279
+ conv.append_message(conv.roles[0], "<Video><ImageHere></Video> "+ msg)
280
+ return "Received."
281
+
282
+ def upload_video_without_audio(self, video_path, conv, img_list):
283
+ msg = ""
284
+ if isinstance(video_path, str): # is a video path
285
+ ext = os.path.splitext(video_path)[-1].lower()
286
+ print(video_path)
287
+ # image = self.vis_processor(image).unsqueeze(0).to(self.device)
288
+ video, msg = load_video(
289
+ video_path=video_path,
290
+ n_frms=32,
291
+ height=224,
292
+ width=224,
293
+ sampling ="uniform", return_msg = True
294
+ )
295
+ video_motion, msg_motion = load_video_motion(
296
+ video_path=video_path,
297
+ n_frms=32,
298
+ height=224,
299
+ width=224,
300
+ sampling ="uniform", return_msg = True
301
+ )
302
+ video = self.vis_processor.transform(video)
303
+ video_motion = self.vis_processor.transform(video_motion)
304
+
305
+ video = video.unsqueeze(0).to(self.device)
306
+ video_motion = video_motion.unsqueeze(0).to(self.device)
307
+ # print(image)
308
+ else:
309
+ raise NotImplementedError
310
+
311
+
312
+ # conv.system = "You can understand the video that the user provides. Follow the instructions carefully and explain your answers in detail."
313
+ image_emb, _, _ = self.model.encode_videoQformer_visual(video) # 1,32,4096
314
+ image_motion_emb, _, _ = self.model.encode_videoQformer_visual(video_motion, motion=True) # 1,32,4096
315
+ img_list.append(torch.cat((image_emb, image_motion_emb), dim=1))
316
+ # img_list.append(image_motion_emb)
317
+ conv.append_message(conv.roles[0], "<Video><ImageHere></Video> ")
318
+ return "Received."
319
+
320
+ def upload_img(self, image, conv, img_list):
321
+
322
+ msg = ""
323
+ if isinstance(image, str): # is a image path
324
+ raw_image = Image.open(image).convert('RGB') # 增加一个时间维度
325
+ image = self.image_vis_processor(raw_image).unsqueeze(0).unsqueeze(2).to(self.device)
326
+ elif isinstance(image, Image.Image):
327
+ raw_image = image
328
+ image = self.image_vis_processor(raw_image).unsqueeze(0).unsqueeze(2).to(self.device)
329
+ elif isinstance(image, torch.Tensor):
330
+ if len(image.shape) == 3:
331
+ image = image.unsqueeze(0)
332
+ image = image.to(self.device)
333
+ else:
334
+ raise NotImplementedError
335
+
336
+ image_emb, _ = self.model.encode_videoQformer_visual(image)
337
+ img_list.append(image_emb)
338
+ # Todo msg=""
339
+ conv.append_message(conv.roles[0], "<Image><ImageHere></Image> "+ msg)
340
+
341
+ return "Received."
342
+
343
+ def get_context_emb(self, conv, img_list):
344
+ prompt = conv.get_prompt()
345
+ prompt_segs = prompt.split('<ImageHere>')
346
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
347
+ seg_tokens = [
348
+ self.model.llama_tokenizer(
349
+ seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
350
+ # only add bos to the first seg
351
+ for i, seg in enumerate(prompt_segs)
352
+ ]
353
+ seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens] #torch.Size([1, 44, 4096]), torch.Size([1, 204, 4096])
354
+ mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] #torch.Size([1, 64, 4096])
355
+ mixed_embs = torch.cat(mixed_embs, dim=1)
356
+ return mixed_embs
357
+
358
+ if __name__ =='__main__':
359
+ video_path = '/mnt/workspace/videoGPT/Video-LLaMA/examples/applausing.mp4'
360
+ # import torch.classes.torchaudio.ffmpeg_StreamReader
361
+ # ffmpeg_StreamReader(video_path)
362
+ load_and_transform_audio_data([video_path],"cpu", clips_per_video=8)
hawk/datasets/__init__.py ADDED
File without changes
hawk/datasets/builders/__init__.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from hawk.datasets.builders.base_dataset_builder import load_dataset_config
9
+ # from hawk.datasets.builders.image_text_pair_builder import (
10
+ # CCSBUBuilder,
11
+ # LaionBuilder,
12
+ # CCSBUAlignBuilder
13
+ # )
14
+ from hawk.datasets.builders.video_caption_builder import WebvidBuilder
15
+ from hawk.common.registry import registry
16
+ from hawk.datasets.builders.instruct_builder import WebvidInstruct_Builder
17
+ __all__ = [
18
+ # "CCSBUBuilder",
19
+ # "LaionBuilder",
20
+ # "CCSBUAlignBuilder",
21
+ "WebvidBuilder",
22
+ # "LlavaInstruct_Builder",
23
+ "WebvidInstruct_Builder"
24
+
25
+ ]
26
+
27
+
28
+ def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
29
+ """
30
+ Example
31
+
32
+ >>> dataset = load_dataset("coco_caption", cfg=None)
33
+ >>> splits = dataset.keys()
34
+ >>> print([len(dataset[split]) for split in splits])
35
+
36
+ """
37
+ if cfg_path is None:
38
+ cfg = None
39
+ else:
40
+ cfg = load_dataset_config(cfg_path)
41
+
42
+ try:
43
+ builder = registry.get_builder_class(name)(cfg)
44
+ except TypeError:
45
+ print(
46
+ f"Dataset {name} not found. Available datasets:\n"
47
+ + ", ".join([str(k) for k in dataset_zoo.get_names()])
48
+ )
49
+ exit(1)
50
+
51
+ if vis_path is not None:
52
+ if data_type is None:
53
+ # use default data type in the config
54
+ data_type = builder.config.data_type
55
+
56
+ assert (
57
+ data_type in builder.config.build_info
58
+ ), f"Invalid data_type {data_type} for {name}."
59
+
60
+ builder.config.build_info.get(data_type).storage = vis_path
61
+
62
+ dataset = builder.build_datasets()
63
+ return dataset
64
+
65
+
66
+ class DatasetZoo:
67
+ def __init__(self) -> None:
68
+ self.dataset_zoo = {
69
+ k: list(v.DATASET_CONFIG_DICT.keys())
70
+ for k, v in sorted(registry.mapping["builder_name_mapping"].items())
71
+ }
72
+
73
+ def get_names(self):
74
+ return list(self.dataset_zoo.keys())
75
+
76
+
77
+ dataset_zoo = DatasetZoo()
hawk/datasets/builders/base_dataset_builder.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is from
3
+ Copyright (c) 2022, salesforce.com, inc.
4
+ All rights reserved.
5
+ SPDX-License-Identifier: BSD-3-Clause
6
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7
+ """
8
+
9
+ import logging
10
+ import os
11
+ import shutil
12
+ import warnings
13
+
14
+ from omegaconf import OmegaConf
15
+ import torch.distributed as dist
16
+ from torchvision.datasets.utils import download_url
17
+
18
+ import hawk.common.utils as utils
19
+ from hawk.common.dist_utils import is_dist_avail_and_initialized, is_main_process
20
+ from hawk.common.registry import registry
21
+ from hawk.processors.base_processor import BaseProcessor
22
+
23
+
24
+
25
+ class BaseDatasetBuilder:
26
+ train_dataset_cls, eval_dataset_cls = None, None
27
+
28
+ def __init__(self, cfg=None):
29
+ super().__init__()
30
+
31
+ if cfg is None:
32
+ # help to create datasets from default config.
33
+ self.config = load_dataset_config(self.default_config_path())
34
+ elif isinstance(cfg, str):
35
+ self.config = load_dataset_config(cfg)
36
+ else:
37
+ # when called from task.build_dataset()
38
+ self.config = cfg
39
+
40
+ self.data_type = self.config.data_type
41
+
42
+ self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
43
+ self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
44
+
45
+ def build_datasets(self):
46
+ # download, split, etc...
47
+ # only called on 1 GPU/TPU in distributed
48
+
49
+ if is_main_process():
50
+ self._download_data()
51
+
52
+ if is_dist_avail_and_initialized():
53
+ dist.barrier()
54
+
55
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
56
+ logging.info("Building datasets...")
57
+ datasets = self.build() # dataset['train'/'val'/'test']
58
+
59
+ return datasets
60
+
61
+ def build_processors(self):
62
+ vis_proc_cfg = self.config.get("vis_processor")
63
+ txt_proc_cfg = self.config.get("text_processor")
64
+
65
+ if vis_proc_cfg is not None:
66
+ vis_train_cfg = vis_proc_cfg.get("train")
67
+ vis_eval_cfg = vis_proc_cfg.get("eval")
68
+
69
+ self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
70
+ self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
71
+
72
+ if txt_proc_cfg is not None:
73
+ txt_train_cfg = txt_proc_cfg.get("train")
74
+ txt_eval_cfg = txt_proc_cfg.get("eval")
75
+
76
+ self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
77
+ self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
78
+
79
+ @staticmethod
80
+ def _build_proc_from_cfg(cfg):
81
+ return (
82
+ registry.get_processor_class(cfg.name).from_config(cfg)
83
+ if cfg is not None
84
+ else None
85
+ )
86
+
87
+ @classmethod
88
+ def default_config_path(cls, type="default"):
89
+ return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
90
+
91
+ def _download_data(self):
92
+ self._download_ann()
93
+ self._download_vis()
94
+
95
+ def _download_ann(self):
96
+ """
97
+ Download annotation files if necessary.
98
+ All the vision-language datasets should have annotations of unified format.
99
+
100
+ storage_path can be:
101
+ (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
102
+ (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
103
+
104
+ Local annotation paths should be relative.
105
+ """
106
+ anns = self.config.build_info.annotations
107
+
108
+ splits = anns.keys()
109
+
110
+ cache_root = registry.get_path("cache_root")
111
+
112
+ for split in splits:
113
+ info = anns[split]
114
+
115
+ urls, storage_paths = info.get("url", None), info.storage
116
+
117
+ if isinstance(urls, str):
118
+ urls = [urls]
119
+ if isinstance(storage_paths, str):
120
+ storage_paths = [storage_paths]
121
+
122
+ assert len(urls) == len(storage_paths)
123
+
124
+ for url_or_filename, storage_path in zip(urls, storage_paths):
125
+ # if storage_path is relative, make it full by prefixing with cache_root.
126
+ if not os.path.isabs(storage_path):
127
+ storage_path = os.path.join(cache_root, storage_path)
128
+
129
+ dirname = os.path.dirname(storage_path)
130
+ if not os.path.exists(dirname):
131
+ os.makedirs(dirname)
132
+
133
+ if os.path.isfile(url_or_filename):
134
+ src, dst = url_or_filename, storage_path
135
+ if not os.path.exists(dst):
136
+ shutil.copyfile(src=src, dst=dst)
137
+ else:
138
+ logging.info("Using existing file {}.".format(dst))
139
+ else:
140
+ if os.path.isdir(storage_path):
141
+ # if only dirname is provided, suffix with basename of URL.
142
+ raise ValueError(
143
+ "Expecting storage_path to be a file path, got directory {}".format(
144
+ storage_path
145
+ )
146
+ )
147
+ else:
148
+ filename = os.path.basename(storage_path)
149
+
150
+ download_url(url=url_or_filename, root=dirname, filename=filename)
151
+
152
+ def _download_vis(self):
153
+
154
+ storage_path = self.config.build_info.get(self.data_type).storage
155
+ storage_path = utils.get_cache_path(storage_path)
156
+
157
+ if not os.path.exists(storage_path):
158
+ warnings.warn(
159
+ f"""
160
+ The specified path {storage_path} for visual inputs does not exist.
161
+ Please provide a correct path to the visual inputs or
162
+ refer to datasets/download_scripts/README.md for downloading instructions.
163
+ """
164
+ )
165
+
166
+ def build(self):
167
+ """
168
+ Create by split datasets inheriting torch.utils.data.Datasets.
169
+
170
+ # build() can be dataset-specific. Overwrite to customize.
171
+ """
172
+ self.build_processors()
173
+
174
+ build_info = self.config.build_info
175
+
176
+ ann_info = build_info.annotations
177
+ vis_info = build_info.get(self.data_type)
178
+
179
+ datasets = dict()
180
+ for split in ann_info.keys():
181
+ if split not in ["train", "val", "test"]:
182
+ continue
183
+
184
+ is_train = split == "train"
185
+
186
+ # processors
187
+ vis_processor = (
188
+ self.vis_processors["train"]
189
+ if is_train
190
+ else self.vis_processors["eval"]
191
+ )
192
+ text_processor = (
193
+ self.text_processors["train"]
194
+ if is_train
195
+ else self.text_processors["eval"]
196
+ )
197
+
198
+ # annotation path
199
+ ann_paths = ann_info.get(split).storage
200
+ if isinstance(ann_paths, str):
201
+ ann_paths = [ann_paths]
202
+
203
+ abs_ann_paths = []
204
+ for ann_path in ann_paths:
205
+ if not os.path.isabs(ann_path):
206
+ ann_path = utils.get_cache_path(ann_path)
207
+ abs_ann_paths.append(ann_path)
208
+ ann_paths = abs_ann_paths
209
+
210
+ # visual data storage path
211
+ vis_path = os.path.join(vis_info.storage, split)
212
+
213
+ if not os.path.isabs(vis_path):
214
+ # vis_path = os.path.join(utils.get_cache_path(), vis_path)
215
+ vis_path = utils.get_cache_path(vis_path)
216
+
217
+ if not os.path.exists(vis_path):
218
+ warnings.warn("storage path {} does not exist.".format(vis_path))
219
+
220
+ # create datasets
221
+ dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
222
+ datasets[split] = dataset_cls(
223
+ vis_processor=vis_processor,
224
+ text_processor=text_processor,
225
+ ann_paths=ann_paths,
226
+ vis_root=vis_path,
227
+ )
228
+
229
+ return datasets
230
+
231
+
232
+ def load_dataset_config(cfg_path):
233
+ cfg = OmegaConf.load(cfg_path).datasets
234
+ cfg = cfg[list(cfg.keys())[0]]
235
+
236
+ return cfg
hawk/datasets/builders/image_text_pair_builder.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import warnings
4
+
5
+ from hawk.common.registry import registry
6
+ from hawk.datasets.builders.base_dataset_builder import BaseDatasetBuilder
7
+ # from hawk.datasets.datasets.laion_dataset import LaionDataset
8
+ # from hawk.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
9
+
10
+
11
+ # @registry.register_builder("cc_sbu")
12
+ # class CCSBUBuilder(BaseDatasetBuilder):
13
+ # train_dataset_cls = CCSBUDataset
14
+
15
+ # DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
16
+
17
+ # def _download_ann(self):
18
+ # pass
19
+
20
+ # def _download_vis(self):
21
+ # pass
22
+
23
+ # def build(self):
24
+ # self.build_processors()
25
+
26
+ # build_info = self.config.build_info
27
+
28
+ # datasets = dict()
29
+ # split = "train"
30
+
31
+ # # create datasets
32
+ # # [NOTE] return inner_datasets (wds.DataPipeline)
33
+ # dataset_cls = self.train_dataset_cls
34
+ # datasets[split] = dataset_cls(
35
+ # vis_processor=self.vis_processors[split],
36
+ # text_processor=self.text_processors[split],
37
+ # location=build_info.storage,
38
+ # ).inner_dataset
39
+
40
+ # return datasets
41
+
42
+
43
+ # @registry.register_builder("laion")
44
+ # class LaionBuilder(BaseDatasetBuilder):
45
+ # train_dataset_cls = LaionDataset
46
+
47
+ # DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
48
+
49
+ # def _download_ann(self):
50
+ # pass
51
+
52
+ # def _download_vis(self):
53
+ # pass
54
+
55
+ # def build(self):
56
+ # self.build_processors()
57
+
58
+ # build_info = self.config.build_info
59
+
60
+ # datasets = dict()
61
+ # split = "train"
62
+
63
+ # # create datasets
64
+ # # [NOTE] return inner_datasets (wds.DataPipeline)
65
+ # dataset_cls = self.train_dataset_cls
66
+ # datasets[split] = dataset_cls(
67
+ # vis_processor=self.vis_processors[split],
68
+ # text_processor=self.text_processors[split],
69
+ # location=build_info.storage,
70
+ # ).inner_dataset
71
+
72
+ # return datasets
73
+
74
+
75
+ # @registry.register_builder("cc_sbu_align")
76
+ # class CCSBUAlignBuilder(BaseDatasetBuilder):
77
+ # train_dataset_cls = CCSBUAlignDataset
78
+
79
+ # DATASET_CONFIG_DICT = {
80
+ # "default": "configs/datasets/cc_sbu/align.yaml",
81
+ # }
82
+
83
+ # def build_datasets(self):
84
+ # # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
85
+ # logging.info("Building datasets...")
86
+ # self.build_processors()
87
+
88
+ # build_info = self.config.build_info
89
+ # storage_path = build_info.storage
90
+
91
+ # datasets = dict()
92
+
93
+ # if not os.path.exists(storage_path):
94
+ # warnings.warn("storage path {} does not exist.".format(storage_path))
95
+
96
+ # # create datasets
97
+ # dataset_cls = self.train_dataset_cls
98
+ # datasets['train'] = dataset_cls(
99
+ # vis_processor=self.vis_processors["train"],
100
+ # text_processor=self.text_processors["train"],
101
+ # ann_paths=[os.path.join(storage_path, 'filter_cap.json')],
102
+ # vis_root=os.path.join(storage_path, 'image'),
103
+ # )
104
+
105
+ # return datasets
106
+
hawk/datasets/builders/instruct_builder.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import warnings
4
+
5
+ from hawk.common.registry import registry
6
+ from hawk.datasets.builders.base_dataset_builder import BaseDatasetBuilder
7
+ # from hawk.datasets.datasets.laion_dataset import LaionDataset
8
+ from hawk.datasets.datasets.llava_instruct_dataset import Instruct_Dataset
9
+ from hawk.datasets.datasets.video_instruct_dataset import Video_Instruct_Dataset
10
+
11
+ @registry.register_builder("instruct")
12
+ class Instruct_Builder(BaseDatasetBuilder):
13
+ train_dataset_cls = Instruct_Dataset
14
+
15
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/instruct/defaults.yaml"}
16
+
17
+ def _download_ann(self):
18
+ pass
19
+
20
+ def _download_vis(self):
21
+ pass
22
+
23
+ def build(self):
24
+ self.build_processors()
25
+ datasets = dict()
26
+ split = "train"
27
+
28
+ build_info = self.config.build_info
29
+ dataset_cls = self.train_dataset_cls
30
+ if self.config.num_video_query_token:
31
+ num_video_query_token = self.config.num_video_query_token
32
+ else:
33
+ num_video_query_token = 32
34
+
35
+ if self.config.tokenizer_name:
36
+ tokenizer_name = self.config.tokenizer_name
37
+ else:
38
+ tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/'
39
+
40
+
41
+ datasets[split] = dataset_cls(
42
+ vis_processor=self.vis_processors[split],
43
+ text_processor=self.text_processors[split],
44
+ vis_root=build_info.videos_dir,
45
+ ann_root=build_info.anno_dir,
46
+ num_video_query_token = num_video_query_token,
47
+ tokenizer_name = tokenizer_name,
48
+ data_type = self.config.data_type,
49
+ model_type = self.config.model_type
50
+ )
51
+
52
+ return datasets
53
+
54
+ @registry.register_builder("webvid_instruct")
55
+ class WebvidInstruct_Builder(Instruct_Builder):
56
+ train_dataset_cls = Video_Instruct_Dataset
57
+
58
+ DATASET_CONFIG_DICT = {
59
+ "default": "configs/datasets/instruct/webvid_instruct.yaml",
60
+ }
61
+
62
+ # @registry.register_builder("webvid_instruct_zh")
63
+ # class WebvidInstruct_zh_Builder(Instruct_Builder):
64
+ # train_dataset_cls = Video_Instruct_Dataset
65
+
66
+ # DATASET_CONFIG_DICT = {
67
+ # "default": "configs/datasets/instruct/webvid_instruct.yaml",
68
+ # }
69
+
70
+
71
+
72
+ # @registry.register_builder("llava_instruct")
73
+ # class LlavaInstruct_Builder(Instruct_Builder):
74
+ # train_dataset_cls = Instruct_Dataset
75
+
76
+ # DATASET_CONFIG_DICT = {
77
+ # "default": "configs/datasets/instruct/llava_instruct.yaml",
78
+ # }
79
+
hawk/datasets/builders/video_caption_builder.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import warnings
4
+
5
+ from hawk.common.registry import registry
6
+ from hawk.datasets.builders.base_dataset_builder import BaseDatasetBuilder
7
+ from hawk.datasets.datasets.webvid_datasets import WebvidDataset
8
+
9
+ @registry.register_builder("webvid")
10
+ class WebvidBuilder(BaseDatasetBuilder):
11
+ train_dataset_cls = WebvidDataset
12
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/webvid/defaults.yaml"}
13
+
14
+ def _download_ann(self):
15
+ pass
16
+
17
+ def _download_vis(self):
18
+ pass
19
+
20
+ def build(self):
21
+ self.build_processors()
22
+ datasets = dict()
23
+ split = "train"
24
+
25
+ build_info = self.config.build_info
26
+ dataset_cls = self.train_dataset_cls
27
+ datasets[split] = dataset_cls(
28
+ vis_processor=self.vis_processors[split],
29
+ text_processor=self.text_processors[split],
30
+ vis_root=build_info.videos_dir,
31
+ ann_root=build_info.anno_dir
32
+ )
33
+
34
+ return datasets
hawk/datasets/data_utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import gzip
9
+ import logging
10
+ import os
11
+ import random as rnd
12
+ import tarfile
13
+ import zipfile
14
+ import random
15
+ from typing import List
16
+ from tqdm import tqdm
17
+
18
+ import decord
19
+ from decord import VideoReader
20
+ import webdataset as wds
21
+ import numpy as np
22
+ import torch
23
+ from torch.utils.data.dataset import IterableDataset
24
+
25
+ from hawk.common.registry import registry
26
+ from hawk.datasets.datasets.base_dataset import ConcatDataset
27
+
28
+
29
+ decord.bridge.set_bridge("torch")
30
+ MAX_INT = registry.get("MAX_INT")
31
+
32
+
33
+ class ChainDataset(wds.DataPipeline):
34
+ r"""Dataset for chaining multiple :class:`DataPipeline` s.
35
+
36
+ This class is useful to assemble different existing dataset streams. The
37
+ chaining operation is done on-the-fly, so concatenating large-scale
38
+ datasets with this class will be efficient.
39
+
40
+ Args:
41
+ datasets (iterable of IterableDataset): datasets to be chained together
42
+ """
43
+ def __init__(self, datasets: List[wds.DataPipeline]) -> None:
44
+ super().__init__()
45
+ self.datasets = datasets
46
+ self.prob = []
47
+ self.names = []
48
+ for dataset in self.datasets:
49
+ if hasattr(dataset, 'name'):
50
+ self.names.append(dataset.name)
51
+ else:
52
+ self.names.append('Unknown')
53
+ if hasattr(dataset, 'sample_ratio'):
54
+ self.prob.append(dataset.sample_ratio)
55
+ else:
56
+ self.prob.append(1)
57
+ logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")
58
+
59
+ def __iter__(self):
60
+ datastreams = [iter(dataset) for dataset in self.datasets]
61
+ while True:
62
+ select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
63
+ yield next(select_datastream)
64
+
65
+
66
+ def apply_to_sample(f, sample):
67
+ if len(sample) == 0:
68
+ return {}
69
+
70
+ def _apply(x):
71
+ if torch.is_tensor(x):
72
+ return f(x)
73
+ elif isinstance(x, dict):
74
+ return {key: _apply(value) for key, value in x.items()}
75
+ elif isinstance(x, list):
76
+ return [_apply(x) for x in x]
77
+ else:
78
+ return x
79
+
80
+ return _apply(sample)
81
+
82
+
83
+ def move_to_cuda(sample):
84
+ def _move_to_cuda(tensor):
85
+ return tensor.cuda()
86
+
87
+ return apply_to_sample(_move_to_cuda, sample)
88
+
89
+
90
+ def prepare_sample(samples, cuda_enabled=True):
91
+ if cuda_enabled:
92
+ samples = move_to_cuda(samples)
93
+
94
+ # TODO fp16 support
95
+
96
+ return samples
97
+
98
+
99
+ def reorg_datasets_by_split(datasets):
100
+ """
101
+ Organizes datasets by split.
102
+
103
+ Args:
104
+ datasets: dict of torch.utils.data.Dataset objects by name.
105
+
106
+ Returns:
107
+ Dict of datasets by split {split_name: List[Datasets]}.
108
+ """
109
+ # if len(datasets) == 1:
110
+ # return datasets[list(datasets.keys())[0]]
111
+ # else:
112
+ reorg_datasets = dict()
113
+
114
+ # reorganize by split
115
+ for _, dataset in datasets.items():
116
+ for split_name, dataset_split in dataset.items():
117
+ if split_name not in reorg_datasets:
118
+ reorg_datasets[split_name] = [dataset_split]
119
+ else:
120
+ reorg_datasets[split_name].append(dataset_split)
121
+
122
+ return reorg_datasets
123
+
124
+
125
+ def concat_datasets(datasets):
126
+ """
127
+ Concatenates multiple datasets into a single dataset.
128
+
129
+ It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
130
+ generic IterableDataset because it requires creating separate samplers.
131
+
132
+ Now only supports conctenating training datasets and assuming validation and testing
133
+ have only a single dataset. This is because metrics should not be computed on the concatenated
134
+ datasets.
135
+
136
+ Args:
137
+ datasets: dict of torch.utils.data.Dataset objects by split.
138
+
139
+ Returns:
140
+ Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
141
+ "val" and "test" remain the same.
142
+
143
+ If the input training datasets contain both map-style and DataPipeline datasets, returns
144
+ a tuple, where the first element is a concatenated map-style dataset and the second
145
+ element is a chained DataPipeline dataset.
146
+
147
+ """
148
+ # concatenate datasets in the same split
149
+ for split_name in datasets:
150
+ if split_name != "train":
151
+ assert (
152
+ len(datasets[split_name]) == 1
153
+ ), "Do not support multiple {} datasets.".format(split_name)
154
+ datasets[split_name] = datasets[split_name][0]
155
+ else:
156
+ iterable_datasets, map_datasets = [], []
157
+ for dataset in datasets[split_name]:
158
+ if isinstance(dataset, wds.DataPipeline):
159
+ logging.info(
160
+ "Dataset {} is IterableDataset, can't be concatenated.".format(
161
+ dataset
162
+ )
163
+ )
164
+ iterable_datasets.append(dataset)
165
+ elif isinstance(dataset, IterableDataset):
166
+ raise NotImplementedError(
167
+ "Do not support concatenation of generic IterableDataset."
168
+ )
169
+ else:
170
+ map_datasets.append(dataset)
171
+
172
+ # if len(iterable_datasets) > 0:
173
+ # concatenate map-style datasets and iterable-style datasets separately
174
+ if len(iterable_datasets) > 1:
175
+ chained_datasets = (
176
+ ChainDataset(iterable_datasets)
177
+ )
178
+ elif len(iterable_datasets) == 1:
179
+ chained_datasets = iterable_datasets[0]
180
+ else:
181
+ chained_datasets = None
182
+
183
+ concat_datasets = (
184
+ ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
185
+ )
186
+
187
+ train_datasets = concat_datasets, chained_datasets
188
+ train_datasets = tuple([x for x in train_datasets if x is not None])
189
+ train_datasets = (
190
+ train_datasets[0] if len(train_datasets) == 1 else train_datasets
191
+ )
192
+
193
+ datasets[split_name] = train_datasets
194
+
195
+ return datasets
196
+
hawk/datasets/datasets/__init__.py ADDED
File without changes
hawk/datasets/datasets/base_dataset.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import json
9
+ from typing import Iterable
10
+
11
+ from torch.utils.data import Dataset, ConcatDataset
12
+ from torch.utils.data.dataloader import default_collate
13
+
14
+
15
+ class BaseDataset(Dataset):
16
+ def __init__(
17
+ self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
18
+ ):
19
+ """
20
+ vis_root (string): Root directory of images (e.g. coco/images/)
21
+ ann_root (string): directory to store the annotation file
22
+ """
23
+ self.vis_root = vis_root
24
+
25
+ self.annotation = []
26
+ for ann_path in ann_paths:
27
+ self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
28
+
29
+ self.vis_processor = vis_processor
30
+ self.text_processor = text_processor
31
+
32
+ self._add_instance_ids()
33
+
34
+ def __len__(self):
35
+ return len(self.annotation)
36
+
37
+ def collater(self, samples):
38
+ return default_collate(samples)
39
+
40
+ def set_processors(self, vis_processor, text_processor):
41
+ self.vis_processor = vis_processor
42
+ self.text_processor = text_processor
43
+
44
+ def _add_instance_ids(self, key="instance_id"):
45
+ for idx, ann in enumerate(self.annotation):
46
+ ann[key] = str(idx)
47
+
48
+
49
+ class ConcatDataset(ConcatDataset):
50
+ def __init__(self, datasets: Iterable[Dataset]) -> None:
51
+ super().__init__(datasets)
52
+
53
+ def collater(self, samples):
54
+ # TODO For now only supports datasets with same underlying collater implementations
55
+
56
+ all_keys = set()
57
+ for s in samples:
58
+ all_keys.update(s)
59
+
60
+ shared_keys = all_keys
61
+ for s in samples:
62
+ shared_keys = shared_keys & set(s.keys())
63
+
64
+ samples_shared_keys = []
65
+ for s in samples:
66
+ samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
67
+
68
+ return self.datasets[0].collater(samples_shared_keys)
hawk/datasets/datasets/caption_datasets.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ from collections import OrderedDict
10
+
11
+ from hawk.datasets.datasets.base_dataset import BaseDataset
12
+ from PIL import Image
13
+
14
+
15
+ class __DisplMixin:
16
+ def displ_item(self, index):
17
+ sample, ann = self.__getitem__(index), self.annotation[index]
18
+
19
+ return OrderedDict(
20
+ {
21
+ "file": ann["image"],
22
+ "caption": ann["caption"],
23
+ "image": sample["image"],
24
+ }
25
+ )
26
+
27
+
28
+ class CaptionDataset(BaseDataset, __DisplMixin):
29
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
30
+ """
31
+ vis_root (string): Root directory of images (e.g. coco/images/)
32
+ ann_root (string): directory to store the annotation file
33
+ """
34
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
35
+
36
+ self.img_ids = {}
37
+ n = 0
38
+ for ann in self.annotation:
39
+ img_id = ann["image_id"]
40
+ if img_id not in self.img_ids.keys():
41
+ self.img_ids[img_id] = n
42
+ n += 1
43
+
44
+ def __getitem__(self, index):
45
+
46
+ # TODO this assumes image input, not general enough
47
+ ann = self.annotation[index]
48
+
49
+ img_file = '{:0>12}.jpg'.format(ann["image_id"])
50
+ image_path = os.path.join(self.vis_root, img_file)
51
+ image = Image.open(image_path).convert("RGB")
52
+
53
+ image = self.vis_processor(image)
54
+ caption = self.text_processor(ann["caption"])
55
+
56
+ return {
57
+ "image": image,
58
+ "text_input": caption,
59
+ "image_id": self.img_ids[ann["image_id"]],
60
+ }
61
+
62
+
63
+ class CaptionEvalDataset(BaseDataset, __DisplMixin):
64
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
65
+ """
66
+ vis_root (string): Root directory of images (e.g. coco/images/)
67
+ ann_root (string): directory to store the annotation file
68
+ split (string): val or test
69
+ """
70
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
71
+
72
+ def __getitem__(self, index):
73
+
74
+ ann = self.annotation[index]
75
+
76
+ image_path = os.path.join(self.vis_root, ann["image"])
77
+ image = Image.open(image_path).convert("RGB")
78
+
79
+ image = self.vis_processor(image)
80
+
81
+ return {
82
+ "image": image,
83
+ "image_id": ann["image_id"],
84
+ "instance_id": ann["instance_id"],
85
+ }
hawk/datasets/datasets/dataloader_utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import time
9
+ import random
10
+ import torch
11
+ from hawk.datasets.data_utils import move_to_cuda
12
+ from torch.utils.data import DataLoader
13
+
14
+
15
+ class MultiIterLoader:
16
+ """
17
+ A simple wrapper for iterating over multiple iterators.
18
+
19
+ Args:
20
+ loaders (List[Loader]): List of Iterator loaders.
21
+ ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
22
+ """
23
+
24
+ def __init__(self, loaders, ratios=None):
25
+ # assert all loaders has __next__ method
26
+ for loader in loaders:
27
+ assert hasattr(
28
+ loader, "__next__"
29
+ ), "Loader {} has no __next__ method.".format(loader)
30
+
31
+ if ratios is None:
32
+ ratios = [1.0] * len(loaders)
33
+ else:
34
+ assert len(ratios) == len(loaders)
35
+ ratios = [float(ratio) / sum(ratios) for ratio in ratios]
36
+
37
+ self.loaders = loaders
38
+ self.ratios = ratios
39
+
40
+ def __next__(self):
41
+ # random sample from each loader by ratio
42
+ loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
43
+ return next(self.loaders[loader_idx])
44
+
45
+
46
+ class PrefetchLoader(object):
47
+ """
48
+ Modified from https://github.com/ChenRocks/UNITER.
49
+
50
+ overlap compute and cuda data transfer
51
+ (copied and then modified from nvidia apex)
52
+ """
53
+
54
+ def __init__(self, loader):
55
+ self.loader = loader
56
+ self.stream = torch.cuda.Stream()
57
+
58
+ def __iter__(self):
59
+ loader_it = iter(self.loader)
60
+ self.preload(loader_it)
61
+ batch = self.next(loader_it)
62
+ while batch is not None:
63
+ is_tuple = isinstance(batch, tuple)
64
+ if is_tuple:
65
+ task, batch = batch
66
+
67
+ if is_tuple:
68
+ yield task, batch
69
+ else:
70
+ yield batch
71
+ batch = self.next(loader_it)
72
+
73
+ def __len__(self):
74
+ return len(self.loader)
75
+
76
+ def preload(self, it):
77
+ try:
78
+ self.batch = next(it)
79
+ except StopIteration:
80
+ self.batch = None
81
+ return
82
+ # if record_stream() doesn't work, another option is to make sure
83
+ # device inputs are created on the main stream.
84
+ # self.next_input_gpu = torch.empty_like(self.next_input,
85
+ # device='cuda')
86
+ # self.next_target_gpu = torch.empty_like(self.next_target,
87
+ # device='cuda')
88
+ # Need to make sure the memory allocated for next_* is not still in use
89
+ # by the main stream at the time we start copying to next_*:
90
+ # self.stream.wait_stream(torch.cuda.current_stream())
91
+ with torch.cuda.stream(self.stream):
92
+ self.batch = move_to_cuda(self.batch)
93
+ # more code for the alternative if record_stream() doesn't work:
94
+ # copy_ will record the use of the pinned source tensor in this
95
+ # side stream.
96
+ # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
97
+ # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
98
+ # self.next_input = self.next_input_gpu
99
+ # self.next_target = self.next_target_gpu
100
+
101
+ def next(self, it):
102
+ torch.cuda.current_stream().wait_stream(self.stream)
103
+ batch = self.batch
104
+ if batch is not None:
105
+ record_cuda_stream(batch)
106
+ self.preload(it)
107
+ return batch
108
+
109
+ def __getattr__(self, name):
110
+ method = self.loader.__getattribute__(name)
111
+ return method
112
+
113
+
114
+ def record_cuda_stream(batch):
115
+ if isinstance(batch, torch.Tensor):
116
+ batch.record_stream(torch.cuda.current_stream())
117
+ elif isinstance(batch, list) or isinstance(batch, tuple):
118
+ for t in batch:
119
+ record_cuda_stream(t)
120
+ elif isinstance(batch, dict):
121
+ for t in batch.values():
122
+ record_cuda_stream(t)
123
+ else:
124
+ pass
125
+
126
+
127
+ class IterLoader:
128
+ """
129
+ A wrapper to convert DataLoader as an infinite iterator.
130
+
131
+ Modified from:
132
+ https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
133
+ """
134
+
135
+ def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
136
+ self._dataloader = dataloader
137
+ self.iter_loader = iter(self._dataloader)
138
+ self._use_distributed = use_distributed
139
+ self._epoch = 0
140
+
141
+ @property
142
+ def epoch(self) -> int:
143
+ return self._epoch
144
+
145
+ def __next__(self):
146
+ try:
147
+ data = next(self.iter_loader)
148
+ except StopIteration:
149
+ self._epoch += 1
150
+ if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
151
+ self._dataloader.sampler.set_epoch(self._epoch)
152
+ time.sleep(2) # Prevent possible deadlock during epoch transition
153
+ self.iter_loader = iter(self._dataloader)
154
+ data = next(self.iter_loader)
155
+
156
+ return data
157
+
158
+ def __iter__(self):
159
+ return self
160
+
161
+ def __len__(self):
162
+ return len(self._dataloader)
hawk/datasets/datasets/llava_instruct_dataset.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from hawk.datasets.datasets.base_dataset import BaseDataset
3
+ from hawk.datasets.datasets.caption_datasets import CaptionDataset
4
+ import pandas as pd
5
+ import decord
6
+ from decord import VideoReader
7
+ import random
8
+ import torch
9
+ from torch.utils.data.dataloader import default_collate
10
+ from PIL import Image
11
+ from typing import Dict, Optional, Sequence
12
+ import transformers
13
+ import pathlib
14
+ import json
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
16
+ from hawk.conversation.conversation_video import Conversation,SeparatorStyle
17
+ DEFAULT_IMAGE_PATCH_TOKEN = '<ImageHere>'
18
+ DEFAULT_IMAGE_TOKEN = "<image>"
19
+ import copy
20
+ from hawk.processors import transforms_video,AlproVideoTrainProcessor
21
+ IGNORE_INDEX = -100
22
+ image_conversation = Conversation(
23
+ system="",
24
+ roles=("Human", "Assistant"),
25
+ messages=[],
26
+ offset=0,
27
+ sep_style=SeparatorStyle.SINGLE,
28
+ sep="###",
29
+ )
30
+ llama_v2_image_conversation = Conversation(
31
+ system=" ",
32
+ roles=("USER", "ASSISTANT"),
33
+ messages=(),
34
+ offset=0,
35
+ sep_style=SeparatorStyle.LLAMA_2,
36
+ sep="<s>",
37
+ sep2="</s>",
38
+ )
39
+ IGNORE_INDEX = -100
40
+
41
+ class Instruct_Dataset(BaseDataset):
42
+ def __init__(self, vis_processor, text_processor, vis_root, ann_root,num_video_query_token=32,tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/',data_type = 'image', model_type='vicuna'):
43
+ """
44
+ vis_root (string): Root directory of Llava images (e.g. webvid_eval/video/)
45
+ ann_root (string): Root directory of video (e.g. webvid_eval/annotations/)
46
+ split (string): val or test
47
+ """
48
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
49
+
50
+ data_path = pathlib.Path(ann_root)
51
+ with data_path.open(encoding='utf-8') as f:
52
+ self.annotation = json.load(f)
53
+
54
+ self.vis_root = vis_root
55
+ self.resize_size = 224
56
+ self.num_frm = 8
57
+ self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name, use_fast=False)
58
+ self.tokenizer.pad_token = self.tokenizer.unk_token
59
+ self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
60
+ self.num_video_query_token = num_video_query_token
61
+ self.IMAGE_PATCH_TOKEN_ID = self.tokenizer.get_vocab()[DEFAULT_IMAGE_PATCH_TOKEN]
62
+
63
+ self.transform = AlproVideoTrainProcessor(
64
+ image_size=self.resize_size, n_frms = self.num_frm
65
+ ).transform
66
+ self.data_type = data_type
67
+ self.model_type = model_type
68
+
69
+ def _get_image_path(self, sample):
70
+ rel_video_fp ='COCO_train2014_' + sample['image']
71
+ full_video_fp = os.path.join(self.vis_root, rel_video_fp)
72
+ return full_video_fp
73
+
74
+ def __getitem__(self, index):
75
+ num_retries = 10 # skip error videos
76
+ for _ in range(num_retries):
77
+ try:
78
+ sample = self.annotation[index]
79
+
80
+ image_path = self._get_image_path(sample)
81
+ conversation_list = sample['conversations']
82
+ image = Image.open(image_path).convert("RGB")
83
+
84
+ image = self.vis_processor(image)
85
+ # text = self.text_processor(text)
86
+ sources = preprocess_multimodal(copy.deepcopy(conversation_list), None, cur_token_len=self.num_video_query_token)
87
+ if self.model_type =='vicuna':
88
+ data_dict = preprocess(
89
+ sources,
90
+ self.tokenizer)
91
+ elif self.model_type =='llama_v2':
92
+ data_dict = preprocess_for_llama_v2(
93
+ sources,
94
+ self.tokenizer)
95
+ else:
96
+ print('not support')
97
+ raise('not support')
98
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
99
+ labels=data_dict["labels"][0])
100
+
101
+ # image exist in the data
102
+ data_dict['image'] = image
103
+ except:
104
+ print(f"Failed to load examples with image: {image_path}. "
105
+ f"Will randomly sample an example as a replacement.")
106
+ index = random.randint(0, len(self) - 1)
107
+ continue
108
+ break
109
+ else:
110
+ raise RuntimeError(f"Failed to fetch image after {num_retries} retries.")
111
+ # "image_id" is kept to stay compatible with the COCO evaluation format
112
+ return {
113
+ "image": image,
114
+ "text_input": data_dict["input_ids"],
115
+ "labels": data_dict["labels"],
116
+ "type":'image',
117
+ }
118
+
119
+ def __len__(self):
120
+ return len(self.annotation)
121
+
122
+ def collater(self, instances):
123
+ input_ids, labels = tuple([instance[key] for instance in instances]
124
+ for key in ("text_input", "labels"))
125
+ input_ids = torch.nn.utils.rnn.pad_sequence(
126
+ input_ids,
127
+ batch_first=True,
128
+ padding_value=self.tokenizer.pad_token_id)
129
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
130
+ batch_first=True,
131
+ padding_value=IGNORE_INDEX)
132
+ batch = dict(
133
+ input_ids=input_ids,
134
+ labels=labels,
135
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
136
+ )
137
+
138
+ if 'image' in instances[0]:
139
+ images = [instance['image'] for instance in instances]
140
+ if all(x is not None and x.shape == images[0].shape for x in images):
141
+ batch['images'] = torch.stack(images)
142
+ else:
143
+ batch['images'] = images
144
+ batch['conv_type'] = 'multi'
145
+ return batch
146
+
147
+
148
+ def preprocess_multimodal(
149
+ conversation_list: Sequence[str],
150
+ multimodal_cfg: dict,
151
+ cur_token_len: int,
152
+ ) -> Dict:
153
+ # 将conversational list中
154
+ is_multimodal = True
155
+ # image_token_len = multimodal_cfg['image_token_len']
156
+ image_token_len = cur_token_len
157
+
158
+ for sentence in conversation_list:
159
+ replace_token = '<Image>'+DEFAULT_IMAGE_PATCH_TOKEN * image_token_len+'</Image>'
160
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
161
+
162
+ return [conversation_list]
163
+
164
+ def _add_speaker_and_signal(header, source, get_conversation=True):
165
+ """Add speaker and start/end signal on each round."""
166
+ BEGIN_SIGNAL = "###"
167
+ END_SIGNAL = "\n"
168
+ conversation = header
169
+ for sentence in source:
170
+ from_str = sentence["from"]
171
+ if from_str.lower() == "human":
172
+ from_str = image_conversation.roles[0]
173
+ elif from_str.lower() == "gpt":
174
+ from_str = image_conversation.roles[1]
175
+ else:
176
+ from_str = 'unknown'
177
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
178
+ sentence["value"] + END_SIGNAL)
179
+ if get_conversation:
180
+ conversation += sentence["value"]
181
+ conversation += BEGIN_SIGNAL
182
+ return conversation
183
+
184
+ def _tokenize_fn(strings: Sequence[str],
185
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
186
+ """Tokenize a list of strings."""
187
+ tokenized_list = [
188
+ tokenizer(
189
+ text,
190
+ return_tensors="pt",
191
+ padding="longest",
192
+ max_length=512,
193
+ truncation=True,
194
+ ) for text in strings
195
+ ]
196
+ input_ids = labels = [
197
+ tokenized.input_ids[0] for tokenized in tokenized_list
198
+ ]
199
+ input_ids_lens = labels_lens = [
200
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
201
+ for tokenized in tokenized_list
202
+ ]
203
+ return dict(
204
+ input_ids=input_ids,
205
+ labels=labels,
206
+ input_ids_lens=input_ids_lens,
207
+ labels_lens=labels_lens,
208
+ )
209
+
210
+ def preprocess(
211
+ sources: Sequence[str],
212
+ tokenizer: transformers.PreTrainedTokenizer,
213
+ ) -> Dict:
214
+ """
215
+ Given a list of sources, each is a conversation list. This transform:
216
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
217
+ 2. Concatenate conversations together;
218
+ 3. Tokenize the concatenated conversation;
219
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
220
+ """
221
+ # add end signal and concatenate together
222
+ conversations = []
223
+ for source in sources:
224
+ header = f"{image_conversation.system}\n\n"
225
+ conversation = _add_speaker_and_signal(header, source)
226
+ conversations.append(conversation)
227
+ # tokenize conversations
228
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
229
+ input_ids = conversations_tokenized["input_ids"]
230
+ targets = copy.deepcopy(input_ids)
231
+ for target, source in zip(targets, sources):
232
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source],
233
+ tokenizer)["input_ids_lens"]
234
+ speakers = [sentence["from"] for sentence in source]
235
+ _mask_targets(target, tokenized_lens, speakers)
236
+
237
+ return dict(input_ids=input_ids, labels=targets)
238
+
239
+ def preprocess_for_llama_v2(
240
+ sources: Sequence[str],
241
+ tokenizer: transformers.PreTrainedTokenizer,
242
+ ) -> Dict:
243
+ """
244
+ Given a list of sources, each is a conversation list. This transform:
245
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
246
+ 2. Concatenate conversations together;
247
+ 3. Tokenize the concatenated conversation;
248
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
249
+ """
250
+ # add end signal and concatenate together
251
+ conversations = []
252
+ conv = copy.deepcopy(llama_v2_image_conversation.copy())
253
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
254
+ for source in sources:
255
+ # <s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n
256
+ header = f"<s>[INST] <<SYS>>\n{conv.system}\n</SYS>>\n\n"
257
+
258
+ if roles[source[0]["from"]] != conv.roles[0]:
259
+ # Skip the first one if it is not from human
260
+ source = source[1:]
261
+ conv.messages = []
262
+ for j, sentence in enumerate(source):
263
+ role = roles[sentence["from"]]
264
+ assert role == conv.roles[j % 2]
265
+ conv.append_message(role, sentence["value"])
266
+ conversations.append(conv.get_prompt())
267
+
268
+ input_ids = tokenizer(
269
+ conversations,
270
+ return_tensors="pt",
271
+ padding="longest",
272
+ max_length=512,
273
+ truncation=True,
274
+ ).input_ids
275
+ targets = copy.deepcopy(input_ids)
276
+
277
+
278
+ sep = "[/INST] "
279
+ for conversation, target in zip(conversations, targets):
280
+ # total_len = int(target.ne(tokenizer.pad_token_id).sum())
281
+ rounds = conversation.split(conv.sep2)
282
+ cur_len = 1
283
+ target[:cur_len] = IGNORE_INDEX
284
+ for i, rou in enumerate(rounds):
285
+ if rou == "":
286
+ break
287
+
288
+ parts = rou.split(sep)
289
+ if len(parts) != 2:
290
+ break
291
+ parts[0] += sep
292
+
293
+
294
+ round_len = len(tokenizer(rou).input_ids)
295
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2 # 为什么减去2,speical token 的数目
296
+
297
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
298
+
299
+ cur_len += round_len
300
+ target[cur_len:] = IGNORE_INDEX
301
+
302
+ return dict(input_ids=input_ids, labels=targets)
303
+
304
+ def _mask_targets(target, tokenized_lens, speakers):
305
+ # cur_idx = 0
306
+ cur_idx = tokenized_lens[0]
307
+ tokenized_lens = tokenized_lens[1:]
308
+ target[:cur_idx] = IGNORE_INDEX
309
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
310
+ if speaker == "human":
311
+ target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
312
+ cur_idx += tokenized_len
hawk/datasets/datasets/video_instruct_dataset.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from hawk.datasets.datasets.base_dataset import BaseDataset
3
+ from hawk.datasets.datasets.caption_datasets import CaptionDataset
4
+ import pandas as pd
5
+ import decord
6
+ from decord import VideoReader
7
+ import random
8
+ import torch
9
+ from torch.utils.data.dataloader import default_collate
10
+ from PIL import Image
11
+ from typing import Dict, Optional, Sequence
12
+ import transformers
13
+ import pathlib
14
+ import json
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
16
+ import copy
17
+ from hawk.processors import transforms_video,AlproVideoTrainProcessor
18
+ from torchvision import transforms
19
+ from hawk.processors.video_processor import ToTHWC,ToUint8,load_video,load_video_motion
20
+ from hawk.conversation.conversation_video import Conversation,SeparatorStyle
21
+ import numpy as np
22
+
23
+ #提取Motion+Entity
24
+ import spacy
25
+
26
+ # 加载SpaCy英文模型
27
+ nlp = spacy.load("en_core_web_sm")
28
+
29
+ # Define the list of questions
30
+ Question = [
31
+ "Can you describe the anomaly in the video?",
32
+ "How would you detail the anomaly found in the video?",
33
+ "What anomaly can you identify in the video?",
34
+ "Could you explain the anomaly observed in the video?",
35
+ "Can you point out the anomaly in the video?",
36
+ "What's the anomaly depicted in the video?",
37
+ "Could you specify the anomaly present in the video?",
38
+ "How do you perceive the anomaly in the video?",
39
+ "Can you highlight the anomaly within the video?",
40
+ "What anomaly is noticeable in the video?",
41
+ "Could you characterize the anomaly seen in the video?",
42
+ "Can you detail the specific anomaly encountered in the video?",
43
+ "How would you describe the particular anomaly in the video?",
44
+ "What details can you provide about the anomaly in the video?",
45
+ "Could you elucidate on the anomaly detected in the video?",
46
+ "Can you illustrate the nature of the anomaly in the video?",
47
+ "What features of the anomaly in the video can you describe?",
48
+ "Could you outline the anomaly observed in the video?",
49
+ "How does the anomaly in the video manifest?",
50
+ "Can you clarify the aspects of the anomaly in the video?"
51
+ ]
52
+
53
+
54
+ def setup_seed(seed):
55
+ torch.manual_seed(seed)
56
+ torch.cuda.manual_seed_all(seed)
57
+ np.random.seed(seed)
58
+ random.seed(seed)
59
+ torch.backends.cudnn.deterministic = True
60
+
61
+ def extract_actions_and_entities_sentence(sentence):
62
+ doc = nlp(sentence)
63
+ action_sentences = []
64
+
65
+ for token in doc:
66
+ # 检查是否为动词
67
+ if token.pos_ == "VERB":
68
+ subjects = ' and '.join(child.text for child in token.children if child.dep_ in ["nsubj", "nsubjpass"]) #主语
69
+ objects = ' and '.join(child.text for child in token.children if child.dep_ in ["dobj", "pobj", "obj"]) #宾语
70
+
71
+ # 构建包含动作和实体的句子
72
+ action_sentence = f"{subjects} {token.text} {objects}".strip()
73
+ action_sentences.append(action_sentence)
74
+
75
+ return ', '.join(action_sentences)
76
+
77
+
78
+ DEFAULT_IMAGE_PATCH_TOKEN = '<ImageHere>'
79
+ video_conversation = Conversation(
80
+ system="",
81
+ roles=("Human", "Assistant"),
82
+ messages=[],
83
+ offset=0,
84
+ sep_style=SeparatorStyle.SINGLE,
85
+ sep="###",
86
+ )
87
+ llama_v2_video_conversation = Conversation(
88
+ system=" ",
89
+ roles=("USER", "ASSISTANT"),
90
+ messages=(),
91
+ offset=0,
92
+ sep_style=SeparatorStyle.LLAMA_2,
93
+ sep="<s>",
94
+ sep2="</s>",
95
+ )
96
+ IGNORE_INDEX = -100
97
+
98
+ class Video_Instruct_Dataset(BaseDataset):
99
+ def __init__(self, vis_processor, text_processor, vis_root, ann_root,num_video_query_token=32,tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/',data_type = 'video', model_type='vicuna'):
100
+ """
101
+ vis_root (string): Root directory of Llava images (e.g. webvid_eval/video/)
102
+ ann_root (string): Root directory of video (e.g. webvid_eval/annotations/)
103
+ split (string): val or test
104
+ """
105
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
106
+
107
+ data_path = pathlib.Path(ann_root)
108
+ with data_path.open(encoding='utf-8') as f:
109
+ self.annotation = json.load(f)
110
+
111
+ self.num_video_query_token = num_video_query_token
112
+ self.vis_root = vis_root
113
+ self.resize_size = 224
114
+ self.num_frm = 32
115
+ self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name, use_fast=False)
116
+ self.tokenizer.pad_token = self.tokenizer.unk_token
117
+ self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
118
+ self.IMAGE_PATCH_TOKEN_ID = self.tokenizer.get_vocab()[DEFAULT_IMAGE_PATCH_TOKEN]
119
+
120
+ self.transform = AlproVideoTrainProcessor(
121
+ image_size=self.resize_size, n_frms = self.num_frm
122
+ ).transform
123
+ self.data_type = data_type
124
+ self.model_type = model_type
125
+
126
+ def _get_video_path(self, sample):
127
+ rel_video_fp = sample['video']
128
+ full_video_fp = os.path.join(self.vis_root, rel_video_fp)
129
+ return full_video_fp
130
+
131
+ def __getitem__(self, index):
132
+ num_retries = 10 # skip error videos
133
+ for _ in range(num_retries):
134
+ try:
135
+ sample = self.annotation[index]
136
+
137
+ video_path = self._get_video_path(sample)
138
+ # print(video_path)
139
+ conversation_list = sample['QA']
140
+
141
+ #替换为GPT的回答
142
+ conversation_answer = sample['description']
143
+
144
+ #提取Language Motion
145
+ # conversation_answer = extract_actions_and_entities_sentence(conversation_answer)
146
+
147
+ random_number = random.choice([0, 1])
148
+ if random_number == 1:
149
+ conversation_list[0]["q"] = random.choice(Question)
150
+ conversation_list[0]["a"] = conversation_answer
151
+
152
+ video, msg = load_video(
153
+ video_path=video_path,
154
+ n_frms=self.num_frm,
155
+ height=self.resize_size,
156
+ width=self.resize_size,
157
+ sampling ="uniform", return_msg = True
158
+ )
159
+ #读入动作视频
160
+ video_motion, msg_motion = load_video_motion(
161
+ video_path=video_path,
162
+ n_frms=self.num_frm,
163
+ height=self.resize_size,
164
+ width=self.resize_size,
165
+ sampling ="uniform", return_msg = True
166
+ )
167
+
168
+ random_seed = random.randint(0, 2**32 - 1)
169
+ setup_seed(random_seed)
170
+ video = self.transform(video)
171
+ video_motion = self.transform(video_motion)
172
+
173
+ if 'cn' in self.data_type:
174
+ msg = ""
175
+ # 添加视频<DEFAULT_IMAGE_PATCH_TOKEN>,以及msg到convsation list 0
176
+ sources = preprocess_multimodal(copy.deepcopy(conversation_list), None, cur_token_len=self.num_video_query_token,msg = msg)
177
+ new_sources = convert_source_vicuna_format(sources)
178
+
179
+ if self.model_type =='vicuna':
180
+ data_dict = preprocess(
181
+ new_sources,
182
+ self.tokenizer)
183
+ elif self.model_type =='llama_v2':
184
+ data_dict = preprocess_for_llama_v2(
185
+ new_sources,
186
+ self.tokenizer)
187
+ else:
188
+ print('not support')
189
+ raise('not support')
190
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
191
+ labels=data_dict["labels"][0])
192
+ # image exist in the data
193
+ data_dict['image'] = video
194
+ data_dict['image_motion'] = video_motion
195
+ except:
196
+ print(f"Failed to load examples with video: {video_path}. "
197
+ f"Will randomly sample an example as a replacement.")
198
+ index = random.randint(0, len(self) - 1)
199
+ continue
200
+ break
201
+ else:
202
+ raise RuntimeError(f"Failed to fetch video after {num_retries} retries.")
203
+ # "image_id" is kept to stay compatible with the COCO evaluation format
204
+ return {
205
+ "image": video,
206
+ "image_motion": video_motion,
207
+ "text_input": data_dict["input_ids"],
208
+ "labels": data_dict["labels"],
209
+ "type":'video',
210
+ }
211
+
212
+ def __len__(self):
213
+ return len(self.annotation)
214
+
215
+ def collater(self, instances):
216
+ input_ids, labels = tuple([instance[key] for instance in instances]
217
+ for key in ("text_input", "labels"))
218
+ input_ids = torch.nn.utils.rnn.pad_sequence(
219
+ input_ids,
220
+ batch_first=True,
221
+ padding_value=self.tokenizer.pad_token_id) # 该函数用于将这些列表中的张量填充到相同的长度。这里使用了batch_first=True参数来指定批次维度的位置,以便在后续计算中更容易处理。填充值是self.tokenizer.pad_token_id,它是用于填充输入序列的特殊标记。
222
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
223
+ batch_first=True,
224
+ padding_value=IGNORE_INDEX) #
225
+ batch = dict(
226
+ input_ids=input_ids,
227
+ labels=labels,
228
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id), #input_ids.ne方法,它返回一个布尔张量,指示输入张量中哪些元素不等于指定值。
229
+ )
230
+
231
+ if 'image' in instances[0]:
232
+ images = [instance['image'] for instance in instances]
233
+ if all(x is not None and x.shape == images[0].shape for x in images):
234
+ batch['images'] = torch.stack(images)
235
+ else:
236
+ batch['images'] = images
237
+
238
+ if 'image_motion' in instances[0]:
239
+ images_motion = [instance['image_motion'] for instance in instances]
240
+ if all(x is not None and x.shape == images_motion[0].shape for x in images_motion):
241
+ batch['images_motion'] = torch.stack(images_motion)
242
+ else:
243
+ batch['images_motion'] = images_motion
244
+
245
+ batch['conv_type'] = 'multi'
246
+ return batch
247
+
248
+ def convert_source_vicuna_format(sources):
249
+ new_sources = []
250
+ for source in sources:
251
+ new_source = []
252
+ for i, sentence in enumerate(source):
253
+ role_0_msg = sentence['q']
254
+ role_1_msg = sentence['a']
255
+ new_source.append({
256
+ 'from':'human',
257
+ 'value': role_0_msg,
258
+ })
259
+ new_source.append({
260
+ 'from':'gpt',
261
+ 'value': role_1_msg,
262
+ })
263
+ new_sources.append(new_source)
264
+ return new_sources
265
+
266
+ def preprocess_multimodal(
267
+ conversation_list: Sequence[str],
268
+ multimodal_cfg: dict,
269
+ cur_token_len: int,
270
+ msg=''
271
+ ) -> Dict:
272
+ # 将conversational list中
273
+ is_multimodal = True
274
+ # image_token_len = multimodal_cfg['image_token_len']
275
+ image_token_len = cur_token_len * 2
276
+ conversation_list[0]["q"] = "<Video>"+DEFAULT_IMAGE_PATCH_TOKEN * image_token_len +"</Video> " + msg + conversation_list[0]["q"]
277
+ return [conversation_list]
278
+
279
+ def _add_speaker_and_signal(header, source, get_conversation=True):
280
+ """Add speaker and start/end signal on each round."""
281
+ BEGIN_SIGNAL = "###"
282
+ END_SIGNAL = "\n"
283
+ conversation = header
284
+ for sentence in source:
285
+ from_str = sentence["from"]
286
+ if from_str.lower() == "human":
287
+ from_str = video_conversation.roles[0]
288
+ elif from_str.lower() == "gpt":
289
+ from_str = video_conversation.roles[1]
290
+ else:
291
+ from_str = 'unknown'
292
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
293
+ sentence["value"] + END_SIGNAL)
294
+ if get_conversation:
295
+ conversation += sentence["value"]
296
+ conversation += BEGIN_SIGNAL
297
+ return conversation
298
+
299
+ def _tokenize_fn(strings: Sequence[str],
300
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
301
+ """Tokenize a list of strings."""
302
+ tokenized_list = [
303
+ tokenizer(
304
+ text,
305
+ return_tensors="pt",
306
+ padding="longest",
307
+ max_length=512,
308
+ truncation=True,
309
+ ) for text in strings
310
+ ]
311
+ input_ids = labels = [
312
+ tokenized.input_ids[0] for tokenized in tokenized_list
313
+ ]
314
+ input_ids_lens = labels_lens = [
315
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
316
+ for tokenized in tokenized_list
317
+ ]
318
+ return dict(
319
+ input_ids=input_ids,
320
+ labels=labels,
321
+ input_ids_lens=input_ids_lens,
322
+ labels_lens=labels_lens,
323
+ )
324
+
325
+ def preprocess(
326
+ sources: Sequence[str],
327
+ tokenizer: transformers.PreTrainedTokenizer,
328
+ ) -> Dict:
329
+ """
330
+ Given a list of sources, each is a conversation list. This transform:
331
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
332
+ 2. Concatenate conversations together;
333
+ 3. Tokenize the concatenated conversation;
334
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
335
+ """
336
+ # add end signal and concatenate together
337
+ conversations = []
338
+ for source in sources:
339
+ header = f"{video_conversation.system}\n\n"
340
+ conversation = _add_speaker_and_signal(header, source)
341
+ conversations.append(conversation)
342
+ # tokenize conversations
343
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
344
+ input_ids = conversations_tokenized["input_ids"]
345
+ targets = copy.deepcopy(input_ids)
346
+ for target, source in zip(targets, sources):
347
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source],
348
+ tokenizer)["input_ids_lens"]
349
+ speakers = [sentence["from"] for sentence in source]
350
+ _mask_targets(target, tokenized_lens, speakers)
351
+
352
+ return dict(input_ids=input_ids, labels=targets)
353
+
354
+ def preprocess_for_llama_v2(
355
+ sources: Sequence[str],
356
+ tokenizer: transformers.PreTrainedTokenizer,
357
+ ) -> Dict:
358
+ """
359
+ Given a list of sources, each is a conversation list. This transform:
360
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
361
+ 2. Concatenate conversations together;
362
+ 3. Tokenize the concatenated conversation;
363
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
364
+ """
365
+ # add end signal and concatenate together
366
+ conversations = []
367
+ conv = copy.deepcopy(llama_v2_video_conversation.copy())
368
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
369
+ for source in sources:
370
+ # <s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n
371
+ header = f"<s>[INST] <<SYS>>\n{conv.system}\n</SYS>>\n\n"
372
+
373
+ if roles[source[0]["from"]] != conv.roles[0]:
374
+ # Skip the first one if it is not from human
375
+ source = source[1:]
376
+ conv.messages = []
377
+ for j, sentence in enumerate(source):
378
+ role = roles[sentence["from"]]
379
+ assert role == conv.roles[j % 2]
380
+ conv.append_message(role, sentence["value"])
381
+ conversations.append(conv.get_prompt())
382
+
383
+ input_ids = tokenizer(
384
+ conversations,
385
+ return_tensors="pt",
386
+ padding="longest",
387
+ max_length=512,
388
+ truncation=True,
389
+ ).input_ids
390
+ targets = copy.deepcopy(input_ids)
391
+
392
+
393
+ sep = "[/INST] "
394
+ for conversation, target in zip(conversations, targets):
395
+ # total_len = int(target.ne(tokenizer.pad_token_id).sum())
396
+ rounds = conversation.split(conv.sep2)
397
+ cur_len = 1
398
+ target[:cur_len] = IGNORE_INDEX
399
+ for i, rou in enumerate(rounds):
400
+ if rou == "":
401
+ break
402
+
403
+ parts = rou.split(sep)
404
+ if len(parts) != 2:
405
+ break
406
+ parts[0] += sep
407
+
408
+
409
+ round_len = len(tokenizer(rou).input_ids)
410
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2 # 为什么减去2,speical token 的数目
411
+
412
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
413
+
414
+ cur_len += round_len
415
+ target[cur_len:] = IGNORE_INDEX
416
+
417
+ return dict(input_ids=input_ids, labels=targets)
418
+ def _mask_targets(target, tokenized_lens, speakers):
419
+ # cur_idx = 0
420
+ cur_idx = tokenized_lens[0]
421
+ tokenized_lens = tokenized_lens[1:]
422
+ target[:cur_idx] = IGNORE_INDEX
423
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
424
+ if speaker == "human":
425
+ target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
426
+ cur_idx += tokenized_len
hawk/datasets/datasets/webvid_datasets.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ from hawk.datasets.datasets.base_dataset import BaseDataset
10
+ from hawk.datasets.datasets.caption_datasets import CaptionDataset
11
+ import pandas as pd
12
+ import decord
13
+ from decord import VideoReader
14
+ import random
15
+ import torch
16
+ from torch.utils.data.dataloader import default_collate
17
+ import spacy
18
+ import numpy as np
19
+
20
+ def setup_seed(seed):
21
+ torch.manual_seed(seed)
22
+ torch.cuda.manual_seed_all(seed)
23
+ np.random.seed(seed)
24
+ random.seed(seed)
25
+ torch.backends.cudnn.deterministic = True
26
+
27
+ # 加载SpaCy英文模型
28
+ nlp = spacy.load("en_core_web_sm")
29
+
30
+ def extract_actions_and_entities_sentence(sentence):
31
+ doc = nlp(sentence)
32
+ action_sentences = []
33
+
34
+ for token in doc:
35
+ # 检查是否为动词
36
+ if token.pos_ == "VERB":
37
+ subjects = ' and '.join(child.text for child in token.children if child.dep_ in ["nsubj", "nsubjpass"]) #主语
38
+ objects = ' and '.join(child.text for child in token.children if child.dep_ in ["dobj", "pobj", "obj"]) #宾语
39
+
40
+ # 构建包含动作和实体的句子
41
+ action_sentence = f"{subjects} {token.text} {objects}".strip()
42
+ action_sentences.append(action_sentence)
43
+
44
+ return ', '.join(action_sentences)
45
+
46
+ class WebvidDataset(BaseDataset):
47
+ def __init__(self, vis_processor, text_processor, vis_root, ann_root):
48
+ """
49
+ vis_root (string): Root directory of video (e.g. webvid_eval/video/)
50
+ ann_root (string): Root directory of video (e.g. webvid_eval/annotations/)
51
+ split (string): val or test
52
+ """
53
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
54
+
55
+
56
+ # 读取一个路径下所有的
57
+ ts_df = []
58
+ for file_name in os.listdir(ann_root):
59
+ if file_name.endswith('.csv'):
60
+ df = pd.read_csv(os.path.join(ann_root, file_name))
61
+ ts_df.append(df)
62
+
63
+ print(ts_df)
64
+ merged_df = pd.concat(ts_df)
65
+
66
+ self.annotation = merged_df
67
+ self.vis_root = vis_root
68
+ self.resize_size = 224
69
+ self.num_frm = 32
70
+ self.frm_sampling_strategy = 'headtail'
71
+
72
+ def _get_video_path(self, sample):
73
+ rel_video_fp = os.path.join(str(sample['page_dir']), str(sample['videoid']) + '.mp4')
74
+ full_video_fp = os.path.join(self.vis_root, rel_video_fp)
75
+ return full_video_fp
76
+
77
+ def __getitem__(self, index):
78
+ num_retries = 10 # skip error videos
79
+ for _ in range(num_retries):
80
+
81
+ sample = self.annotation.iloc[index]
82
+ sample_dict = sample.to_dict()
83
+ # video_id = sample_dict['videoid']
84
+ # fetch video
85
+ video_path = self._get_video_path(sample_dict)
86
+
87
+ # while not os.path.exists(video_path):
88
+ # index = random.randint(0, len(self.annotation) - 1)
89
+ # sample = self.annotation.iloc[index]
90
+ # sample_dict = sample.to_dict()
91
+ # video_path = self._get_video_path(sample_dict)
92
+
93
+ while not os.path.exists(video_path) or (os.path.exists(video_path) and os.path.getsize(video_path) == 0):
94
+ index = random.randint(0, len(self.annotation) - 1)
95
+ sample = self.annotation.iloc[index]
96
+ sample_dict = sample.to_dict()
97
+ video_path = self._get_video_path(sample_dict)
98
+
99
+ if 'name' in sample_dict.keys():
100
+ text = sample_dict['name'].strip()
101
+ text_motion = extract_actions_and_entities_sentence(text)
102
+ else:
103
+ raise NotImplementedError("Un-supported text annotation format.")
104
+
105
+ # if os.path.exists(video_path):
106
+ try:
107
+ random_seed = random.randint(0, 2**32 - 1)
108
+ setup_seed(random_seed)
109
+ video, video_motion = self.vis_processor(video_path)
110
+ except:
111
+ print(f"for A Failed to load examples with video: {video_path}. "
112
+ f"Will randomly sample an example as a replacement.")
113
+ index = random.randint(0, len(self) - 1)
114
+ continue
115
+
116
+ # text = extract_actions_and_entities_sentence(text)
117
+ caption = self.text_processor(text)
118
+ caption_motion = self.text_processor(text_motion)
119
+
120
+ # print(video.size())
121
+ if video is None or caption is None or video.size()!=torch.Size([3,self.vis_processor.n_frms,224,224]):
122
+ print(f"for B Failed to load examples with video: {video_path}. "
123
+ f"Will randomly sample an example as a replacement.")
124
+ index = random.randint(0, len(self) - 1)
125
+ continue
126
+ else:
127
+ break
128
+ else:
129
+ raise RuntimeError(f"Failed to fetch video after {num_retries} retries.")
130
+ # "image_id" is kept to stay compatible with the COCO evaluation format
131
+ return {
132
+ "image": video, #torch.Size([3, 32, 224, 224])
133
+ "image_motion": video_motion, #torch.Size([3, 32, 224, 224])
134
+ "text_input": caption,
135
+ "text_input_motion": caption_motion,
136
+ "type":'video',
137
+ }
138
+
139
+ def __len__(self):
140
+ return len(self.annotation)
141
+
142
+ # def collater(self, samples):
143
+ # new_result = {}
144
+ # new_result['image'] = default_collate( [sample["image"] for sample in samples])
145
+ # new_result['image_motion'] = default_collate( [sample["image_motion"] for sample in samples])
146
+ # new_result['text_input'] = default_collate( [sample["text_input"] for sample in samples])
147
+ # return new_result
148
+
149
+ class WebvidDatasetEvalDataset(BaseDataset):
150
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
151
+ """
152
+ vis_root (string): Root directory of images (e.g. coco/images/)
153
+ ann_root (string): directory to store the annotation file
154
+ split (string): val or test
155
+ """
156
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
157
+
158
+ def __getitem__(self, index):
159
+
160
+ ann = self.annotation[index]
161
+
162
+ vname = ann["video"]
163
+ video_path = os.path.join(self.vis_root, vname)
164
+
165
+ video = self.vis_processor(video_path)
166
+
167
+ return {
168
+ "video": video,
169
+ "image_id": ann["image_id"],
170
+ "instance_id": ann["instance_id"],
171
+ }
172
+
173
+
hawk/models/ImageBind/.assets/bird_audio.wav ADDED
Binary file (882 kB). View file
 
hawk/models/ImageBind/.assets/bird_image.jpg ADDED
hawk/models/ImageBind/.assets/car_audio.wav ADDED
Binary file (441 kB). View file
 
hawk/models/ImageBind/.assets/car_image.jpg ADDED
hawk/models/ImageBind/.assets/dog_audio.wav ADDED
Binary file (461 kB). View file
 
hawk/models/ImageBind/.assets/dog_image.jpg ADDED