HillFir commited on
Commit
668a98f
·
verified ·
1 Parent(s): dead8fc

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ overview.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - RLinf
5
+ language:
6
+ - en
7
+ metrics:
8
+ - accuracy
9
+ base_model:
10
+ - Haozhan72/Openvla-oft-SFT-libero-goal-trajall
11
+ pipeline_tag: reinforcement-learning
12
+ model-index:
13
+ - name: RLinf-openvlaoft-maniskill3-ppo
14
+ results:
15
+ - task:
16
+ type: VLA
17
+ dataset:
18
+ type: maniskill-vision
19
+ name: maniskill-vision
20
+ metrics:
21
+ - type: accuracy
22
+ value: 80.5
23
+ - task:
24
+ type: VLA
25
+ dataset:
26
+ type: maniskill-semantic
27
+ name: maniskill-semantic
28
+ metrics:
29
+ - type: accuracy
30
+ value: 56.6
31
+ - task:
32
+ type: VLA
33
+ dataset:
34
+ type: maniskill-position
35
+ name: maniskill-position
36
+ metrics:
37
+ - type: accuracy
38
+ value: 56.1
39
+ ---
40
+
41
+ <div align="center">
42
+ <img src="logo.svg" alt="RLinf-logo" width="500"/>
43
+ </div>
44
+
45
+
46
+ <div align="center">
47
+ <!-- <a href="TODO"><img src="https://img.shields.io/badge/arXiv-Paper-red?logo=arxiv"></a> -->
48
+ <!-- <a href="TODO"><img src="https://img.shields.io/badge/HuggingFace-yellow?logo=huggingface&logoColor=white" alt="Hugging Face"></a> -->
49
+ <a href="https://github.com/RLinf/RLinf"><img src="https://img.shields.io/badge/Github-blue"></a>
50
+ <a href="https://rlinf.readthedocs.io/en/latest/"><img src="https://img.shields.io/badge/Documentation-Purple?color=8A2BE2&logo=readthedocs"></a>
51
+ <!-- <a href="TODO"><img src="https://devin.ai/assets/deepwiki-badge.png" alt="Ask DeepWiki.com" style="height:20px;"></a>
52
+ <a href="TODO"><img src="https://img.shields.io/badge/微信-green?logo=wechat&amp"></a> -->
53
+ </div>
54
+
55
+ <h1 align="center">RLinf: Reinforcement Learning Infrastructure for Agentic AI</h1>
56
+
57
+ [RLinf](https://github.com/RLinf/RLinf) is a flexible and scalable open-source infrastructure designed for post-training foundation models (LLMs, VLMs, VLAs) via reinforcement learning. The 'inf' in RLinf stands for Infrastructure, highlighting its role as a robust backbone for next-generation training. It also stands for Infinite, symbolizing the system’s support for open-ended learning, continuous generalization, and limitless possibilities in intelligence development.
58
+
59
+
60
+ <div align="center">
61
+ <img src="overview.png" alt="RLinf-overview" width="600"/>
62
+ </div>
63
+
64
+ ## Model Description
65
+ This openvla-oft model is trained on ``Haozhan72/Openvla-oft-SFT-libero10-trajall`` with an additional lora SFT checkpoint and finetuned by Proximal Policy Optimization (PPO) on the ManiSkill simulator.
66
+
67
+ ## Full OOD Evaluation and Results
68
+ ### Overall OOD Eval Results
69
+ Note: rl4vla refers to the paper VLA-RL-Study: What Can RL Bring to VLA Generalization? An Empirical Study.
70
+ | Description | rl4vla | GRPO-openvlaoft | __PPO-openvlaoft__ | PPO-openvla | GRPO-openvla |
71
+ |---------------|-----------|-----------------|----------------|-------------|---------------|
72
+ | Avg results | 0.7608 | 0.61484375 | 0.6453125 | **0.822135417** | 0.7546875 |
73
+ ### OOD Eval on Vision
74
+
75
+ | Description | rl4vla | GRPO-openvlaoft | __PPO-openvlaoft__ | PPO-openvla | GRPO-openvla |
76
+ |---------------|-----------|-----------------|----------------|-------------|---------------|
77
+ | vision avg | 0.7656 | 0.846875 | 0.80546875 | **0.8203125** | 0.746875 |
78
+ | unseen table | 0.844 | 0.9140625 | 0.9453125 | **0.95703125** | 0.8984375 |
79
+ | dynamic texture (weak) | 0.833 | **0.91015625** | 0.82421875 | 0.85546875 | 0.7890625 |
80
+ | dynamic texture (strong) | 0.63 | **0.7734375** | 0.625 | 0.72265625 | 0.65625 |
81
+ | dynamic noise (weak) | 0.854 | 0.89453125 | **0.8984375** | 0.87109375 | 0.796875|
82
+ | dynamic noise (strong) | 0.667 | **0.7421875** | 0.734375 | 0.6953125 | 0.59375|
83
+
84
+ ### OOD Eval on Semantic
85
+ | Description | rl4vla | GRPO-openvlaoft | __PPO-openvlaoft__ | PPO-openvla | GRPO-openvla |
86
+ |---------------|-----------|-----------------|----------------|-------------|---------------|
87
+ | object avg | 0.754 | 0.516113281 | 0.56640625 | **0.805664063** | 0.744140625|
88
+ | train setting | 0.938 | 0.94140625 | 0.91796875 | **0.9609375** | 0.84375|
89
+ | unseen objects | 0.714 | 0.8046875 | 0.77734375 | **0.81640625** | 0.765625|
90
+ | unseen receptacles | 0.75 | 0.7421875 | 0.78125 | **0.8125** | 0.734375|
91
+ | unseen instructions | 0.891 | 0.6796875 | 0.68359375 | **0.9453125** | 0.890625|
92
+ | multi-object (both seen) | 0.75 | 0.3515625 | 0.4296875 | **0.84375** | 0.7578125|
93
+ | multi-object (both unseen) | 0.578 | 0.3046875 | 0.38671875 | **0.62890625** | 0.578125|
94
+ | distractive receptacle | 0.812 | 0.1875 | 0.31640625 | **0.828125** | 0.78125|
95
+ | multi-receptacle (both unseen) | 0.599 | 0.1171875 | 0.23828125 | **0.609375** | 0.6015625|
96
+
97
+ ### OOD Eval on Position
98
+ | Description | rl4vla | GRPO-openvlaoft | __PPO-openvlaoft__ | PPO-openvla | GRPO-openvla |
99
+ |---------------|-----------|-----------------|----------------|-------------|---------------|
100
+ | position avg | 0.776 | 0.4296875 | 0.560546875 | **0.892578125** | 0.81640625|
101
+ | unseen position (object & receptacle) | 0.807 | 0.40234375 | 0.50390625 | **0.86328125** | 0.75|
102
+ | mid-episode object reposition | 0.745 | 0.45703125 | 0.6171875 | **0.921875** | 0.8828125|
103
+
104
+ ## How to Use
105
+ Please integrate the provided model with the [RLinf](https://github.com/RLinf/RLinf) codebase. To do so, modify the following parameters in the configuration file ``examples/embodiment/config/maniskill_ppo_openvlaoft.yaml``:
106
+
107
+ - Set ``actor.checkpoint_load_path``, ``actor.tokenizer.tokenizer_model``, and ``rollout.model_dir`` to the path of the model checkpoint.
108
+
109
+ Note: If you intend to evaluate the model directly, make sure to set ``actor.model.is_lora`` to ``false``.
110
+
111
+ ## License
112
+ This code repository and the model weights are licensed under the MIT License.
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<PAD>": 32000
3
+ }
config.json ADDED
@@ -0,0 +1,3322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "Haozhan72/Openvla-oft-SFT-libero10-trajall",
3
+ "action_dim": 7,
4
+ "add_bias_linear": false,
5
+ "add_qkv_bias": true,
6
+ "arch_specifier": "no-align+fused-gelu-mlp",
7
+ "architectures": [
8
+ "OpenVLAOFTForRLActionPrediction"
9
+ ],
10
+ "attn_implementation": "flash_attention_2",
11
+ "auto_map": {
12
+ "AutoConfig": "configuration_prismatic.OpenVLAConfig",
13
+ "AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction"
14
+ },
15
+ "center_crop": true,
16
+ "ckpt_path": "/mnt/public/hongzhi/rlinf-0827/logs/20250827-10:10:18/checkpoints/global_step_800/actor/model.pt",
17
+ "hf_llm_id": "meta-llama/Llama-2-7b-hf",
18
+ "hidden_size": 4096,
19
+ "image_resize_strategy": "resize-naive",
20
+ "image_sizes": [
21
+ 224,
22
+ 224
23
+ ],
24
+ "is_lora": true,
25
+ "llm_backbone_id": "llama2-7b-pure",
26
+ "llm_max_length": 2048,
27
+ "lora_path": "/mnt/public/hongzhi/models/oft-sft/lora_004000",
28
+ "lora_rank": 32,
29
+ "low_cpu_mem_usage": true,
30
+ "model_name": "openvla_oft",
31
+ "model_type": "openvla",
32
+ "n_action_bins": 256,
33
+ "norm_stats": {
34
+ "austin_buds_dataset_converted_externally_to_rlds": {
35
+ "action": {
36
+ "mask": [
37
+ true,
38
+ true,
39
+ true,
40
+ true,
41
+ true,
42
+ true,
43
+ false
44
+ ],
45
+ "max": [
46
+ 1.0,
47
+ 1.0,
48
+ 1.0,
49
+ 0.0,
50
+ 0.0,
51
+ 0.0,
52
+ 1.0
53
+ ],
54
+ "mean": [
55
+ -0.07678354531526566,
56
+ 0.0036849044263362885,
57
+ 0.05644911900162697,
58
+ 0.0,
59
+ 0.0,
60
+ 0.0,
61
+ 0.3510494828224182
62
+ ],
63
+ "min": [
64
+ -1.0,
65
+ -1.0,
66
+ -1.0,
67
+ 0.0,
68
+ 0.0,
69
+ 0.0,
70
+ 0.0
71
+ ],
72
+ "q01": [
73
+ -1.0,
74
+ -0.9599999785423279,
75
+ -0.8714285492897034,
76
+ 0.0,
77
+ 0.0,
78
+ 0.0,
79
+ 0.0
80
+ ],
81
+ "q99": [
82
+ 1.0,
83
+ 0.8600000143051147,
84
+ 1.0,
85
+ 0.0,
86
+ 0.0,
87
+ 0.0,
88
+ 1.0
89
+ ],
90
+ "std": [
91
+ 0.6367740631103516,
92
+ 0.37889179587364197,
93
+ 0.47796326875686646,
94
+ 0.0,
95
+ 0.0,
96
+ 0.0,
97
+ 0.47721168398857117
98
+ ]
99
+ },
100
+ "num_trajectories": 50,
101
+ "num_transitions": 34112,
102
+ "proprio": {
103
+ "max": [
104
+ 0.0,
105
+ 0.0,
106
+ 0.0,
107
+ 0.0,
108
+ 0.0,
109
+ 0.0,
110
+ 0.0
111
+ ],
112
+ "mean": [
113
+ 0.0,
114
+ 0.0,
115
+ 0.0,
116
+ 0.0,
117
+ 0.0,
118
+ 0.0,
119
+ 0.0
120
+ ],
121
+ "min": [
122
+ 0.0,
123
+ 0.0,
124
+ 0.0,
125
+ 0.0,
126
+ 0.0,
127
+ 0.0,
128
+ 0.0
129
+ ],
130
+ "q01": [
131
+ 0.0,
132
+ 0.0,
133
+ 0.0,
134
+ 0.0,
135
+ 0.0,
136
+ 0.0,
137
+ 0.0
138
+ ],
139
+ "q99": [
140
+ 0.0,
141
+ 0.0,
142
+ 0.0,
143
+ 0.0,
144
+ 0.0,
145
+ 0.0,
146
+ 0.0
147
+ ],
148
+ "std": [
149
+ 0.0,
150
+ 0.0,
151
+ 0.0,
152
+ 0.0,
153
+ 0.0,
154
+ 0.0,
155
+ 0.0
156
+ ]
157
+ }
158
+ },
159
+ "austin_sailor_dataset_converted_externally_to_rlds": {
160
+ "action": {
161
+ "mask": [
162
+ true,
163
+ true,
164
+ true,
165
+ true,
166
+ true,
167
+ true,
168
+ false
169
+ ],
170
+ "max": [
171
+ 1.0,
172
+ 1.0,
173
+ 1.0,
174
+ 0.0,
175
+ 0.0,
176
+ 0.375,
177
+ 1.0
178
+ ],
179
+ "mean": [
180
+ 0.011825348250567913,
181
+ 0.006461074110120535,
182
+ 0.06023626774549484,
183
+ 0.0,
184
+ 0.0,
185
+ 0.0016465914668515325,
186
+ 0.5260950326919556
187
+ ],
188
+ "min": [
189
+ -1.0,
190
+ -1.0,
191
+ -1.0,
192
+ 0.0,
193
+ 0.0,
194
+ -0.375,
195
+ 0.0
196
+ ],
197
+ "q01": [
198
+ -1.0,
199
+ -0.9828571677207947,
200
+ -0.6000000238418579,
201
+ 0.0,
202
+ 0.0,
203
+ -0.17249999940395355,
204
+ 0.0
205
+ ],
206
+ "q99": [
207
+ 1.0,
208
+ 0.9457142949104309,
209
+ 1.0,
210
+ 0.0,
211
+ 0.0,
212
+ 0.17892856895923615,
213
+ 1.0
214
+ ],
215
+ "std": [
216
+ 0.46348899602890015,
217
+ 0.41240179538726807,
218
+ 0.411862850189209,
219
+ 0.0,
220
+ 0.0,
221
+ 0.0578610822558403,
222
+ 0.49894046783447266
223
+ ]
224
+ },
225
+ "num_trajectories": 240,
226
+ "num_transitions": 353094,
227
+ "proprio": {
228
+ "max": [
229
+ 0.0,
230
+ 0.0,
231
+ 0.0,
232
+ 0.0,
233
+ 0.0,
234
+ 0.0,
235
+ 0.0
236
+ ],
237
+ "mean": [
238
+ 0.0,
239
+ 0.0,
240
+ 0.0,
241
+ 0.0,
242
+ 0.0,
243
+ 0.0,
244
+ 0.0
245
+ ],
246
+ "min": [
247
+ 0.0,
248
+ 0.0,
249
+ 0.0,
250
+ 0.0,
251
+ 0.0,
252
+ 0.0,
253
+ 0.0
254
+ ],
255
+ "q01": [
256
+ 0.0,
257
+ 0.0,
258
+ 0.0,
259
+ 0.0,
260
+ 0.0,
261
+ 0.0,
262
+ 0.0
263
+ ],
264
+ "q99": [
265
+ 0.0,
266
+ 0.0,
267
+ 0.0,
268
+ 0.0,
269
+ 0.0,
270
+ 0.0,
271
+ 0.0
272
+ ],
273
+ "std": [
274
+ 0.0,
275
+ 0.0,
276
+ 0.0,
277
+ 0.0,
278
+ 0.0,
279
+ 0.0,
280
+ 0.0
281
+ ]
282
+ }
283
+ },
284
+ "austin_sirius_dataset_converted_externally_to_rlds": {
285
+ "action": {
286
+ "mask": [
287
+ true,
288
+ true,
289
+ true,
290
+ true,
291
+ true,
292
+ true,
293
+ false
294
+ ],
295
+ "max": [
296
+ 1.0002285242080688,
297
+ 0.960608720779419,
298
+ 1.105179786682129,
299
+ 0.0,
300
+ 0.0,
301
+ 0.341785728931427,
302
+ 1.0
303
+ ],
304
+ "mean": [
305
+ 0.07747682929039001,
306
+ 0.03195561468601227,
307
+ 0.04244732856750488,
308
+ 0.0,
309
+ 0.0,
310
+ -0.01603456400334835,
311
+ 0.43260177969932556
312
+ ],
313
+ "min": [
314
+ -1.0183025598526,
315
+ -0.9800000190734863,
316
+ -0.9774575233459473,
317
+ 0.0,
318
+ 0.0,
319
+ -0.34607142210006714,
320
+ 0.0
321
+ ],
322
+ "q01": [
323
+ -0.780905865430832,
324
+ -0.5667179036140442,
325
+ -0.5254343223571777,
326
+ 0.0,
327
+ 0.0,
328
+ -0.28495091378688814,
329
+ 0.0
330
+ ],
331
+ "q99": [
332
+ 0.9569637751579284,
333
+ 0.6971374487876891,
334
+ 0.8124888157844541,
335
+ 0.0,
336
+ 0.0,
337
+ 0.1971428543329239,
338
+ 1.0
339
+ ],
340
+ "std": [
341
+ 0.3906329572200775,
342
+ 0.2998155355453491,
343
+ 0.2782271206378937,
344
+ 0.0,
345
+ 0.0,
346
+ 0.08120622485876083,
347
+ 0.49528297781944275
348
+ ]
349
+ },
350
+ "num_trajectories": 559,
351
+ "num_transitions": 279939,
352
+ "proprio": {
353
+ "max": [
354
+ 0.0,
355
+ 0.0,
356
+ 0.0,
357
+ 0.0,
358
+ 0.0,
359
+ 0.0,
360
+ 0.0
361
+ ],
362
+ "mean": [
363
+ 0.0,
364
+ 0.0,
365
+ 0.0,
366
+ 0.0,
367
+ 0.0,
368
+ 0.0,
369
+ 0.0
370
+ ],
371
+ "min": [
372
+ 0.0,
373
+ 0.0,
374
+ 0.0,
375
+ 0.0,
376
+ 0.0,
377
+ 0.0,
378
+ 0.0
379
+ ],
380
+ "q01": [
381
+ 0.0,
382
+ 0.0,
383
+ 0.0,
384
+ 0.0,
385
+ 0.0,
386
+ 0.0,
387
+ 0.0
388
+ ],
389
+ "q99": [
390
+ 0.0,
391
+ 0.0,
392
+ 0.0,
393
+ 0.0,
394
+ 0.0,
395
+ 0.0,
396
+ 0.0
397
+ ],
398
+ "std": [
399
+ 0.0,
400
+ 0.0,
401
+ 0.0,
402
+ 0.0,
403
+ 0.0,
404
+ 0.0,
405
+ 0.0
406
+ ]
407
+ }
408
+ },
409
+ "bc_z": {
410
+ "action": {
411
+ "mask": [
412
+ true,
413
+ true,
414
+ true,
415
+ true,
416
+ true,
417
+ true,
418
+ false
419
+ ],
420
+ "max": [
421
+ 0.2165454924106598,
422
+ 0.1251407265663147,
423
+ 0.10772687941789627,
424
+ 0.33544227480888367,
425
+ 0.28117990493774414,
426
+ 0.40614867210388184,
427
+ 1.0
428
+ ],
429
+ "mean": [
430
+ -0.009958467446267605,
431
+ 0.0008958321413956583,
432
+ 0.004995597992092371,
433
+ 0.00029755113064311445,
434
+ -0.008735382929444313,
435
+ -0.030693737789988518,
436
+ 0.8344562649726868
437
+ ],
438
+ "min": [
439
+ -0.1677047461271286,
440
+ -0.14630407094955444,
441
+ -0.10066790133714676,
442
+ -0.29421567916870117,
443
+ -0.32101404666900635,
444
+ -0.4635624885559082,
445
+ 0.0
446
+ ],
447
+ "q01": [
448
+ -0.09220654994249344,
449
+ -0.06456145539879798,
450
+ -0.049121275544166565,
451
+ -0.11594625547528267,
452
+ -0.14152548640966414,
453
+ -0.2251061636209488,
454
+ 0.0
455
+ ],
456
+ "q99": [
457
+ 0.07628866866230968,
458
+ 0.058019736707210584,
459
+ 0.052540797740221024,
460
+ 0.11740604028105736,
461
+ 0.11703975558280955,
462
+ 0.16729306846857078,
463
+ 1.0
464
+ ],
465
+ "std": [
466
+ 0.03053455986082554,
467
+ 0.0231423731893301,
468
+ 0.020641816779971123,
469
+ 0.04155943542718887,
470
+ 0.046427831053733826,
471
+ 0.0769818127155304,
472
+ 0.3610210120677948
473
+ ]
474
+ },
475
+ "num_trajectories": 43264,
476
+ "num_transitions": 6015535,
477
+ "proprio": {
478
+ "max": [
479
+ 0.0,
480
+ 0.0,
481
+ 0.0,
482
+ 0.0,
483
+ 0.0,
484
+ 0.0,
485
+ 0.0
486
+ ],
487
+ "mean": [
488
+ 0.0,
489
+ 0.0,
490
+ 0.0,
491
+ 0.0,
492
+ 0.0,
493
+ 0.0,
494
+ 0.0
495
+ ],
496
+ "min": [
497
+ 0.0,
498
+ 0.0,
499
+ 0.0,
500
+ 0.0,
501
+ 0.0,
502
+ 0.0,
503
+ 0.0
504
+ ],
505
+ "q01": [
506
+ 0.0,
507
+ 0.0,
508
+ 0.0,
509
+ 0.0,
510
+ 0.0,
511
+ 0.0,
512
+ 0.0
513
+ ],
514
+ "q99": [
515
+ 0.0,
516
+ 0.0,
517
+ 0.0,
518
+ 0.0,
519
+ 0.0,
520
+ 0.0,
521
+ 0.0
522
+ ],
523
+ "std": [
524
+ 0.0,
525
+ 0.0,
526
+ 0.0,
527
+ 0.0,
528
+ 0.0,
529
+ 0.0,
530
+ 0.0
531
+ ]
532
+ }
533
+ },
534
+ "berkeley_autolab_ur5": {
535
+ "action": {
536
+ "mask": [
537
+ true,
538
+ true,
539
+ true,
540
+ true,
541
+ true,
542
+ true,
543
+ false
544
+ ],
545
+ "max": [
546
+ 0.019999999552965164,
547
+ 0.019999999552965164,
548
+ 0.019999999552965164,
549
+ 0.06666667014360428,
550
+ 0.06666667014360428,
551
+ 0.06666667014360428,
552
+ 1.0
553
+ ],
554
+ "mean": [
555
+ 0.0005683620693162084,
556
+ 0.001217700308188796,
557
+ -0.0005296372692100704,
558
+ 0.00021029810886830091,
559
+ 6.0695128922816366e-05,
560
+ 0.001204986940138042,
561
+ 0.6298308372497559
562
+ ],
563
+ "min": [
564
+ -0.019999999552965164,
565
+ -0.019999999552965164,
566
+ -0.019999999552965164,
567
+ -0.06666667014360428,
568
+ -0.06666667014360428,
569
+ -0.06666667014360428,
570
+ 0.0
571
+ ],
572
+ "q01": [
573
+ -0.019999999552965164,
574
+ -0.019999999552965164,
575
+ -0.019999999552965164,
576
+ -0.02628571353852749,
577
+ -0.06666667014360428,
578
+ -0.03847619146108627,
579
+ 0.0
580
+ ],
581
+ "q99": [
582
+ 0.019999999552965164,
583
+ 0.019999999552965164,
584
+ 0.019999999552965164,
585
+ 0.031809523701667786,
586
+ 0.06666667014360428,
587
+ 0.036571428179740906,
588
+ 1.0
589
+ ],
590
+ "std": [
591
+ 0.0115329809486866,
592
+ 0.007990492507815361,
593
+ 0.009577835910022259,
594
+ 0.009432995691895485,
595
+ 0.016427582129836082,
596
+ 0.011053967289626598,
597
+ 0.48267969489097595
598
+ ]
599
+ },
600
+ "num_trajectories": 1000,
601
+ "num_transitions": 97939,
602
+ "proprio": {
603
+ "max": [
604
+ 0.0,
605
+ 0.0,
606
+ 0.0,
607
+ 0.0,
608
+ 0.0,
609
+ 0.0,
610
+ 0.0
611
+ ],
612
+ "mean": [
613
+ 0.0,
614
+ 0.0,
615
+ 0.0,
616
+ 0.0,
617
+ 0.0,
618
+ 0.0,
619
+ 0.0
620
+ ],
621
+ "min": [
622
+ 0.0,
623
+ 0.0,
624
+ 0.0,
625
+ 0.0,
626
+ 0.0,
627
+ 0.0,
628
+ 0.0
629
+ ],
630
+ "q01": [
631
+ 0.0,
632
+ 0.0,
633
+ 0.0,
634
+ 0.0,
635
+ 0.0,
636
+ 0.0,
637
+ 0.0
638
+ ],
639
+ "q99": [
640
+ 0.0,
641
+ 0.0,
642
+ 0.0,
643
+ 0.0,
644
+ 0.0,
645
+ 0.0,
646
+ 0.0
647
+ ],
648
+ "std": [
649
+ 0.0,
650
+ 0.0,
651
+ 0.0,
652
+ 0.0,
653
+ 0.0,
654
+ 0.0,
655
+ 0.0
656
+ ]
657
+ }
658
+ },
659
+ "berkeley_cable_routing": {
660
+ "action": {
661
+ "mask": [
662
+ true,
663
+ true,
664
+ true,
665
+ true,
666
+ true,
667
+ true,
668
+ false
669
+ ],
670
+ "max": [
671
+ 0.9633283019065857,
672
+ 1.0,
673
+ 1.0,
674
+ 0.0,
675
+ 0.0,
676
+ 1.0,
677
+ 0.0
678
+ ],
679
+ "mean": [
680
+ -0.07139874249696732,
681
+ 0.023609008640050888,
682
+ 0.10241943597793579,
683
+ 0.0,
684
+ 0.0,
685
+ 0.049671024084091187,
686
+ 0.0
687
+ ],
688
+ "min": [
689
+ -0.9809081554412842,
690
+ -0.9554349184036255,
691
+ -0.9994775056838989,
692
+ 0.0,
693
+ 0.0,
694
+ -1.0,
695
+ 0.0
696
+ ],
697
+ "q01": [
698
+ -0.5534318816661835,
699
+ -0.4797285574674606,
700
+ -0.5314934802055359,
701
+ 0.0,
702
+ 0.0,
703
+ -0.8855219376087189,
704
+ 0.0
705
+ ],
706
+ "q99": [
707
+ 0.42652835428714786,
708
+ 0.5000944086909298,
709
+ 0.639823433756829,
710
+ 0.0,
711
+ 0.0,
712
+ 0.984243879914284,
713
+ 0.0
714
+ ],
715
+ "std": [
716
+ 0.1815500408411026,
717
+ 0.1810990273952484,
718
+ 0.21220779418945312,
719
+ 0.0,
720
+ 0.0,
721
+ 0.3475511968135834,
722
+ 0.0
723
+ ]
724
+ },
725
+ "num_trajectories": 1647,
726
+ "num_transitions": 42328,
727
+ "proprio": {
728
+ "max": [
729
+ 0.0,
730
+ 0.0,
731
+ 0.0,
732
+ 0.0,
733
+ 0.0,
734
+ 0.0,
735
+ 0.0
736
+ ],
737
+ "mean": [
738
+ 0.0,
739
+ 0.0,
740
+ 0.0,
741
+ 0.0,
742
+ 0.0,
743
+ 0.0,
744
+ 0.0
745
+ ],
746
+ "min": [
747
+ 0.0,
748
+ 0.0,
749
+ 0.0,
750
+ 0.0,
751
+ 0.0,
752
+ 0.0,
753
+ 0.0
754
+ ],
755
+ "q01": [
756
+ 0.0,
757
+ 0.0,
758
+ 0.0,
759
+ 0.0,
760
+ 0.0,
761
+ 0.0,
762
+ 0.0
763
+ ],
764
+ "q99": [
765
+ 0.0,
766
+ 0.0,
767
+ 0.0,
768
+ 0.0,
769
+ 0.0,
770
+ 0.0,
771
+ 0.0
772
+ ],
773
+ "std": [
774
+ 0.0,
775
+ 0.0,
776
+ 0.0,
777
+ 0.0,
778
+ 0.0,
779
+ 0.0,
780
+ 0.0
781
+ ]
782
+ }
783
+ },
784
+ "berkeley_fanuc_manipulation": {
785
+ "action": {
786
+ "mask": [
787
+ true,
788
+ true,
789
+ true,
790
+ true,
791
+ true,
792
+ true,
793
+ false
794
+ ],
795
+ "max": [
796
+ 0.009999999776482582,
797
+ 0.009999999776482582,
798
+ 0.009999999776482582,
799
+ 0.03490658476948738,
800
+ 0.03490658476948738,
801
+ 0.03490658476948738,
802
+ 1.0
803
+ ],
804
+ "mean": [
805
+ 0.0007744057802483439,
806
+ -0.00031240080716088414,
807
+ -0.0015001941937953234,
808
+ -0.0007515158504247665,
809
+ -0.00015832878125365824,
810
+ 0.00014327642566058785,
811
+ 0.699295699596405
812
+ ],
813
+ "min": [
814
+ -0.009999999776482582,
815
+ -0.009999999776482582,
816
+ -0.009999999776482582,
817
+ -0.03490658476948738,
818
+ -0.03490658476948738,
819
+ -0.03490658476948738,
820
+ 0.0
821
+ ],
822
+ "q01": [
823
+ -0.009999999776482582,
824
+ -0.009999999776482582,
825
+ -0.009999999776482582,
826
+ -0.03490658476948738,
827
+ 0.0,
828
+ -0.03490658476948738,
829
+ 0.0
830
+ ],
831
+ "q99": [
832
+ 0.009999999776482582,
833
+ 0.009999999776482582,
834
+ 0.009999999776482582,
835
+ 0.03490658476948738,
836
+ 0.0,
837
+ 0.03490658476948738,
838
+ 1.0
839
+ ],
840
+ "std": [
841
+ 0.0034070091787725687,
842
+ 0.0049921851605176926,
843
+ 0.005344334989786148,
844
+ 0.00759894959628582,
845
+ 0.004081866703927517,
846
+ 0.008568956516683102,
847
+ 0.4586937427520752
848
+ ]
849
+ },
850
+ "num_trajectories": 415,
851
+ "num_transitions": 62613,
852
+ "proprio": {
853
+ "max": [
854
+ 0.0,
855
+ 0.0,
856
+ 0.0,
857
+ 0.0,
858
+ 0.0,
859
+ 0.0,
860
+ 0.0
861
+ ],
862
+ "mean": [
863
+ 0.0,
864
+ 0.0,
865
+ 0.0,
866
+ 0.0,
867
+ 0.0,
868
+ 0.0,
869
+ 0.0
870
+ ],
871
+ "min": [
872
+ 0.0,
873
+ 0.0,
874
+ 0.0,
875
+ 0.0,
876
+ 0.0,
877
+ 0.0,
878
+ 0.0
879
+ ],
880
+ "q01": [
881
+ 0.0,
882
+ 0.0,
883
+ 0.0,
884
+ 0.0,
885
+ 0.0,
886
+ 0.0,
887
+ 0.0
888
+ ],
889
+ "q99": [
890
+ 0.0,
891
+ 0.0,
892
+ 0.0,
893
+ 0.0,
894
+ 0.0,
895
+ 0.0,
896
+ 0.0
897
+ ],
898
+ "std": [
899
+ 0.0,
900
+ 0.0,
901
+ 0.0,
902
+ 0.0,
903
+ 0.0,
904
+ 0.0,
905
+ 0.0
906
+ ]
907
+ }
908
+ },
909
+ "bridge_orig": {
910
+ "action": {
911
+ "mask": [
912
+ true,
913
+ true,
914
+ true,
915
+ true,
916
+ true,
917
+ true,
918
+ false
919
+ ],
920
+ "max": [
921
+ 0.41691166162490845,
922
+ 0.25864794850349426,
923
+ 0.21218234300613403,
924
+ 3.122201919555664,
925
+ 1.8618112802505493,
926
+ 6.280478477478027,
927
+ 1.0
928
+ ],
929
+ "mean": [
930
+ 0.0002334194869035855,
931
+ 0.00013004911306779832,
932
+ -0.00012762474943883717,
933
+ -0.0001556558854645118,
934
+ -0.0004039328487124294,
935
+ 0.00023557482927571982,
936
+ 0.5764579176902771
937
+ ],
938
+ "min": [
939
+ -0.4007510244846344,
940
+ -0.13874775171279907,
941
+ -0.22553899884223938,
942
+ -3.2010786533355713,
943
+ -1.8618112802505493,
944
+ -6.279075622558594,
945
+ 0.0
946
+ ],
947
+ "q01": [
948
+ -0.02872725307941437,
949
+ -0.04170349963009357,
950
+ -0.026093858778476715,
951
+ -0.08092105075716972,
952
+ -0.09288699507713317,
953
+ -0.20718276381492615,
954
+ 0.0
955
+ ],
956
+ "q99": [
957
+ 0.028309678435325586,
958
+ 0.040855254605412394,
959
+ 0.040161586627364146,
960
+ 0.08192047759890528,
961
+ 0.07792850524187081,
962
+ 0.20382574498653397,
963
+ 1.0
964
+ ],
965
+ "std": [
966
+ 0.009765930473804474,
967
+ 0.013689135201275349,
968
+ 0.012667362578213215,
969
+ 0.028534092009067535,
970
+ 0.030637972056865692,
971
+ 0.07691419124603271,
972
+ 0.4973701536655426
973
+ ]
974
+ },
975
+ "num_trajectories": 60064,
976
+ "num_transitions": 2135463,
977
+ "proprio": {
978
+ "max": [
979
+ 0.0,
980
+ 0.0,
981
+ 0.0,
982
+ 0.0,
983
+ 0.0,
984
+ 0.0,
985
+ 0.0
986
+ ],
987
+ "mean": [
988
+ 0.0,
989
+ 0.0,
990
+ 0.0,
991
+ 0.0,
992
+ 0.0,
993
+ 0.0,
994
+ 0.0
995
+ ],
996
+ "min": [
997
+ 0.0,
998
+ 0.0,
999
+ 0.0,
1000
+ 0.0,
1001
+ 0.0,
1002
+ 0.0,
1003
+ 0.0
1004
+ ],
1005
+ "q01": [
1006
+ 0.0,
1007
+ 0.0,
1008
+ 0.0,
1009
+ 0.0,
1010
+ 0.0,
1011
+ 0.0,
1012
+ 0.0
1013
+ ],
1014
+ "q99": [
1015
+ 0.0,
1016
+ 0.0,
1017
+ 0.0,
1018
+ 0.0,
1019
+ 0.0,
1020
+ 0.0,
1021
+ 0.0
1022
+ ],
1023
+ "std": [
1024
+ 0.0,
1025
+ 0.0,
1026
+ 0.0,
1027
+ 0.0,
1028
+ 0.0,
1029
+ 0.0,
1030
+ 0.0
1031
+ ]
1032
+ }
1033
+ },
1034
+ "cmu_stretch": {
1035
+ "action": {
1036
+ "mask": [
1037
+ true,
1038
+ true,
1039
+ true,
1040
+ true,
1041
+ true,
1042
+ true,
1043
+ false
1044
+ ],
1045
+ "max": [
1046
+ 0.02338407188653946,
1047
+ 0.0,
1048
+ 0.023404927924275398,
1049
+ 0.0,
1050
+ 0.0,
1051
+ 0.0,
1052
+ 1.0
1053
+ ],
1054
+ "mean": [
1055
+ 0.00036304505192674696,
1056
+ 0.0,
1057
+ 0.0016466958913952112,
1058
+ 0.0,
1059
+ 0.0,
1060
+ 0.0,
1061
+ 0.3987048268318176
1062
+ ],
1063
+ "min": [
1064
+ -0.019353797659277916,
1065
+ 0.0,
1066
+ -0.02019215188920498,
1067
+ 0.0,
1068
+ 0.0,
1069
+ 0.0,
1070
+ 0.0
1071
+ ],
1072
+ "q01": [
1073
+ -0.011175686959177256,
1074
+ 0.0,
1075
+ -0.0032206363626755773,
1076
+ 0.0,
1077
+ 0.0,
1078
+ 0.0,
1079
+ 0.0
1080
+ ],
1081
+ "q99": [
1082
+ 0.014501785952597848,
1083
+ 0.0,
1084
+ 0.015056106168776728,
1085
+ 0.0,
1086
+ 0.0,
1087
+ 0.0,
1088
+ 1.0
1089
+ ],
1090
+ "std": [
1091
+ 0.004081828519701958,
1092
+ 0.0,
1093
+ 0.0037743328139185905,
1094
+ 0.0,
1095
+ 0.0,
1096
+ 0.0,
1097
+ 0.48963725566864014
1098
+ ]
1099
+ },
1100
+ "num_trajectories": 135,
1101
+ "num_transitions": 25016,
1102
+ "proprio": {
1103
+ "max": [
1104
+ 0.0,
1105
+ 0.0,
1106
+ 0.0,
1107
+ 0.0,
1108
+ 0.0,
1109
+ 0.0,
1110
+ 0.0
1111
+ ],
1112
+ "mean": [
1113
+ 0.0,
1114
+ 0.0,
1115
+ 0.0,
1116
+ 0.0,
1117
+ 0.0,
1118
+ 0.0,
1119
+ 0.0
1120
+ ],
1121
+ "min": [
1122
+ 0.0,
1123
+ 0.0,
1124
+ 0.0,
1125
+ 0.0,
1126
+ 0.0,
1127
+ 0.0,
1128
+ 0.0
1129
+ ],
1130
+ "q01": [
1131
+ 0.0,
1132
+ 0.0,
1133
+ 0.0,
1134
+ 0.0,
1135
+ 0.0,
1136
+ 0.0,
1137
+ 0.0
1138
+ ],
1139
+ "q99": [
1140
+ 0.0,
1141
+ 0.0,
1142
+ 0.0,
1143
+ 0.0,
1144
+ 0.0,
1145
+ 0.0,
1146
+ 0.0
1147
+ ],
1148
+ "std": [
1149
+ 0.0,
1150
+ 0.0,
1151
+ 0.0,
1152
+ 0.0,
1153
+ 0.0,
1154
+ 0.0,
1155
+ 0.0
1156
+ ]
1157
+ }
1158
+ },
1159
+ "dlr_edan_shared_control_converted_externally_to_rlds": {
1160
+ "action": {
1161
+ "mask": [
1162
+ true,
1163
+ true,
1164
+ true,
1165
+ true,
1166
+ true,
1167
+ true,
1168
+ false
1169
+ ],
1170
+ "max": [
1171
+ 0.18991442024707794,
1172
+ 0.0739002525806427,
1173
+ 0.18064819276332855,
1174
+ 0.0866486132144928,
1175
+ 0.13464981317520142,
1176
+ 0.16910280287265778,
1177
+ 1.0
1178
+ ],
1179
+ "mean": [
1180
+ 0.006647810339927673,
1181
+ -0.0007657372043468058,
1182
+ 0.006522852927446365,
1183
+ 0.0011679717572405934,
1184
+ -0.006395625416189432,
1185
+ -0.011902998201549053,
1186
+ 0.6985887289047241
1187
+ ],
1188
+ "min": [
1189
+ -0.10054297000169754,
1190
+ -0.08427435159683228,
1191
+ -0.13533438742160797,
1192
+ -0.17556548118591309,
1193
+ -0.18485672771930695,
1194
+ -0.2680685818195343,
1195
+ 0.0
1196
+ ],
1197
+ "q01": [
1198
+ -0.02987122368067503,
1199
+ -0.06013262912631035,
1200
+ -0.08286409199237824,
1201
+ -0.05924444157630205,
1202
+ -0.15986866518855095,
1203
+ -0.15636983573436739,
1204
+ 0.0
1205
+ ],
1206
+ "q99": [
1207
+ 0.08832092039287087,
1208
+ 0.042126184627413736,
1209
+ 0.11311905644834042,
1210
+ 0.0643695573508739,
1211
+ 0.03941855944693088,
1212
+ 0.156646853685379,
1213
+ 1.0
1214
+ ],
1215
+ "std": [
1216
+ 0.021393608301877975,
1217
+ 0.01814231649041176,
1218
+ 0.03374375030398369,
1219
+ 0.01743541844189167,
1220
+ 0.03394376486539841,
1221
+ 0.04641875624656677,
1222
+ 0.4588589072227478
1223
+ ]
1224
+ },
1225
+ "num_trajectories": 104,
1226
+ "num_transitions": 8928,
1227
+ "proprio": {
1228
+ "max": [
1229
+ 0.0,
1230
+ 0.0,
1231
+ 0.0,
1232
+ 0.0,
1233
+ 0.0,
1234
+ 0.0,
1235
+ 0.0
1236
+ ],
1237
+ "mean": [
1238
+ 0.0,
1239
+ 0.0,
1240
+ 0.0,
1241
+ 0.0,
1242
+ 0.0,
1243
+ 0.0,
1244
+ 0.0
1245
+ ],
1246
+ "min": [
1247
+ 0.0,
1248
+ 0.0,
1249
+ 0.0,
1250
+ 0.0,
1251
+ 0.0,
1252
+ 0.0,
1253
+ 0.0
1254
+ ],
1255
+ "q01": [
1256
+ 0.0,
1257
+ 0.0,
1258
+ 0.0,
1259
+ 0.0,
1260
+ 0.0,
1261
+ 0.0,
1262
+ 0.0
1263
+ ],
1264
+ "q99": [
1265
+ 0.0,
1266
+ 0.0,
1267
+ 0.0,
1268
+ 0.0,
1269
+ 0.0,
1270
+ 0.0,
1271
+ 0.0
1272
+ ],
1273
+ "std": [
1274
+ 0.0,
1275
+ 0.0,
1276
+ 0.0,
1277
+ 0.0,
1278
+ 0.0,
1279
+ 0.0,
1280
+ 0.0
1281
+ ]
1282
+ }
1283
+ },
1284
+ "dobbe": {
1285
+ "action": {
1286
+ "mask": [
1287
+ true,
1288
+ true,
1289
+ true,
1290
+ true,
1291
+ true,
1292
+ true,
1293
+ false
1294
+ ],
1295
+ "max": [
1296
+ 38.590423583984375,
1297
+ 17.932697296142578,
1298
+ 4.843764305114746,
1299
+ 1.4372116327285767,
1300
+ 0.4340403974056244,
1301
+ 1.2057193517684937,
1302
+ 0.9998947381973267
1303
+ ],
1304
+ "mean": [
1305
+ -0.0001120665911003016,
1306
+ 0.0011229600058868527,
1307
+ -0.00010194431524723768,
1308
+ -7.371398532995954e-05,
1309
+ -0.00067531579406932,
1310
+ -5.6643435527803376e-05,
1311
+ 0.6318281888961792
1312
+ ],
1313
+ "min": [
1314
+ -5.700923442840576,
1315
+ -21.605947494506836,
1316
+ -123.72489929199219,
1317
+ -1.7229845523834229,
1318
+ -0.4998578727245331,
1319
+ -0.8867913484573364,
1320
+ 1.4196479014572105e-06
1321
+ ],
1322
+ "q01": [
1323
+ -0.01119564864784479,
1324
+ -0.014266146533191203,
1325
+ -0.0071747214533388615,
1326
+ -0.009444301575422287,
1327
+ -0.03990109823644161,
1328
+ -0.017422311007976532,
1329
+ 4.003279136668425e-05
1330
+ ],
1331
+ "q99": [
1332
+ 0.01015154086053368,
1333
+ 0.017181577533483497,
1334
+ 0.007216989761218411,
1335
+ 0.010380979906767595,
1336
+ 0.03556173853576176,
1337
+ 0.018032474815845446,
1338
+ 0.9982578039169312
1339
+ ],
1340
+ "std": [
1341
+ 0.04264938458800316,
1342
+ 0.04428559169173241,
1343
+ 0.12224084138870239,
1344
+ 0.005388413090258837,
1345
+ 0.011246449314057827,
1346
+ 0.006287882570177317,
1347
+ 0.39732322096824646
1348
+ ]
1349
+ },
1350
+ "num_trajectories": 5208,
1351
+ "num_transitions": 1139911,
1352
+ "proprio": {
1353
+ "max": [
1354
+ 0.0,
1355
+ 0.0,
1356
+ 0.0,
1357
+ 0.0,
1358
+ 0.0,
1359
+ 0.0,
1360
+ 0.0
1361
+ ],
1362
+ "mean": [
1363
+ 0.0,
1364
+ 0.0,
1365
+ 0.0,
1366
+ 0.0,
1367
+ 0.0,
1368
+ 0.0,
1369
+ 0.0
1370
+ ],
1371
+ "min": [
1372
+ 0.0,
1373
+ 0.0,
1374
+ 0.0,
1375
+ 0.0,
1376
+ 0.0,
1377
+ 0.0,
1378
+ 0.0
1379
+ ],
1380
+ "q01": [
1381
+ 0.0,
1382
+ 0.0,
1383
+ 0.0,
1384
+ 0.0,
1385
+ 0.0,
1386
+ 0.0,
1387
+ 0.0
1388
+ ],
1389
+ "q99": [
1390
+ 0.0,
1391
+ 0.0,
1392
+ 0.0,
1393
+ 0.0,
1394
+ 0.0,
1395
+ 0.0,
1396
+ 0.0
1397
+ ],
1398
+ "std": [
1399
+ 0.0,
1400
+ 0.0,
1401
+ 0.0,
1402
+ 0.0,
1403
+ 0.0,
1404
+ 0.0,
1405
+ 0.0
1406
+ ]
1407
+ }
1408
+ },
1409
+ "fmb_dataset": {
1410
+ "action": {
1411
+ "mask": [
1412
+ true,
1413
+ true,
1414
+ true,
1415
+ true,
1416
+ true,
1417
+ true,
1418
+ false
1419
+ ],
1420
+ "max": [
1421
+ 1.399999976158142,
1422
+ 1.0,
1423
+ 1.399999976158142,
1424
+ 1.0,
1425
+ 1.0,
1426
+ 1.0,
1427
+ 1.0
1428
+ ],
1429
+ "mean": [
1430
+ 0.059029702097177505,
1431
+ -0.06476633995771408,
1432
+ -0.09787475317716599,
1433
+ 0.004325388930737972,
1434
+ 0.00028963794466108084,
1435
+ -0.04457257315516472,
1436
+ 0.7336440086364746
1437
+ ],
1438
+ "min": [
1439
+ -1.399999976158142,
1440
+ -1.399999976158142,
1441
+ -1.0,
1442
+ -1.0,
1443
+ -1.0,
1444
+ -1.0,
1445
+ 0.0
1446
+ ],
1447
+ "q01": [
1448
+ -0.8257142901420593,
1449
+ -1.399999976158142,
1450
+ -1.0,
1451
+ -1.0,
1452
+ -0.3028571307659149,
1453
+ -1.0,
1454
+ 0.0
1455
+ ],
1456
+ "q99": [
1457
+ 1.0,
1458
+ 0.5257142782211304,
1459
+ 1.0,
1460
+ 1.0,
1461
+ 0.3400000035762787,
1462
+ 1.0,
1463
+ 1.0
1464
+ ],
1465
+ "std": [
1466
+ 0.28809213638305664,
1467
+ 0.2820415794849396,
1468
+ 0.4626740515232086,
1469
+ 0.3266514539718628,
1470
+ 0.10842999070882797,
1471
+ 0.3440099358558655,
1472
+ 0.4435282051563263
1473
+ ]
1474
+ },
1475
+ "num_trajectories": 8612,
1476
+ "num_transitions": 1137459,
1477
+ "proprio": {
1478
+ "max": [
1479
+ 0.0,
1480
+ 0.0,
1481
+ 0.0,
1482
+ 0.0,
1483
+ 0.0,
1484
+ 0.0,
1485
+ 0.0
1486
+ ],
1487
+ "mean": [
1488
+ 0.0,
1489
+ 0.0,
1490
+ 0.0,
1491
+ 0.0,
1492
+ 0.0,
1493
+ 0.0,
1494
+ 0.0
1495
+ ],
1496
+ "min": [
1497
+ 0.0,
1498
+ 0.0,
1499
+ 0.0,
1500
+ 0.0,
1501
+ 0.0,
1502
+ 0.0,
1503
+ 0.0
1504
+ ],
1505
+ "q01": [
1506
+ 0.0,
1507
+ 0.0,
1508
+ 0.0,
1509
+ 0.0,
1510
+ 0.0,
1511
+ 0.0,
1512
+ 0.0
1513
+ ],
1514
+ "q99": [
1515
+ 0.0,
1516
+ 0.0,
1517
+ 0.0,
1518
+ 0.0,
1519
+ 0.0,
1520
+ 0.0,
1521
+ 0.0
1522
+ ],
1523
+ "std": [
1524
+ 0.0,
1525
+ 0.0,
1526
+ 0.0,
1527
+ 0.0,
1528
+ 0.0,
1529
+ 0.0,
1530
+ 0.0
1531
+ ]
1532
+ }
1533
+ },
1534
+ "fractal20220817_data": {
1535
+ "action": {
1536
+ "mask": [
1537
+ true,
1538
+ true,
1539
+ true,
1540
+ true,
1541
+ true,
1542
+ true,
1543
+ false
1544
+ ],
1545
+ "max": [
1546
+ 2.9984593391418457,
1547
+ 22.09052848815918,
1548
+ 2.7507524490356445,
1549
+ 1.570636510848999,
1550
+ 1.5321086645126343,
1551
+ 1.5691522359848022,
1552
+ 1.0
1553
+ ],
1554
+ "mean": [
1555
+ 0.006987582892179489,
1556
+ 0.006265917327255011,
1557
+ -0.01262515690177679,
1558
+ 0.04333311319351196,
1559
+ -0.005756212864071131,
1560
+ 0.0009130256366916001,
1561
+ 0.5354204773902893
1562
+ ],
1563
+ "min": [
1564
+ -2.0204520225524902,
1565
+ -5.497899532318115,
1566
+ -2.031663417816162,
1567
+ -1.569917917251587,
1568
+ -1.569892168045044,
1569
+ -1.570419430732727,
1570
+ 0.0
1571
+ ],
1572
+ "q01": [
1573
+ -0.22453527510166169,
1574
+ -0.14820013284683228,
1575
+ -0.231589707583189,
1576
+ -0.3517994859814644,
1577
+ -0.4193011274933815,
1578
+ -0.43643461108207704,
1579
+ 0.0
1580
+ ],
1581
+ "q99": [
1582
+ 0.17824687153100965,
1583
+ 0.14938379630446405,
1584
+ 0.21842354819178575,
1585
+ 0.5892666035890578,
1586
+ 0.35272657424211445,
1587
+ 0.44796681255102094,
1588
+ 1.0
1589
+ ],
1590
+ "std": [
1591
+ 0.0692116990685463,
1592
+ 0.05970962345600128,
1593
+ 0.07353084534406662,
1594
+ 0.15610496699810028,
1595
+ 0.13164450228214264,
1596
+ 0.14593800902366638,
1597
+ 0.497110515832901
1598
+ ]
1599
+ },
1600
+ "num_trajectories": 87212,
1601
+ "num_transitions": 3786400,
1602
+ "proprio": {
1603
+ "max": [
1604
+ 0.0,
1605
+ 0.0,
1606
+ 0.0,
1607
+ 0.0,
1608
+ 0.0,
1609
+ 0.0,
1610
+ 0.0
1611
+ ],
1612
+ "mean": [
1613
+ 0.0,
1614
+ 0.0,
1615
+ 0.0,
1616
+ 0.0,
1617
+ 0.0,
1618
+ 0.0,
1619
+ 0.0
1620
+ ],
1621
+ "min": [
1622
+ 0.0,
1623
+ 0.0,
1624
+ 0.0,
1625
+ 0.0,
1626
+ 0.0,
1627
+ 0.0,
1628
+ 0.0
1629
+ ],
1630
+ "q01": [
1631
+ 0.0,
1632
+ 0.0,
1633
+ 0.0,
1634
+ 0.0,
1635
+ 0.0,
1636
+ 0.0,
1637
+ 0.0
1638
+ ],
1639
+ "q99": [
1640
+ 0.0,
1641
+ 0.0,
1642
+ 0.0,
1643
+ 0.0,
1644
+ 0.0,
1645
+ 0.0,
1646
+ 0.0
1647
+ ],
1648
+ "std": [
1649
+ 0.0,
1650
+ 0.0,
1651
+ 0.0,
1652
+ 0.0,
1653
+ 0.0,
1654
+ 0.0,
1655
+ 0.0
1656
+ ]
1657
+ }
1658
+ },
1659
+ "furniture_bench_dataset_converted_externally_to_rlds": {
1660
+ "action": {
1661
+ "mask": [
1662
+ true,
1663
+ true,
1664
+ true,
1665
+ true,
1666
+ true,
1667
+ true,
1668
+ false
1669
+ ],
1670
+ "max": [
1671
+ 0.10000000149011612,
1672
+ 0.10000000149011612,
1673
+ 0.10000000149011612,
1674
+ 0.8651833534240723,
1675
+ 1.0909736156463623,
1676
+ 2.863185405731201,
1677
+ 1.0
1678
+ ],
1679
+ "mean": [
1680
+ 0.00014610752987209707,
1681
+ 0.0010830952087417245,
1682
+ 0.0006224989192560315,
1683
+ -0.003303206292912364,
1684
+ -0.0026880695950239897,
1685
+ 0.018242603167891502,
1686
+ 0.48854944109916687
1687
+ ],
1688
+ "min": [
1689
+ -0.10495579987764359,
1690
+ -0.10939455777406693,
1691
+ -0.10000000149011612,
1692
+ -0.971906840801239,
1693
+ -1.0475432872772217,
1694
+ -3.06000018119812,
1695
+ 0.0
1696
+ ],
1697
+ "q01": [
1698
+ -0.053988199681043625,
1699
+ -0.05049169331789017,
1700
+ -0.032499241530895236,
1701
+ -0.1953887003660202,
1702
+ -0.41674559473991396,
1703
+ -0.8886768388748169,
1704
+ 0.0
1705
+ ],
1706
+ "q99": [
1707
+ 0.05414841488003723,
1708
+ 0.04965164884924884,
1709
+ 0.060055799782276154,
1710
+ 0.18231668293476103,
1711
+ 0.39867786407470646,
1712
+ 0.8772023963928218,
1713
+ 1.0
1714
+ ],
1715
+ "std": [
1716
+ 0.01610708422958851,
1717
+ 0.014891477301716805,
1718
+ 0.014014219865202904,
1719
+ 0.058274295181035995,
1720
+ 0.11417088657617569,
1721
+ 0.33479776978492737,
1722
+ 0.49991825222969055
1723
+ ]
1724
+ },
1725
+ "num_trajectories": 5100,
1726
+ "num_transitions": 3948057,
1727
+ "proprio": {
1728
+ "max": [
1729
+ 0.0,
1730
+ 0.0,
1731
+ 0.0,
1732
+ 0.0,
1733
+ 0.0,
1734
+ 0.0,
1735
+ 0.0
1736
+ ],
1737
+ "mean": [
1738
+ 0.0,
1739
+ 0.0,
1740
+ 0.0,
1741
+ 0.0,
1742
+ 0.0,
1743
+ 0.0,
1744
+ 0.0
1745
+ ],
1746
+ "min": [
1747
+ 0.0,
1748
+ 0.0,
1749
+ 0.0,
1750
+ 0.0,
1751
+ 0.0,
1752
+ 0.0,
1753
+ 0.0
1754
+ ],
1755
+ "q01": [
1756
+ 0.0,
1757
+ 0.0,
1758
+ 0.0,
1759
+ 0.0,
1760
+ 0.0,
1761
+ 0.0,
1762
+ 0.0
1763
+ ],
1764
+ "q99": [
1765
+ 0.0,
1766
+ 0.0,
1767
+ 0.0,
1768
+ 0.0,
1769
+ 0.0,
1770
+ 0.0,
1771
+ 0.0
1772
+ ],
1773
+ "std": [
1774
+ 0.0,
1775
+ 0.0,
1776
+ 0.0,
1777
+ 0.0,
1778
+ 0.0,
1779
+ 0.0,
1780
+ 0.0
1781
+ ]
1782
+ }
1783
+ },
1784
+ "iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
1785
+ "action": {
1786
+ "mask": [
1787
+ true,
1788
+ true,
1789
+ true,
1790
+ true,
1791
+ true,
1792
+ true,
1793
+ false
1794
+ ],
1795
+ "max": [
1796
+ 0.6634981632232666,
1797
+ 0.23428471386432648,
1798
+ 0.4308285415172577,
1799
+ 3.1415927410125732,
1800
+ 0.13647015392780304,
1801
+ 3.141592502593994,
1802
+ 1.0
1803
+ ],
1804
+ "mean": [
1805
+ 0.5274372696876526,
1806
+ 0.02858201041817665,
1807
+ 0.18712575733661652,
1808
+ 1.2339589595794678,
1809
+ 0.03226623684167862,
1810
+ -1.4199490547180176,
1811
+ 0.5550631880760193
1812
+ ],
1813
+ "min": [
1814
+ 0.3071657121181488,
1815
+ -0.29754969477653503,
1816
+ 0.06578229367733002,
1817
+ -3.1415927410125732,
1818
+ -0.04584203287959099,
1819
+ -3.141592502593994,
1820
+ 0.0
1821
+ ],
1822
+ "q01": [
1823
+ 0.3148897051811218,
1824
+ -0.20317550599575043,
1825
+ 0.06785467118024827,
1826
+ -3.140952730178833,
1827
+ -0.029743434861302376,
1828
+ -3.141091251373291,
1829
+ 0.0
1830
+ ],
1831
+ "q99": [
1832
+ 0.6472805738449097,
1833
+ 0.20846802592277527,
1834
+ 0.36855655312538155,
1835
+ 3.1409926891326903,
1836
+ 0.11424950212240226,
1837
+ 3.1410969257354737,
1838
+ 1.0
1839
+ ],
1840
+ "std": [
1841
+ 0.08108345419168472,
1842
+ 0.1116757020354271,
1843
+ 0.07747554779052734,
1844
+ 2.8737246990203857,
1845
+ 0.02774704433977604,
1846
+ 2.7678682804107666,
1847
+ 0.49695101380348206
1848
+ ]
1849
+ },
1850
+ "num_trajectories": 631,
1851
+ "num_transitions": 146241,
1852
+ "proprio": {
1853
+ "max": [
1854
+ 0.0,
1855
+ 0.0,
1856
+ 0.0,
1857
+ 0.0,
1858
+ 0.0,
1859
+ 0.0,
1860
+ 0.0
1861
+ ],
1862
+ "mean": [
1863
+ 0.0,
1864
+ 0.0,
1865
+ 0.0,
1866
+ 0.0,
1867
+ 0.0,
1868
+ 0.0,
1869
+ 0.0
1870
+ ],
1871
+ "min": [
1872
+ 0.0,
1873
+ 0.0,
1874
+ 0.0,
1875
+ 0.0,
1876
+ 0.0,
1877
+ 0.0,
1878
+ 0.0
1879
+ ],
1880
+ "q01": [
1881
+ 0.0,
1882
+ 0.0,
1883
+ 0.0,
1884
+ 0.0,
1885
+ 0.0,
1886
+ 0.0,
1887
+ 0.0
1888
+ ],
1889
+ "q99": [
1890
+ 0.0,
1891
+ 0.0,
1892
+ 0.0,
1893
+ 0.0,
1894
+ 0.0,
1895
+ 0.0,
1896
+ 0.0
1897
+ ],
1898
+ "std": [
1899
+ 0.0,
1900
+ 0.0,
1901
+ 0.0,
1902
+ 0.0,
1903
+ 0.0,
1904
+ 0.0,
1905
+ 0.0
1906
+ ]
1907
+ }
1908
+ },
1909
+ "jaco_play": {
1910
+ "action": {
1911
+ "mask": [
1912
+ true,
1913
+ true,
1914
+ true,
1915
+ true,
1916
+ true,
1917
+ true,
1918
+ false
1919
+ ],
1920
+ "max": [
1921
+ 0.20000000298023224,
1922
+ 0.20000000298023224,
1923
+ 0.20000000298023224,
1924
+ 0.0,
1925
+ 0.0,
1926
+ 0.0,
1927
+ 1.0
1928
+ ],
1929
+ "mean": [
1930
+ 0.0009658430935814977,
1931
+ -0.00580078037455678,
1932
+ -0.00395062193274498,
1933
+ 0.0,
1934
+ 0.0,
1935
+ 0.0,
1936
+ 0.34934908151626587
1937
+ ],
1938
+ "min": [
1939
+ -0.20000000298023224,
1940
+ -0.20000000298023224,
1941
+ -0.20000000298023224,
1942
+ 0.0,
1943
+ 0.0,
1944
+ 0.0,
1945
+ 0.0
1946
+ ],
1947
+ "q01": [
1948
+ -0.20000000298023224,
1949
+ -0.20000000298023224,
1950
+ -0.20000000298023224,
1951
+ 0.0,
1952
+ 0.0,
1953
+ 0.0,
1954
+ 0.0
1955
+ ],
1956
+ "q99": [
1957
+ 0.20000000298023224,
1958
+ 0.20000000298023224,
1959
+ 0.20000000298023224,
1960
+ 0.0,
1961
+ 0.0,
1962
+ 0.0,
1963
+ 1.0
1964
+ ],
1965
+ "std": [
1966
+ 0.12235074490308762,
1967
+ 0.09678777307271957,
1968
+ 0.11155334860086441,
1969
+ 0.0,
1970
+ 0.0,
1971
+ 0.0,
1972
+ 0.4768252968788147
1973
+ ]
1974
+ },
1975
+ "num_trajectories": 1085,
1976
+ "num_transitions": 77965,
1977
+ "proprio": {
1978
+ "max": [
1979
+ 0.0,
1980
+ 0.0,
1981
+ 0.0,
1982
+ 0.0,
1983
+ 0.0,
1984
+ 0.0,
1985
+ 0.0
1986
+ ],
1987
+ "mean": [
1988
+ 0.0,
1989
+ 0.0,
1990
+ 0.0,
1991
+ 0.0,
1992
+ 0.0,
1993
+ 0.0,
1994
+ 0.0
1995
+ ],
1996
+ "min": [
1997
+ 0.0,
1998
+ 0.0,
1999
+ 0.0,
2000
+ 0.0,
2001
+ 0.0,
2002
+ 0.0,
2003
+ 0.0
2004
+ ],
2005
+ "q01": [
2006
+ 0.0,
2007
+ 0.0,
2008
+ 0.0,
2009
+ 0.0,
2010
+ 0.0,
2011
+ 0.0,
2012
+ 0.0
2013
+ ],
2014
+ "q99": [
2015
+ 0.0,
2016
+ 0.0,
2017
+ 0.0,
2018
+ 0.0,
2019
+ 0.0,
2020
+ 0.0,
2021
+ 0.0
2022
+ ],
2023
+ "std": [
2024
+ 0.0,
2025
+ 0.0,
2026
+ 0.0,
2027
+ 0.0,
2028
+ 0.0,
2029
+ 0.0,
2030
+ 0.0
2031
+ ]
2032
+ }
2033
+ },
2034
+ "kuka": {
2035
+ "action": {
2036
+ "mask": [
2037
+ true,
2038
+ true,
2039
+ true,
2040
+ true,
2041
+ true,
2042
+ true,
2043
+ false
2044
+ ],
2045
+ "max": [
2046
+ 0.1697135865688324,
2047
+ 0.2777623236179352,
2048
+ 0.43710532784461975,
2049
+ 0.0,
2050
+ 0.0,
2051
+ 1.9684287309646606,
2052
+ 1.0
2053
+ ],
2054
+ "mean": [
2055
+ -0.0004668905457947403,
2056
+ 0.00040138536132872105,
2057
+ -0.001280792523175478,
2058
+ 0.0,
2059
+ 0.0,
2060
+ -0.03722453489899635,
2061
+ 0.4131543040275574
2062
+ ],
2063
+ "min": [
2064
+ -0.159867063164711,
2065
+ -0.2892282009124756,
2066
+ -0.2795473635196686,
2067
+ 0.0,
2068
+ 0.0,
2069
+ -1.9875637292861938,
2070
+ 0.0
2071
+ ],
2072
+ "q01": [
2073
+ -0.06619441494345665,
2074
+ -0.08713878810405731,
2075
+ -0.15083016991615295,
2076
+ 0.0,
2077
+ 0.0,
2078
+ -0.5415697038173676,
2079
+ 0.0
2080
+ ],
2081
+ "q99": [
2082
+ 0.06601839080452929,
2083
+ 0.08732476785779003,
2084
+ 0.18168179214000715,
2085
+ 0.0,
2086
+ 0.0,
2087
+ 0.2923380345106127,
2088
+ 1.0
2089
+ ],
2090
+ "std": [
2091
+ 0.02083250693976879,
2092
+ 0.02915887162089348,
2093
+ 0.06422865390777588,
2094
+ 0.0,
2095
+ 0.0,
2096
+ 0.14224295318126678,
2097
+ 0.49086448550224304
2098
+ ]
2099
+ },
2100
+ "num_trajectories": 209880,
2101
+ "num_transitions": 2455879,
2102
+ "proprio": {
2103
+ "max": [
2104
+ 0.0,
2105
+ 0.0,
2106
+ 0.0,
2107
+ 0.0,
2108
+ 0.0,
2109
+ 0.0,
2110
+ 0.0
2111
+ ],
2112
+ "mean": [
2113
+ 0.0,
2114
+ 0.0,
2115
+ 0.0,
2116
+ 0.0,
2117
+ 0.0,
2118
+ 0.0,
2119
+ 0.0
2120
+ ],
2121
+ "min": [
2122
+ 0.0,
2123
+ 0.0,
2124
+ 0.0,
2125
+ 0.0,
2126
+ 0.0,
2127
+ 0.0,
2128
+ 0.0
2129
+ ],
2130
+ "q01": [
2131
+ 0.0,
2132
+ 0.0,
2133
+ 0.0,
2134
+ 0.0,
2135
+ 0.0,
2136
+ 0.0,
2137
+ 0.0
2138
+ ],
2139
+ "q99": [
2140
+ 0.0,
2141
+ 0.0,
2142
+ 0.0,
2143
+ 0.0,
2144
+ 0.0,
2145
+ 0.0,
2146
+ 0.0
2147
+ ],
2148
+ "std": [
2149
+ 0.0,
2150
+ 0.0,
2151
+ 0.0,
2152
+ 0.0,
2153
+ 0.0,
2154
+ 0.0,
2155
+ 0.0
2156
+ ]
2157
+ }
2158
+ },
2159
+ "libero_10_no_noops": {
2160
+ "action": {
2161
+ "mask": [
2162
+ true,
2163
+ true,
2164
+ true,
2165
+ true,
2166
+ true,
2167
+ true,
2168
+ false
2169
+ ],
2170
+ "max": [
2171
+ 0.9375,
2172
+ 0.9375,
2173
+ 0.9375,
2174
+ 0.30000001192092896,
2175
+ 0.29357144236564636,
2176
+ 0.375,
2177
+ 1.0
2178
+ ],
2179
+ "mean": [
2180
+ 0.01820324920117855,
2181
+ 0.05858374014496803,
2182
+ -0.05592384561896324,
2183
+ 0.004626928828656673,
2184
+ 0.00289608770981431,
2185
+ -0.007673131301999092,
2186
+ 0.5457824468612671
2187
+ ],
2188
+ "min": [
2189
+ -0.9375,
2190
+ -0.9375,
2191
+ -0.9375,
2192
+ -0.23642857372760773,
2193
+ -0.3053571283817291,
2194
+ -0.3675000071525574,
2195
+ 0.0
2196
+ ],
2197
+ "q01": [
2198
+ -0.6348214149475098,
2199
+ -0.7741071581840515,
2200
+ -0.7633928656578064,
2201
+ -0.09749999642372131,
2202
+ -0.14819999992847435,
2203
+ -0.2742857038974762,
2204
+ 0.0
2205
+ ],
2206
+ "q99": [
2207
+ 0.7714285850524902,
2208
+ 0.8464285731315613,
2209
+ 0.9375,
2210
+ 0.13928571343421936,
2211
+ 0.15964286029338837,
2212
+ 0.3246428668498993,
2213
+ 1.0
2214
+ ],
2215
+ "std": [
2216
+ 0.2825464606285095,
2217
+ 0.35904666781425476,
2218
+ 0.3673802614212036,
2219
+ 0.03770702704787254,
2220
+ 0.05429719388484955,
2221
+ 0.08725254982709885,
2222
+ 0.49815231561660767
2223
+ ]
2224
+ },
2225
+ "num_trajectories": 379,
2226
+ "num_transitions": 101469,
2227
+ "proprio": {
2228
+ "max": [
2229
+ 0.21031762659549713,
2230
+ 0.39128610491752625,
2231
+ 1.3332009315490723,
2232
+ 3.6714255809783936,
2233
+ 3.560650587081909,
2234
+ 1.386339545249939,
2235
+ 0.04160946607589722,
2236
+ 0.0013633022317662835
2237
+ ],
2238
+ "mean": [
2239
+ -0.04190658777952194,
2240
+ 0.03539430722594261,
2241
+ 0.8257141709327698,
2242
+ 2.908308267593384,
2243
+ -0.5562185049057007,
2244
+ -0.16649018228054047,
2245
+ 0.028316624462604523,
2246
+ -0.028561657294631004
2247
+ ],
2248
+ "min": [
2249
+ -0.4828203022480011,
2250
+ -0.3255046010017395,
2251
+ 0.445506751537323,
2252
+ 1.1321442127227783,
2253
+ -3.641430377960205,
2254
+ -1.842738389968872,
2255
+ -0.0010040868073701859,
2256
+ -0.04111652821302414
2257
+ ],
2258
+ "q01": [
2259
+ -0.3899900782108307,
2260
+ -0.2838300323486328,
2261
+ 0.44795057058334353,
2262
+ 1.8810229921340942,
2263
+ -2.886677579879761,
2264
+ -1.1599004411697387,
2265
+ 0.002066459748893976,
2266
+ -0.04001387819647789
2267
+ ],
2268
+ "q99": [
2269
+ 0.1530261474847791,
2270
+ 0.32915401458740223,
2271
+ 1.2546923208236693,
2272
+ 3.303542451858519,
2273
+ 2.7496529006957933,
2274
+ 0.6893712210655194,
2275
+ 0.040048558115959164,
2276
+ -0.0017598449345678235
2277
+ ],
2278
+ "std": [
2279
+ 0.10743364691734314,
2280
+ 0.14424669742584229,
2281
+ 0.2572328448295593,
2282
+ 0.3441362977027893,
2283
+ 1.234421730041504,
2284
+ 0.3579835891723633,
2285
+ 0.013308707624673843,
2286
+ 0.013174631632864475
2287
+ ]
2288
+ }
2289
+ },
2290
+ "nyu_franka_play_dataset_converted_externally_to_rlds": {
2291
+ "action": {
2292
+ "mask": [
2293
+ true,
2294
+ true,
2295
+ true,
2296
+ true,
2297
+ true,
2298
+ true,
2299
+ false
2300
+ ],
2301
+ "max": [
2302
+ 0.06424188613891602,
2303
+ 0.07027634978294373,
2304
+ 0.06129661202430725,
2305
+ 6.281067848205566,
2306
+ 0.1967729926109314,
2307
+ 0.26377415657043457,
2308
+ 1.0
2309
+ ],
2310
+ "mean": [
2311
+ 0.001021989737637341,
2312
+ -0.00012002651783404872,
2313
+ 0.00032894269679673016,
2314
+ 0.0015034361276775599,
2315
+ -0.002198522910475731,
2316
+ -0.001663230243138969,
2317
+ 0.7230083346366882
2318
+ ],
2319
+ "min": [
2320
+ -0.05952230095863342,
2321
+ -0.07232445478439331,
2322
+ -0.06730806827545166,
2323
+ -6.278434753417969,
2324
+ -0.21479034423828125,
2325
+ -0.3627619743347168,
2326
+ 0.0
2327
+ ],
2328
+ "q01": [
2329
+ -0.03199600875377655,
2330
+ -0.032861671447753905,
2331
+ -0.03368805110454559,
2332
+ -0.12080862045288086,
2333
+ -0.12175218224525451,
2334
+ -0.11370223641395569,
2335
+ 0.0
2336
+ ],
2337
+ "q99": [
2338
+ 0.03101520001888276,
2339
+ 0.0373908892273903,
2340
+ 0.03646374464035038,
2341
+ 0.11764093399047852,
2342
+ 0.1258920183777809,
2343
+ 0.09366151213645942,
2344
+ 1.0
2345
+ ],
2346
+ "std": [
2347
+ 0.01327415369451046,
2348
+ 0.013215910643339157,
2349
+ 0.012822109274566174,
2350
+ 0.2732451558113098,
2351
+ 0.057022541761398315,
2352
+ 0.039172880351543427,
2353
+ 0.44752755761146545
2354
+ ]
2355
+ },
2356
+ "num_trajectories": 456,
2357
+ "num_transitions": 44875,
2358
+ "proprio": {
2359
+ "max": [
2360
+ 0.0,
2361
+ 0.0,
2362
+ 0.0,
2363
+ 0.0,
2364
+ 0.0,
2365
+ 0.0,
2366
+ 0.0
2367
+ ],
2368
+ "mean": [
2369
+ 0.0,
2370
+ 0.0,
2371
+ 0.0,
2372
+ 0.0,
2373
+ 0.0,
2374
+ 0.0,
2375
+ 0.0
2376
+ ],
2377
+ "min": [
2378
+ 0.0,
2379
+ 0.0,
2380
+ 0.0,
2381
+ 0.0,
2382
+ 0.0,
2383
+ 0.0,
2384
+ 0.0
2385
+ ],
2386
+ "q01": [
2387
+ 0.0,
2388
+ 0.0,
2389
+ 0.0,
2390
+ 0.0,
2391
+ 0.0,
2392
+ 0.0,
2393
+ 0.0
2394
+ ],
2395
+ "q99": [
2396
+ 0.0,
2397
+ 0.0,
2398
+ 0.0,
2399
+ 0.0,
2400
+ 0.0,
2401
+ 0.0,
2402
+ 0.0
2403
+ ],
2404
+ "std": [
2405
+ 0.0,
2406
+ 0.0,
2407
+ 0.0,
2408
+ 0.0,
2409
+ 0.0,
2410
+ 0.0,
2411
+ 0.0
2412
+ ]
2413
+ }
2414
+ },
2415
+ "roboturk": {
2416
+ "action": {
2417
+ "mask": [
2418
+ true,
2419
+ true,
2420
+ true,
2421
+ true,
2422
+ true,
2423
+ true,
2424
+ false
2425
+ ],
2426
+ "max": [
2427
+ 0.39124172925949097,
2428
+ 0.4601028263568878,
2429
+ 0.4870833456516266,
2430
+ 1.816888689994812,
2431
+ 1.8240282535552979,
2432
+ 1.4824820756912231,
2433
+ 1.0
2434
+ ],
2435
+ "mean": [
2436
+ 0.0014448732836171985,
2437
+ -0.0015945249469950795,
2438
+ -0.0011753785656765103,
2439
+ 0.0023012510500848293,
2440
+ -0.0009382463176734746,
2441
+ -0.00011485807772260159,
2442
+ 0.5746025443077087
2443
+ ],
2444
+ "min": [
2445
+ -0.6546999216079712,
2446
+ -0.6365841031074524,
2447
+ -0.4217723608016968,
2448
+ -1.6695482730865479,
2449
+ -1.8023357391357422,
2450
+ -1.4630827903747559,
2451
+ 0.0
2452
+ ],
2453
+ "q01": [
2454
+ -0.1342635464668274,
2455
+ -0.19996687173843383,
2456
+ -0.1482972100377083,
2457
+ -0.20720748245716095,
2458
+ -0.09676413893699647,
2459
+ -0.18075634717941286,
2460
+ 0.0
2461
+ ],
2462
+ "q99": [
2463
+ 0.14956976801157001,
2464
+ 0.1805950567126275,
2465
+ 0.18841815620660796,
2466
+ 0.21615413755178453,
2467
+ 0.09457383215427405,
2468
+ 0.18543301910162005,
2469
+ 1.0
2470
+ ],
2471
+ "std": [
2472
+ 0.04935386776924133,
2473
+ 0.0635455846786499,
2474
+ 0.061164740473032,
2475
+ 0.09553450345993042,
2476
+ 0.08420111238956451,
2477
+ 0.06517903506755829,
2478
+ 0.49452081322669983
2479
+ ]
2480
+ },
2481
+ "num_trajectories": 1995,
2482
+ "num_transitions": 187507,
2483
+ "proprio": {
2484
+ "max": [
2485
+ 0.0,
2486
+ 0.0,
2487
+ 0.0,
2488
+ 0.0,
2489
+ 0.0,
2490
+ 0.0,
2491
+ 0.0
2492
+ ],
2493
+ "mean": [
2494
+ 0.0,
2495
+ 0.0,
2496
+ 0.0,
2497
+ 0.0,
2498
+ 0.0,
2499
+ 0.0,
2500
+ 0.0
2501
+ ],
2502
+ "min": [
2503
+ 0.0,
2504
+ 0.0,
2505
+ 0.0,
2506
+ 0.0,
2507
+ 0.0,
2508
+ 0.0,
2509
+ 0.0
2510
+ ],
2511
+ "q01": [
2512
+ 0.0,
2513
+ 0.0,
2514
+ 0.0,
2515
+ 0.0,
2516
+ 0.0,
2517
+ 0.0,
2518
+ 0.0
2519
+ ],
2520
+ "q99": [
2521
+ 0.0,
2522
+ 0.0,
2523
+ 0.0,
2524
+ 0.0,
2525
+ 0.0,
2526
+ 0.0,
2527
+ 0.0
2528
+ ],
2529
+ "std": [
2530
+ 0.0,
2531
+ 0.0,
2532
+ 0.0,
2533
+ 0.0,
2534
+ 0.0,
2535
+ 0.0,
2536
+ 0.0
2537
+ ]
2538
+ }
2539
+ },
2540
+ "stanford_hydra_dataset_converted_externally_to_rlds": {
2541
+ "action": {
2542
+ "mask": [
2543
+ true,
2544
+ true,
2545
+ true,
2546
+ true,
2547
+ true,
2548
+ true,
2549
+ false
2550
+ ],
2551
+ "max": [
2552
+ 0.02499854564666748,
2553
+ 0.02499903365969658,
2554
+ 0.024999922141432762,
2555
+ 0.24974457919597626,
2556
+ 0.24997030198574066,
2557
+ 0.24999946355819702,
2558
+ 1.0
2559
+ ],
2560
+ "mean": [
2561
+ 0.0007790001109242439,
2562
+ 0.00013707754260394722,
2563
+ -0.0002548607881180942,
2564
+ 0.0012903271708637476,
2565
+ -0.004751681815832853,
2566
+ 0.002692886395379901,
2567
+ 0.48855218291282654
2568
+ ],
2569
+ "min": [
2570
+ -0.024999044835567474,
2571
+ -0.024999700486660004,
2572
+ -0.02499929815530777,
2573
+ -0.24993225932121277,
2574
+ -0.2499666064977646,
2575
+ -0.2499932497739792,
2576
+ 0.0
2577
+ ],
2578
+ "q01": [
2579
+ -0.019992006458342076,
2580
+ -0.02415412735193968,
2581
+ -0.022941758055239916,
2582
+ -0.11085530579090118,
2583
+ -0.12024572037160397,
2584
+ -0.13314770206809043,
2585
+ 0.0
2586
+ ],
2587
+ "q99": [
2588
+ 0.022886231057345868,
2589
+ 0.022358838934451335,
2590
+ 0.02410089675337076,
2591
+ 0.12370114490389822,
2592
+ 0.11323311634361738,
2593
+ 0.18474749639630164,
2594
+ 1.0
2595
+ ],
2596
+ "std": [
2597
+ 0.008022161200642586,
2598
+ 0.009131459519267082,
2599
+ 0.009574338793754578,
2600
+ 0.04122216999530792,
2601
+ 0.0384303517639637,
2602
+ 0.04606688767671585,
2603
+ 0.49976691603660583
2604
+ ]
2605
+ },
2606
+ "num_trajectories": 570,
2607
+ "num_transitions": 358234,
2608
+ "proprio": {
2609
+ "max": [
2610
+ 0.0,
2611
+ 0.0,
2612
+ 0.0,
2613
+ 0.0,
2614
+ 0.0,
2615
+ 0.0,
2616
+ 0.0
2617
+ ],
2618
+ "mean": [
2619
+ 0.0,
2620
+ 0.0,
2621
+ 0.0,
2622
+ 0.0,
2623
+ 0.0,
2624
+ 0.0,
2625
+ 0.0
2626
+ ],
2627
+ "min": [
2628
+ 0.0,
2629
+ 0.0,
2630
+ 0.0,
2631
+ 0.0,
2632
+ 0.0,
2633
+ 0.0,
2634
+ 0.0
2635
+ ],
2636
+ "q01": [
2637
+ 0.0,
2638
+ 0.0,
2639
+ 0.0,
2640
+ 0.0,
2641
+ 0.0,
2642
+ 0.0,
2643
+ 0.0
2644
+ ],
2645
+ "q99": [
2646
+ 0.0,
2647
+ 0.0,
2648
+ 0.0,
2649
+ 0.0,
2650
+ 0.0,
2651
+ 0.0,
2652
+ 0.0
2653
+ ],
2654
+ "std": [
2655
+ 0.0,
2656
+ 0.0,
2657
+ 0.0,
2658
+ 0.0,
2659
+ 0.0,
2660
+ 0.0,
2661
+ 0.0
2662
+ ]
2663
+ }
2664
+ },
2665
+ "taco_play": {
2666
+ "action": {
2667
+ "mask": [
2668
+ true,
2669
+ true,
2670
+ true,
2671
+ true,
2672
+ true,
2673
+ true,
2674
+ false
2675
+ ],
2676
+ "max": [
2677
+ 1.4915844202041626,
2678
+ 2.1842432022094727,
2679
+ 2.6836395263671875,
2680
+ 5.035226821899414,
2681
+ 2.665864944458008,
2682
+ 4.250768661499023,
2683
+ 1.0
2684
+ ],
2685
+ "mean": [
2686
+ -0.003845922416076064,
2687
+ 0.009671456180512905,
2688
+ 0.012780580669641495,
2689
+ -0.005403771996498108,
2690
+ -0.009606587700545788,
2691
+ -0.002480733208358288,
2692
+ 0.4263913035392761
2693
+ ],
2694
+ "min": [
2695
+ -4.242457866668701,
2696
+ -3.192805051803589,
2697
+ -1.3371467590332031,
2698
+ -4.202683448791504,
2699
+ -2.6722638607025146,
2700
+ -3.3467135429382324,
2701
+ 0.0
2702
+ ],
2703
+ "q01": [
2704
+ -0.7106140398979186,
2705
+ -1.056944659948349,
2706
+ -0.5878450274467468,
2707
+ -0.7682853937149048,
2708
+ -0.7180147767066956,
2709
+ -1.5527938604354858,
2710
+ 0.0
2711
+ ],
2712
+ "q99": [
2713
+ 0.6482916426658629,
2714
+ 1.0051310062408447,
2715
+ 0.9480248689651489,
2716
+ 0.6926478147506714,
2717
+ 0.6351067513227462,
2718
+ 1.628010264635086,
2719
+ 1.0
2720
+ ],
2721
+ "std": [
2722
+ 0.23254038393497467,
2723
+ 0.36298269033432007,
2724
+ 0.28692901134490967,
2725
+ 0.2617705166339874,
2726
+ 0.2438892275094986,
2727
+ 0.5216503143310547,
2728
+ 0.4946896731853485
2729
+ ]
2730
+ },
2731
+ "num_trajectories": 3603,
2732
+ "num_transitions": 237798,
2733
+ "proprio": {
2734
+ "max": [
2735
+ 0.0,
2736
+ 0.0,
2737
+ 0.0,
2738
+ 0.0,
2739
+ 0.0,
2740
+ 0.0,
2741
+ 0.0
2742
+ ],
2743
+ "mean": [
2744
+ 0.0,
2745
+ 0.0,
2746
+ 0.0,
2747
+ 0.0,
2748
+ 0.0,
2749
+ 0.0,
2750
+ 0.0
2751
+ ],
2752
+ "min": [
2753
+ 0.0,
2754
+ 0.0,
2755
+ 0.0,
2756
+ 0.0,
2757
+ 0.0,
2758
+ 0.0,
2759
+ 0.0
2760
+ ],
2761
+ "q01": [
2762
+ 0.0,
2763
+ 0.0,
2764
+ 0.0,
2765
+ 0.0,
2766
+ 0.0,
2767
+ 0.0,
2768
+ 0.0
2769
+ ],
2770
+ "q99": [
2771
+ 0.0,
2772
+ 0.0,
2773
+ 0.0,
2774
+ 0.0,
2775
+ 0.0,
2776
+ 0.0,
2777
+ 0.0
2778
+ ],
2779
+ "std": [
2780
+ 0.0,
2781
+ 0.0,
2782
+ 0.0,
2783
+ 0.0,
2784
+ 0.0,
2785
+ 0.0,
2786
+ 0.0
2787
+ ]
2788
+ }
2789
+ },
2790
+ "toto": {
2791
+ "action": {
2792
+ "mask": [
2793
+ true,
2794
+ true,
2795
+ true,
2796
+ true,
2797
+ true,
2798
+ true,
2799
+ false
2800
+ ],
2801
+ "max": [
2802
+ 0.6839867234230042,
2803
+ 0.4454185664653778,
2804
+ 0.7984078526496887,
2805
+ 2.120781660079956,
2806
+ 1.371164321899414,
2807
+ 1.4118704795837402,
2808
+ 0.0
2809
+ ],
2810
+ "mean": [
2811
+ 0.38542115688323975,
2812
+ 0.007769413758069277,
2813
+ 0.3632740378379822,
2814
+ -0.6652036905288696,
2815
+ 0.1890396922826767,
2816
+ 0.03298724442720413,
2817
+ 0.0
2818
+ ],
2819
+ "min": [
2820
+ 0.09922284632921219,
2821
+ -0.5180193781852722,
2822
+ 0.13791072368621826,
2823
+ -2.635117530822754,
2824
+ -1.0734480619430542,
2825
+ -1.9282547235488892,
2826
+ 0.0
2827
+ ],
2828
+ "q01": [
2829
+ 0.1756722891330719,
2830
+ -0.3077590811252594,
2831
+ 0.235383919775486,
2832
+ -2.0908505964279174,
2833
+ -0.6191593289375306,
2834
+ -0.7488683319091797,
2835
+ 0.0
2836
+ ],
2837
+ "q99": [
2838
+ 0.6136963081359863,
2839
+ 0.33704194784164443,
2840
+ 0.6681221985816956,
2841
+ 0.7422861719131538,
2842
+ 0.7955395007133507,
2843
+ 0.740464625358582,
2844
+ 0.0
2845
+ ],
2846
+ "std": [
2847
+ 0.12211652100086212,
2848
+ 0.19378550350666046,
2849
+ 0.10178236663341522,
2850
+ 0.5725259184837341,
2851
+ 0.29884573817253113,
2852
+ 0.3259911835193634,
2853
+ 0.0
2854
+ ]
2855
+ },
2856
+ "num_trajectories": 1003,
2857
+ "num_transitions": 325699,
2858
+ "proprio": {
2859
+ "max": [
2860
+ 0.0,
2861
+ 0.0,
2862
+ 0.0,
2863
+ 0.0,
2864
+ 0.0,
2865
+ 0.0,
2866
+ 0.0
2867
+ ],
2868
+ "mean": [
2869
+ 0.0,
2870
+ 0.0,
2871
+ 0.0,
2872
+ 0.0,
2873
+ 0.0,
2874
+ 0.0,
2875
+ 0.0
2876
+ ],
2877
+ "min": [
2878
+ 0.0,
2879
+ 0.0,
2880
+ 0.0,
2881
+ 0.0,
2882
+ 0.0,
2883
+ 0.0,
2884
+ 0.0
2885
+ ],
2886
+ "q01": [
2887
+ 0.0,
2888
+ 0.0,
2889
+ 0.0,
2890
+ 0.0,
2891
+ 0.0,
2892
+ 0.0,
2893
+ 0.0
2894
+ ],
2895
+ "q99": [
2896
+ 0.0,
2897
+ 0.0,
2898
+ 0.0,
2899
+ 0.0,
2900
+ 0.0,
2901
+ 0.0,
2902
+ 0.0
2903
+ ],
2904
+ "std": [
2905
+ 0.0,
2906
+ 0.0,
2907
+ 0.0,
2908
+ 0.0,
2909
+ 0.0,
2910
+ 0.0,
2911
+ 0.0
2912
+ ]
2913
+ }
2914
+ },
2915
+ "ucsd_kitchen_dataset_converted_externally_to_rlds": {
2916
+ "action": {
2917
+ "mask": [
2918
+ true,
2919
+ true,
2920
+ true,
2921
+ true,
2922
+ true,
2923
+ true,
2924
+ false
2925
+ ],
2926
+ "max": [
2927
+ 678.0,
2928
+ 400.0,
2929
+ 507.0,
2930
+ 180.00001525878906,
2931
+ 6.000013828277588,
2932
+ 116.99998474121094,
2933
+ 1.0
2934
+ ],
2935
+ "mean": [
2936
+ 410.37567138671875,
2937
+ 116.9518814086914,
2938
+ 192.35032653808594,
2939
+ -121.22441864013672,
2940
+ -33.84893035888672,
2941
+ 50.016136169433594,
2942
+ 0.741813600063324
2943
+ ],
2944
+ "min": [
2945
+ 172.0,
2946
+ -166.0,
2947
+ -99.99999237060547,
2948
+ -180.00001525878906,
2949
+ -89.0,
2950
+ -96.00010681152344,
2951
+ 0.0
2952
+ ],
2953
+ "q01": [
2954
+ 200.00001052856445,
2955
+ -102.31004211425781,
2956
+ -94.99993370056153,
2957
+ -180.00001525878906,
2958
+ -88.00001525878906,
2959
+ -38.999977111816406,
2960
+ 0.0
2961
+ ],
2962
+ "q99": [
2963
+ 637.0,
2964
+ 368.30999999999995,
2965
+ 493.0,
2966
+ 180.00001525878906,
2967
+ 0.999983012676239,
2968
+ 105.00001525878906,
2969
+ 1.0
2970
+ ],
2971
+ "std": [
2972
+ 122.81494903564453,
2973
+ 108.8009033203125,
2974
+ 130.303466796875,
2975
+ 116.28205108642578,
2976
+ 27.621843338012695,
2977
+ 41.02094650268555,
2978
+ 0.43763357400894165
2979
+ ]
2980
+ },
2981
+ "num_trajectories": 150,
2982
+ "num_transitions": 3970,
2983
+ "proprio": {
2984
+ "max": [
2985
+ 0.0,
2986
+ 0.0,
2987
+ 0.0,
2988
+ 0.0,
2989
+ 0.0,
2990
+ 0.0,
2991
+ 0.0
2992
+ ],
2993
+ "mean": [
2994
+ 0.0,
2995
+ 0.0,
2996
+ 0.0,
2997
+ 0.0,
2998
+ 0.0,
2999
+ 0.0,
3000
+ 0.0
3001
+ ],
3002
+ "min": [
3003
+ 0.0,
3004
+ 0.0,
3005
+ 0.0,
3006
+ 0.0,
3007
+ 0.0,
3008
+ 0.0,
3009
+ 0.0
3010
+ ],
3011
+ "q01": [
3012
+ 0.0,
3013
+ 0.0,
3014
+ 0.0,
3015
+ 0.0,
3016
+ 0.0,
3017
+ 0.0,
3018
+ 0.0
3019
+ ],
3020
+ "q99": [
3021
+ 0.0,
3022
+ 0.0,
3023
+ 0.0,
3024
+ 0.0,
3025
+ 0.0,
3026
+ 0.0,
3027
+ 0.0
3028
+ ],
3029
+ "std": [
3030
+ 0.0,
3031
+ 0.0,
3032
+ 0.0,
3033
+ 0.0,
3034
+ 0.0,
3035
+ 0.0,
3036
+ 0.0
3037
+ ]
3038
+ }
3039
+ },
3040
+ "utaustin_mutex": {
3041
+ "action": {
3042
+ "mask": [
3043
+ true,
3044
+ true,
3045
+ true,
3046
+ true,
3047
+ true,
3048
+ true,
3049
+ false
3050
+ ],
3051
+ "max": [
3052
+ 1.0,
3053
+ 1.0,
3054
+ 1.0,
3055
+ 0.375,
3056
+ 0.375,
3057
+ 0.375,
3058
+ 1.0
3059
+ ],
3060
+ "mean": [
3061
+ 0.06176406890153885,
3062
+ -0.005005486309528351,
3063
+ 0.10216785222291946,
3064
+ -0.03314131125807762,
3065
+ 0.013895004987716675,
3066
+ -0.011317633092403412,
3067
+ 0.5038976669311523
3068
+ ],
3069
+ "min": [
3070
+ -1.0,
3071
+ -1.0,
3072
+ -1.0,
3073
+ -0.375,
3074
+ -0.375,
3075
+ -0.375,
3076
+ 0.0
3077
+ ],
3078
+ "q01": [
3079
+ -0.4285714328289032,
3080
+ -0.9800000190734863,
3081
+ -0.5571428537368774,
3082
+ -0.375,
3083
+ -0.15642857551574707,
3084
+ -0.335357129573822,
3085
+ 0.0
3086
+ ],
3087
+ "q99": [
3088
+ 0.5914285778999329,
3089
+ 0.9714285731315613,
3090
+ 1.0,
3091
+ 0.3278571367263794,
3092
+ 0.207857146859169,
3093
+ 0.25607141852378845,
3094
+ 1.0
3095
+ ],
3096
+ "std": [
3097
+ 0.1875014752149582,
3098
+ 0.4468473494052887,
3099
+ 0.3792876601219177,
3100
+ 0.14097853004932404,
3101
+ 0.06453701853752136,
3102
+ 0.11765272170305252,
3103
+ 0.501045286655426
3104
+ ]
3105
+ },
3106
+ "num_trajectories": 1500,
3107
+ "num_transitions": 361883,
3108
+ "proprio": {
3109
+ "max": [
3110
+ 0.0,
3111
+ 0.0,
3112
+ 0.0,
3113
+ 0.0,
3114
+ 0.0,
3115
+ 0.0,
3116
+ 0.0
3117
+ ],
3118
+ "mean": [
3119
+ 0.0,
3120
+ 0.0,
3121
+ 0.0,
3122
+ 0.0,
3123
+ 0.0,
3124
+ 0.0,
3125
+ 0.0
3126
+ ],
3127
+ "min": [
3128
+ 0.0,
3129
+ 0.0,
3130
+ 0.0,
3131
+ 0.0,
3132
+ 0.0,
3133
+ 0.0,
3134
+ 0.0
3135
+ ],
3136
+ "q01": [
3137
+ 0.0,
3138
+ 0.0,
3139
+ 0.0,
3140
+ 0.0,
3141
+ 0.0,
3142
+ 0.0,
3143
+ 0.0
3144
+ ],
3145
+ "q99": [
3146
+ 0.0,
3147
+ 0.0,
3148
+ 0.0,
3149
+ 0.0,
3150
+ 0.0,
3151
+ 0.0,
3152
+ 0.0
3153
+ ],
3154
+ "std": [
3155
+ 0.0,
3156
+ 0.0,
3157
+ 0.0,
3158
+ 0.0,
3159
+ 0.0,
3160
+ 0.0,
3161
+ 0.0
3162
+ ]
3163
+ }
3164
+ },
3165
+ "viola": {
3166
+ "action": {
3167
+ "mask": [
3168
+ true,
3169
+ true,
3170
+ true,
3171
+ true,
3172
+ true,
3173
+ true,
3174
+ false
3175
+ ],
3176
+ "max": [
3177
+ 1.0,
3178
+ 1.0,
3179
+ 1.0,
3180
+ 0.375,
3181
+ 0.36321428418159485,
3182
+ 0.375,
3183
+ 1.0
3184
+ ],
3185
+ "mean": [
3186
+ 0.04761844128370285,
3187
+ -0.029204415157437325,
3188
+ 0.05586736649274826,
3189
+ -0.002618510741740465,
3190
+ 0.006867344491183758,
3191
+ -0.01682133786380291,
3192
+ 0.7323777675628662
3193
+ ],
3194
+ "min": [
3195
+ -1.0,
3196
+ -1.0,
3197
+ -1.0,
3198
+ -0.375,
3199
+ -0.375,
3200
+ -0.375,
3201
+ 0.0
3202
+ ],
3203
+ "q01": [
3204
+ -0.9628571271896362,
3205
+ -1.0,
3206
+ -1.0,
3207
+ -0.26249998807907104,
3208
+ -0.21321429312229156,
3209
+ -0.3385714292526245,
3210
+ 0.0
3211
+ ],
3212
+ "q99": [
3213
+ 0.9114285707473755,
3214
+ 0.868571400642395,
3215
+ 1.0,
3216
+ 0.2817857265472412,
3217
+ 0.2239285707473755,
3218
+ 0.3557142913341522,
3219
+ 1.0
3220
+ ],
3221
+ "std": [
3222
+ 0.39157867431640625,
3223
+ 0.4076525568962097,
3224
+ 0.40077948570251465,
3225
+ 0.10023996233940125,
3226
+ 0.0844319611787796,
3227
+ 0.10375042259693146,
3228
+ 0.44260647892951965
3229
+ ]
3230
+ },
3231
+ "num_trajectories": 150,
3232
+ "num_transitions": 76324,
3233
+ "proprio": {
3234
+ "max": [
3235
+ 0.0,
3236
+ 0.0,
3237
+ 0.0,
3238
+ 0.0,
3239
+ 0.0,
3240
+ 0.0,
3241
+ 0.0
3242
+ ],
3243
+ "mean": [
3244
+ 0.0,
3245
+ 0.0,
3246
+ 0.0,
3247
+ 0.0,
3248
+ 0.0,
3249
+ 0.0,
3250
+ 0.0
3251
+ ],
3252
+ "min": [
3253
+ 0.0,
3254
+ 0.0,
3255
+ 0.0,
3256
+ 0.0,
3257
+ 0.0,
3258
+ 0.0,
3259
+ 0.0
3260
+ ],
3261
+ "q01": [
3262
+ 0.0,
3263
+ 0.0,
3264
+ 0.0,
3265
+ 0.0,
3266
+ 0.0,
3267
+ 0.0,
3268
+ 0.0
3269
+ ],
3270
+ "q99": [
3271
+ 0.0,
3272
+ 0.0,
3273
+ 0.0,
3274
+ 0.0,
3275
+ 0.0,
3276
+ 0.0,
3277
+ 0.0
3278
+ ],
3279
+ "std": [
3280
+ 0.0,
3281
+ 0.0,
3282
+ 0.0,
3283
+ 0.0,
3284
+ 0.0,
3285
+ 0.0,
3286
+ 0.0
3287
+ ]
3288
+ }
3289
+ }
3290
+ },
3291
+ "num_action_chunks": 8,
3292
+ "num_images_in_input": 1,
3293
+ "output_projector_states": false,
3294
+ "pad_to_multiple_of": 64,
3295
+ "pad_token_id": 32000,
3296
+ "policy_setup": "widowx_bridge",
3297
+ "precision": "bf16",
3298
+ "text_config": {
3299
+ "model_type": "llama",
3300
+ "pad_token_id": 32000,
3301
+ "torch_dtype": "bfloat16",
3302
+ "vocab_size": 32064
3303
+ },
3304
+ "timm_model_ids": [
3305
+ "vit_large_patch14_reg4_dinov2.lvd142m",
3306
+ "vit_so400m_patch14_siglip_224"
3307
+ ],
3308
+ "timm_override_act_layers": [
3309
+ null,
3310
+ null
3311
+ ],
3312
+ "torch_dtype": "bfloat16",
3313
+ "transformers_version": "4.40.1",
3314
+ "trust_remote_code": true,
3315
+ "unnorm_key": "bridge_orig",
3316
+ "use_fused_vision_backbone": true,
3317
+ "use_proprio": false,
3318
+ "value_type": "chunk_level",
3319
+ "vh_mode": "a0",
3320
+ "vision_backbone_id": "dinosiglip-vit-so-224px",
3321
+ "vocab_size": 32000
3322
+ }
configuration_prismatic.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ configuration_prismatic.py
3
+
4
+ HuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`.
5
+ Default configuration specifies `siglip-224px+7b`.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ from transformers import PretrainedConfig
11
+ from transformers.models.auto import CONFIG_MAPPING
12
+
13
+ # === Utilities for Mapping Prismatic names to HF names ===
14
+ # fmt: off
15
+ VISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = {
16
+ "clip-vit-l": [224], "siglip-vit-so400m": [224], "dinov2-vit-l": [224], "in1k-vit-l": [224],
17
+
18
+ "clip-vit-l-336px": [336],
19
+ "siglip-vit-so400m-384px": [384],
20
+
21
+ "dinoclip-vit-l-336px": [336, 336],
22
+ "dinosiglip-vit-so-224px": [224, 224],
23
+ "dinosiglip-vit-so-384px": [384, 384],
24
+ }
25
+ VISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = {
26
+ "clip-vit-l": ["vit_large_patch14_clip_224.openai"],
27
+ "clip-vit-l-336px": ["vit_large_patch14_clip_336.openai"],
28
+
29
+ "dinov2-vit-l": ["vit_large_patch14_reg4_dinov2.lvd142m"],
30
+ "in1k-vit-l": ["vit_large_patch16_224.augreg_in21k_ft_in1k"],
31
+
32
+ "siglip-vit-so400m": ["vit_so400m_patch14_siglip_224"],
33
+ "siglip-vit-so400m-384px": ["vit_so400m_patch14_siglip_384"],
34
+
35
+ "dinoclip-vit-l-336px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_large_patch14_clip_336.openai"],
36
+ "dinosiglip-vit-so-224px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_224"],
37
+ "dinosiglip-vit-so-384px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_384"],
38
+ }
39
+ TIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = {
40
+ "clip-vit-l": ["quick_gelu"], "clip-vit-l-336px": ["quick_gelu"],
41
+ "dinov2-vit-l": [None], "in1k-vit-l": [None],
42
+ "siglip-vit-so400m": [None], "siglip-vit-so400m-384px": [None],
43
+ "dinoclip-vit-l-336px": [None, "quick_gelu"],
44
+ "dinosiglip-vit-so-224px": [None, None], "dinosiglip-vit-so-384px": [None, None]
45
+ }
46
+
47
+ LLM_BACKBONE_TO_HF_PATH = {
48
+ "llama2-7b-pure": "meta-llama/Llama-2-7b-hf", "llama2-13b-pure": "meta-llama/Llama-2-13b-hf",
49
+ "llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", "llama2-13b-chat": "meta-llama/Llama-2-13b-chat-hf",
50
+
51
+ "vicuna-v15-7b": "lmsys/vicuna-7b-v1.5", "vicuna-v15-13b": "lmsys/vicuna-13b-v1.5",
52
+
53
+ "mistral-v0.1-7b-pure": "mistralai/Mistral-7B-v0.1",
54
+ "mistral-v0.1-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1",
55
+
56
+ "phi-2-3b": "microsoft/phi-2",
57
+ }
58
+ LLM_BACKBONE_TO_HF_METACLASS = {
59
+ "llama2-7b-pure": "llama", "llama2-13b-pure": "llama", "llama2-7b-chat": "llama", "llama2-13b-chat": "llama",
60
+ "vicuna-v15-7b": "llama", "vicuna-v15-13b": "llama",
61
+
62
+ "mistral-v0.1-7b-pure": "mistral", "mistral-v0.1-7b-instruct": "mistral",
63
+
64
+ "phi-2-3b": "phi",
65
+ }
66
+
67
+ VALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys())
68
+ VALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH)
69
+ # fmt: on
70
+
71
+
72
+ class PrismaticConfig(PretrainedConfig):
73
+ model_type: str = "prismatic"
74
+ is_composition: bool = False
75
+
76
+ def __init__(
77
+ self,
78
+ vision_backbone_id: str = "siglip-vit-so400m",
79
+ llm_backbone_id: str = "vicuna-v15-7b",
80
+ arch_specifier: str = "no-align+gelu-mlp",
81
+ use_fused_vision_backbone: Optional[bool] = None,
82
+ image_resize_strategy: str = "letterbox",
83
+ text_config: Optional[Dict[str, Any]] = None,
84
+ llm_max_length: int = 2048,
85
+ pad_token_id: int = 32000,
86
+ pad_to_multiple_of: int = 64,
87
+ output_projector_states: bool = False,
88
+ **kwargs: str,
89
+ ) -> None:
90
+ if vision_backbone_id not in VALID_VISION_BACKBONES:
91
+ raise ValueError(f"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }")
92
+
93
+ if llm_backbone_id not in VALID_LLM_BACKBONES:
94
+ raise ValueError(f"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }")
95
+
96
+ # Set Prismatic Configuration Fields
97
+ self.vision_backbone_id = vision_backbone_id
98
+ self.llm_backbone_id = llm_backbone_id
99
+ self.arch_specifier = arch_specifier
100
+ self.output_projector_states = output_projector_states
101
+
102
+ # [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing
103
+ self.use_fused_vision_backbone = (
104
+ use_fused_vision_backbone
105
+ if use_fused_vision_backbone is not None
106
+ else any(self.vision_backbone_id.startswith(v) for v in ["dinoclip", "dinosiglip"])
107
+ )
108
+
109
+ self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id]
110
+ self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id]
111
+ self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id]
112
+ self.image_resize_strategy = image_resize_strategy
113
+
114
+ self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id]
115
+ self.llm_max_length = llm_max_length
116
+ self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of
117
+
118
+ # [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming!
119
+ self.text_config = (
120
+ CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config)
121
+ if text_config is not None
122
+ else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]()
123
+ )
124
+
125
+ # Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well...
126
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
127
+
128
+
129
+ class OpenVLAConfig(PrismaticConfig):
130
+ model_type: str = "openvla"
131
+
132
+ def __init__(
133
+ self,
134
+ norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None,
135
+ n_action_bins: int = 256,
136
+ **kwargs: str,
137
+ ) -> None:
138
+ self.norm_stats, self.n_action_bins = norm_stats, n_action_bins
139
+
140
+ super().__init__(**kwargs)
dataset_statistics.json ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "libero_10_no_noops": {
3
+ "action": {
4
+ "mean": [
5
+ 0.01820324920117855,
6
+ 0.05858374014496803,
7
+ -0.05592384561896324,
8
+ 0.004626928828656673,
9
+ 0.00289608770981431,
10
+ -0.007673131301999092,
11
+ 0.5457824468612671
12
+ ],
13
+ "std": [
14
+ 0.2825464606285095,
15
+ 0.35904666781425476,
16
+ 0.3673802614212036,
17
+ 0.03770702704787254,
18
+ 0.05429719388484955,
19
+ 0.08725254982709885,
20
+ 0.49815231561660767
21
+ ],
22
+ "max": [
23
+ 0.9375,
24
+ 0.9375,
25
+ 0.9375,
26
+ 0.30000001192092896,
27
+ 0.29357144236564636,
28
+ 0.375,
29
+ 1.0
30
+ ],
31
+ "min": [
32
+ -0.9375,
33
+ -0.9375,
34
+ -0.9375,
35
+ -0.23642857372760773,
36
+ -0.3053571283817291,
37
+ -0.3675000071525574,
38
+ 0.0
39
+ ],
40
+ "q01": [
41
+ -0.6348214149475098,
42
+ -0.7741071581840515,
43
+ -0.7633928656578064,
44
+ -0.09749999642372131,
45
+ -0.14819999992847435,
46
+ -0.2742857038974762,
47
+ 0.0
48
+ ],
49
+ "q99": [
50
+ 0.7714285850524902,
51
+ 0.8464285731315613,
52
+ 0.9375,
53
+ 0.13928571343421936,
54
+ 0.15964286029338837,
55
+ 0.3246428668498993,
56
+ 1.0
57
+ ],
58
+ "mask": [
59
+ true,
60
+ true,
61
+ true,
62
+ true,
63
+ true,
64
+ true,
65
+ false
66
+ ]
67
+ },
68
+ "proprio": {
69
+ "mean": [
70
+ -0.04190658777952194,
71
+ 0.03539430722594261,
72
+ 0.8257141709327698,
73
+ 2.908308267593384,
74
+ -0.5562185049057007,
75
+ -0.16649018228054047,
76
+ 0.028316624462604523,
77
+ -0.028561657294631004
78
+ ],
79
+ "std": [
80
+ 0.10743364691734314,
81
+ 0.14424669742584229,
82
+ 0.2572328448295593,
83
+ 0.3441362977027893,
84
+ 1.234421730041504,
85
+ 0.3579835891723633,
86
+ 0.013308707624673843,
87
+ 0.013174631632864475
88
+ ],
89
+ "max": [
90
+ 0.21031762659549713,
91
+ 0.39128610491752625,
92
+ 1.3332009315490723,
93
+ 3.6714255809783936,
94
+ 3.560650587081909,
95
+ 1.386339545249939,
96
+ 0.04160946607589722,
97
+ 0.0013633022317662835
98
+ ],
99
+ "min": [
100
+ -0.4828203022480011,
101
+ -0.3255046010017395,
102
+ 0.445506751537323,
103
+ 1.1321442127227783,
104
+ -3.641430377960205,
105
+ -1.842738389968872,
106
+ -0.0010040868073701859,
107
+ -0.04111652821302414
108
+ ],
109
+ "q01": [
110
+ -0.3899900782108307,
111
+ -0.2838300323486328,
112
+ 0.44795057058334353,
113
+ 1.8810229921340942,
114
+ -2.886677579879761,
115
+ -1.1599004411697387,
116
+ 0.002066459748893976,
117
+ -0.04001387819647789
118
+ ],
119
+ "q99": [
120
+ 0.1530261474847791,
121
+ 0.32915401458740223,
122
+ 1.2546923208236693,
123
+ 3.303542451858519,
124
+ 2.7496529006957933,
125
+ 0.6893712210655194,
126
+ 0.040048558115959164,
127
+ -0.0017598449345678235
128
+ ]
129
+ },
130
+ "num_transitions": 101469,
131
+ "num_trajectories": 379
132
+ }
133
+ }
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 32000,
6
+ "transformers_version": "4.40.1"
7
+ }
logo.svg ADDED
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8614d125a9731a370fd8c741a46ab08a39b632aa001dc51f2196e93dd856c196
3
+ size 4925122448
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:546178903b990c4a49ce959fa77a2b8701dfa04253214a6a0f323a1c909da369
3
+ size 4947392496
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fabd1cd9f6ca3c749238f3a72e9eb829d71b8402aa94c5ced96b06172cfc511c
3
+ size 4947417456
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:366d4543637c1793e1badd974129d9b484b4204a4111630423f1e7bdb30dc51e
3
+ size 266995832
model.safetensors.index.json ADDED
@@ -0,0 +1,994 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 15086801280
4
+ },
5
+ "weight_map": {
6
+ "language_model.lm_head.weight": "model-00004-of-00004.safetensors",
7
+ "language_model.model.embed_tokens.weight": "model-00001-of-00004.safetensors",
8
+ "language_model.model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
9
+ "language_model.model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
10
+ "language_model.model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
11
+ "language_model.model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
12
+ "language_model.model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
13
+ "language_model.model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
14
+ "language_model.model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
15
+ "language_model.model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
16
+ "language_model.model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
17
+ "language_model.model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
18
+ "language_model.model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
19
+ "language_model.model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
20
+ "language_model.model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
21
+ "language_model.model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
22
+ "language_model.model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
23
+ "language_model.model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
24
+ "language_model.model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
25
+ "language_model.model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
26
+ "language_model.model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
27
+ "language_model.model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
28
+ "language_model.model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
29
+ "language_model.model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
30
+ "language_model.model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
31
+ "language_model.model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
32
+ "language_model.model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
33
+ "language_model.model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
34
+ "language_model.model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
35
+ "language_model.model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
36
+ "language_model.model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
37
+ "language_model.model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
38
+ "language_model.model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
39
+ "language_model.model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
40
+ "language_model.model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
41
+ "language_model.model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
42
+ "language_model.model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
43
+ "language_model.model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
44
+ "language_model.model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
45
+ "language_model.model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
46
+ "language_model.model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
47
+ "language_model.model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
48
+ "language_model.model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
49
+ "language_model.model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
50
+ "language_model.model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
51
+ "language_model.model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
52
+ "language_model.model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
53
+ "language_model.model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
54
+ "language_model.model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
55
+ "language_model.model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
56
+ "language_model.model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
57
+ "language_model.model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
58
+ "language_model.model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
59
+ "language_model.model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
60
+ "language_model.model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
61
+ "language_model.model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
62
+ "language_model.model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
63
+ "language_model.model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
64
+ "language_model.model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
65
+ "language_model.model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
66
+ "language_model.model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
67
+ "language_model.model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
68
+ "language_model.model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
69
+ "language_model.model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
70
+ "language_model.model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
71
+ "language_model.model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
72
+ "language_model.model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
73
+ "language_model.model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
74
+ "language_model.model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
75
+ "language_model.model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
76
+ "language_model.model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
77
+ "language_model.model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
78
+ "language_model.model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
79
+ "language_model.model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
80
+ "language_model.model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
81
+ "language_model.model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
82
+ "language_model.model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
83
+ "language_model.model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
84
+ "language_model.model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
85
+ "language_model.model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
86
+ "language_model.model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
87
+ "language_model.model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
88
+ "language_model.model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
89
+ "language_model.model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
90
+ "language_model.model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
91
+ "language_model.model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
92
+ "language_model.model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
93
+ "language_model.model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
94
+ "language_model.model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
95
+ "language_model.model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
96
+ "language_model.model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
97
+ "language_model.model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
98
+ "language_model.model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
99
+ "language_model.model.layers.18.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
100
+ "language_model.model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
101
+ "language_model.model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
102
+ "language_model.model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
103
+ "language_model.model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
104
+ "language_model.model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
105
+ "language_model.model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
106
+ "language_model.model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
107
+ "language_model.model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
108
+ "language_model.model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
109
+ "language_model.model.layers.19.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
110
+ "language_model.model.layers.19.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
111
+ "language_model.model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
112
+ "language_model.model.layers.19.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
113
+ "language_model.model.layers.19.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
114
+ "language_model.model.layers.19.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
115
+ "language_model.model.layers.19.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
116
+ "language_model.model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
117
+ "language_model.model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
118
+ "language_model.model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
119
+ "language_model.model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
120
+ "language_model.model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
121
+ "language_model.model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
122
+ "language_model.model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
123
+ "language_model.model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
124
+ "language_model.model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
125
+ "language_model.model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
126
+ "language_model.model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
127
+ "language_model.model.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
128
+ "language_model.model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
129
+ "language_model.model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
130
+ "language_model.model.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
131
+ "language_model.model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
132
+ "language_model.model.layers.20.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
133
+ "language_model.model.layers.20.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
134
+ "language_model.model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
135
+ "language_model.model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
136
+ "language_model.model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
137
+ "language_model.model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
138
+ "language_model.model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
139
+ "language_model.model.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
140
+ "language_model.model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
141
+ "language_model.model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
142
+ "language_model.model.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
143
+ "language_model.model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
144
+ "language_model.model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
145
+ "language_model.model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
146
+ "language_model.model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
147
+ "language_model.model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
148
+ "language_model.model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
149
+ "language_model.model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
150
+ "language_model.model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
151
+ "language_model.model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
152
+ "language_model.model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
153
+ "language_model.model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
154
+ "language_model.model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
155
+ "language_model.model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
156
+ "language_model.model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
157
+ "language_model.model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
158
+ "language_model.model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
159
+ "language_model.model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
160
+ "language_model.model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
161
+ "language_model.model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
162
+ "language_model.model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
163
+ "language_model.model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
164
+ "language_model.model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
165
+ "language_model.model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
166
+ "language_model.model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
167
+ "language_model.model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
168
+ "language_model.model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
169
+ "language_model.model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
170
+ "language_model.model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
171
+ "language_model.model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
172
+ "language_model.model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
173
+ "language_model.model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
174
+ "language_model.model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
175
+ "language_model.model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
176
+ "language_model.model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
177
+ "language_model.model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
178
+ "language_model.model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
179
+ "language_model.model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
180
+ "language_model.model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
181
+ "language_model.model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
182
+ "language_model.model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
183
+ "language_model.model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
184
+ "language_model.model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
185
+ "language_model.model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
186
+ "language_model.model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
187
+ "language_model.model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
188
+ "language_model.model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
189
+ "language_model.model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
190
+ "language_model.model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
191
+ "language_model.model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
192
+ "language_model.model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
193
+ "language_model.model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
194
+ "language_model.model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
195
+ "language_model.model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
196
+ "language_model.model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
197
+ "language_model.model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
198
+ "language_model.model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
199
+ "language_model.model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
200
+ "language_model.model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
201
+ "language_model.model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
202
+ "language_model.model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
203
+ "language_model.model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
204
+ "language_model.model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
205
+ "language_model.model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
206
+ "language_model.model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
207
+ "language_model.model.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
208
+ "language_model.model.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
209
+ "language_model.model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
210
+ "language_model.model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
211
+ "language_model.model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
212
+ "language_model.model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
213
+ "language_model.model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
214
+ "language_model.model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
215
+ "language_model.model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
216
+ "language_model.model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
217
+ "language_model.model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
218
+ "language_model.model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
219
+ "language_model.model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
220
+ "language_model.model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
221
+ "language_model.model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
222
+ "language_model.model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
223
+ "language_model.model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
224
+ "language_model.model.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors",
225
+ "language_model.model.layers.30.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
226
+ "language_model.model.layers.30.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
227
+ "language_model.model.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
228
+ "language_model.model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
229
+ "language_model.model.layers.30.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
230
+ "language_model.model.layers.30.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
231
+ "language_model.model.layers.30.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
232
+ "language_model.model.layers.30.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
233
+ "language_model.model.layers.31.input_layernorm.weight": "model-00003-of-00004.safetensors",
234
+ "language_model.model.layers.31.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
235
+ "language_model.model.layers.31.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
236
+ "language_model.model.layers.31.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
237
+ "language_model.model.layers.31.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
238
+ "language_model.model.layers.31.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
239
+ "language_model.model.layers.31.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
240
+ "language_model.model.layers.31.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
241
+ "language_model.model.layers.31.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
242
+ "language_model.model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
243
+ "language_model.model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
244
+ "language_model.model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
245
+ "language_model.model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
246
+ "language_model.model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
247
+ "language_model.model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
248
+ "language_model.model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
249
+ "language_model.model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
250
+ "language_model.model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
251
+ "language_model.model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
252
+ "language_model.model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
253
+ "language_model.model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
254
+ "language_model.model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
255
+ "language_model.model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
256
+ "language_model.model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
257
+ "language_model.model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
258
+ "language_model.model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
259
+ "language_model.model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
260
+ "language_model.model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
261
+ "language_model.model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
262
+ "language_model.model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
263
+ "language_model.model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
264
+ "language_model.model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
265
+ "language_model.model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
266
+ "language_model.model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
267
+ "language_model.model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
268
+ "language_model.model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
269
+ "language_model.model.layers.7.input_layernorm.weight": "model-00002-of-00004.safetensors",
270
+ "language_model.model.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
271
+ "language_model.model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
272
+ "language_model.model.layers.7.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
273
+ "language_model.model.layers.7.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
274
+ "language_model.model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
275
+ "language_model.model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
276
+ "language_model.model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
277
+ "language_model.model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
278
+ "language_model.model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
279
+ "language_model.model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
280
+ "language_model.model.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
281
+ "language_model.model.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
282
+ "language_model.model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
283
+ "language_model.model.layers.8.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
284
+ "language_model.model.layers.8.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
285
+ "language_model.model.layers.8.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
286
+ "language_model.model.layers.8.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
287
+ "language_model.model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
288
+ "language_model.model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
289
+ "language_model.model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
290
+ "language_model.model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
291
+ "language_model.model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
292
+ "language_model.model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
293
+ "language_model.model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
294
+ "language_model.model.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
295
+ "language_model.model.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
296
+ "language_model.model.norm.weight": "model-00003-of-00004.safetensors",
297
+ "projector.fc1.bias": "model-00001-of-00004.safetensors",
298
+ "projector.fc1.weight": "model-00001-of-00004.safetensors",
299
+ "projector.fc2.bias": "model-00001-of-00004.safetensors",
300
+ "projector.fc2.weight": "model-00001-of-00004.safetensors",
301
+ "projector.fc3.bias": "model-00001-of-00004.safetensors",
302
+ "projector.fc3.weight": "model-00001-of-00004.safetensors",
303
+ "value_head.head_l1.bias": "model-00004-of-00004.safetensors",
304
+ "value_head.head_l1.weight": "model-00004-of-00004.safetensors",
305
+ "value_head.head_l2.bias": "model-00004-of-00004.safetensors",
306
+ "value_head.head_l2.weight": "model-00004-of-00004.safetensors",
307
+ "value_head.head_l3.weight": "model-00004-of-00004.safetensors",
308
+ "vision_backbone.featurizer.blocks.0.attn.proj.bias": "model-00001-of-00004.safetensors",
309
+ "vision_backbone.featurizer.blocks.0.attn.proj.weight": "model-00001-of-00004.safetensors",
310
+ "vision_backbone.featurizer.blocks.0.attn.qkv.bias": "model-00001-of-00004.safetensors",
311
+ "vision_backbone.featurizer.blocks.0.attn.qkv.weight": "model-00001-of-00004.safetensors",
312
+ "vision_backbone.featurizer.blocks.0.ls1.scale_factor": "model-00001-of-00004.safetensors",
313
+ "vision_backbone.featurizer.blocks.0.ls2.scale_factor": "model-00001-of-00004.safetensors",
314
+ "vision_backbone.featurizer.blocks.0.mlp.fc1.bias": "model-00001-of-00004.safetensors",
315
+ "vision_backbone.featurizer.blocks.0.mlp.fc1.weight": "model-00001-of-00004.safetensors",
316
+ "vision_backbone.featurizer.blocks.0.mlp.fc2.bias": "model-00001-of-00004.safetensors",
317
+ "vision_backbone.featurizer.blocks.0.mlp.fc2.weight": "model-00001-of-00004.safetensors",
318
+ "vision_backbone.featurizer.blocks.0.norm1.bias": "model-00001-of-00004.safetensors",
319
+ "vision_backbone.featurizer.blocks.0.norm1.weight": "model-00001-of-00004.safetensors",
320
+ "vision_backbone.featurizer.blocks.0.norm2.bias": "model-00001-of-00004.safetensors",
321
+ "vision_backbone.featurizer.blocks.0.norm2.weight": "model-00001-of-00004.safetensors",
322
+ "vision_backbone.featurizer.blocks.1.attn.proj.bias": "model-00001-of-00004.safetensors",
323
+ "vision_backbone.featurizer.blocks.1.attn.proj.weight": "model-00001-of-00004.safetensors",
324
+ "vision_backbone.featurizer.blocks.1.attn.qkv.bias": "model-00001-of-00004.safetensors",
325
+ "vision_backbone.featurizer.blocks.1.attn.qkv.weight": "model-00001-of-00004.safetensors",
326
+ "vision_backbone.featurizer.blocks.1.ls1.scale_factor": "model-00001-of-00004.safetensors",
327
+ "vision_backbone.featurizer.blocks.1.ls2.scale_factor": "model-00001-of-00004.safetensors",
328
+ "vision_backbone.featurizer.blocks.1.mlp.fc1.bias": "model-00001-of-00004.safetensors",
329
+ "vision_backbone.featurizer.blocks.1.mlp.fc1.weight": "model-00001-of-00004.safetensors",
330
+ "vision_backbone.featurizer.blocks.1.mlp.fc2.bias": "model-00001-of-00004.safetensors",
331
+ "vision_backbone.featurizer.blocks.1.mlp.fc2.weight": "model-00001-of-00004.safetensors",
332
+ "vision_backbone.featurizer.blocks.1.norm1.bias": "model-00001-of-00004.safetensors",
333
+ "vision_backbone.featurizer.blocks.1.norm1.weight": "model-00001-of-00004.safetensors",
334
+ "vision_backbone.featurizer.blocks.1.norm2.bias": "model-00001-of-00004.safetensors",
335
+ "vision_backbone.featurizer.blocks.1.norm2.weight": "model-00001-of-00004.safetensors",
336
+ "vision_backbone.featurizer.blocks.10.attn.proj.bias": "model-00001-of-00004.safetensors",
337
+ "vision_backbone.featurizer.blocks.10.attn.proj.weight": "model-00001-of-00004.safetensors",
338
+ "vision_backbone.featurizer.blocks.10.attn.qkv.bias": "model-00001-of-00004.safetensors",
339
+ "vision_backbone.featurizer.blocks.10.attn.qkv.weight": "model-00001-of-00004.safetensors",
340
+ "vision_backbone.featurizer.blocks.10.ls1.scale_factor": "model-00001-of-00004.safetensors",
341
+ "vision_backbone.featurizer.blocks.10.ls2.scale_factor": "model-00001-of-00004.safetensors",
342
+ "vision_backbone.featurizer.blocks.10.mlp.fc1.bias": "model-00001-of-00004.safetensors",
343
+ "vision_backbone.featurizer.blocks.10.mlp.fc1.weight": "model-00001-of-00004.safetensors",
344
+ "vision_backbone.featurizer.blocks.10.mlp.fc2.bias": "model-00001-of-00004.safetensors",
345
+ "vision_backbone.featurizer.blocks.10.mlp.fc2.weight": "model-00001-of-00004.safetensors",
346
+ "vision_backbone.featurizer.blocks.10.norm1.bias": "model-00001-of-00004.safetensors",
347
+ "vision_backbone.featurizer.blocks.10.norm1.weight": "model-00001-of-00004.safetensors",
348
+ "vision_backbone.featurizer.blocks.10.norm2.bias": "model-00001-of-00004.safetensors",
349
+ "vision_backbone.featurizer.blocks.10.norm2.weight": "model-00001-of-00004.safetensors",
350
+ "vision_backbone.featurizer.blocks.11.attn.proj.bias": "model-00001-of-00004.safetensors",
351
+ "vision_backbone.featurizer.blocks.11.attn.proj.weight": "model-00001-of-00004.safetensors",
352
+ "vision_backbone.featurizer.blocks.11.attn.qkv.bias": "model-00001-of-00004.safetensors",
353
+ "vision_backbone.featurizer.blocks.11.attn.qkv.weight": "model-00001-of-00004.safetensors",
354
+ "vision_backbone.featurizer.blocks.11.ls1.scale_factor": "model-00001-of-00004.safetensors",
355
+ "vision_backbone.featurizer.blocks.11.ls2.scale_factor": "model-00001-of-00004.safetensors",
356
+ "vision_backbone.featurizer.blocks.11.mlp.fc1.bias": "model-00001-of-00004.safetensors",
357
+ "vision_backbone.featurizer.blocks.11.mlp.fc1.weight": "model-00001-of-00004.safetensors",
358
+ "vision_backbone.featurizer.blocks.11.mlp.fc2.bias": "model-00001-of-00004.safetensors",
359
+ "vision_backbone.featurizer.blocks.11.mlp.fc2.weight": "model-00001-of-00004.safetensors",
360
+ "vision_backbone.featurizer.blocks.11.norm1.bias": "model-00001-of-00004.safetensors",
361
+ "vision_backbone.featurizer.blocks.11.norm1.weight": "model-00001-of-00004.safetensors",
362
+ "vision_backbone.featurizer.blocks.11.norm2.bias": "model-00001-of-00004.safetensors",
363
+ "vision_backbone.featurizer.blocks.11.norm2.weight": "model-00001-of-00004.safetensors",
364
+ "vision_backbone.featurizer.blocks.12.attn.proj.bias": "model-00001-of-00004.safetensors",
365
+ "vision_backbone.featurizer.blocks.12.attn.proj.weight": "model-00001-of-00004.safetensors",
366
+ "vision_backbone.featurizer.blocks.12.attn.qkv.bias": "model-00001-of-00004.safetensors",
367
+ "vision_backbone.featurizer.blocks.12.attn.qkv.weight": "model-00001-of-00004.safetensors",
368
+ "vision_backbone.featurizer.blocks.12.ls1.scale_factor": "model-00001-of-00004.safetensors",
369
+ "vision_backbone.featurizer.blocks.12.ls2.scale_factor": "model-00001-of-00004.safetensors",
370
+ "vision_backbone.featurizer.blocks.12.mlp.fc1.bias": "model-00001-of-00004.safetensors",
371
+ "vision_backbone.featurizer.blocks.12.mlp.fc1.weight": "model-00001-of-00004.safetensors",
372
+ "vision_backbone.featurizer.blocks.12.mlp.fc2.bias": "model-00001-of-00004.safetensors",
373
+ "vision_backbone.featurizer.blocks.12.mlp.fc2.weight": "model-00001-of-00004.safetensors",
374
+ "vision_backbone.featurizer.blocks.12.norm1.bias": "model-00001-of-00004.safetensors",
375
+ "vision_backbone.featurizer.blocks.12.norm1.weight": "model-00001-of-00004.safetensors",
376
+ "vision_backbone.featurizer.blocks.12.norm2.bias": "model-00001-of-00004.safetensors",
377
+ "vision_backbone.featurizer.blocks.12.norm2.weight": "model-00001-of-00004.safetensors",
378
+ "vision_backbone.featurizer.blocks.13.attn.proj.bias": "model-00001-of-00004.safetensors",
379
+ "vision_backbone.featurizer.blocks.13.attn.proj.weight": "model-00001-of-00004.safetensors",
380
+ "vision_backbone.featurizer.blocks.13.attn.qkv.bias": "model-00001-of-00004.safetensors",
381
+ "vision_backbone.featurizer.blocks.13.attn.qkv.weight": "model-00001-of-00004.safetensors",
382
+ "vision_backbone.featurizer.blocks.13.ls1.scale_factor": "model-00001-of-00004.safetensors",
383
+ "vision_backbone.featurizer.blocks.13.ls2.scale_factor": "model-00001-of-00004.safetensors",
384
+ "vision_backbone.featurizer.blocks.13.mlp.fc1.bias": "model-00001-of-00004.safetensors",
385
+ "vision_backbone.featurizer.blocks.13.mlp.fc1.weight": "model-00001-of-00004.safetensors",
386
+ "vision_backbone.featurizer.blocks.13.mlp.fc2.bias": "model-00001-of-00004.safetensors",
387
+ "vision_backbone.featurizer.blocks.13.mlp.fc2.weight": "model-00001-of-00004.safetensors",
388
+ "vision_backbone.featurizer.blocks.13.norm1.bias": "model-00001-of-00004.safetensors",
389
+ "vision_backbone.featurizer.blocks.13.norm1.weight": "model-00001-of-00004.safetensors",
390
+ "vision_backbone.featurizer.blocks.13.norm2.bias": "model-00001-of-00004.safetensors",
391
+ "vision_backbone.featurizer.blocks.13.norm2.weight": "model-00001-of-00004.safetensors",
392
+ "vision_backbone.featurizer.blocks.14.attn.proj.bias": "model-00001-of-00004.safetensors",
393
+ "vision_backbone.featurizer.blocks.14.attn.proj.weight": "model-00001-of-00004.safetensors",
394
+ "vision_backbone.featurizer.blocks.14.attn.qkv.bias": "model-00001-of-00004.safetensors",
395
+ "vision_backbone.featurizer.blocks.14.attn.qkv.weight": "model-00001-of-00004.safetensors",
396
+ "vision_backbone.featurizer.blocks.14.ls1.scale_factor": "model-00001-of-00004.safetensors",
397
+ "vision_backbone.featurizer.blocks.14.ls2.scale_factor": "model-00001-of-00004.safetensors",
398
+ "vision_backbone.featurizer.blocks.14.mlp.fc1.bias": "model-00001-of-00004.safetensors",
399
+ "vision_backbone.featurizer.blocks.14.mlp.fc1.weight": "model-00001-of-00004.safetensors",
400
+ "vision_backbone.featurizer.blocks.14.mlp.fc2.bias": "model-00001-of-00004.safetensors",
401
+ "vision_backbone.featurizer.blocks.14.mlp.fc2.weight": "model-00001-of-00004.safetensors",
402
+ "vision_backbone.featurizer.blocks.14.norm1.bias": "model-00001-of-00004.safetensors",
403
+ "vision_backbone.featurizer.blocks.14.norm1.weight": "model-00001-of-00004.safetensors",
404
+ "vision_backbone.featurizer.blocks.14.norm2.bias": "model-00001-of-00004.safetensors",
405
+ "vision_backbone.featurizer.blocks.14.norm2.weight": "model-00001-of-00004.safetensors",
406
+ "vision_backbone.featurizer.blocks.15.attn.proj.bias": "model-00001-of-00004.safetensors",
407
+ "vision_backbone.featurizer.blocks.15.attn.proj.weight": "model-00001-of-00004.safetensors",
408
+ "vision_backbone.featurizer.blocks.15.attn.qkv.bias": "model-00001-of-00004.safetensors",
409
+ "vision_backbone.featurizer.blocks.15.attn.qkv.weight": "model-00001-of-00004.safetensors",
410
+ "vision_backbone.featurizer.blocks.15.ls1.scale_factor": "model-00001-of-00004.safetensors",
411
+ "vision_backbone.featurizer.blocks.15.ls2.scale_factor": "model-00001-of-00004.safetensors",
412
+ "vision_backbone.featurizer.blocks.15.mlp.fc1.bias": "model-00001-of-00004.safetensors",
413
+ "vision_backbone.featurizer.blocks.15.mlp.fc1.weight": "model-00001-of-00004.safetensors",
414
+ "vision_backbone.featurizer.blocks.15.mlp.fc2.bias": "model-00001-of-00004.safetensors",
415
+ "vision_backbone.featurizer.blocks.15.mlp.fc2.weight": "model-00001-of-00004.safetensors",
416
+ "vision_backbone.featurizer.blocks.15.norm1.bias": "model-00001-of-00004.safetensors",
417
+ "vision_backbone.featurizer.blocks.15.norm1.weight": "model-00001-of-00004.safetensors",
418
+ "vision_backbone.featurizer.blocks.15.norm2.bias": "model-00001-of-00004.safetensors",
419
+ "vision_backbone.featurizer.blocks.15.norm2.weight": "model-00001-of-00004.safetensors",
420
+ "vision_backbone.featurizer.blocks.16.attn.proj.bias": "model-00001-of-00004.safetensors",
421
+ "vision_backbone.featurizer.blocks.16.attn.proj.weight": "model-00001-of-00004.safetensors",
422
+ "vision_backbone.featurizer.blocks.16.attn.qkv.bias": "model-00001-of-00004.safetensors",
423
+ "vision_backbone.featurizer.blocks.16.attn.qkv.weight": "model-00001-of-00004.safetensors",
424
+ "vision_backbone.featurizer.blocks.16.ls1.scale_factor": "model-00001-of-00004.safetensors",
425
+ "vision_backbone.featurizer.blocks.16.ls2.scale_factor": "model-00001-of-00004.safetensors",
426
+ "vision_backbone.featurizer.blocks.16.mlp.fc1.bias": "model-00001-of-00004.safetensors",
427
+ "vision_backbone.featurizer.blocks.16.mlp.fc1.weight": "model-00001-of-00004.safetensors",
428
+ "vision_backbone.featurizer.blocks.16.mlp.fc2.bias": "model-00001-of-00004.safetensors",
429
+ "vision_backbone.featurizer.blocks.16.mlp.fc2.weight": "model-00001-of-00004.safetensors",
430
+ "vision_backbone.featurizer.blocks.16.norm1.bias": "model-00001-of-00004.safetensors",
431
+ "vision_backbone.featurizer.blocks.16.norm1.weight": "model-00001-of-00004.safetensors",
432
+ "vision_backbone.featurizer.blocks.16.norm2.bias": "model-00001-of-00004.safetensors",
433
+ "vision_backbone.featurizer.blocks.16.norm2.weight": "model-00001-of-00004.safetensors",
434
+ "vision_backbone.featurizer.blocks.17.attn.proj.bias": "model-00001-of-00004.safetensors",
435
+ "vision_backbone.featurizer.blocks.17.attn.proj.weight": "model-00001-of-00004.safetensors",
436
+ "vision_backbone.featurizer.blocks.17.attn.qkv.bias": "model-00001-of-00004.safetensors",
437
+ "vision_backbone.featurizer.blocks.17.attn.qkv.weight": "model-00001-of-00004.safetensors",
438
+ "vision_backbone.featurizer.blocks.17.ls1.scale_factor": "model-00001-of-00004.safetensors",
439
+ "vision_backbone.featurizer.blocks.17.ls2.scale_factor": "model-00001-of-00004.safetensors",
440
+ "vision_backbone.featurizer.blocks.17.mlp.fc1.bias": "model-00001-of-00004.safetensors",
441
+ "vision_backbone.featurizer.blocks.17.mlp.fc1.weight": "model-00001-of-00004.safetensors",
442
+ "vision_backbone.featurizer.blocks.17.mlp.fc2.bias": "model-00001-of-00004.safetensors",
443
+ "vision_backbone.featurizer.blocks.17.mlp.fc2.weight": "model-00001-of-00004.safetensors",
444
+ "vision_backbone.featurizer.blocks.17.norm1.bias": "model-00001-of-00004.safetensors",
445
+ "vision_backbone.featurizer.blocks.17.norm1.weight": "model-00001-of-00004.safetensors",
446
+ "vision_backbone.featurizer.blocks.17.norm2.bias": "model-00001-of-00004.safetensors",
447
+ "vision_backbone.featurizer.blocks.17.norm2.weight": "model-00001-of-00004.safetensors",
448
+ "vision_backbone.featurizer.blocks.18.attn.proj.bias": "model-00001-of-00004.safetensors",
449
+ "vision_backbone.featurizer.blocks.18.attn.proj.weight": "model-00001-of-00004.safetensors",
450
+ "vision_backbone.featurizer.blocks.18.attn.qkv.bias": "model-00001-of-00004.safetensors",
451
+ "vision_backbone.featurizer.blocks.18.attn.qkv.weight": "model-00001-of-00004.safetensors",
452
+ "vision_backbone.featurizer.blocks.18.ls1.scale_factor": "model-00001-of-00004.safetensors",
453
+ "vision_backbone.featurizer.blocks.18.ls2.scale_factor": "model-00001-of-00004.safetensors",
454
+ "vision_backbone.featurizer.blocks.18.mlp.fc1.bias": "model-00001-of-00004.safetensors",
455
+ "vision_backbone.featurizer.blocks.18.mlp.fc1.weight": "model-00001-of-00004.safetensors",
456
+ "vision_backbone.featurizer.blocks.18.mlp.fc2.bias": "model-00001-of-00004.safetensors",
457
+ "vision_backbone.featurizer.blocks.18.mlp.fc2.weight": "model-00001-of-00004.safetensors",
458
+ "vision_backbone.featurizer.blocks.18.norm1.bias": "model-00001-of-00004.safetensors",
459
+ "vision_backbone.featurizer.blocks.18.norm1.weight": "model-00001-of-00004.safetensors",
460
+ "vision_backbone.featurizer.blocks.18.norm2.bias": "model-00001-of-00004.safetensors",
461
+ "vision_backbone.featurizer.blocks.18.norm2.weight": "model-00001-of-00004.safetensors",
462
+ "vision_backbone.featurizer.blocks.19.attn.proj.bias": "model-00001-of-00004.safetensors",
463
+ "vision_backbone.featurizer.blocks.19.attn.proj.weight": "model-00001-of-00004.safetensors",
464
+ "vision_backbone.featurizer.blocks.19.attn.qkv.bias": "model-00001-of-00004.safetensors",
465
+ "vision_backbone.featurizer.blocks.19.attn.qkv.weight": "model-00001-of-00004.safetensors",
466
+ "vision_backbone.featurizer.blocks.19.ls1.scale_factor": "model-00001-of-00004.safetensors",
467
+ "vision_backbone.featurizer.blocks.19.ls2.scale_factor": "model-00001-of-00004.safetensors",
468
+ "vision_backbone.featurizer.blocks.19.mlp.fc1.bias": "model-00001-of-00004.safetensors",
469
+ "vision_backbone.featurizer.blocks.19.mlp.fc1.weight": "model-00001-of-00004.safetensors",
470
+ "vision_backbone.featurizer.blocks.19.mlp.fc2.bias": "model-00001-of-00004.safetensors",
471
+ "vision_backbone.featurizer.blocks.19.mlp.fc2.weight": "model-00001-of-00004.safetensors",
472
+ "vision_backbone.featurizer.blocks.19.norm1.bias": "model-00001-of-00004.safetensors",
473
+ "vision_backbone.featurizer.blocks.19.norm1.weight": "model-00001-of-00004.safetensors",
474
+ "vision_backbone.featurizer.blocks.19.norm2.bias": "model-00001-of-00004.safetensors",
475
+ "vision_backbone.featurizer.blocks.19.norm2.weight": "model-00001-of-00004.safetensors",
476
+ "vision_backbone.featurizer.blocks.2.attn.proj.bias": "model-00001-of-00004.safetensors",
477
+ "vision_backbone.featurizer.blocks.2.attn.proj.weight": "model-00001-of-00004.safetensors",
478
+ "vision_backbone.featurizer.blocks.2.attn.qkv.bias": "model-00001-of-00004.safetensors",
479
+ "vision_backbone.featurizer.blocks.2.attn.qkv.weight": "model-00001-of-00004.safetensors",
480
+ "vision_backbone.featurizer.blocks.2.ls1.scale_factor": "model-00001-of-00004.safetensors",
481
+ "vision_backbone.featurizer.blocks.2.ls2.scale_factor": "model-00001-of-00004.safetensors",
482
+ "vision_backbone.featurizer.blocks.2.mlp.fc1.bias": "model-00001-of-00004.safetensors",
483
+ "vision_backbone.featurizer.blocks.2.mlp.fc1.weight": "model-00001-of-00004.safetensors",
484
+ "vision_backbone.featurizer.blocks.2.mlp.fc2.bias": "model-00001-of-00004.safetensors",
485
+ "vision_backbone.featurizer.blocks.2.mlp.fc2.weight": "model-00001-of-00004.safetensors",
486
+ "vision_backbone.featurizer.blocks.2.norm1.bias": "model-00001-of-00004.safetensors",
487
+ "vision_backbone.featurizer.blocks.2.norm1.weight": "model-00001-of-00004.safetensors",
488
+ "vision_backbone.featurizer.blocks.2.norm2.bias": "model-00001-of-00004.safetensors",
489
+ "vision_backbone.featurizer.blocks.2.norm2.weight": "model-00001-of-00004.safetensors",
490
+ "vision_backbone.featurizer.blocks.20.attn.proj.bias": "model-00001-of-00004.safetensors",
491
+ "vision_backbone.featurizer.blocks.20.attn.proj.weight": "model-00001-of-00004.safetensors",
492
+ "vision_backbone.featurizer.blocks.20.attn.qkv.bias": "model-00001-of-00004.safetensors",
493
+ "vision_backbone.featurizer.blocks.20.attn.qkv.weight": "model-00001-of-00004.safetensors",
494
+ "vision_backbone.featurizer.blocks.20.ls1.scale_factor": "model-00001-of-00004.safetensors",
495
+ "vision_backbone.featurizer.blocks.20.ls2.scale_factor": "model-00001-of-00004.safetensors",
496
+ "vision_backbone.featurizer.blocks.20.mlp.fc1.bias": "model-00001-of-00004.safetensors",
497
+ "vision_backbone.featurizer.blocks.20.mlp.fc1.weight": "model-00001-of-00004.safetensors",
498
+ "vision_backbone.featurizer.blocks.20.mlp.fc2.bias": "model-00001-of-00004.safetensors",
499
+ "vision_backbone.featurizer.blocks.20.mlp.fc2.weight": "model-00001-of-00004.safetensors",
500
+ "vision_backbone.featurizer.blocks.20.norm1.bias": "model-00001-of-00004.safetensors",
501
+ "vision_backbone.featurizer.blocks.20.norm1.weight": "model-00001-of-00004.safetensors",
502
+ "vision_backbone.featurizer.blocks.20.norm2.bias": "model-00001-of-00004.safetensors",
503
+ "vision_backbone.featurizer.blocks.20.norm2.weight": "model-00001-of-00004.safetensors",
504
+ "vision_backbone.featurizer.blocks.21.attn.proj.bias": "model-00001-of-00004.safetensors",
505
+ "vision_backbone.featurizer.blocks.21.attn.proj.weight": "model-00001-of-00004.safetensors",
506
+ "vision_backbone.featurizer.blocks.21.attn.qkv.bias": "model-00001-of-00004.safetensors",
507
+ "vision_backbone.featurizer.blocks.21.attn.qkv.weight": "model-00001-of-00004.safetensors",
508
+ "vision_backbone.featurizer.blocks.21.ls1.scale_factor": "model-00001-of-00004.safetensors",
509
+ "vision_backbone.featurizer.blocks.21.ls2.scale_factor": "model-00001-of-00004.safetensors",
510
+ "vision_backbone.featurizer.blocks.21.mlp.fc1.bias": "model-00001-of-00004.safetensors",
511
+ "vision_backbone.featurizer.blocks.21.mlp.fc1.weight": "model-00001-of-00004.safetensors",
512
+ "vision_backbone.featurizer.blocks.21.mlp.fc2.bias": "model-00001-of-00004.safetensors",
513
+ "vision_backbone.featurizer.blocks.21.mlp.fc2.weight": "model-00001-of-00004.safetensors",
514
+ "vision_backbone.featurizer.blocks.21.norm1.bias": "model-00001-of-00004.safetensors",
515
+ "vision_backbone.featurizer.blocks.21.norm1.weight": "model-00001-of-00004.safetensors",
516
+ "vision_backbone.featurizer.blocks.21.norm2.bias": "model-00001-of-00004.safetensors",
517
+ "vision_backbone.featurizer.blocks.21.norm2.weight": "model-00001-of-00004.safetensors",
518
+ "vision_backbone.featurizer.blocks.22.attn.proj.bias": "model-00001-of-00004.safetensors",
519
+ "vision_backbone.featurizer.blocks.22.attn.proj.weight": "model-00001-of-00004.safetensors",
520
+ "vision_backbone.featurizer.blocks.22.attn.qkv.bias": "model-00001-of-00004.safetensors",
521
+ "vision_backbone.featurizer.blocks.22.attn.qkv.weight": "model-00001-of-00004.safetensors",
522
+ "vision_backbone.featurizer.blocks.22.ls1.scale_factor": "model-00001-of-00004.safetensors",
523
+ "vision_backbone.featurizer.blocks.22.ls2.scale_factor": "model-00001-of-00004.safetensors",
524
+ "vision_backbone.featurizer.blocks.22.mlp.fc1.bias": "model-00001-of-00004.safetensors",
525
+ "vision_backbone.featurizer.blocks.22.mlp.fc1.weight": "model-00001-of-00004.safetensors",
526
+ "vision_backbone.featurizer.blocks.22.mlp.fc2.bias": "model-00001-of-00004.safetensors",
527
+ "vision_backbone.featurizer.blocks.22.mlp.fc2.weight": "model-00001-of-00004.safetensors",
528
+ "vision_backbone.featurizer.blocks.22.norm1.bias": "model-00001-of-00004.safetensors",
529
+ "vision_backbone.featurizer.blocks.22.norm1.weight": "model-00001-of-00004.safetensors",
530
+ "vision_backbone.featurizer.blocks.22.norm2.bias": "model-00001-of-00004.safetensors",
531
+ "vision_backbone.featurizer.blocks.22.norm2.weight": "model-00001-of-00004.safetensors",
532
+ "vision_backbone.featurizer.blocks.23.attn.proj.bias": "model-00001-of-00004.safetensors",
533
+ "vision_backbone.featurizer.blocks.23.attn.proj.weight": "model-00001-of-00004.safetensors",
534
+ "vision_backbone.featurizer.blocks.23.attn.qkv.bias": "model-00001-of-00004.safetensors",
535
+ "vision_backbone.featurizer.blocks.23.attn.qkv.weight": "model-00001-of-00004.safetensors",
536
+ "vision_backbone.featurizer.blocks.23.ls1.scale_factor": "model-00001-of-00004.safetensors",
537
+ "vision_backbone.featurizer.blocks.23.ls2.scale_factor": "model-00001-of-00004.safetensors",
538
+ "vision_backbone.featurizer.blocks.23.mlp.fc1.bias": "model-00001-of-00004.safetensors",
539
+ "vision_backbone.featurizer.blocks.23.mlp.fc1.weight": "model-00001-of-00004.safetensors",
540
+ "vision_backbone.featurizer.blocks.23.mlp.fc2.bias": "model-00001-of-00004.safetensors",
541
+ "vision_backbone.featurizer.blocks.23.mlp.fc2.weight": "model-00001-of-00004.safetensors",
542
+ "vision_backbone.featurizer.blocks.23.norm1.bias": "model-00001-of-00004.safetensors",
543
+ "vision_backbone.featurizer.blocks.23.norm1.weight": "model-00001-of-00004.safetensors",
544
+ "vision_backbone.featurizer.blocks.23.norm2.bias": "model-00001-of-00004.safetensors",
545
+ "vision_backbone.featurizer.blocks.23.norm2.weight": "model-00001-of-00004.safetensors",
546
+ "vision_backbone.featurizer.blocks.3.attn.proj.bias": "model-00001-of-00004.safetensors",
547
+ "vision_backbone.featurizer.blocks.3.attn.proj.weight": "model-00001-of-00004.safetensors",
548
+ "vision_backbone.featurizer.blocks.3.attn.qkv.bias": "model-00001-of-00004.safetensors",
549
+ "vision_backbone.featurizer.blocks.3.attn.qkv.weight": "model-00001-of-00004.safetensors",
550
+ "vision_backbone.featurizer.blocks.3.ls1.scale_factor": "model-00001-of-00004.safetensors",
551
+ "vision_backbone.featurizer.blocks.3.ls2.scale_factor": "model-00001-of-00004.safetensors",
552
+ "vision_backbone.featurizer.blocks.3.mlp.fc1.bias": "model-00001-of-00004.safetensors",
553
+ "vision_backbone.featurizer.blocks.3.mlp.fc1.weight": "model-00001-of-00004.safetensors",
554
+ "vision_backbone.featurizer.blocks.3.mlp.fc2.bias": "model-00001-of-00004.safetensors",
555
+ "vision_backbone.featurizer.blocks.3.mlp.fc2.weight": "model-00001-of-00004.safetensors",
556
+ "vision_backbone.featurizer.blocks.3.norm1.bias": "model-00001-of-00004.safetensors",
557
+ "vision_backbone.featurizer.blocks.3.norm1.weight": "model-00001-of-00004.safetensors",
558
+ "vision_backbone.featurizer.blocks.3.norm2.bias": "model-00001-of-00004.safetensors",
559
+ "vision_backbone.featurizer.blocks.3.norm2.weight": "model-00001-of-00004.safetensors",
560
+ "vision_backbone.featurizer.blocks.4.attn.proj.bias": "model-00001-of-00004.safetensors",
561
+ "vision_backbone.featurizer.blocks.4.attn.proj.weight": "model-00001-of-00004.safetensors",
562
+ "vision_backbone.featurizer.blocks.4.attn.qkv.bias": "model-00001-of-00004.safetensors",
563
+ "vision_backbone.featurizer.blocks.4.attn.qkv.weight": "model-00001-of-00004.safetensors",
564
+ "vision_backbone.featurizer.blocks.4.ls1.scale_factor": "model-00001-of-00004.safetensors",
565
+ "vision_backbone.featurizer.blocks.4.ls2.scale_factor": "model-00001-of-00004.safetensors",
566
+ "vision_backbone.featurizer.blocks.4.mlp.fc1.bias": "model-00001-of-00004.safetensors",
567
+ "vision_backbone.featurizer.blocks.4.mlp.fc1.weight": "model-00001-of-00004.safetensors",
568
+ "vision_backbone.featurizer.blocks.4.mlp.fc2.bias": "model-00001-of-00004.safetensors",
569
+ "vision_backbone.featurizer.blocks.4.mlp.fc2.weight": "model-00001-of-00004.safetensors",
570
+ "vision_backbone.featurizer.blocks.4.norm1.bias": "model-00001-of-00004.safetensors",
571
+ "vision_backbone.featurizer.blocks.4.norm1.weight": "model-00001-of-00004.safetensors",
572
+ "vision_backbone.featurizer.blocks.4.norm2.bias": "model-00001-of-00004.safetensors",
573
+ "vision_backbone.featurizer.blocks.4.norm2.weight": "model-00001-of-00004.safetensors",
574
+ "vision_backbone.featurizer.blocks.5.attn.proj.bias": "model-00001-of-00004.safetensors",
575
+ "vision_backbone.featurizer.blocks.5.attn.proj.weight": "model-00001-of-00004.safetensors",
576
+ "vision_backbone.featurizer.blocks.5.attn.qkv.bias": "model-00001-of-00004.safetensors",
577
+ "vision_backbone.featurizer.blocks.5.attn.qkv.weight": "model-00001-of-00004.safetensors",
578
+ "vision_backbone.featurizer.blocks.5.ls1.scale_factor": "model-00001-of-00004.safetensors",
579
+ "vision_backbone.featurizer.blocks.5.ls2.scale_factor": "model-00001-of-00004.safetensors",
580
+ "vision_backbone.featurizer.blocks.5.mlp.fc1.bias": "model-00001-of-00004.safetensors",
581
+ "vision_backbone.featurizer.blocks.5.mlp.fc1.weight": "model-00001-of-00004.safetensors",
582
+ "vision_backbone.featurizer.blocks.5.mlp.fc2.bias": "model-00001-of-00004.safetensors",
583
+ "vision_backbone.featurizer.blocks.5.mlp.fc2.weight": "model-00001-of-00004.safetensors",
584
+ "vision_backbone.featurizer.blocks.5.norm1.bias": "model-00001-of-00004.safetensors",
585
+ "vision_backbone.featurizer.blocks.5.norm1.weight": "model-00001-of-00004.safetensors",
586
+ "vision_backbone.featurizer.blocks.5.norm2.bias": "model-00001-of-00004.safetensors",
587
+ "vision_backbone.featurizer.blocks.5.norm2.weight": "model-00001-of-00004.safetensors",
588
+ "vision_backbone.featurizer.blocks.6.attn.proj.bias": "model-00001-of-00004.safetensors",
589
+ "vision_backbone.featurizer.blocks.6.attn.proj.weight": "model-00001-of-00004.safetensors",
590
+ "vision_backbone.featurizer.blocks.6.attn.qkv.bias": "model-00001-of-00004.safetensors",
591
+ "vision_backbone.featurizer.blocks.6.attn.qkv.weight": "model-00001-of-00004.safetensors",
592
+ "vision_backbone.featurizer.blocks.6.ls1.scale_factor": "model-00001-of-00004.safetensors",
593
+ "vision_backbone.featurizer.blocks.6.ls2.scale_factor": "model-00001-of-00004.safetensors",
594
+ "vision_backbone.featurizer.blocks.6.mlp.fc1.bias": "model-00001-of-00004.safetensors",
595
+ "vision_backbone.featurizer.blocks.6.mlp.fc1.weight": "model-00001-of-00004.safetensors",
596
+ "vision_backbone.featurizer.blocks.6.mlp.fc2.bias": "model-00001-of-00004.safetensors",
597
+ "vision_backbone.featurizer.blocks.6.mlp.fc2.weight": "model-00001-of-00004.safetensors",
598
+ "vision_backbone.featurizer.blocks.6.norm1.bias": "model-00001-of-00004.safetensors",
599
+ "vision_backbone.featurizer.blocks.6.norm1.weight": "model-00001-of-00004.safetensors",
600
+ "vision_backbone.featurizer.blocks.6.norm2.bias": "model-00001-of-00004.safetensors",
601
+ "vision_backbone.featurizer.blocks.6.norm2.weight": "model-00001-of-00004.safetensors",
602
+ "vision_backbone.featurizer.blocks.7.attn.proj.bias": "model-00001-of-00004.safetensors",
603
+ "vision_backbone.featurizer.blocks.7.attn.proj.weight": "model-00001-of-00004.safetensors",
604
+ "vision_backbone.featurizer.blocks.7.attn.qkv.bias": "model-00001-of-00004.safetensors",
605
+ "vision_backbone.featurizer.blocks.7.attn.qkv.weight": "model-00001-of-00004.safetensors",
606
+ "vision_backbone.featurizer.blocks.7.ls1.scale_factor": "model-00001-of-00004.safetensors",
607
+ "vision_backbone.featurizer.blocks.7.ls2.scale_factor": "model-00001-of-00004.safetensors",
608
+ "vision_backbone.featurizer.blocks.7.mlp.fc1.bias": "model-00001-of-00004.safetensors",
609
+ "vision_backbone.featurizer.blocks.7.mlp.fc1.weight": "model-00001-of-00004.safetensors",
610
+ "vision_backbone.featurizer.blocks.7.mlp.fc2.bias": "model-00001-of-00004.safetensors",
611
+ "vision_backbone.featurizer.blocks.7.mlp.fc2.weight": "model-00001-of-00004.safetensors",
612
+ "vision_backbone.featurizer.blocks.7.norm1.bias": "model-00001-of-00004.safetensors",
613
+ "vision_backbone.featurizer.blocks.7.norm1.weight": "model-00001-of-00004.safetensors",
614
+ "vision_backbone.featurizer.blocks.7.norm2.bias": "model-00001-of-00004.safetensors",
615
+ "vision_backbone.featurizer.blocks.7.norm2.weight": "model-00001-of-00004.safetensors",
616
+ "vision_backbone.featurizer.blocks.8.attn.proj.bias": "model-00001-of-00004.safetensors",
617
+ "vision_backbone.featurizer.blocks.8.attn.proj.weight": "model-00001-of-00004.safetensors",
618
+ "vision_backbone.featurizer.blocks.8.attn.qkv.bias": "model-00001-of-00004.safetensors",
619
+ "vision_backbone.featurizer.blocks.8.attn.qkv.weight": "model-00001-of-00004.safetensors",
620
+ "vision_backbone.featurizer.blocks.8.ls1.scale_factor": "model-00001-of-00004.safetensors",
621
+ "vision_backbone.featurizer.blocks.8.ls2.scale_factor": "model-00001-of-00004.safetensors",
622
+ "vision_backbone.featurizer.blocks.8.mlp.fc1.bias": "model-00001-of-00004.safetensors",
623
+ "vision_backbone.featurizer.blocks.8.mlp.fc1.weight": "model-00001-of-00004.safetensors",
624
+ "vision_backbone.featurizer.blocks.8.mlp.fc2.bias": "model-00001-of-00004.safetensors",
625
+ "vision_backbone.featurizer.blocks.8.mlp.fc2.weight": "model-00001-of-00004.safetensors",
626
+ "vision_backbone.featurizer.blocks.8.norm1.bias": "model-00001-of-00004.safetensors",
627
+ "vision_backbone.featurizer.blocks.8.norm1.weight": "model-00001-of-00004.safetensors",
628
+ "vision_backbone.featurizer.blocks.8.norm2.bias": "model-00001-of-00004.safetensors",
629
+ "vision_backbone.featurizer.blocks.8.norm2.weight": "model-00001-of-00004.safetensors",
630
+ "vision_backbone.featurizer.blocks.9.attn.proj.bias": "model-00001-of-00004.safetensors",
631
+ "vision_backbone.featurizer.blocks.9.attn.proj.weight": "model-00001-of-00004.safetensors",
632
+ "vision_backbone.featurizer.blocks.9.attn.qkv.bias": "model-00001-of-00004.safetensors",
633
+ "vision_backbone.featurizer.blocks.9.attn.qkv.weight": "model-00001-of-00004.safetensors",
634
+ "vision_backbone.featurizer.blocks.9.ls1.scale_factor": "model-00001-of-00004.safetensors",
635
+ "vision_backbone.featurizer.blocks.9.ls2.scale_factor": "model-00001-of-00004.safetensors",
636
+ "vision_backbone.featurizer.blocks.9.mlp.fc1.bias": "model-00001-of-00004.safetensors",
637
+ "vision_backbone.featurizer.blocks.9.mlp.fc1.weight": "model-00001-of-00004.safetensors",
638
+ "vision_backbone.featurizer.blocks.9.mlp.fc2.bias": "model-00001-of-00004.safetensors",
639
+ "vision_backbone.featurizer.blocks.9.mlp.fc2.weight": "model-00001-of-00004.safetensors",
640
+ "vision_backbone.featurizer.blocks.9.norm1.bias": "model-00001-of-00004.safetensors",
641
+ "vision_backbone.featurizer.blocks.9.norm1.weight": "model-00001-of-00004.safetensors",
642
+ "vision_backbone.featurizer.blocks.9.norm2.bias": "model-00001-of-00004.safetensors",
643
+ "vision_backbone.featurizer.blocks.9.norm2.weight": "model-00001-of-00004.safetensors",
644
+ "vision_backbone.featurizer.cls_token": "model-00001-of-00004.safetensors",
645
+ "vision_backbone.featurizer.norm.bias": "model-00001-of-00004.safetensors",
646
+ "vision_backbone.featurizer.norm.weight": "model-00001-of-00004.safetensors",
647
+ "vision_backbone.featurizer.patch_embed.proj.bias": "model-00001-of-00004.safetensors",
648
+ "vision_backbone.featurizer.patch_embed.proj.weight": "model-00001-of-00004.safetensors",
649
+ "vision_backbone.featurizer.pos_embed": "model-00001-of-00004.safetensors",
650
+ "vision_backbone.featurizer.reg_token": "model-00001-of-00004.safetensors",
651
+ "vision_backbone.fused_featurizer.attn_pool.kv.bias": "model-00001-of-00004.safetensors",
652
+ "vision_backbone.fused_featurizer.attn_pool.kv.weight": "model-00001-of-00004.safetensors",
653
+ "vision_backbone.fused_featurizer.attn_pool.latent": "model-00001-of-00004.safetensors",
654
+ "vision_backbone.fused_featurizer.attn_pool.mlp.fc1.bias": "model-00001-of-00004.safetensors",
655
+ "vision_backbone.fused_featurizer.attn_pool.mlp.fc1.weight": "model-00001-of-00004.safetensors",
656
+ "vision_backbone.fused_featurizer.attn_pool.mlp.fc2.bias": "model-00001-of-00004.safetensors",
657
+ "vision_backbone.fused_featurizer.attn_pool.mlp.fc2.weight": "model-00001-of-00004.safetensors",
658
+ "vision_backbone.fused_featurizer.attn_pool.norm.bias": "model-00001-of-00004.safetensors",
659
+ "vision_backbone.fused_featurizer.attn_pool.norm.weight": "model-00001-of-00004.safetensors",
660
+ "vision_backbone.fused_featurizer.attn_pool.proj.bias": "model-00001-of-00004.safetensors",
661
+ "vision_backbone.fused_featurizer.attn_pool.proj.weight": "model-00001-of-00004.safetensors",
662
+ "vision_backbone.fused_featurizer.attn_pool.q.bias": "model-00001-of-00004.safetensors",
663
+ "vision_backbone.fused_featurizer.attn_pool.q.weight": "model-00001-of-00004.safetensors",
664
+ "vision_backbone.fused_featurizer.blocks.0.attn.proj.bias": "model-00001-of-00004.safetensors",
665
+ "vision_backbone.fused_featurizer.blocks.0.attn.proj.weight": "model-00001-of-00004.safetensors",
666
+ "vision_backbone.fused_featurizer.blocks.0.attn.qkv.bias": "model-00001-of-00004.safetensors",
667
+ "vision_backbone.fused_featurizer.blocks.0.attn.qkv.weight": "model-00001-of-00004.safetensors",
668
+ "vision_backbone.fused_featurizer.blocks.0.mlp.fc1.bias": "model-00001-of-00004.safetensors",
669
+ "vision_backbone.fused_featurizer.blocks.0.mlp.fc1.weight": "model-00001-of-00004.safetensors",
670
+ "vision_backbone.fused_featurizer.blocks.0.mlp.fc2.bias": "model-00001-of-00004.safetensors",
671
+ "vision_backbone.fused_featurizer.blocks.0.mlp.fc2.weight": "model-00001-of-00004.safetensors",
672
+ "vision_backbone.fused_featurizer.blocks.0.norm1.bias": "model-00001-of-00004.safetensors",
673
+ "vision_backbone.fused_featurizer.blocks.0.norm1.weight": "model-00001-of-00004.safetensors",
674
+ "vision_backbone.fused_featurizer.blocks.0.norm2.bias": "model-00001-of-00004.safetensors",
675
+ "vision_backbone.fused_featurizer.blocks.0.norm2.weight": "model-00001-of-00004.safetensors",
676
+ "vision_backbone.fused_featurizer.blocks.1.attn.proj.bias": "model-00001-of-00004.safetensors",
677
+ "vision_backbone.fused_featurizer.blocks.1.attn.proj.weight": "model-00001-of-00004.safetensors",
678
+ "vision_backbone.fused_featurizer.blocks.1.attn.qkv.bias": "model-00001-of-00004.safetensors",
679
+ "vision_backbone.fused_featurizer.blocks.1.attn.qkv.weight": "model-00001-of-00004.safetensors",
680
+ "vision_backbone.fused_featurizer.blocks.1.mlp.fc1.bias": "model-00001-of-00004.safetensors",
681
+ "vision_backbone.fused_featurizer.blocks.1.mlp.fc1.weight": "model-00001-of-00004.safetensors",
682
+ "vision_backbone.fused_featurizer.blocks.1.mlp.fc2.bias": "model-00001-of-00004.safetensors",
683
+ "vision_backbone.fused_featurizer.blocks.1.mlp.fc2.weight": "model-00001-of-00004.safetensors",
684
+ "vision_backbone.fused_featurizer.blocks.1.norm1.bias": "model-00001-of-00004.safetensors",
685
+ "vision_backbone.fused_featurizer.blocks.1.norm1.weight": "model-00001-of-00004.safetensors",
686
+ "vision_backbone.fused_featurizer.blocks.1.norm2.bias": "model-00001-of-00004.safetensors",
687
+ "vision_backbone.fused_featurizer.blocks.1.norm2.weight": "model-00001-of-00004.safetensors",
688
+ "vision_backbone.fused_featurizer.blocks.10.attn.proj.bias": "model-00001-of-00004.safetensors",
689
+ "vision_backbone.fused_featurizer.blocks.10.attn.proj.weight": "model-00001-of-00004.safetensors",
690
+ "vision_backbone.fused_featurizer.blocks.10.attn.qkv.bias": "model-00001-of-00004.safetensors",
691
+ "vision_backbone.fused_featurizer.blocks.10.attn.qkv.weight": "model-00001-of-00004.safetensors",
692
+ "vision_backbone.fused_featurizer.blocks.10.mlp.fc1.bias": "model-00001-of-00004.safetensors",
693
+ "vision_backbone.fused_featurizer.blocks.10.mlp.fc1.weight": "model-00001-of-00004.safetensors",
694
+ "vision_backbone.fused_featurizer.blocks.10.mlp.fc2.bias": "model-00001-of-00004.safetensors",
695
+ "vision_backbone.fused_featurizer.blocks.10.mlp.fc2.weight": "model-00001-of-00004.safetensors",
696
+ "vision_backbone.fused_featurizer.blocks.10.norm1.bias": "model-00001-of-00004.safetensors",
697
+ "vision_backbone.fused_featurizer.blocks.10.norm1.weight": "model-00001-of-00004.safetensors",
698
+ "vision_backbone.fused_featurizer.blocks.10.norm2.bias": "model-00001-of-00004.safetensors",
699
+ "vision_backbone.fused_featurizer.blocks.10.norm2.weight": "model-00001-of-00004.safetensors",
700
+ "vision_backbone.fused_featurizer.blocks.11.attn.proj.bias": "model-00001-of-00004.safetensors",
701
+ "vision_backbone.fused_featurizer.blocks.11.attn.proj.weight": "model-00001-of-00004.safetensors",
702
+ "vision_backbone.fused_featurizer.blocks.11.attn.qkv.bias": "model-00001-of-00004.safetensors",
703
+ "vision_backbone.fused_featurizer.blocks.11.attn.qkv.weight": "model-00001-of-00004.safetensors",
704
+ "vision_backbone.fused_featurizer.blocks.11.mlp.fc1.bias": "model-00001-of-00004.safetensors",
705
+ "vision_backbone.fused_featurizer.blocks.11.mlp.fc1.weight": "model-00001-of-00004.safetensors",
706
+ "vision_backbone.fused_featurizer.blocks.11.mlp.fc2.bias": "model-00001-of-00004.safetensors",
707
+ "vision_backbone.fused_featurizer.blocks.11.mlp.fc2.weight": "model-00001-of-00004.safetensors",
708
+ "vision_backbone.fused_featurizer.blocks.11.norm1.bias": "model-00001-of-00004.safetensors",
709
+ "vision_backbone.fused_featurizer.blocks.11.norm1.weight": "model-00001-of-00004.safetensors",
710
+ "vision_backbone.fused_featurizer.blocks.11.norm2.bias": "model-00001-of-00004.safetensors",
711
+ "vision_backbone.fused_featurizer.blocks.11.norm2.weight": "model-00001-of-00004.safetensors",
712
+ "vision_backbone.fused_featurizer.blocks.12.attn.proj.bias": "model-00001-of-00004.safetensors",
713
+ "vision_backbone.fused_featurizer.blocks.12.attn.proj.weight": "model-00001-of-00004.safetensors",
714
+ "vision_backbone.fused_featurizer.blocks.12.attn.qkv.bias": "model-00001-of-00004.safetensors",
715
+ "vision_backbone.fused_featurizer.blocks.12.attn.qkv.weight": "model-00001-of-00004.safetensors",
716
+ "vision_backbone.fused_featurizer.blocks.12.mlp.fc1.bias": "model-00001-of-00004.safetensors",
717
+ "vision_backbone.fused_featurizer.blocks.12.mlp.fc1.weight": "model-00001-of-00004.safetensors",
718
+ "vision_backbone.fused_featurizer.blocks.12.mlp.fc2.bias": "model-00001-of-00004.safetensors",
719
+ "vision_backbone.fused_featurizer.blocks.12.mlp.fc2.weight": "model-00001-of-00004.safetensors",
720
+ "vision_backbone.fused_featurizer.blocks.12.norm1.bias": "model-00001-of-00004.safetensors",
721
+ "vision_backbone.fused_featurizer.blocks.12.norm1.weight": "model-00001-of-00004.safetensors",
722
+ "vision_backbone.fused_featurizer.blocks.12.norm2.bias": "model-00001-of-00004.safetensors",
723
+ "vision_backbone.fused_featurizer.blocks.12.norm2.weight": "model-00001-of-00004.safetensors",
724
+ "vision_backbone.fused_featurizer.blocks.13.attn.proj.bias": "model-00001-of-00004.safetensors",
725
+ "vision_backbone.fused_featurizer.blocks.13.attn.proj.weight": "model-00001-of-00004.safetensors",
726
+ "vision_backbone.fused_featurizer.blocks.13.attn.qkv.bias": "model-00001-of-00004.safetensors",
727
+ "vision_backbone.fused_featurizer.blocks.13.attn.qkv.weight": "model-00001-of-00004.safetensors",
728
+ "vision_backbone.fused_featurizer.blocks.13.mlp.fc1.bias": "model-00001-of-00004.safetensors",
729
+ "vision_backbone.fused_featurizer.blocks.13.mlp.fc1.weight": "model-00001-of-00004.safetensors",
730
+ "vision_backbone.fused_featurizer.blocks.13.mlp.fc2.bias": "model-00001-of-00004.safetensors",
731
+ "vision_backbone.fused_featurizer.blocks.13.mlp.fc2.weight": "model-00001-of-00004.safetensors",
732
+ "vision_backbone.fused_featurizer.blocks.13.norm1.bias": "model-00001-of-00004.safetensors",
733
+ "vision_backbone.fused_featurizer.blocks.13.norm1.weight": "model-00001-of-00004.safetensors",
734
+ "vision_backbone.fused_featurizer.blocks.13.norm2.bias": "model-00001-of-00004.safetensors",
735
+ "vision_backbone.fused_featurizer.blocks.13.norm2.weight": "model-00001-of-00004.safetensors",
736
+ "vision_backbone.fused_featurizer.blocks.14.attn.proj.bias": "model-00001-of-00004.safetensors",
737
+ "vision_backbone.fused_featurizer.blocks.14.attn.proj.weight": "model-00001-of-00004.safetensors",
738
+ "vision_backbone.fused_featurizer.blocks.14.attn.qkv.bias": "model-00001-of-00004.safetensors",
739
+ "vision_backbone.fused_featurizer.blocks.14.attn.qkv.weight": "model-00001-of-00004.safetensors",
740
+ "vision_backbone.fused_featurizer.blocks.14.mlp.fc1.bias": "model-00001-of-00004.safetensors",
741
+ "vision_backbone.fused_featurizer.blocks.14.mlp.fc1.weight": "model-00001-of-00004.safetensors",
742
+ "vision_backbone.fused_featurizer.blocks.14.mlp.fc2.bias": "model-00001-of-00004.safetensors",
743
+ "vision_backbone.fused_featurizer.blocks.14.mlp.fc2.weight": "model-00001-of-00004.safetensors",
744
+ "vision_backbone.fused_featurizer.blocks.14.norm1.bias": "model-00001-of-00004.safetensors",
745
+ "vision_backbone.fused_featurizer.blocks.14.norm1.weight": "model-00001-of-00004.safetensors",
746
+ "vision_backbone.fused_featurizer.blocks.14.norm2.bias": "model-00001-of-00004.safetensors",
747
+ "vision_backbone.fused_featurizer.blocks.14.norm2.weight": "model-00001-of-00004.safetensors",
748
+ "vision_backbone.fused_featurizer.blocks.15.attn.proj.bias": "model-00001-of-00004.safetensors",
749
+ "vision_backbone.fused_featurizer.blocks.15.attn.proj.weight": "model-00001-of-00004.safetensors",
750
+ "vision_backbone.fused_featurizer.blocks.15.attn.qkv.bias": "model-00001-of-00004.safetensors",
751
+ "vision_backbone.fused_featurizer.blocks.15.attn.qkv.weight": "model-00001-of-00004.safetensors",
752
+ "vision_backbone.fused_featurizer.blocks.15.mlp.fc1.bias": "model-00001-of-00004.safetensors",
753
+ "vision_backbone.fused_featurizer.blocks.15.mlp.fc1.weight": "model-00001-of-00004.safetensors",
754
+ "vision_backbone.fused_featurizer.blocks.15.mlp.fc2.bias": "model-00001-of-00004.safetensors",
755
+ "vision_backbone.fused_featurizer.blocks.15.mlp.fc2.weight": "model-00001-of-00004.safetensors",
756
+ "vision_backbone.fused_featurizer.blocks.15.norm1.bias": "model-00001-of-00004.safetensors",
757
+ "vision_backbone.fused_featurizer.blocks.15.norm1.weight": "model-00001-of-00004.safetensors",
758
+ "vision_backbone.fused_featurizer.blocks.15.norm2.bias": "model-00001-of-00004.safetensors",
759
+ "vision_backbone.fused_featurizer.blocks.15.norm2.weight": "model-00001-of-00004.safetensors",
760
+ "vision_backbone.fused_featurizer.blocks.16.attn.proj.bias": "model-00001-of-00004.safetensors",
761
+ "vision_backbone.fused_featurizer.blocks.16.attn.proj.weight": "model-00001-of-00004.safetensors",
762
+ "vision_backbone.fused_featurizer.blocks.16.attn.qkv.bias": "model-00001-of-00004.safetensors",
763
+ "vision_backbone.fused_featurizer.blocks.16.attn.qkv.weight": "model-00001-of-00004.safetensors",
764
+ "vision_backbone.fused_featurizer.blocks.16.mlp.fc1.bias": "model-00001-of-00004.safetensors",
765
+ "vision_backbone.fused_featurizer.blocks.16.mlp.fc1.weight": "model-00001-of-00004.safetensors",
766
+ "vision_backbone.fused_featurizer.blocks.16.mlp.fc2.bias": "model-00001-of-00004.safetensors",
767
+ "vision_backbone.fused_featurizer.blocks.16.mlp.fc2.weight": "model-00001-of-00004.safetensors",
768
+ "vision_backbone.fused_featurizer.blocks.16.norm1.bias": "model-00001-of-00004.safetensors",
769
+ "vision_backbone.fused_featurizer.blocks.16.norm1.weight": "model-00001-of-00004.safetensors",
770
+ "vision_backbone.fused_featurizer.blocks.16.norm2.bias": "model-00001-of-00004.safetensors",
771
+ "vision_backbone.fused_featurizer.blocks.16.norm2.weight": "model-00001-of-00004.safetensors",
772
+ "vision_backbone.fused_featurizer.blocks.17.attn.proj.bias": "model-00001-of-00004.safetensors",
773
+ "vision_backbone.fused_featurizer.blocks.17.attn.proj.weight": "model-00001-of-00004.safetensors",
774
+ "vision_backbone.fused_featurizer.blocks.17.attn.qkv.bias": "model-00001-of-00004.safetensors",
775
+ "vision_backbone.fused_featurizer.blocks.17.attn.qkv.weight": "model-00001-of-00004.safetensors",
776
+ "vision_backbone.fused_featurizer.blocks.17.mlp.fc1.bias": "model-00001-of-00004.safetensors",
777
+ "vision_backbone.fused_featurizer.blocks.17.mlp.fc1.weight": "model-00001-of-00004.safetensors",
778
+ "vision_backbone.fused_featurizer.blocks.17.mlp.fc2.bias": "model-00001-of-00004.safetensors",
779
+ "vision_backbone.fused_featurizer.blocks.17.mlp.fc2.weight": "model-00001-of-00004.safetensors",
780
+ "vision_backbone.fused_featurizer.blocks.17.norm1.bias": "model-00001-of-00004.safetensors",
781
+ "vision_backbone.fused_featurizer.blocks.17.norm1.weight": "model-00001-of-00004.safetensors",
782
+ "vision_backbone.fused_featurizer.blocks.17.norm2.bias": "model-00001-of-00004.safetensors",
783
+ "vision_backbone.fused_featurizer.blocks.17.norm2.weight": "model-00001-of-00004.safetensors",
784
+ "vision_backbone.fused_featurizer.blocks.18.attn.proj.bias": "model-00001-of-00004.safetensors",
785
+ "vision_backbone.fused_featurizer.blocks.18.attn.proj.weight": "model-00001-of-00004.safetensors",
786
+ "vision_backbone.fused_featurizer.blocks.18.attn.qkv.bias": "model-00001-of-00004.safetensors",
787
+ "vision_backbone.fused_featurizer.blocks.18.attn.qkv.weight": "model-00001-of-00004.safetensors",
788
+ "vision_backbone.fused_featurizer.blocks.18.mlp.fc1.bias": "model-00001-of-00004.safetensors",
789
+ "vision_backbone.fused_featurizer.blocks.18.mlp.fc1.weight": "model-00001-of-00004.safetensors",
790
+ "vision_backbone.fused_featurizer.blocks.18.mlp.fc2.bias": "model-00001-of-00004.safetensors",
791
+ "vision_backbone.fused_featurizer.blocks.18.mlp.fc2.weight": "model-00001-of-00004.safetensors",
792
+ "vision_backbone.fused_featurizer.blocks.18.norm1.bias": "model-00001-of-00004.safetensors",
793
+ "vision_backbone.fused_featurizer.blocks.18.norm1.weight": "model-00001-of-00004.safetensors",
794
+ "vision_backbone.fused_featurizer.blocks.18.norm2.bias": "model-00001-of-00004.safetensors",
795
+ "vision_backbone.fused_featurizer.blocks.18.norm2.weight": "model-00001-of-00004.safetensors",
796
+ "vision_backbone.fused_featurizer.blocks.19.attn.proj.bias": "model-00001-of-00004.safetensors",
797
+ "vision_backbone.fused_featurizer.blocks.19.attn.proj.weight": "model-00001-of-00004.safetensors",
798
+ "vision_backbone.fused_featurizer.blocks.19.attn.qkv.bias": "model-00001-of-00004.safetensors",
799
+ "vision_backbone.fused_featurizer.blocks.19.attn.qkv.weight": "model-00001-of-00004.safetensors",
800
+ "vision_backbone.fused_featurizer.blocks.19.mlp.fc1.bias": "model-00001-of-00004.safetensors",
801
+ "vision_backbone.fused_featurizer.blocks.19.mlp.fc1.weight": "model-00001-of-00004.safetensors",
802
+ "vision_backbone.fused_featurizer.blocks.19.mlp.fc2.bias": "model-00001-of-00004.safetensors",
803
+ "vision_backbone.fused_featurizer.blocks.19.mlp.fc2.weight": "model-00001-of-00004.safetensors",
804
+ "vision_backbone.fused_featurizer.blocks.19.norm1.bias": "model-00001-of-00004.safetensors",
805
+ "vision_backbone.fused_featurizer.blocks.19.norm1.weight": "model-00001-of-00004.safetensors",
806
+ "vision_backbone.fused_featurizer.blocks.19.norm2.bias": "model-00001-of-00004.safetensors",
807
+ "vision_backbone.fused_featurizer.blocks.19.norm2.weight": "model-00001-of-00004.safetensors",
808
+ "vision_backbone.fused_featurizer.blocks.2.attn.proj.bias": "model-00001-of-00004.safetensors",
809
+ "vision_backbone.fused_featurizer.blocks.2.attn.proj.weight": "model-00001-of-00004.safetensors",
810
+ "vision_backbone.fused_featurizer.blocks.2.attn.qkv.bias": "model-00001-of-00004.safetensors",
811
+ "vision_backbone.fused_featurizer.blocks.2.attn.qkv.weight": "model-00001-of-00004.safetensors",
812
+ "vision_backbone.fused_featurizer.blocks.2.mlp.fc1.bias": "model-00001-of-00004.safetensors",
813
+ "vision_backbone.fused_featurizer.blocks.2.mlp.fc1.weight": "model-00001-of-00004.safetensors",
814
+ "vision_backbone.fused_featurizer.blocks.2.mlp.fc2.bias": "model-00001-of-00004.safetensors",
815
+ "vision_backbone.fused_featurizer.blocks.2.mlp.fc2.weight": "model-00001-of-00004.safetensors",
816
+ "vision_backbone.fused_featurizer.blocks.2.norm1.bias": "model-00001-of-00004.safetensors",
817
+ "vision_backbone.fused_featurizer.blocks.2.norm1.weight": "model-00001-of-00004.safetensors",
818
+ "vision_backbone.fused_featurizer.blocks.2.norm2.bias": "model-00001-of-00004.safetensors",
819
+ "vision_backbone.fused_featurizer.blocks.2.norm2.weight": "model-00001-of-00004.safetensors",
820
+ "vision_backbone.fused_featurizer.blocks.20.attn.proj.bias": "model-00001-of-00004.safetensors",
821
+ "vision_backbone.fused_featurizer.blocks.20.attn.proj.weight": "model-00001-of-00004.safetensors",
822
+ "vision_backbone.fused_featurizer.blocks.20.attn.qkv.bias": "model-00001-of-00004.safetensors",
823
+ "vision_backbone.fused_featurizer.blocks.20.attn.qkv.weight": "model-00001-of-00004.safetensors",
824
+ "vision_backbone.fused_featurizer.blocks.20.mlp.fc1.bias": "model-00001-of-00004.safetensors",
825
+ "vision_backbone.fused_featurizer.blocks.20.mlp.fc1.weight": "model-00001-of-00004.safetensors",
826
+ "vision_backbone.fused_featurizer.blocks.20.mlp.fc2.bias": "model-00001-of-00004.safetensors",
827
+ "vision_backbone.fused_featurizer.blocks.20.mlp.fc2.weight": "model-00001-of-00004.safetensors",
828
+ "vision_backbone.fused_featurizer.blocks.20.norm1.bias": "model-00001-of-00004.safetensors",
829
+ "vision_backbone.fused_featurizer.blocks.20.norm1.weight": "model-00001-of-00004.safetensors",
830
+ "vision_backbone.fused_featurizer.blocks.20.norm2.bias": "model-00001-of-00004.safetensors",
831
+ "vision_backbone.fused_featurizer.blocks.20.norm2.weight": "model-00001-of-00004.safetensors",
832
+ "vision_backbone.fused_featurizer.blocks.21.attn.proj.bias": "model-00001-of-00004.safetensors",
833
+ "vision_backbone.fused_featurizer.blocks.21.attn.proj.weight": "model-00001-of-00004.safetensors",
834
+ "vision_backbone.fused_featurizer.blocks.21.attn.qkv.bias": "model-00001-of-00004.safetensors",
835
+ "vision_backbone.fused_featurizer.blocks.21.attn.qkv.weight": "model-00001-of-00004.safetensors",
836
+ "vision_backbone.fused_featurizer.blocks.21.mlp.fc1.bias": "model-00001-of-00004.safetensors",
837
+ "vision_backbone.fused_featurizer.blocks.21.mlp.fc1.weight": "model-00001-of-00004.safetensors",
838
+ "vision_backbone.fused_featurizer.blocks.21.mlp.fc2.bias": "model-00001-of-00004.safetensors",
839
+ "vision_backbone.fused_featurizer.blocks.21.mlp.fc2.weight": "model-00001-of-00004.safetensors",
840
+ "vision_backbone.fused_featurizer.blocks.21.norm1.bias": "model-00001-of-00004.safetensors",
841
+ "vision_backbone.fused_featurizer.blocks.21.norm1.weight": "model-00001-of-00004.safetensors",
842
+ "vision_backbone.fused_featurizer.blocks.21.norm2.bias": "model-00001-of-00004.safetensors",
843
+ "vision_backbone.fused_featurizer.blocks.21.norm2.weight": "model-00001-of-00004.safetensors",
844
+ "vision_backbone.fused_featurizer.blocks.22.attn.proj.bias": "model-00001-of-00004.safetensors",
845
+ "vision_backbone.fused_featurizer.blocks.22.attn.proj.weight": "model-00001-of-00004.safetensors",
846
+ "vision_backbone.fused_featurizer.blocks.22.attn.qkv.bias": "model-00001-of-00004.safetensors",
847
+ "vision_backbone.fused_featurizer.blocks.22.attn.qkv.weight": "model-00001-of-00004.safetensors",
848
+ "vision_backbone.fused_featurizer.blocks.22.mlp.fc1.bias": "model-00001-of-00004.safetensors",
849
+ "vision_backbone.fused_featurizer.blocks.22.mlp.fc1.weight": "model-00001-of-00004.safetensors",
850
+ "vision_backbone.fused_featurizer.blocks.22.mlp.fc2.bias": "model-00001-of-00004.safetensors",
851
+ "vision_backbone.fused_featurizer.blocks.22.mlp.fc2.weight": "model-00001-of-00004.safetensors",
852
+ "vision_backbone.fused_featurizer.blocks.22.norm1.bias": "model-00001-of-00004.safetensors",
853
+ "vision_backbone.fused_featurizer.blocks.22.norm1.weight": "model-00001-of-00004.safetensors",
854
+ "vision_backbone.fused_featurizer.blocks.22.norm2.bias": "model-00001-of-00004.safetensors",
855
+ "vision_backbone.fused_featurizer.blocks.22.norm2.weight": "model-00001-of-00004.safetensors",
856
+ "vision_backbone.fused_featurizer.blocks.23.attn.proj.bias": "model-00001-of-00004.safetensors",
857
+ "vision_backbone.fused_featurizer.blocks.23.attn.proj.weight": "model-00001-of-00004.safetensors",
858
+ "vision_backbone.fused_featurizer.blocks.23.attn.qkv.bias": "model-00001-of-00004.safetensors",
859
+ "vision_backbone.fused_featurizer.blocks.23.attn.qkv.weight": "model-00001-of-00004.safetensors",
860
+ "vision_backbone.fused_featurizer.blocks.23.mlp.fc1.bias": "model-00001-of-00004.safetensors",
861
+ "vision_backbone.fused_featurizer.blocks.23.mlp.fc1.weight": "model-00001-of-00004.safetensors",
862
+ "vision_backbone.fused_featurizer.blocks.23.mlp.fc2.bias": "model-00001-of-00004.safetensors",
863
+ "vision_backbone.fused_featurizer.blocks.23.mlp.fc2.weight": "model-00001-of-00004.safetensors",
864
+ "vision_backbone.fused_featurizer.blocks.23.norm1.bias": "model-00001-of-00004.safetensors",
865
+ "vision_backbone.fused_featurizer.blocks.23.norm1.weight": "model-00001-of-00004.safetensors",
866
+ "vision_backbone.fused_featurizer.blocks.23.norm2.bias": "model-00001-of-00004.safetensors",
867
+ "vision_backbone.fused_featurizer.blocks.23.norm2.weight": "model-00001-of-00004.safetensors",
868
+ "vision_backbone.fused_featurizer.blocks.24.attn.proj.bias": "model-00001-of-00004.safetensors",
869
+ "vision_backbone.fused_featurizer.blocks.24.attn.proj.weight": "model-00001-of-00004.safetensors",
870
+ "vision_backbone.fused_featurizer.blocks.24.attn.qkv.bias": "model-00001-of-00004.safetensors",
871
+ "vision_backbone.fused_featurizer.blocks.24.attn.qkv.weight": "model-00001-of-00004.safetensors",
872
+ "vision_backbone.fused_featurizer.blocks.24.mlp.fc1.bias": "model-00001-of-00004.safetensors",
873
+ "vision_backbone.fused_featurizer.blocks.24.mlp.fc1.weight": "model-00001-of-00004.safetensors",
874
+ "vision_backbone.fused_featurizer.blocks.24.mlp.fc2.bias": "model-00001-of-00004.safetensors",
875
+ "vision_backbone.fused_featurizer.blocks.24.mlp.fc2.weight": "model-00001-of-00004.safetensors",
876
+ "vision_backbone.fused_featurizer.blocks.24.norm1.bias": "model-00001-of-00004.safetensors",
877
+ "vision_backbone.fused_featurizer.blocks.24.norm1.weight": "model-00001-of-00004.safetensors",
878
+ "vision_backbone.fused_featurizer.blocks.24.norm2.bias": "model-00001-of-00004.safetensors",
879
+ "vision_backbone.fused_featurizer.blocks.24.norm2.weight": "model-00001-of-00004.safetensors",
880
+ "vision_backbone.fused_featurizer.blocks.25.attn.proj.bias": "model-00001-of-00004.safetensors",
881
+ "vision_backbone.fused_featurizer.blocks.25.attn.proj.weight": "model-00001-of-00004.safetensors",
882
+ "vision_backbone.fused_featurizer.blocks.25.attn.qkv.bias": "model-00001-of-00004.safetensors",
883
+ "vision_backbone.fused_featurizer.blocks.25.attn.qkv.weight": "model-00001-of-00004.safetensors",
884
+ "vision_backbone.fused_featurizer.blocks.25.mlp.fc1.bias": "model-00001-of-00004.safetensors",
885
+ "vision_backbone.fused_featurizer.blocks.25.mlp.fc1.weight": "model-00001-of-00004.safetensors",
886
+ "vision_backbone.fused_featurizer.blocks.25.mlp.fc2.bias": "model-00001-of-00004.safetensors",
887
+ "vision_backbone.fused_featurizer.blocks.25.mlp.fc2.weight": "model-00001-of-00004.safetensors",
888
+ "vision_backbone.fused_featurizer.blocks.25.norm1.bias": "model-00001-of-00004.safetensors",
889
+ "vision_backbone.fused_featurizer.blocks.25.norm1.weight": "model-00001-of-00004.safetensors",
890
+ "vision_backbone.fused_featurizer.blocks.25.norm2.bias": "model-00001-of-00004.safetensors",
891
+ "vision_backbone.fused_featurizer.blocks.25.norm2.weight": "model-00001-of-00004.safetensors",
892
+ "vision_backbone.fused_featurizer.blocks.26.attn.proj.bias": "model-00001-of-00004.safetensors",
893
+ "vision_backbone.fused_featurizer.blocks.26.attn.proj.weight": "model-00001-of-00004.safetensors",
894
+ "vision_backbone.fused_featurizer.blocks.26.attn.qkv.bias": "model-00001-of-00004.safetensors",
895
+ "vision_backbone.fused_featurizer.blocks.26.attn.qkv.weight": "model-00001-of-00004.safetensors",
896
+ "vision_backbone.fused_featurizer.blocks.26.mlp.fc1.bias": "model-00001-of-00004.safetensors",
897
+ "vision_backbone.fused_featurizer.blocks.26.mlp.fc1.weight": "model-00001-of-00004.safetensors",
898
+ "vision_backbone.fused_featurizer.blocks.26.mlp.fc2.bias": "model-00001-of-00004.safetensors",
899
+ "vision_backbone.fused_featurizer.blocks.26.mlp.fc2.weight": "model-00001-of-00004.safetensors",
900
+ "vision_backbone.fused_featurizer.blocks.26.norm1.bias": "model-00001-of-00004.safetensors",
901
+ "vision_backbone.fused_featurizer.blocks.26.norm1.weight": "model-00001-of-00004.safetensors",
902
+ "vision_backbone.fused_featurizer.blocks.26.norm2.bias": "model-00001-of-00004.safetensors",
903
+ "vision_backbone.fused_featurizer.blocks.26.norm2.weight": "model-00001-of-00004.safetensors",
904
+ "vision_backbone.fused_featurizer.blocks.3.attn.proj.bias": "model-00001-of-00004.safetensors",
905
+ "vision_backbone.fused_featurizer.blocks.3.attn.proj.weight": "model-00001-of-00004.safetensors",
906
+ "vision_backbone.fused_featurizer.blocks.3.attn.qkv.bias": "model-00001-of-00004.safetensors",
907
+ "vision_backbone.fused_featurizer.blocks.3.attn.qkv.weight": "model-00001-of-00004.safetensors",
908
+ "vision_backbone.fused_featurizer.blocks.3.mlp.fc1.bias": "model-00001-of-00004.safetensors",
909
+ "vision_backbone.fused_featurizer.blocks.3.mlp.fc1.weight": "model-00001-of-00004.safetensors",
910
+ "vision_backbone.fused_featurizer.blocks.3.mlp.fc2.bias": "model-00001-of-00004.safetensors",
911
+ "vision_backbone.fused_featurizer.blocks.3.mlp.fc2.weight": "model-00001-of-00004.safetensors",
912
+ "vision_backbone.fused_featurizer.blocks.3.norm1.bias": "model-00001-of-00004.safetensors",
913
+ "vision_backbone.fused_featurizer.blocks.3.norm1.weight": "model-00001-of-00004.safetensors",
914
+ "vision_backbone.fused_featurizer.blocks.3.norm2.bias": "model-00001-of-00004.safetensors",
915
+ "vision_backbone.fused_featurizer.blocks.3.norm2.weight": "model-00001-of-00004.safetensors",
916
+ "vision_backbone.fused_featurizer.blocks.4.attn.proj.bias": "model-00001-of-00004.safetensors",
917
+ "vision_backbone.fused_featurizer.blocks.4.attn.proj.weight": "model-00001-of-00004.safetensors",
918
+ "vision_backbone.fused_featurizer.blocks.4.attn.qkv.bias": "model-00001-of-00004.safetensors",
919
+ "vision_backbone.fused_featurizer.blocks.4.attn.qkv.weight": "model-00001-of-00004.safetensors",
920
+ "vision_backbone.fused_featurizer.blocks.4.mlp.fc1.bias": "model-00001-of-00004.safetensors",
921
+ "vision_backbone.fused_featurizer.blocks.4.mlp.fc1.weight": "model-00001-of-00004.safetensors",
922
+ "vision_backbone.fused_featurizer.blocks.4.mlp.fc2.bias": "model-00001-of-00004.safetensors",
923
+ "vision_backbone.fused_featurizer.blocks.4.mlp.fc2.weight": "model-00001-of-00004.safetensors",
924
+ "vision_backbone.fused_featurizer.blocks.4.norm1.bias": "model-00001-of-00004.safetensors",
925
+ "vision_backbone.fused_featurizer.blocks.4.norm1.weight": "model-00001-of-00004.safetensors",
926
+ "vision_backbone.fused_featurizer.blocks.4.norm2.bias": "model-00001-of-00004.safetensors",
927
+ "vision_backbone.fused_featurizer.blocks.4.norm2.weight": "model-00001-of-00004.safetensors",
928
+ "vision_backbone.fused_featurizer.blocks.5.attn.proj.bias": "model-00001-of-00004.safetensors",
929
+ "vision_backbone.fused_featurizer.blocks.5.attn.proj.weight": "model-00001-of-00004.safetensors",
930
+ "vision_backbone.fused_featurizer.blocks.5.attn.qkv.bias": "model-00001-of-00004.safetensors",
931
+ "vision_backbone.fused_featurizer.blocks.5.attn.qkv.weight": "model-00001-of-00004.safetensors",
932
+ "vision_backbone.fused_featurizer.blocks.5.mlp.fc1.bias": "model-00001-of-00004.safetensors",
933
+ "vision_backbone.fused_featurizer.blocks.5.mlp.fc1.weight": "model-00001-of-00004.safetensors",
934
+ "vision_backbone.fused_featurizer.blocks.5.mlp.fc2.bias": "model-00001-of-00004.safetensors",
935
+ "vision_backbone.fused_featurizer.blocks.5.mlp.fc2.weight": "model-00001-of-00004.safetensors",
936
+ "vision_backbone.fused_featurizer.blocks.5.norm1.bias": "model-00001-of-00004.safetensors",
937
+ "vision_backbone.fused_featurizer.blocks.5.norm1.weight": "model-00001-of-00004.safetensors",
938
+ "vision_backbone.fused_featurizer.blocks.5.norm2.bias": "model-00001-of-00004.safetensors",
939
+ "vision_backbone.fused_featurizer.blocks.5.norm2.weight": "model-00001-of-00004.safetensors",
940
+ "vision_backbone.fused_featurizer.blocks.6.attn.proj.bias": "model-00001-of-00004.safetensors",
941
+ "vision_backbone.fused_featurizer.blocks.6.attn.proj.weight": "model-00001-of-00004.safetensors",
942
+ "vision_backbone.fused_featurizer.blocks.6.attn.qkv.bias": "model-00001-of-00004.safetensors",
943
+ "vision_backbone.fused_featurizer.blocks.6.attn.qkv.weight": "model-00001-of-00004.safetensors",
944
+ "vision_backbone.fused_featurizer.blocks.6.mlp.fc1.bias": "model-00001-of-00004.safetensors",
945
+ "vision_backbone.fused_featurizer.blocks.6.mlp.fc1.weight": "model-00001-of-00004.safetensors",
946
+ "vision_backbone.fused_featurizer.blocks.6.mlp.fc2.bias": "model-00001-of-00004.safetensors",
947
+ "vision_backbone.fused_featurizer.blocks.6.mlp.fc2.weight": "model-00001-of-00004.safetensors",
948
+ "vision_backbone.fused_featurizer.blocks.6.norm1.bias": "model-00001-of-00004.safetensors",
949
+ "vision_backbone.fused_featurizer.blocks.6.norm1.weight": "model-00001-of-00004.safetensors",
950
+ "vision_backbone.fused_featurizer.blocks.6.norm2.bias": "model-00001-of-00004.safetensors",
951
+ "vision_backbone.fused_featurizer.blocks.6.norm2.weight": "model-00001-of-00004.safetensors",
952
+ "vision_backbone.fused_featurizer.blocks.7.attn.proj.bias": "model-00001-of-00004.safetensors",
953
+ "vision_backbone.fused_featurizer.blocks.7.attn.proj.weight": "model-00001-of-00004.safetensors",
954
+ "vision_backbone.fused_featurizer.blocks.7.attn.qkv.bias": "model-00001-of-00004.safetensors",
955
+ "vision_backbone.fused_featurizer.blocks.7.attn.qkv.weight": "model-00001-of-00004.safetensors",
956
+ "vision_backbone.fused_featurizer.blocks.7.mlp.fc1.bias": "model-00001-of-00004.safetensors",
957
+ "vision_backbone.fused_featurizer.blocks.7.mlp.fc1.weight": "model-00001-of-00004.safetensors",
958
+ "vision_backbone.fused_featurizer.blocks.7.mlp.fc2.bias": "model-00001-of-00004.safetensors",
959
+ "vision_backbone.fused_featurizer.blocks.7.mlp.fc2.weight": "model-00001-of-00004.safetensors",
960
+ "vision_backbone.fused_featurizer.blocks.7.norm1.bias": "model-00001-of-00004.safetensors",
961
+ "vision_backbone.fused_featurizer.blocks.7.norm1.weight": "model-00001-of-00004.safetensors",
962
+ "vision_backbone.fused_featurizer.blocks.7.norm2.bias": "model-00001-of-00004.safetensors",
963
+ "vision_backbone.fused_featurizer.blocks.7.norm2.weight": "model-00001-of-00004.safetensors",
964
+ "vision_backbone.fused_featurizer.blocks.8.attn.proj.bias": "model-00001-of-00004.safetensors",
965
+ "vision_backbone.fused_featurizer.blocks.8.attn.proj.weight": "model-00001-of-00004.safetensors",
966
+ "vision_backbone.fused_featurizer.blocks.8.attn.qkv.bias": "model-00001-of-00004.safetensors",
967
+ "vision_backbone.fused_featurizer.blocks.8.attn.qkv.weight": "model-00001-of-00004.safetensors",
968
+ "vision_backbone.fused_featurizer.blocks.8.mlp.fc1.bias": "model-00001-of-00004.safetensors",
969
+ "vision_backbone.fused_featurizer.blocks.8.mlp.fc1.weight": "model-00001-of-00004.safetensors",
970
+ "vision_backbone.fused_featurizer.blocks.8.mlp.fc2.bias": "model-00001-of-00004.safetensors",
971
+ "vision_backbone.fused_featurizer.blocks.8.mlp.fc2.weight": "model-00001-of-00004.safetensors",
972
+ "vision_backbone.fused_featurizer.blocks.8.norm1.bias": "model-00001-of-00004.safetensors",
973
+ "vision_backbone.fused_featurizer.blocks.8.norm1.weight": "model-00001-of-00004.safetensors",
974
+ "vision_backbone.fused_featurizer.blocks.8.norm2.bias": "model-00001-of-00004.safetensors",
975
+ "vision_backbone.fused_featurizer.blocks.8.norm2.weight": "model-00001-of-00004.safetensors",
976
+ "vision_backbone.fused_featurizer.blocks.9.attn.proj.bias": "model-00001-of-00004.safetensors",
977
+ "vision_backbone.fused_featurizer.blocks.9.attn.proj.weight": "model-00001-of-00004.safetensors",
978
+ "vision_backbone.fused_featurizer.blocks.9.attn.qkv.bias": "model-00001-of-00004.safetensors",
979
+ "vision_backbone.fused_featurizer.blocks.9.attn.qkv.weight": "model-00001-of-00004.safetensors",
980
+ "vision_backbone.fused_featurizer.blocks.9.mlp.fc1.bias": "model-00001-of-00004.safetensors",
981
+ "vision_backbone.fused_featurizer.blocks.9.mlp.fc1.weight": "model-00001-of-00004.safetensors",
982
+ "vision_backbone.fused_featurizer.blocks.9.mlp.fc2.bias": "model-00001-of-00004.safetensors",
983
+ "vision_backbone.fused_featurizer.blocks.9.mlp.fc2.weight": "model-00001-of-00004.safetensors",
984
+ "vision_backbone.fused_featurizer.blocks.9.norm1.bias": "model-00001-of-00004.safetensors",
985
+ "vision_backbone.fused_featurizer.blocks.9.norm1.weight": "model-00001-of-00004.safetensors",
986
+ "vision_backbone.fused_featurizer.blocks.9.norm2.bias": "model-00001-of-00004.safetensors",
987
+ "vision_backbone.fused_featurizer.blocks.9.norm2.weight": "model-00001-of-00004.safetensors",
988
+ "vision_backbone.fused_featurizer.norm.bias": "model-00001-of-00004.safetensors",
989
+ "vision_backbone.fused_featurizer.norm.weight": "model-00001-of-00004.safetensors",
990
+ "vision_backbone.fused_featurizer.patch_embed.proj.bias": "model-00001-of-00004.safetensors",
991
+ "vision_backbone.fused_featurizer.patch_embed.proj.weight": "model-00001-of-00004.safetensors",
992
+ "vision_backbone.fused_featurizer.pos_embed": "model-00001-of-00004.safetensors"
993
+ }
994
+ }
modeling_prismatic.py ADDED
@@ -0,0 +1,1996 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeling_prismatic.py
3
+
4
+ Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions.
5
+ Inherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained,
6
+ but exactly replicate the logic in `prismatic.models.vlms.prismatic.py`.
7
+ """
8
+
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from functools import partial
12
+ from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
13
+
14
+ import numpy as np
15
+ import timm
16
+ import tokenizers
17
+ import torch
18
+ import torch.nn as nn
19
+ import transformers
20
+ from timm.models.vision_transformer import LayerScale
21
+ from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import ModelOutput
23
+
24
+ from prismatic.training.train_utils import (
25
+ get_current_action_mask,
26
+ get_next_actions_mask,
27
+ )
28
+ from prismatic.vla.constants import (
29
+ ACTION_DIM,
30
+ ACTION_PROPRIO_NORMALIZATION_TYPE,
31
+ ACTION_TOKEN_BEGIN_IDX,
32
+ IGNORE_INDEX,
33
+ NUM_ACTIONS_CHUNK,
34
+ STOP_INDEX,
35
+ NormalizationType,
36
+ )
37
+
38
+ from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
39
+
40
+ # Set up logger
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ # === Utility Functions for Monkey-Patching ===
45
+ def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
46
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
47
+ result = fn(*args, **kwargs)
48
+ return result[0] if isinstance(result, tuple) else result
49
+
50
+ return wrapper
51
+
52
+
53
+ # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
54
+ # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
55
+ # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
56
+ def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
58
+
59
+
60
+ def ls_apply_patch(ls_module: LayerScale):
61
+ ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
62
+ ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
63
+ del ls_module.gamma
64
+
65
+
66
+ # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
67
+ class PrismaticVisionBackbone(nn.Module):
68
+ """
69
+ Vision backbone for Prismatic models that handles image feature extraction.
70
+
71
+ Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations.
72
+ For fused backbones, features from both models are concatenated along the feature dimension.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ use_fused_vision_backbone: bool,
78
+ image_sizes: List[int],
79
+ timm_model_ids: List[str],
80
+ timm_override_act_layers: List[Optional[str]],
81
+ ) -> None:
82
+ """
83
+ Initialize the vision backbone.
84
+
85
+ Args:
86
+ use_fused_vision_backbone: Whether to use two backbones and fuse their features
87
+ image_sizes: List of image sizes for each backbone
88
+ timm_model_ids: List of TIMM model IDs to use for each backbone
89
+ timm_override_act_layers: List of activation layer overrides for each backbone
90
+ """
91
+ super().__init__()
92
+ self.use_fused_vision_backbone = use_fused_vision_backbone
93
+ self.num_images_in_input = 1 # Default value, can be overridden later
94
+
95
+ # Validate number of (fused) vision backbones
96
+ if len(timm_model_ids) > 2:
97
+ raise ValueError("Prismatic models only support up to 2 (fused) vision backbones!")
98
+
99
+ # Create primary featurizer
100
+ self.featurizer = self._create_featurizer(
101
+ model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0]
102
+ )
103
+ self.embed_dim = self.featurizer.embed_dim
104
+
105
+ # Create secondary featurizer if using fused backbone
106
+ if self.use_fused_vision_backbone:
107
+ self.fused_featurizer = self._create_featurizer(
108
+ model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1]
109
+ )
110
+ self.embed_dim += self.fused_featurizer.embed_dim
111
+
112
+ # Patch LayerScale modules for HF compatibility
113
+ self._patch_layer_scales()
114
+
115
+ def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module:
116
+ """
117
+ Create a TIMM-based featurizer model with appropriate configurations.
118
+
119
+ Args:
120
+ model_id: The TIMM model ID to load
121
+ img_size: Input image size for the model
122
+ act_layer: Override for the activation layer type
123
+
124
+ Returns:
125
+ A configured featurizer model
126
+ """
127
+ featurizer = timm.create_model(
128
+ model_id,
129
+ pretrained=False,
130
+ num_classes=0,
131
+ img_size=img_size,
132
+ act_layer=act_layer,
133
+ )
134
+
135
+ # Monkey-patch the forward function to extract the second-to-last layer features
136
+ num_blocks = len(featurizer.blocks)
137
+ featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2}))
138
+
139
+ return featurizer
140
+
141
+ def _patch_layer_scales(self) -> None:
142
+ """
143
+ Patch all LayerScale modules to be compatible with HF's parameter naming.
144
+
145
+ HF Transformers overwrites parameters with names containing 'gamma',
146
+ so we need to rename and modify the forward method.
147
+ """
148
+ # Patch primary featurizer
149
+ for module in self.featurizer.modules():
150
+ if isinstance(module, LayerScale):
151
+ ls_apply_patch(module)
152
+
153
+ # Patch secondary featurizer if it exists
154
+ if self.use_fused_vision_backbone:
155
+ for module in self.fused_featurizer.modules():
156
+ if isinstance(module, LayerScale):
157
+ ls_apply_patch(module)
158
+
159
+ def get_num_patches(self) -> int:
160
+ """
161
+ Returns the number of vision patches output by the vision backbone.
162
+
163
+ Returns:
164
+ Number of patches per image
165
+ """
166
+ return self.featurizer.patch_embed.num_patches
167
+
168
+ def get_num_images_in_input(self) -> int:
169
+ """
170
+ Returns the number of input images for the vision backbone.
171
+
172
+ Returns:
173
+ Number of images expected in the input
174
+ """
175
+ return self.num_images_in_input
176
+
177
+ def set_num_images_in_input(self, num_images_in_input: int) -> None:
178
+ """
179
+ Sets the number of input images for the vision backbone.
180
+
181
+ Args:
182
+ num_images_in_input: Number of images to expect in the input
183
+ """
184
+ self.num_images_in_input = num_images_in_input
185
+
186
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
187
+ """
188
+ Implements the forward pass for the vision backbone.
189
+
190
+ If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features
191
+ (otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone).
192
+
193
+ Args:
194
+ pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W).
195
+ """
196
+ if self.num_images_in_input == 1:
197
+ if not self.use_fused_vision_backbone:
198
+ return self.featurizer(pixel_values)
199
+
200
+ # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
201
+ img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
202
+ patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
203
+
204
+ return torch.cat([patches, patches_fused], dim=2)
205
+
206
+ else:
207
+ assert self.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!"
208
+
209
+ # Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2)
210
+ images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1)
211
+
212
+ # Process each image and collect patches
213
+ all_patches = []
214
+ for img in images:
215
+ # Split each image further into two stacks of channels (each with 3 channels)
216
+ img_regular, img_fused = torch.split(img, [3, 3], dim=1)
217
+
218
+ # Get patches from both SigLIP and DINOv2 vision transformers
219
+ patches = self.featurizer(img_regular)
220
+ patches_fused = self.fused_featurizer(img_fused)
221
+
222
+ # Concatenate SigLIP and DINOv2 patches along the hidden dimension
223
+ combined_patches = torch.cat([patches, patches_fused], dim=2)
224
+ all_patches.append(combined_patches)
225
+
226
+ # Concatenate all patches along the patch dimension
227
+ return torch.cat(all_patches, dim=1)
228
+
229
+
230
+ # === Prismatic Projector (nn.Module) Definitions ===
231
+ class PrismaticProjector(nn.Module):
232
+ def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
233
+ super().__init__()
234
+ self.use_fused_vision_backbone = use_fused_vision_backbone
235
+ self.vision_dim, self.llm_dim = vision_dim, llm_dim
236
+
237
+ # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
238
+ if not self.use_fused_vision_backbone:
239
+ self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
240
+ self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
241
+ self.act_fn1 = nn.GELU()
242
+ else:
243
+ initial_projection_dim = 4 * vision_dim
244
+ self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
245
+ self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
246
+ self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
247
+ self.act_fn1 = nn.GELU()
248
+ self.act_fn2 = nn.GELU()
249
+
250
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
251
+ if not self.use_fused_vision_backbone:
252
+ projected_features = self.fc1(img_patches)
253
+ projected_features = self.act_fn1(projected_features)
254
+ projected_features = self.fc2(projected_features)
255
+ else:
256
+ projected_features = self.fc1(img_patches)
257
+ projected_features = self.act_fn1(projected_features)
258
+ projected_features = self.fc2(projected_features)
259
+ projected_features = self.act_fn2(projected_features)
260
+ projected_features = self.fc3(projected_features)
261
+
262
+ return projected_features
263
+
264
+
265
+ # === Main HF Class Definitions ===
266
+ @dataclass
267
+ class PrismaticCausalLMOutputWithPast(ModelOutput):
268
+ """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
269
+
270
+ loss: Optional[torch.FloatTensor] = None
271
+ logits: torch.FloatTensor = None
272
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
273
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
274
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
275
+
276
+ # Additions for VLMs
277
+ projector_features: Optional[torch.FloatTensor] = None
278
+
279
+
280
+ class PrismaticPreTrainedModel(PreTrainedModel):
281
+ config_class: PretrainedConfig = PrismaticConfig
282
+ base_model_prefix: str = "model"
283
+ supports_gradient_checkpointing: bool = True
284
+
285
+ _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
286
+ _skip_keys_device_placement: str = "past_key_values"
287
+ _supports_flash_attn_2: bool = True
288
+
289
+ def _init_weights(self, module: nn.Module) -> None:
290
+ # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
291
+ # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
292
+ # https://github.com/TRI-ML/prismatic-vlms
293
+ std = (
294
+ self.config.initializer_range
295
+ if hasattr(self.config, "initializer_range")
296
+ else self.config.text_config.initializer_range
297
+ )
298
+
299
+ if hasattr(module, "class_embedding"):
300
+ module.class_embedding.data.normal_(mean=0.0, std=std)
301
+
302
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
303
+ module.weight.data.normal_(mean=0.0, std=std)
304
+ if module.bias is not None:
305
+ module.bias.data.zero_()
306
+ elif isinstance(module, nn.Embedding):
307
+ module.weight.data.normal_(mean=0.0, std=std)
308
+ if module.padding_idx is not None:
309
+ module.weight.data[module.padding_idx].zero_()
310
+
311
+ @property
312
+ def _supports_sdpa(self) -> bool:
313
+ """Check LLM supports SDPA Attention"""
314
+ return self.language_model._supports_sdpa
315
+
316
+
317
+ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
318
+ def __init__(self, config: PrismaticConfig) -> None:
319
+ super().__init__(config)
320
+
321
+ # [Validation] Lightweight Validate on `config` Fields + Dependency Versions
322
+ if config.use_fused_vision_backbone is None:
323
+ raise ValueError("Missing config field `use_fused_vision_backbone`")
324
+
325
+ if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
326
+ raise NotImplementedError(
327
+ "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
328
+ "if you urgently need support for latest TIMM versions."
329
+ )
330
+
331
+ if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
332
+ logger.warning(
333
+ f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
334
+ f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
335
+ f"there might be inference-time regressions due to dependency changes. If in doubt, please"
336
+ f"use the above versions."
337
+ )
338
+
339
+ # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
340
+ self.vision_backbone = PrismaticVisionBackbone(
341
+ config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
342
+ )
343
+
344
+ # Create Multimodal Projector
345
+ self.projector = PrismaticProjector(
346
+ config.use_fused_vision_backbone,
347
+ vision_dim=self.vision_backbone.embed_dim,
348
+ llm_dim=config.text_config.hidden_size,
349
+ )
350
+
351
+ # Instantiate LLM Backbone
352
+ self.language_model = AutoModelForCausalLM.from_config(
353
+ config.text_config, attn_implementation=config._attn_implementation
354
+ )
355
+ self.vocab_size = config.text_config.vocab_size
356
+ self.pad_token_id = config.pad_token_id
357
+ self.llm_dim = config.text_config.hidden_size
358
+
359
+ # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
360
+ self.post_init()
361
+
362
+ # === `PreTrainedModel` Boilerplate ===
363
+ def get_input_embeddings(self) -> nn.Module:
364
+ return self.language_model.get_input_embeddings()
365
+
366
+ def set_input_embeddings(self, value: nn.Module) -> None:
367
+ self.language_model.set_input_embeddings(value)
368
+
369
+ def get_output_embeddings(self) -> nn.Module:
370
+ return self.language_model.get_output_embeddings()
371
+
372
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
373
+ self.language_model.set_output_embeddings(new_embeddings)
374
+
375
+ def get_decoder(self) -> nn.Module:
376
+ return self.language_model.get_decoder()
377
+
378
+ def set_decoder(self, decoder: nn.Module) -> None:
379
+ self.language_model.set_decoder(decoder)
380
+
381
+ def tie_weights(self) -> None:
382
+ self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
383
+
384
+ def resize_token_embeddings(
385
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
386
+ ) -> nn.Embedding:
387
+ updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
388
+
389
+ # Update config/instance variables
390
+ self.config.text_config.vocab_size = updated_embeddings.num_embeddings
391
+ self.vocab_size = updated_embeddings.num_embeddings
392
+
393
+ return updated_embeddings
394
+
395
+ def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features):
396
+ """
397
+ Replace embeddings in input_embeddings at positions where all_actions_mask is True
398
+ with embeddings from noisy_action_features, using vectorized operations.
399
+
400
+ Args:
401
+ input_embeddings: Tensor of shape (B, S, D)
402
+ all_actions_mask: Boolean tensor of shape (B, S)
403
+ noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample
404
+
405
+ Returns:
406
+ Modified input_embeddings tensor
407
+ """
408
+ # Clone input to avoid modifying the original tensor
409
+ new_input_embeddings = input_embeddings.clone()
410
+
411
+ # Create a tensor with the same shape of input_embeddings to hold the noisy action features
412
+ repositioned_noisy_action_features = torch.zeros_like(input_embeddings)
413
+
414
+ # Create batch indices for splicing
415
+ batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device)
416
+ batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1])
417
+
418
+ # Get indices where mask is True for each sample
419
+ masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask])
420
+
421
+ # Move the noisy action features into their correct positions
422
+ repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features
423
+
424
+ # Combine original input embeddings and noisy action embeddings using the mask
425
+ new_input_embeddings = torch.where(
426
+ all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings
427
+ )
428
+
429
+ return new_input_embeddings
430
+
431
+ def _process_action_masks(self, labels):
432
+ """Helper to get action masks from labels"""
433
+ current_action_mask = get_current_action_mask(labels)
434
+ next_actions_mask = get_next_actions_mask(labels)
435
+ all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len)
436
+ return all_actions_mask
437
+
438
+ def _process_vision_features(self, pixel_values, language_embeddings=None, use_film=False):
439
+ """Process vision features with optional FiLM conditioning"""
440
+ if use_film:
441
+ # FiLM: Infuse language inputs into visual features
442
+ patch_features = self.vision_backbone(pixel_values, language_embeddings) # (bsz, 256 * num_images, D)
443
+ else:
444
+ patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D)
445
+
446
+ # Project patch embeddings into language embedding space
447
+ return self.projector(patch_features)
448
+
449
+ def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector):
450
+ """Process proprioceptive features and append to vision features"""
451
+ if proprio_projector is not None and proprio is not None:
452
+ # projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim)
453
+ # proprio: (bsz, proprio_dim) or (propro_dim,)
454
+ proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1) # (bsz, proprio_dim)
455
+ proprio_features = proprio_projector(proprio) # (bsz, llm_dim)
456
+ proprio_features = proprio_features.unsqueeze(dim=1) # (bsz, 1, llm_dim)
457
+ # For simplicity, just append proprio token to the end of projected vision patch tokens
458
+ return torch.cat((projected_patch_embeddings, proprio_features), dim=1)
459
+ return projected_patch_embeddings
460
+
461
+ def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):
462
+ """Build multimodal embeddings and attention mask"""
463
+ # Update attention mask
464
+ projected_patch_attention_mask = None
465
+ if attention_mask is not None:
466
+ projected_patch_attention_mask = torch.full(
467
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
468
+ fill_value=True,
469
+ dtype=attention_mask.dtype,
470
+ device=attention_mask.device,
471
+ )
472
+
473
+ # Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)
474
+ multimodal_embeddings = torch.cat(
475
+ [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
476
+ )
477
+
478
+ multimodal_attention_mask = None
479
+ if attention_mask is not None:
480
+ multimodal_attention_mask = torch.cat(
481
+ [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
482
+ )
483
+
484
+ return multimodal_embeddings, multimodal_attention_mask
485
+
486
+ def _build_multimodal_labels(self, labels, projected_patch_embeddings):
487
+ """Build multimodal labels with IGNORE_INDEX for patch embeddings"""
488
+ if labels is not None:
489
+ projected_patch_labels = torch.full(
490
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
491
+ fill_value=IGNORE_INDEX,
492
+ dtype=labels.dtype,
493
+ device=labels.device,
494
+ )
495
+ return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
496
+ return None
497
+
498
+ # === Core Prismatic VLM `forward()` Logic ===
499
+ # def forward(
500
+ # self,
501
+ # input_ids: Optional[torch.LongTensor] = None,
502
+ # attention_mask: Optional[torch.Tensor] = None,
503
+ # pixel_values: Optional[torch.FloatTensor] = None,
504
+ # labels: Optional[torch.LongTensor] = None,
505
+ # inputs_embeds: Optional[torch.FloatTensor] = None,
506
+ # past_key_values: Optional[List[torch.FloatTensor]] = None,
507
+ # use_cache: Optional[bool] = None,
508
+ # output_attentions: Optional[bool] = None,
509
+ # output_hidden_states: Optional[bool] = None,
510
+ # output_projector_features: Optional[bool] = None,
511
+ # return_dict: Optional[bool] = None,
512
+ # proprio=None,
513
+ # proprio_projector=None,
514
+ # noisy_actions=None,
515
+ # noisy_action_projector=None,
516
+ # diffusion_timestep_embeddings=None,
517
+ # use_film: bool = False,
518
+ # ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
519
+ # """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
520
+ # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
521
+ # output_hidden_states = (
522
+ # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
523
+ # )
524
+ # output_projector_features = output_projector_features if output_projector_features is not None else False
525
+ # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
526
+
527
+ # # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
528
+ # use_cache = use_cache and not self.training
529
+
530
+ # # Instantiate Placeholder for Projector Features
531
+ # projected_patch_embeddings = None
532
+
533
+ # # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
534
+ # if input_ids.shape[1] == 1:
535
+ # assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
536
+ # assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
537
+ # assert labels is None, "Unexpected key `labels` provided during cached generation!"
538
+
539
+ # language_model_output = self.language_model(
540
+ # input_ids=input_ids,
541
+ # attention_mask=None,
542
+ # position_ids=None,
543
+ # past_key_values=past_key_values,
544
+ # inputs_embeds=None,
545
+ # labels=None,
546
+ # use_cache=use_cache,
547
+ # output_attentions=output_attentions,
548
+ # output_hidden_states=output_hidden_states,
549
+ # return_dict=return_dict,
550
+ # )
551
+
552
+ # # === Handle Unimodal Forward ===
553
+ # elif pixel_values is None:
554
+ # assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
555
+ # assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
556
+
557
+ # language_model_output = self.language_model(
558
+ # input_ids=input_ids,
559
+ # attention_mask=attention_mask,
560
+ # position_ids=None,
561
+ # past_key_values=None,
562
+ # inputs_embeds=None,
563
+ # labels=labels,
564
+ # use_cache=use_cache,
565
+ # output_attentions=output_attentions,
566
+ # output_hidden_states=output_hidden_states,
567
+ # return_dict=return_dict,
568
+ # )
569
+
570
+ # # === Handle Multimodal Forward ===
571
+ # elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
572
+ # assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!"
573
+
574
+ # #test
575
+ #
576
+ # #test end
577
+
578
+ # # Get input embeddings (from language model embeddings)
579
+ # input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
580
+
581
+ # # Extract action masks
582
+ # all_actions_mask = self._process_action_masks(labels)
583
+
584
+ # # Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
585
+ # language_embeddings = input_embeddings[~all_actions_mask].reshape(
586
+ # input_embeddings.shape[0], -1, input_embeddings.shape[2]
587
+ # ) # (B, lang_seq_len, llm_dim)
588
+
589
+ # # Get visual features
590
+ # projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
591
+
592
+ # # Add proprioceptive state if provided
593
+ # projected_patch_embeddings = self._process_proprio_features(
594
+ # projected_patch_embeddings, proprio, proprio_projector
595
+ # )
596
+
597
+ # # [Diffusion] Add diffusion timestep embedding if provided
598
+ # if diffusion_timestep_embeddings is not None:
599
+ # # For simplicity, just append diffusion timestep embedding to the end of projected vision patch tokens
600
+ # projected_patch_embeddings = torch.cat(
601
+ # (projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
602
+ # )
603
+
604
+ # # Process action embeddings
605
+ # if noisy_actions is not None:
606
+ # # Get mask corresponding to all action tokens
607
+ # all_actions_mask = self._process_action_masks(labels)
608
+
609
+ # # Reshape noisy actions into individual action tokens
610
+ # # noisy_actions: (B, chunk_len, action_dim) -> (B, chunk_len * action_dim, 1)
611
+ # B = noisy_actions.shape[0]
612
+ # noisy_actions = noisy_actions.reshape(B, -1).unsqueeze(-1)
613
+
614
+ # # Project noisy action tokens into language model embedding space
615
+ # noisy_action_features = noisy_action_projector(noisy_actions) # (B, chunk_len * action_dim, llm_dim)
616
+
617
+ # # Replace embeddings of the action tokens with noisy action embeddings
618
+ # input_embeddings = self._replace_input_embeddings(
619
+ # input_embeddings, all_actions_mask, noisy_action_features
620
+ # )
621
+ # else:
622
+ # # Replace the embeddings of the action tokens with zeros
623
+ # # (Later on, the positional embeddings will be added to them)
624
+ # all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
625
+ # input_embeddings = input_embeddings * ~all_actions_mask
626
+
627
+ # # Build multimodal embeddings & attention mask
628
+ # multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
629
+ # input_embeddings, projected_patch_embeddings, attention_mask
630
+ # )
631
+
632
+ # # Build labels for multimodal sequence if needed
633
+ # multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
634
+
635
+ # # Dispatch to language model
636
+ # language_model_output = self.language_model(
637
+ # input_ids=None,
638
+ # attention_mask=multimodal_attention_mask,
639
+ # position_ids=None,
640
+ # past_key_values=None,
641
+ # inputs_embeds=multimodal_embeddings,
642
+ # labels=multimodal_labels,
643
+ # use_cache=use_cache,
644
+ # output_attentions=output_attentions,
645
+ # output_hidden_states=output_hidden_states,
646
+ # return_dict=return_dict,
647
+ # )
648
+
649
+ # # === Otherwise =>> Assume Invalid! ===
650
+ # elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
651
+ # raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
652
+
653
+ # else:
654
+ # raise ValueError(
655
+ # "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
656
+ # f"=> `input_ids` = {input_ids is not None}\n"
657
+ # f"=> `attention_mask` = {attention_mask is not None}\n"
658
+ # f"=> `pixel_values` = {pixel_values is not None}\n"
659
+ # f"=> `labels` = {labels is not None}\n"
660
+ # f"=> `input_embeds` = {inputs_embeds is not None}\n"
661
+ # f"=> `past_key_values` = {past_key_values is not None}\n"
662
+ # f"=> `use_cache` = {use_cache}"
663
+ # )
664
+
665
+ # # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
666
+ # if not return_dict:
667
+ # if output_projector_features and (projected_patch_embeddings is not None):
668
+ # return *language_model_output, projected_patch_embeddings
669
+
670
+ # return language_model_output
671
+
672
+ # return PrismaticCausalLMOutputWithPast(
673
+ # loss=language_model_output.loss,
674
+ # logits=language_model_output.logits,
675
+ # past_key_values=language_model_output.past_key_values,
676
+ # hidden_states=language_model_output.hidden_states,
677
+ # attentions=language_model_output.attentions,
678
+ # projector_features=projected_patch_embeddings,
679
+ # )
680
+
681
+ # === GenerationMixin Methods ===
682
+ def prepare_inputs_for_generation(
683
+ self,
684
+ input_ids: Optional[torch.Tensor] = None,
685
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
686
+ inputs_embeds: Optional[torch.FloatTensor] = None,
687
+ pixel_values: Optional[torch.FloatTensor] = None,
688
+ attention_mask: Optional[torch.Tensor] = None,
689
+ **kwargs: str,
690
+ ) -> Dict[str, torch.Tensor]:
691
+ """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
692
+ if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
693
+ (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
694
+ ):
695
+ raise ValueError("Generation with batch size > 1 is not currently supported!")
696
+
697
+ # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
698
+ if past_key_values is not None:
699
+ input_ids = input_ids[:, -1:]
700
+
701
+ # If `input_embeds` are passed, we only want to use them in the 1st generation step
702
+ if inputs_embeds is not None and past_key_values is None:
703
+ model_inputs = {"input_embeds": inputs_embeds}
704
+ else:
705
+ model_inputs = {"input_ids": input_ids}
706
+
707
+ # Make sure `pixel_values` are preserved in `model_inputs`
708
+ model_inputs.update(
709
+ {
710
+ "attention_mask": attention_mask,
711
+ "pixel_values": pixel_values,
712
+ "past_key_values": past_key_values,
713
+ "use_cache": kwargs.get("use_cache"),
714
+ }
715
+ )
716
+
717
+ return model_inputs
718
+
719
+ # Defer to Language Model (all handle this differently, with different return types)
720
+ def _reorder_cache(self, *args, **kwargs) -> Any:
721
+ return self.language_model._reorder_cache(*args, **kwargs)
722
+
723
+ def _prepare_input_for_action_prediction_verl(self, input_ids, attention_mask):
724
+ """Prepares input for action prediction by adding necessary tokens"""
725
+ # Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens
726
+ placeholder_action_token_ids = (
727
+ torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)
728
+ )
729
+ input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)
730
+
731
+ # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)
732
+ stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX
733
+ input_ids = torch.cat([input_ids, stop_token_id], dim=-1)
734
+
735
+ # Extend the attention mask to fit the new shape of input
736
+ # Note: Only batch size == 1 supported right now
737
+ mask_extension = (
738
+ torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
739
+ .to(attention_mask.device)
740
+ .to(attention_mask.dtype)
741
+ )
742
+ attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
743
+
744
+ return input_ids, attention_mask
745
+
746
+ def _prepare_labels_for_action_prediction_verl(self, labels, input_ids):
747
+ """Creates labels tensor for action prediction if not provided"""
748
+ # Extend labels tensor with fake action labels
749
+ ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1
750
+ labels_extension = (
751
+ torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
752
+ * ARBITRARY_ACTION_TOKEN_IDX
753
+ )
754
+ labels = torch.cat([labels, labels_extension], dim=-1)
755
+
756
+ # Replace last label token with stop token
757
+ labels[:, -1] = STOP_INDEX
758
+
759
+ return labels
760
+
761
+ def _verl_discrete_compute_logits(
762
+ self,
763
+ input_embeddings,
764
+ all_actions_mask,
765
+ projected_patch_embeddings,
766
+ attention_mask,
767
+ labels,
768
+ NUM_PATCHES,
769
+ NUM_PROMPT_TOKENS,
770
+ action_head=None,
771
+ ):#contintue!!!!!
772
+ """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
773
+ # Zero out action token embeddings
774
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
775
+ input_embeddings = input_embeddings * ~all_actions_mask
776
+
777
+ # Build multimodal embeddings and attention mask
778
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
779
+ input_embeddings, projected_patch_embeddings, attention_mask
780
+ )
781
+
782
+ # Forward pass through language model
783
+ language_model_output = self.language_model(
784
+ input_ids=None,
785
+ attention_mask=multimodal_attention_mask,
786
+ position_ids=None,
787
+ past_key_values=None,
788
+ inputs_embeds=multimodal_embeddings,
789
+ labels=None,
790
+ use_cache=None,
791
+ output_attentions=False,
792
+ output_hidden_states=False,
793
+ return_dict=True,
794
+ )
795
+
796
+ # Extract hidden states for action tokens
797
+ #last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
798
+ # actions_hidden_states = last_hidden_states[
799
+ # :,
800
+ # NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
801
+ # :,
802
+ # ] # (B, act_chunk_len, D)
803
+
804
+ # Handle different prediction methods
805
+ # if action_head is not None:
806
+ # # L1 regression prediction
807
+ # normalized_actions = action_head.predict_action(actions_hidden_states)
808
+ # normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
809
+ # normalized_actions = normalized_actions.float().cpu().detach().numpy()
810
+ # else:
811
+ # Discrete token-based prediction
812
+
813
+ compute_logits = language_model_output.logits[
814
+ :,
815
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
816
+ ]
817
+
818
+ return compute_logits
819
+
820
+ # def forward(
821
+ # self,
822
+ # input_ids: Optional[torch.LongTensor] = None,
823
+ # unnorm_key: Optional[str] = None,
824
+ # proprio=None,
825
+ # proprio_projector=None,
826
+ # action_head=None,
827
+ # noisy_action_projector=None,
828
+ # use_film: bool = False,
829
+ # **kwargs: str,
830
+ # ) :
831
+ # """Predict actions from input sequence, with options for different prediction methods.
832
+
833
+ # Args:
834
+ # input_ids: Input token ids
835
+ # unnorm_key: Key for unnormalization statistics
836
+ # proprio: Proprioceptive features
837
+ # proprio_projector: Projector for proprioceptive features
838
+ # action_head: Optional head for L1 regression or diffusion-based prediction
839
+ # noisy_action_projector: Projector for noisy actions in diffusion-based prediction
840
+ # use_film: Whether to use FiLM conditioning
841
+ # **kwargs: Additional arguments including pixel_values and attention_mask
842
+
843
+ # Returns:
844
+ # Tuple of (unnormalized_actions, action_hidden_states)
845
+ # """
846
+ # # If the special empty token ('') does not already appear after the colon (':') token in the prompt
847
+ # # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
848
+ # # if not torch.all(input_ids[:, -1] == 29871):
849
+ # # input_ids = torch.cat(
850
+ # # (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
851
+ # # )
852
+ # #print("!!!!!!!!!!!!!!Entering forward!!!!!!!!!!")
853
+ # pixel_values = kwargs["pixel_values"]
854
+ # attention_mask = kwargs["attention_mask"]
855
+
856
+ # # Create fake labels tensor (needed for action mask)
857
+ # labels = input_ids.clone()
858
+ # labels[:] = IGNORE_INDEX
859
+
860
+ # # Get number of tokens in prompt (excluding the start token)
861
+ # NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
862
+
863
+ # # Prepare inputs by adding necessary tokens
864
+ # #input_ids, attention_mask = self._prepare_input_for_action_prediction_verl(input_ids, attention_mask)
865
+
866
+ # #test
867
+ # placeholder_action_token_ids = (
868
+ # torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)
869
+ # )
870
+ # input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)
871
+
872
+ # # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)
873
+ # stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX
874
+ # input_ids = torch.cat([input_ids, stop_token_id], dim=-1)
875
+
876
+ # # Extend the attention mask to fit the new shape of input
877
+ # # Note: Only batch size == 1 supported right now
878
+ # mask_extension = (
879
+ # torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
880
+ # .to(attention_mask.device)
881
+ # .to(attention_mask.dtype)
882
+ # )
883
+ # attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
884
+
885
+ # #return input_ids, attention_mask
886
+
887
+ # #test end
888
+
889
+
890
+ # # Update labels tensor for action mask computation later
891
+ # #labels = self._prepare_labels_for_action_prediction_verl(labels, input_ids)
892
+ # #test
893
+
894
+ # ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1
895
+ # labels_extension = (
896
+ # torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
897
+ # * ARBITRARY_ACTION_TOKEN_IDX
898
+ # )
899
+ # labels = torch.cat([labels, labels_extension], dim=-1)
900
+
901
+ # # Replace last label token with stop token
902
+ # labels[:, -1] = STOP_INDEX
903
+
904
+ # #return labels
905
+
906
+ # #test ed
907
+
908
+
909
+ # # Get input embeddings and action masks
910
+
911
+
912
+
913
+ # input_embeddings = self.get_input_embeddings()(input_ids)
914
+
915
+
916
+ # #all_actions_mask = self._process_action_masks(labels)
917
+ # #test
918
+ # #current_action_mask = get_current_action_mask(labels)
919
+ # newline_positions = labels != IGNORE_INDEX
920
+
921
+ # # Calculate cumulative sum to identify regions between newlines
922
+ # cumsum = torch.cumsum(newline_positions, dim=1)
923
+
924
+ # # Create the mask
925
+ # mask = (1 <= cumsum) & (cumsum <= ACTION_DIM)
926
+
927
+ # # Extract the action part only
928
+ # action_tokens_only_mask = labels > ACTION_TOKEN_BEGIN_IDX
929
+ # current_action_mask = action_tokens_only_mask * mask
930
+
931
+ # #next_actions_mask = get_next_actions_mask(labels)
932
+ # newline_positions = labels != IGNORE_INDEX
933
+
934
+ # # Calculate cumulative sum to identify regions between newlines
935
+ # cumsum = torch.cumsum(newline_positions, dim=1)
936
+
937
+ # # Create the mask
938
+ # mask = cumsum > ACTION_DIM
939
+
940
+ # # Extract the action part only
941
+ # action_tokens_only_mask = labels > ACTION_TOKEN_BEGIN_IDX
942
+ # next_actions_mask = action_tokens_only_mask * mask
943
+
944
+ # all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len)
945
+
946
+ # #test end
947
+
948
+ # # Extract language embeddings
949
+ # language_embeddings = input_embeddings[~all_actions_mask].reshape(
950
+ # input_embeddings.shape[0], -1, input_embeddings.shape[2]
951
+ # )
952
+
953
+ # # Process vision features
954
+ # #projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
955
+ # #test
956
+ # if use_film:
957
+ # # FiLM: Infuse language inputs into visual features
958
+ # raise ValueError
959
+ # patch_features = self.vision_backbone(pixel_values, language_embeddings) # (bsz, 256 * num_images, D)
960
+ # else:
961
+ # patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D)
962
+
963
+ # projected_patch_embeddings = self.projector(patch_features)
964
+ # #test end
965
+
966
+
967
+ # # Add proprioceptive features if provided
968
+ # use_proprio = proprio_projector is not None and proprio is not None
969
+ # if use_proprio:
970
+ # proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
971
+ # projected_patch_embeddings = self._process_proprio_features(
972
+ # projected_patch_embeddings, proprio, proprio_projector
973
+ # )
974
+
975
+ # # Use diffusion if provided, otherwise use regression or discrete prediction
976
+ # use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
977
+
978
+ # # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
979
+ # NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
980
+ # if use_proprio:
981
+ # NUM_PATCHES += 1
982
+ # if use_diffusion:
983
+ # NUM_PATCHES += 1
984
+
985
+ # if use_diffusion:
986
+ # raise ValueError
987
+ # # Sample random noise with shape equal to output action, used as the starting state for reverse diffusion
988
+ # noise = torch.randn(
989
+ # size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype
990
+ # )
991
+
992
+ # # Run diffusion-based prediction
993
+ # normalized_actions, actions_hidden_states = self._run_diffusion_prediction(
994
+ # input_embeddings,
995
+ # all_actions_mask,
996
+ # noise,
997
+ # action_head,
998
+ # projected_patch_embeddings,
999
+ # labels,
1000
+ # attention_mask,
1001
+ # NUM_PATCHES,
1002
+ # NUM_PROMPT_TOKENS,
1003
+ # noisy_action_projector,
1004
+ # )
1005
+ # else:
1006
+ # # Run regression or discrete token-based prediction
1007
+ # # compute_logits = self._verl_discrete_compute_logits(
1008
+ # # input_embeddings,
1009
+ # # all_actions_mask,
1010
+ # # projected_patch_embeddings,
1011
+ # # attention_mask,
1012
+ # # labels,
1013
+ # # NUM_PATCHES,
1014
+ # # NUM_PROMPT_TOKENS,
1015
+ # # action_head,
1016
+ # # )
1017
+
1018
+ # #test
1019
+
1020
+ # all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
1021
+ # input_embeddings = input_embeddings * ~all_actions_mask
1022
+
1023
+ # # Build multimodal embeddings and attention mask
1024
+ # # multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1025
+ # # input_embeddings, projected_patch_embeddings, attention_mask
1026
+ # # )
1027
+ # #test
1028
+
1029
+ # projected_patch_attention_mask = None
1030
+ # if attention_mask is not None:
1031
+ # projected_patch_attention_mask = torch.full(
1032
+ # (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
1033
+ # fill_value=True,
1034
+ # dtype=attention_mask.dtype,
1035
+ # device=attention_mask.device,
1036
+ # )
1037
+
1038
+ # # Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)
1039
+ # multimodal_embeddings = torch.cat(
1040
+ # [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
1041
+ # )
1042
+
1043
+ # multimodal_attention_mask = None
1044
+ # if attention_mask is not None:
1045
+ # multimodal_attention_mask = torch.cat(
1046
+ # [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
1047
+ # )
1048
+
1049
+ # #return multimodal_embeddings, multimodal_attention_mask
1050
+
1051
+ # #test end
1052
+
1053
+ # # Forward pass through language model
1054
+ # language_model_output = self.language_model(
1055
+ # input_ids=None,
1056
+ # attention_mask=multimodal_attention_mask,
1057
+ # position_ids=None,
1058
+ # past_key_values=None,
1059
+ # inputs_embeds=multimodal_embeddings,
1060
+ # labels=None,
1061
+ # use_cache=None,
1062
+ # output_attentions=False,
1063
+ # output_hidden_states=False,
1064
+ # return_dict=True,
1065
+ # )
1066
+
1067
+
1068
+ # compute_logits = language_model_output.logits[
1069
+ # :,
1070
+ # NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1071
+ # ]
1072
+
1073
+ # #test end
1074
+
1075
+ # return compute_logits
1076
+
1077
+ def forward(
1078
+ self,
1079
+ input_ids: Optional[torch.LongTensor] = None,
1080
+ pixel_values=None,
1081
+ attention_mask=None,
1082
+ #labels=None,
1083
+ proprio=None,
1084
+ proprio_projector=None,
1085
+ action_head=None,
1086
+ noisy_action_projector=None,
1087
+ use_film: bool = False,
1088
+ **kwargs: str,
1089
+ ) :
1090
+ """Predict actions from input sequence, with options for different prediction methods.
1091
+
1092
+ Args:
1093
+ input_ids: Input token ids
1094
+ unnorm_key: Key for unnormalization statistics
1095
+ proprio: Proprioceptive features
1096
+ proprio_projector: Projector for proprioceptive features
1097
+ action_head: Optional head for L1 regression or diffusion-based prediction
1098
+ noisy_action_projector: Projector for noisy actions in diffusion-based prediction
1099
+ use_film: Whether to use FiLM conditioning
1100
+ **kwargs: Additional arguments including pixel_values and attention_mask
1101
+
1102
+ Returns:
1103
+ Tuple of (unnormalized_actions, action_hidden_states)
1104
+ """
1105
+ # If the special empty token ('') does not already appear after the colon (':') token in the prompt
1106
+ # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
1107
+ # if not torch.all(input_ids[:, -1] == 29871):
1108
+ # input_ids = torch.cat(
1109
+ # (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
1110
+ # )
1111
+
1112
+ #pixel_values = kwargs["pixel_values"]
1113
+ #attention_mask = kwargs["attention_mask"]
1114
+
1115
+ # Create fake labels tensor (needed for action mask)
1116
+ labels = input_ids.clone()
1117
+ labels[:] = IGNORE_INDEX
1118
+
1119
+ # # Get number of tokens in prompt (excluding the start token)
1120
+ NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
1121
+
1122
+
1123
+ # # Prepare inputs by adding necessary tokens
1124
+ # #input_ids, attention_mask = self._prepare_input_for_action_prediction_verl(input_ids, attention_mask)
1125
+
1126
+ # #test
1127
+ placeholder_action_token_ids = (
1128
+ torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)
1129
+ )
1130
+ input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)
1131
+
1132
+ # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)
1133
+ stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX
1134
+ input_ids = torch.cat([input_ids, stop_token_id], dim=-1)
1135
+
1136
+ # Extend the attention mask to fit the new shape of input
1137
+ # Note: Only batch size == 1 supported right now
1138
+ mask_extension = (
1139
+ torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
1140
+ .to(attention_mask.device)
1141
+ .to(attention_mask.dtype)
1142
+ )
1143
+ attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
1144
+
1145
+ ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1
1146
+ labels_extension = (
1147
+ torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
1148
+ * ARBITRARY_ACTION_TOKEN_IDX
1149
+ )
1150
+ labels = torch.cat([labels, labels_extension], dim=-1)
1151
+
1152
+ # # Replace last label token with stop token
1153
+ labels[:, -1] = STOP_INDEX
1154
+
1155
+
1156
+ # Get input embeddings and action masks
1157
+
1158
+ #NUM_PROMPT_TOKENS = kwargs["num_prompt_tokens"]
1159
+
1160
+ input_embeddings = self.get_input_embeddings()(input_ids)
1161
+
1162
+
1163
+ #all_actions_mask = self._process_action_masks(labels)
1164
+ #test
1165
+ #current_action_mask = get_current_action_mask(labels)
1166
+ newline_positions = labels != IGNORE_INDEX
1167
+
1168
+ # Calculate cumulative sum to identify regions between newlines
1169
+ cumsum = torch.cumsum(newline_positions, dim=1)
1170
+
1171
+ # Create the mask
1172
+ mask = (1 <= cumsum) & (cumsum <= ACTION_DIM)
1173
+
1174
+ # Extract the action part only
1175
+ action_tokens_only_mask = labels > ACTION_TOKEN_BEGIN_IDX
1176
+ current_action_mask = action_tokens_only_mask * mask
1177
+
1178
+ #next_actions_mask = get_next_actions_mask(labels)
1179
+ newline_positions = labels != IGNORE_INDEX
1180
+
1181
+ # Calculate cumulative sum to identify regions between newlines
1182
+ cumsum = torch.cumsum(newline_positions, dim=1)
1183
+
1184
+ # Create the mask
1185
+ mask = cumsum > ACTION_DIM
1186
+
1187
+ # Extract the action part only
1188
+ action_tokens_only_mask = labels > ACTION_TOKEN_BEGIN_IDX
1189
+ next_actions_mask = action_tokens_only_mask * mask
1190
+
1191
+ all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len)
1192
+
1193
+ #test end
1194
+
1195
+ # Extract language embeddings
1196
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
1197
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
1198
+ )
1199
+
1200
+ # Process vision features
1201
+ #projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
1202
+ #test
1203
+ if use_film:
1204
+ # FiLM: Infuse language inputs into visual features
1205
+ raise ValueError
1206
+ patch_features = self.vision_backbone(pixel_values, language_embeddings) # (bsz, 256 * num_images, D)
1207
+ else:
1208
+ patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D)
1209
+
1210
+ projected_patch_embeddings = self.projector(patch_features)
1211
+ #test end
1212
+
1213
+
1214
+ # Add proprioceptive features if provided
1215
+ use_proprio = proprio_projector is not None and proprio is not None
1216
+ if use_proprio:
1217
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1218
+ projected_patch_embeddings = self._process_proprio_features(
1219
+ projected_patch_embeddings, proprio, proprio_projector
1220
+ )
1221
+
1222
+ # Use diffusion if provided, otherwise use regression or discrete prediction
1223
+ use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
1224
+
1225
+ # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
1226
+ NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
1227
+ if use_proprio:
1228
+ NUM_PATCHES += 1
1229
+ if use_diffusion:
1230
+ NUM_PATCHES += 1
1231
+
1232
+ if use_diffusion:
1233
+ raise ValueError
1234
+ # Sample random noise with shape equal to output action, used as the starting state for reverse diffusion
1235
+ noise = torch.randn(
1236
+ size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype
1237
+ )
1238
+
1239
+ # Run diffusion-based prediction
1240
+ normalized_actions, actions_hidden_states = self._run_diffusion_prediction(
1241
+ input_embeddings,
1242
+ all_actions_mask,
1243
+ noise,
1244
+ action_head,
1245
+ projected_patch_embeddings,
1246
+ labels,
1247
+ attention_mask,
1248
+ NUM_PATCHES,
1249
+ NUM_PROMPT_TOKENS,
1250
+ noisy_action_projector,
1251
+ )
1252
+ else:
1253
+ # Run regression or discrete token-based prediction
1254
+ # compute_logits = self._verl_discrete_compute_logits(
1255
+ # input_embeddings,
1256
+ # all_actions_mask,
1257
+ # projected_patch_embeddings,
1258
+ # attention_mask,
1259
+ # labels,
1260
+ # NUM_PATCHES,
1261
+ # NUM_PROMPT_TOKENS,
1262
+ # action_head,
1263
+ # )
1264
+
1265
+ #test
1266
+
1267
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
1268
+ input_embeddings = input_embeddings * ~all_actions_mask
1269
+
1270
+ # Build multimodal embeddings and attention mask
1271
+ # multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1272
+ # input_embeddings, projected_patch_embeddings, attention_mask
1273
+ # )
1274
+ #test
1275
+
1276
+ projected_patch_attention_mask = None
1277
+ if attention_mask is not None:
1278
+ projected_patch_attention_mask = torch.full(
1279
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
1280
+ fill_value=True,
1281
+ dtype=attention_mask.dtype,
1282
+ device=attention_mask.device,
1283
+ )
1284
+
1285
+ # Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)
1286
+ multimodal_embeddings = torch.cat(
1287
+ [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
1288
+ )
1289
+
1290
+ multimodal_attention_mask = None
1291
+ if attention_mask is not None:
1292
+ multimodal_attention_mask = torch.cat(
1293
+ [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
1294
+ )
1295
+
1296
+ #return multimodal_embeddings, multimodal_attention_mask
1297
+
1298
+ #test end
1299
+
1300
+ # Forward pass through language model
1301
+ language_model_output = self.language_model(
1302
+ input_ids=None,
1303
+ attention_mask=multimodal_attention_mask,
1304
+ position_ids=None,
1305
+ past_key_values=None,
1306
+ inputs_embeds=multimodal_embeddings,
1307
+ labels=None,
1308
+ use_cache=None,
1309
+ output_attentions=False,
1310
+ output_hidden_states=False,
1311
+ return_dict=True,
1312
+ )
1313
+
1314
+
1315
+ compute_logits = language_model_output.logits[
1316
+ :,
1317
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1318
+ ]
1319
+
1320
+ #test end
1321
+
1322
+ return compute_logits
1323
+
1324
+
1325
+
1326
+ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
1327
+ config_class: PretrainedConfig = OpenVLAConfig
1328
+
1329
+ def __init__(self, config: OpenVLAConfig) -> None:
1330
+ super().__init__(config)
1331
+ self.norm_stats = config.norm_stats
1332
+
1333
+ # Compute action bins
1334
+ self.bins = np.linspace(-1, 1, config.n_action_bins)
1335
+ self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
1336
+
1337
+ # Compute vocab size for de-tokenization -- revert added "multiple of"
1338
+ self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
1339
+
1340
+ def _prepare_input_for_action_prediction(self, input_ids, attention_mask):
1341
+ """Prepares input for action prediction by adding necessary tokens"""
1342
+ # Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens
1343
+ placeholder_action_token_ids = (
1344
+ torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)
1345
+ )
1346
+ input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)
1347
+
1348
+ # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)
1349
+ stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX
1350
+ input_ids = torch.cat([input_ids, stop_token_id], dim=-1)
1351
+
1352
+ # Extend the attention mask to fit the new shape of input
1353
+ # Note: Only batch size == 1 supported right now
1354
+ mask_extension = (
1355
+ torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
1356
+ .to(attention_mask.device)
1357
+ .to(attention_mask.dtype)
1358
+ )
1359
+ attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
1360
+
1361
+ return input_ids, attention_mask
1362
+
1363
+ def _prepare_labels_for_action_prediction(self, labels, input_ids):
1364
+ """Creates labels tensor for action prediction if not provided"""
1365
+ # Extend labels tensor with fake action labels
1366
+ ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1
1367
+ labels_extension = (
1368
+ torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
1369
+ * ARBITRARY_ACTION_TOKEN_IDX
1370
+ )
1371
+ labels = torch.cat([labels, labels_extension], dim=-1)
1372
+
1373
+ # Replace last label token with stop token
1374
+ labels[:, -1] = STOP_INDEX
1375
+
1376
+ return labels
1377
+
1378
+ def _unnormalize_actions(self, normalized_actions, unnorm_key=None):
1379
+ """Unnormalize actions using dataset statistics"""
1380
+ action_norm_stats = self.get_action_stats(unnorm_key)
1381
+
1382
+ if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
1383
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool))
1384
+ action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"])
1385
+ elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
1386
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
1387
+ action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
1388
+ else:
1389
+ raise ValueError("Unsupported action/proprio normalization type detected!")
1390
+
1391
+ actions = np.where(
1392
+ mask,
1393
+ 0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low,
1394
+ normalized_actions,
1395
+ )
1396
+
1397
+ return actions
1398
+
1399
+ def _run_diffusion_prediction(
1400
+ self,
1401
+ input_embeddings,
1402
+ all_actions_mask,
1403
+ noise,
1404
+ action_head,
1405
+ projected_patch_embeddings,
1406
+ labels,
1407
+ attention_mask,
1408
+ NUM_PATCHES,
1409
+ NUM_PROMPT_TOKENS,
1410
+ noisy_action_projector,
1411
+ ):
1412
+ """Run diffusion-based action prediction"""
1413
+ # Set diffusion timestep values
1414
+ action_head.noise_scheduler.set_timesteps(action_head.num_diffusion_steps)
1415
+ # Clone embedding for reuse in each timestep
1416
+ orig_projected_patch_embeddings = projected_patch_embeddings.clone()
1417
+ curr_noisy_actions = noise
1418
+
1419
+ # Reverse diffusion: Iteratively denoise to generate action prediction
1420
+ for t in action_head.noise_scheduler.timesteps:
1421
+ # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action
1422
+ # embedding, and diffusion timestep embedding)
1423
+ timesteps = torch.Tensor([t]).to(labels.device)
1424
+ diffusion_timestep_embeddings = (
1425
+ action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
1426
+ ) # (B, llm_dim)
1427
+ diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
1428
+
1429
+ # [Diffusion] Replace the embeddings of the action tokens with noisy actions
1430
+ # (Later on, the positional embeddings will be added to them)
1431
+
1432
+ # For simplicity, append diffusion timestep embedding to the end of projected vision tokens
1433
+ projected_patch_embeddings = torch.cat(
1434
+ (orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
1435
+ )
1436
+
1437
+ # Reshape and project noisy actions into language embedding space
1438
+ B = curr_noisy_actions.shape[0]
1439
+ orig_curr_noisy_actions_shape = curr_noisy_actions.shape
1440
+ curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
1441
+ noisy_action_features = noisy_action_projector(curr_noisy_actions)
1442
+ curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
1443
+
1444
+ # Replace action token embeddings with noisy action embeddings
1445
+ input_embeddings = self._replace_input_embeddings(
1446
+ input_embeddings.clone(), all_actions_mask, noisy_action_features
1447
+ )
1448
+
1449
+ # Build multimodal embeddings and attention mask
1450
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1451
+ input_embeddings, projected_patch_embeddings, attention_mask
1452
+ )
1453
+
1454
+ # Forward pass through language model
1455
+ language_model_output = self.language_model(
1456
+ input_ids=None,
1457
+ attention_mask=multimodal_attention_mask,
1458
+ position_ids=None,
1459
+ past_key_values=None,
1460
+ inputs_embeds=multimodal_embeddings,
1461
+ labels=None,
1462
+ use_cache=None,
1463
+ output_attentions=False,
1464
+ output_hidden_states=True,
1465
+ return_dict=True,
1466
+ )
1467
+
1468
+ # Extract hidden states for action portion of response
1469
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
1470
+ actions_hidden_states = last_hidden_states[
1471
+ :,
1472
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1473
+ :,
1474
+ ] # (B, act_chunk_len, D)
1475
+
1476
+ # Predict noise and update noisy actions: x_t -> x_{t-1}
1477
+ noise_pred = action_head.predict_noise(actions_hidden_states)
1478
+ curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
1479
+
1480
+ curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1481
+
1482
+ # Return final actions
1483
+ return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
1484
+
1485
+ def _regression_or_discrete_prediction(
1486
+ self,
1487
+ input_embeddings,
1488
+ all_actions_mask,
1489
+ projected_patch_embeddings,
1490
+ attention_mask,
1491
+ labels,
1492
+ NUM_PATCHES,
1493
+ NUM_PROMPT_TOKENS,
1494
+ action_head=None,
1495
+ ):
1496
+ """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
1497
+ # Zero out action token embeddings
1498
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
1499
+ input_embeddings = input_embeddings * ~all_actions_mask
1500
+
1501
+ # Build multimodal embeddings and attention mask
1502
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1503
+ input_embeddings, projected_patch_embeddings, attention_mask
1504
+ )
1505
+
1506
+ # Forward pass through language model
1507
+ language_model_output = self.language_model(
1508
+ input_ids=None,
1509
+ attention_mask=multimodal_attention_mask,
1510
+ position_ids=None,
1511
+ past_key_values=None,
1512
+ inputs_embeds=multimodal_embeddings,
1513
+ labels=None,
1514
+ use_cache=None,
1515
+ output_attentions=False,
1516
+ output_hidden_states=True,
1517
+ return_dict=True,
1518
+ )
1519
+
1520
+ # Extract hidden states for action tokens
1521
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
1522
+ actions_hidden_states = last_hidden_states[
1523
+ :,
1524
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1525
+ :,
1526
+ ] # (B, act_chunk_len, D)
1527
+
1528
+ # Handle different prediction methods
1529
+ if action_head is not None:
1530
+ # L1 regression prediction
1531
+ normalized_actions = action_head.predict_action(actions_hidden_states)
1532
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1533
+ normalized_actions = normalized_actions.float().cpu().detach().numpy()
1534
+ else:
1535
+ # Discrete token-based prediction
1536
+ predicted_action_token_ids = (
1537
+ language_model_output.logits[
1538
+ :,
1539
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1540
+ ]
1541
+ .argmax(dim=2)
1542
+ .cpu()
1543
+ .numpy()
1544
+ )
1545
+ discretized_actions = self.vocab_size - predicted_action_token_ids
1546
+ discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
1547
+ normalized_actions = self.bin_centers[discretized_actions]
1548
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1549
+
1550
+ return normalized_actions, actions_hidden_states
1551
+
1552
+ def _verl_discrete_prediction(
1553
+ self,
1554
+ input_embeddings,
1555
+ all_actions_mask,
1556
+ projected_patch_embeddings,
1557
+ attention_mask,
1558
+ labels,
1559
+ NUM_PATCHES,
1560
+ NUM_PROMPT_TOKENS,
1561
+ action_head=None,
1562
+ do_sample=True,
1563
+ temperature=1,
1564
+ ):
1565
+ """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
1566
+ # Zero out action token embeddings
1567
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
1568
+ input_embeddings = input_embeddings * ~all_actions_mask
1569
+
1570
+ # Build multimodal embeddings and attention mask
1571
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1572
+ input_embeddings, projected_patch_embeddings, attention_mask
1573
+ )
1574
+
1575
+ # Forward pass through language model
1576
+ language_model_output = self.language_model(
1577
+ input_ids=None,
1578
+ attention_mask=multimodal_attention_mask,
1579
+ position_ids=None,
1580
+ past_key_values=None,
1581
+ inputs_embeds=multimodal_embeddings,
1582
+ labels=None,
1583
+ use_cache=None,
1584
+ output_attentions=False,
1585
+ output_hidden_states=False,
1586
+ return_dict=True,
1587
+ )
1588
+
1589
+ # Extract hidden states for action tokens
1590
+ #last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
1591
+ # actions_hidden_states = last_hidden_states[
1592
+ # :,
1593
+ # NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1594
+ # :,
1595
+ # ] # (B, act_chunk_len, D)
1596
+
1597
+ # Handle different prediction methods
1598
+ # if action_head is not None:
1599
+ # # L1 regression prediction
1600
+ # normalized_actions = action_head.predict_action(actions_hidden_states)
1601
+ # normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1602
+ # normalized_actions = normalized_actions.float().cpu().detach().numpy()
1603
+ # else:
1604
+ # Discrete token-based prediction
1605
+
1606
+ #test
1607
+ # NUM_PROMPT_TOKENS = NUM_PROMPT_TOKENS + NUM_PATCHES
1608
+ # j = torch.arange(language_model_output.logits.shape[1], device=NUM_PROMPT_TOKENS.device)
1609
+ # start = NUM_PROMPT_TOKENS.unsqueeze(1)
1610
+ # end = start + ACTION_DIM * NUM_ACTIONS_CHUNK
1611
+ # mask_2d = (j >= start) & (j < end)
1612
+ # mask = mask_2d.unsqueeze(-1)
1613
+ # actions_masks = mask.expand_as(language_model_output.logits)
1614
+
1615
+
1616
+ NUM_PROMPT_TOKENS = NUM_PROMPT_TOKENS + NUM_PATCHES
1617
+ batch_size = language_model_output.logits.shape[0]
1618
+ device = language_model_output.logits.device
1619
+
1620
+
1621
+ start_indices = NUM_PROMPT_TOKENS.unsqueeze(1) # [batch_size, 1]
1622
+ position_offsets = torch.arange(ACTION_DIM * NUM_ACTIONS_CHUNK, device=device).unsqueeze(0) # [1, seq_length]
1623
+ seq_indices = start_indices + position_offsets # [batch_size, ACTION_DIM*NUM_ACTIONS_CHUNK]
1624
+ #test end
1625
+ #test add
1626
+ #print("language_model_output",language_model_output.logits.shape[-1])
1627
+ #print("self.vocab_size",self.vocab_size) 32000
1628
+ #topk_values, topk_indices = torch.topk(language_model_output.logits, k=256, dim=-1)
1629
+ #print(topk_indices)
1630
+ #assert language_model_output.logits.shape[-1] == self.vocab_size
1631
+ #test add
1632
+ if do_sample == False:
1633
+ #org
1634
+ # reponse_ids = language_model_output.logits[
1635
+ # :,
1636
+ # NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1637
+ # ].argmax(dim=2)
1638
+ #reponse_ids = language_model_output.logits[actions_masks].argmax(dim=2)
1639
+ #org end
1640
+
1641
+ #padding
1642
+ # reponse_ids = language_model_output.logits[
1643
+ # torch.arange(batch_size, device=device).unsqueeze(-1),
1644
+ # seq_indices,
1645
+ # :
1646
+ # ].argmax(dim=2)
1647
+ #padding end
1648
+
1649
+ #padding + only get last 256 token
1650
+ reponse_ids_logits = language_model_output.logits[
1651
+ torch.arange(batch_size, device=device).unsqueeze(-1),
1652
+ seq_indices,
1653
+ :
1654
+ ]
1655
+ start_index = self.vocab_size - 256
1656
+ response_last256 = reponse_ids_logits[..., -256-64:-64] # Shape: [batch_size, seq_len, 256]
1657
+ last256_argmax = response_last256.argmax(dim=-1) # Shape: [batch_size, seq_len]
1658
+ reponse_ids = last256_argmax + start_index # Shape: [batch_size, seq_len]
1659
+ #padding + only get last 256 token end
1660
+
1661
+ predicted_action_token_ids = reponse_ids.cpu().numpy()
1662
+
1663
+ else:
1664
+ assert temperature>0
1665
+ #org
1666
+ # action_logits = language_model_output.logits[
1667
+ # :,
1668
+ # NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1669
+ # ]
1670
+ #action_logits = language_model_output.logits[actions_masks]
1671
+ #org end
1672
+
1673
+ action_logits = language_model_output.logits[
1674
+ torch.arange(batch_size, device=device).unsqueeze(-1),
1675
+ seq_indices,
1676
+ :
1677
+ ]
1678
+ # padding
1679
+ # scaled_logits = action_logits / temperature
1680
+ # probs = torch.softmax(scaled_logits, dim=-1)
1681
+ # probs_flat = probs.reshape(-1, probs.shape[-1]) # (B*act_chunk_len, vocab_size)
1682
+ # sampled_indices_flat = torch.multinomial(probs_flat, num_samples=1) # (B*act_chunk_len, 1)
1683
+ # reponse_ids = sampled_indices_flat.view(action_logits.shape[0], -1)
1684
+ # padding end
1685
+
1686
+ #padding + only get last 256 token
1687
+ action_logits_last256 = action_logits[..., -256-64:-64]
1688
+ scaled_logits = action_logits_last256 / temperature
1689
+ probs = torch.softmax(scaled_logits, dim=-1)
1690
+ assert probs.shape[-1] == 256
1691
+ probs_flat = probs.reshape(-1, probs.shape[-1])
1692
+ sampled_indices_flat = torch.multinomial(probs_flat, num_samples=1)
1693
+ original_ids_flat = sampled_indices_flat + (self.vocab_size - 256)
1694
+ reponse_ids = original_ids_flat.view(action_logits.shape[0], -1)
1695
+ #padding + only get last 256 token end
1696
+
1697
+ predicted_action_token_ids = reponse_ids.cpu().numpy()
1698
+
1699
+ discretized_actions = self.vocab_size - predicted_action_token_ids
1700
+ discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
1701
+ normalized_actions = self.bin_centers[discretized_actions]
1702
+ #normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1703
+ normalized_actions = normalized_actions.reshape(-1, ACTION_DIM)
1704
+
1705
+ return normalized_actions, reponse_ids
1706
+ #return normalized_actions, actions_hidden_states
1707
+
1708
+
1709
+
1710
+
1711
+ def predict_action(
1712
+ self,
1713
+ input_ids: Optional[torch.LongTensor] = None,
1714
+ unnorm_key: Optional[str] = None,
1715
+ proprio=None,
1716
+ proprio_projector=None,
1717
+ action_head=None,
1718
+ noisy_action_projector=None,
1719
+ use_film: bool = False,
1720
+ **kwargs: str,
1721
+ ) -> np.ndarray:
1722
+ """Predict actions from input sequence, with options for different prediction methods.
1723
+
1724
+ Args:
1725
+ input_ids: Input token ids
1726
+ unnorm_key: Key for unnormalization statistics
1727
+ proprio: Proprioceptive features
1728
+ proprio_projector: Projector for proprioceptive features
1729
+ action_head: Optional head for L1 regression or diffusion-based prediction
1730
+ noisy_action_projector: Projector for noisy actions in diffusion-based prediction
1731
+ use_film: Whether to use FiLM conditioning
1732
+ **kwargs: Additional arguments including pixel_values and attention_mask
1733
+
1734
+ Returns:
1735
+ Tuple of (unnormalized_actions, action_hidden_states)
1736
+ """
1737
+ # If the special empty token ('') does not already appear after the colon (':') token in the prompt
1738
+ # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
1739
+ if not torch.all(input_ids[:, -1] == 29871):
1740
+ input_ids = torch.cat(
1741
+ (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
1742
+ )
1743
+
1744
+ pixel_values = kwargs["pixel_values"]
1745
+ attention_mask = kwargs["attention_mask"]
1746
+
1747
+ # Create fake labels tensor (needed for action mask)
1748
+ labels = input_ids.clone()
1749
+ labels[:] = IGNORE_INDEX
1750
+
1751
+ # Get number of tokens in prompt (excluding the start token)
1752
+ NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
1753
+
1754
+ # Prepare inputs by adding necessary tokens
1755
+ input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
1756
+
1757
+ # Update labels tensor for action mask computation later
1758
+ labels = self._prepare_labels_for_action_prediction(labels, input_ids)
1759
+
1760
+ # Get input embeddings and action masks
1761
+ input_embeddings = self.get_input_embeddings()(input_ids)
1762
+ all_actions_mask = self._process_action_masks(labels)
1763
+
1764
+ # Extract language embeddings
1765
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
1766
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
1767
+ )
1768
+
1769
+ # Process vision features
1770
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
1771
+
1772
+ # Add proprioceptive features if provided
1773
+ use_proprio = proprio_projector is not None and proprio is not None
1774
+ if use_proprio:
1775
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1776
+ projected_patch_embeddings = self._process_proprio_features(
1777
+ projected_patch_embeddings, proprio, proprio_projector
1778
+ )
1779
+
1780
+ # Use diffusion if provided, otherwise use regression or discrete prediction
1781
+ use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
1782
+
1783
+ # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
1784
+ NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
1785
+ if use_proprio:
1786
+ NUM_PATCHES += 1
1787
+ if use_diffusion:
1788
+ NUM_PATCHES += 1
1789
+
1790
+ if use_diffusion:
1791
+ # Sample random noise with shape equal to output action, used as the starting state for reverse diffusion
1792
+ noise = torch.randn(
1793
+ size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype
1794
+ )
1795
+
1796
+ # Run diffusion-based prediction
1797
+ normalized_actions, actions_hidden_states = self._run_diffusion_prediction(
1798
+ input_embeddings,
1799
+ all_actions_mask,
1800
+ noise,
1801
+ action_head,
1802
+ projected_patch_embeddings,
1803
+ labels,
1804
+ attention_mask,
1805
+ NUM_PATCHES,
1806
+ NUM_PROMPT_TOKENS,
1807
+ noisy_action_projector,
1808
+ )
1809
+ else:
1810
+ # Run regression or discrete token-based prediction
1811
+ normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(
1812
+ input_embeddings,
1813
+ all_actions_mask,
1814
+ projected_patch_embeddings,
1815
+ attention_mask,
1816
+ labels,
1817
+ NUM_PATCHES,
1818
+ NUM_PROMPT_TOKENS,
1819
+ action_head,
1820
+ )
1821
+
1822
+ # Unnormalize predicted actions
1823
+ actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1824
+
1825
+ return actions, actions_hidden_states
1826
+
1827
+ def generate_action_verl(
1828
+ self,
1829
+ input_ids: Optional[torch.LongTensor] = None,
1830
+ unnorm_key: Optional[str] = None,
1831
+ proprio=None,
1832
+ proprio_projector=None,
1833
+ action_head=None,
1834
+ noisy_action_projector=None,
1835
+ use_film: bool = False,
1836
+ **kwargs: str,
1837
+ ) -> np.ndarray:
1838
+ """Predict actions from input sequence, with options for different prediction methods.
1839
+
1840
+ Args:
1841
+ input_ids: Input token ids
1842
+ unnorm_key: Key for unnormalization statistics
1843
+ proprio: Proprioceptive features
1844
+ proprio_projector: Projector for proprioceptive features
1845
+ action_head: Optional head for L1 regression or diffusion-based prediction
1846
+ noisy_action_projector: Projector for noisy actions in diffusion-based prediction
1847
+ use_film: Whether to use FiLM conditioning
1848
+ **kwargs: Additional arguments including pixel_values and attention_mask
1849
+
1850
+ Returns:
1851
+ Tuple of (unnormalized_actions, action_hidden_states)
1852
+ """
1853
+ # If the special empty token ('') does not already appear after the colon (':') token in the prompt
1854
+ # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
1855
+ # if not torch.all(input_ids[:, -1] == 29871):
1856
+ # input_ids = torch.cat(
1857
+ # (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
1858
+ # )
1859
+
1860
+ pixel_values = kwargs["pixel_values"]
1861
+ attention_mask = kwargs["attention_mask"]
1862
+ do_sample = kwargs["do_sample"]
1863
+ temperature = kwargs["temperature"]
1864
+
1865
+ # Create fake labels tensor (needed for action mask)
1866
+ labels = input_ids.clone()
1867
+ labels[:] = IGNORE_INDEX
1868
+
1869
+ # Get number of tokens in prompt (excluding the start token)
1870
+ #NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
1871
+ #test
1872
+ padding_idx = kwargs["padding_idx"]
1873
+ num_prompt_tokens = input_ids.ne(padding_idx).sum(dim=1) - 1
1874
+ #test end
1875
+
1876
+
1877
+ # Prepare inputs by adding necessary tokens
1878
+ input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
1879
+
1880
+ # Update labels tensor for action mask computation later
1881
+ labels = self._prepare_labels_for_action_prediction(labels, input_ids)
1882
+
1883
+ #here to convert padding from before to last
1884
+ #test
1885
+ padding_mask = input_ids.ne(padding_idx)
1886
+ assert torch.all(padding_mask==attention_mask.ne(0))
1887
+ #print("in predict_action padding_mask:", padding_mask)
1888
+ padding_mask = padding_mask.int()
1889
+ sorted_indices = torch.argsort(padding_mask, dim=1, descending=True, stable=True)
1890
+ input_ids = torch.gather(input_ids, 1, sorted_indices)
1891
+ attention_mask = torch.gather(attention_mask, 1, sorted_indices)
1892
+ labels = torch.gather(labels, 1, sorted_indices)
1893
+ assert use_film==False
1894
+ #test end
1895
+
1896
+
1897
+ # Get input embeddings and action masks
1898
+ input_embeddings = self.get_input_embeddings()(input_ids)
1899
+ all_actions_mask = self._process_action_masks(labels)
1900
+
1901
+ # Extract language embeddings
1902
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
1903
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
1904
+ )
1905
+
1906
+ # Process vision features
1907
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
1908
+
1909
+ # Add proprioceptive features if provided
1910
+ use_proprio = proprio_projector is not None and proprio is not None
1911
+ if use_proprio:
1912
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1913
+ projected_patch_embeddings = self._process_proprio_features(
1914
+ projected_patch_embeddings, proprio, proprio_projector
1915
+ )
1916
+
1917
+ # Use diffusion if provided, otherwise use regression or discrete prediction
1918
+ use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
1919
+
1920
+ # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
1921
+ NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
1922
+ if use_proprio:
1923
+ NUM_PATCHES += 1
1924
+ if use_diffusion:
1925
+ NUM_PATCHES += 1
1926
+
1927
+ if use_diffusion:
1928
+ raise ValueError
1929
+ # Sample random noise with shape equal to output action, used as the starting state for reverse diffusion
1930
+ noise = torch.randn(
1931
+ size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype
1932
+ )
1933
+
1934
+ # Run diffusion-based prediction
1935
+ normalized_actions, actions_hidden_states = self._run_diffusion_prediction(
1936
+ input_embeddings,
1937
+ all_actions_mask,
1938
+ noise,
1939
+ action_head,
1940
+ projected_patch_embeddings,
1941
+ labels,
1942
+ attention_mask,
1943
+ NUM_PATCHES,
1944
+ NUM_PROMPT_TOKENS,
1945
+ noisy_action_projector,
1946
+ )
1947
+ else:
1948
+ # Run regression or discrete token-based prediction
1949
+ normalized_actions, reponse_ids = self._verl_discrete_prediction(
1950
+ input_embeddings,
1951
+ all_actions_mask,
1952
+ projected_patch_embeddings,
1953
+ attention_mask,
1954
+ labels,
1955
+ NUM_PATCHES,
1956
+ num_prompt_tokens,
1957
+ action_head,
1958
+ do_sample=do_sample,
1959
+ temperature=temperature,
1960
+ )
1961
+
1962
+ # Unnormalize predicted actions
1963
+ actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1964
+ #verl add!
1965
+ actions = actions.reshape(-1 ,NUM_ACTIONS_CHUNK, ACTION_DIM)
1966
+ #
1967
+ return actions, reponse_ids
1968
+
1969
+
1970
+
1971
+ @staticmethod
1972
+ def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
1973
+ """Validate and resolve the unnormalization key for action statistics"""
1974
+ if unnorm_key is None:
1975
+ assert len(norm_stats) == 1, (
1976
+ f"Your model was trained on more than one dataset, "
1977
+ f"please pass a `unnorm_key` from the following options to choose the statistics "
1978
+ f"used for un-normalizing actions: {norm_stats.keys()}"
1979
+ )
1980
+ unnorm_key = next(iter(norm_stats.keys()))
1981
+
1982
+ assert unnorm_key in norm_stats, (
1983
+ f"The `unnorm_key` you chose is not in the set of available dataset statistics, "
1984
+ f"please choose from: {norm_stats.keys()}"
1985
+ )
1986
+ return unnorm_key
1987
+
1988
+ def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
1989
+ """Get the dimensionality of the policy's action space."""
1990
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1991
+ return len(self.norm_stats[unnorm_key]["action"]["min"])
1992
+
1993
+ def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
1994
+ """Get all the logged statistics for the given dataset."""
1995
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1996
+ return self.norm_stats[unnorm_key]["action"]
overview.png ADDED

Git LFS Details

  • SHA256: c47aada3c449b61a645b0f0c549b1bfaf7e3a614823844aabaf8e46a7df46e97
  • Pointer size: 131 Bytes
  • Size of remote file: 234 kB
preprocessor_config.json ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "processing_prismatic.PrismaticImageProcessor",
4
+ "AutoProcessor": "processing_prismatic.PrismaticProcessor"
5
+ },
6
+ "image_processor_type": "PrismaticImageProcessor",
7
+ "image_resize_strategy": "resize-naive",
8
+ "input_sizes": [
9
+ [
10
+ 3,
11
+ 224,
12
+ 224
13
+ ],
14
+ [
15
+ 3,
16
+ 224,
17
+ 224
18
+ ]
19
+ ],
20
+ "interpolations": [
21
+ "bicubic",
22
+ "bicubic"
23
+ ],
24
+ "means": [
25
+ [
26
+ 0.485,
27
+ 0.456,
28
+ 0.406
29
+ ],
30
+ [
31
+ 0.5,
32
+ 0.5,
33
+ 0.5
34
+ ]
35
+ ],
36
+ "processor_class": "PrismaticProcessor",
37
+ "stds": [
38
+ [
39
+ 0.229,
40
+ 0.224,
41
+ 0.225
42
+ ],
43
+ [
44
+ 0.5,
45
+ 0.5,
46
+ 0.5
47
+ ]
48
+ ],
49
+ "tvf_crop_params": [
50
+ {
51
+ "output_size": [
52
+ 224,
53
+ 224
54
+ ]
55
+ },
56
+ {
57
+ "output_size": [
58
+ 224,
59
+ 224
60
+ ]
61
+ }
62
+ ],
63
+ "tvf_do_letterbox": false,
64
+ "tvf_letterbox_fill": null,
65
+ "tvf_normalize_params": [
66
+ {
67
+ "inplace": false,
68
+ "mean": [
69
+ 0.484375,
70
+ 0.455078125,
71
+ 0.40625
72
+ ],
73
+ "std": [
74
+ 0.228515625,
75
+ 0.2236328125,
76
+ 0.224609375
77
+ ]
78
+ },
79
+ {
80
+ "inplace": false,
81
+ "mean": [
82
+ 0.5,
83
+ 0.5,
84
+ 0.5
85
+ ],
86
+ "std": [
87
+ 0.5,
88
+ 0.5,
89
+ 0.5
90
+ ]
91
+ }
92
+ ],
93
+ "tvf_resize_params": [
94
+ {
95
+ "antialias": true,
96
+ "interpolation": 3,
97
+ "max_size": null,
98
+ "size": [
99
+ 224,
100
+ 224
101
+ ]
102
+ },
103
+ {
104
+ "antialias": true,
105
+ "interpolation": 3,
106
+ "max_size": null,
107
+ "size": [
108
+ 224,
109
+ 224
110
+ ]
111
+ }
112
+ ],
113
+ "use_fused_vision_backbone": true
114
+ }
processing_prismatic.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ processing_prismatic.py
3
+
4
+ HuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration
5
+ specifies `siglip-224px+7b`.
6
+ """
7
+
8
+ from typing import Any, ClassVar, List, Optional, Tuple, Union
9
+
10
+ import timm.data
11
+ import torch
12
+ import torchvision.transforms.functional as TVF
13
+ from PIL import Image
14
+ from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
15
+ from transformers import PreTrainedTokenizerBase
16
+ from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin
17
+ from transformers.processing_utils import ProcessorMixin
18
+ from transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
19
+ from transformers.utils import TensorType
20
+
21
+
22
+ # === Image Processing ===
23
+ def letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image:
24
+ """Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
25
+ (w, h), max_wh = image.size, max(image.size)
26
+ horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
27
+ padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
28
+
29
+ return TVF.pad(image, padding, fill=padding_fill_value, padding_mode="constant")
30
+
31
+
32
+ class PrismaticImageProcessor(ImageProcessingMixin):
33
+ model_input_names: ClassVar[List[str]] = ["pixel_values"]
34
+
35
+ def __init__(
36
+ self,
37
+ use_fused_vision_backbone: bool = False,
38
+ image_resize_strategy: str = "letterbox",
39
+ input_sizes: Optional[List[Tuple[int, int, int]]] = None,
40
+ interpolations: Optional[List[str]] = None,
41
+ means: Optional[List[Tuple[float, float, float]]] = None,
42
+ stds: Optional[List[Tuple[float, float, float]]] = None,
43
+ **kwargs: str,
44
+ ) -> None:
45
+ """
46
+ Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be
47
+ created by TIMM, and edited to follow our custom `image_resize_strategy` logic.
48
+
49
+ @param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone
50
+ @param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox >
51
+ @param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height)
52
+ @param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: "bicubic")
53
+ @param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`)
54
+ @param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`)
55
+ """
56
+ self.use_fused_vision_backbone = use_fused_vision_backbone
57
+ self.image_resize_strategy = image_resize_strategy
58
+
59
+ # Handle `None` default values
60
+ input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes
61
+ means = [(0.5, 0.5, 0.5)] if means is None else means
62
+ stds = [(0.5, 0.5, 0.5)] if stds is None else stds
63
+
64
+ # TIMM `data_cfg` Parameters
65
+ self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds
66
+
67
+ # Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values!
68
+ self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], []
69
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
70
+
71
+ for idx in range(len(input_sizes)):
72
+ transform = timm.data.create_transform(
73
+ input_size=self.input_sizes[idx],
74
+ interpolation=self.interpolations[idx],
75
+ mean=self.means[idx],
76
+ std=self.stds[idx],
77
+ crop_pct=1.0, # Set to 1.0 to ignore cropping (initial Resize sets `input_size`)
78
+ crop_mode="center", # Default crop mode -- no-op when `crop_pct == 1.0`
79
+ is_training=False, # No image augmentations when loading the transform!
80
+ )
81
+
82
+ # [Validation] Ensure appropriate transform structure, expected sizes
83
+ if not (
84
+ isinstance(transform, Compose)
85
+ and (len(transform.transforms) == 4)
86
+ and isinstance(transform.transforms[0], Resize)
87
+ and isinstance(transform.transforms[1], CenterCrop)
88
+ and isinstance(transform.transforms[2], ToTensor)
89
+ and isinstance(transform.transforms[3], Normalize)
90
+ and (transform.transforms[0].size == self.input_sizes[idx][-1])
91
+ and (transform.transforms[1].size == self.input_sizes[idx][-2:])
92
+ ):
93
+ raise ValueError(f"Unexpected TIMM image transformation structure/sizes: `{transform}`")
94
+
95
+ # HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute.
96
+ # => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`)
97
+ resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3]
98
+ self.tvf_resize_params.append(
99
+ {
100
+ "size": resize_t.size,
101
+ "interpolation": TVF.pil_modes_mapping[resize_t.interpolation],
102
+ "max_size": None,
103
+ "antialias": True,
104
+ }
105
+ )
106
+ self.tvf_crop_params.append({"output_size": crop_t.size})
107
+ self.tvf_normalize_params.append(
108
+ {
109
+ "mean": norm_t.mean.float().numpy().tolist(),
110
+ "std": norm_t.std.float().numpy().tolist(),
111
+ "inplace": False,
112
+ }
113
+ )
114
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
115
+
116
+ # Handle Prismatic `image_resize_strategy`
117
+ if self.image_resize_strategy == "resize-naive":
118
+ self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size)
119
+ elif self.image_resize_strategy == "letterbox":
120
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]])
121
+ elif self.image_resize_strategy == "resize-crop":
122
+ pass
123
+ else:
124
+ raise ValueError(f"Image resize strategy `{self.image_resize_strategy}` is not supported!")
125
+
126
+ # Dispatch **kwargs to super()
127
+ super().__init__(**kwargs)
128
+
129
+ def apply_transform(self, img: Image.Image) -> torch.Tensor:
130
+ """Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])"""
131
+ if self.tvf_do_letterbox:
132
+ img = letterbox_pad_transform(img, self.tvf_letterbox_fill)
133
+
134
+ # [Contract] Fused Backbones expect "channel-stacked" inputs; we'll unpack on the model side!
135
+ imgs_t = []
136
+ for idx in range(len(self.input_sizes)):
137
+ img_idx = TVF.resize(img, **self.tvf_resize_params[idx])
138
+ img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx])
139
+ img_idx_t = TVF.to_tensor(img_idx)
140
+ img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx])
141
+ imgs_t.append(img_idx_t)
142
+
143
+ # [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0
144
+ img_t = torch.vstack(imgs_t)
145
+
146
+ return img_t
147
+
148
+ def preprocess(
149
+ self,
150
+ images: Union[Image.Image, List[Image.Image]],
151
+ return_tensors: Optional[Union[str, TensorType]] = None,
152
+ **_: str,
153
+ ) -> BatchFeature:
154
+ """
155
+ Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we
156
+ explicitly only handle PIL.Image.Image instances for simplicity.
157
+
158
+ @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
159
+ @param return_tensors: BatchFeature default Tensor format (e.g., "pt" for torch); if None, returns np.ndarray
160
+
161
+ @return: Instance of `transformers :: BatchFeature` with a single key "pixel_values"
162
+ """
163
+ if not isinstance(images, list):
164
+ images = [images]
165
+
166
+ # Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor
167
+ pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images])
168
+
169
+ # Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert
170
+ return BatchFeature(data={"pixel_values": pixel_values.float().numpy()}, tensor_type=return_tensors)
171
+
172
+ def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature:
173
+ return self.preprocess(images, **kwargs)
174
+
175
+
176
+ # === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer ===
177
+ # =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py
178
+ class PrismaticProcessor(ProcessorMixin):
179
+ attributes: ClassVar[List[str]] = ["image_processor", "tokenizer"]
180
+ image_processor_class: str = "AutoImageProcessor"
181
+ tokenizer_class: str = "AutoTokenizer"
182
+
183
+ def __init__(
184
+ self,
185
+ image_processor: Optional[ImageProcessingMixin] = None,
186
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
187
+ ) -> None:
188
+ super().__init__(image_processor, tokenizer)
189
+
190
+ def __call__(
191
+ self,
192
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
193
+ images: Union[Image.Image, List[Image.Image]],
194
+ padding: Union[bool, str, PaddingStrategy] = False,
195
+ truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
196
+ max_length: Optional[int] = None,
197
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
198
+ ) -> BatchFeature:
199
+ """
200
+ Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer,
201
+ forwards images to PrismaticImageProcessor.
202
+
203
+ @param text: The (batch) of text to encode; must be a string or list of strings.
204
+ @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
205
+ @param padding: Sequence padding strategy (if multiple specified) in < True = "longest" | "max_length" | False >
206
+ @param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified
207
+ @param max_length: Maximum length (in tokens) to truncate
208
+ @param return_tensors: Type of return tensors (usually "pt" or TensorType.PYTORCH)
209
+
210
+ @return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`.
211
+ """
212
+ pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
213
+ text_inputs = self.tokenizer(
214
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
215
+ )
216
+
217
+ # [Validate] Need same number of images and text inputs!
218
+ if pixel_values.shape[0] != text_inputs.input_ids.shape[0]:
219
+ raise ValueError("Batch is malformed; expected same number of images and text inputs!")
220
+
221
+ return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
222
+
223
+ # === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation ===
224
+ def batch_decode(
225
+ self,
226
+ sequences: Union[List[int], List[List[int]], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
227
+ skip_special_tokens: bool = False,
228
+ clean_up_tokenization_spaces: Optional[bool] = None,
229
+ **kwargs: str,
230
+ ) -> List[str]:
231
+ return self.tokenizer.batch_decode(
232
+ sequences=sequences,
233
+ skip_special_tokens=skip_special_tokens,
234
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
235
+ **kwargs,
236
+ )
237
+
238
+ def decode(
239
+ self,
240
+ token_ids: Union[int, List[int], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
241
+ skip_special_tokens: bool = False,
242
+ clean_up_tokenization_spaces: Optional[bool] = None,
243
+ **kwargs: str,
244
+ ) -> str:
245
+ return self.tokenizer.decode(
246
+ token_ids=token_ids,
247
+ skip_special_tokens=skip_special_tokens,
248
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
249
+ **kwargs,
250
+ )
251
+
252
+ @property
253
+ def model_input_names(self) -> List[str]:
254
+ tokenizer_input_names = self.tokenizer.model_input_names
255
+ image_processor_input_names = self.image_processor.model_input_names
256
+
257
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_prismatic.PrismaticProcessor"
4
+ },
5
+ "processor_class": "PrismaticProcessor"
6
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<PAD>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer_config.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "32000": {
30
+ "content": "<PAD>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ }
37
+ },
38
+ "auto_map": {
39
+ "AutoProcessor": "processing_prismatic.PrismaticProcessor"
40
+ },
41
+ "bos_token": "<s>",
42
+ "clean_up_tokenization_spaces": false,
43
+ "eos_token": "</s>",
44
+ "legacy": false,
45
+ "model_max_length": 2048,
46
+ "pad_token": "<PAD>",
47
+ "padding_side": "right",
48
+ "processor_class": "PrismaticProcessor",
49
+ "sp_model_kwargs": {},
50
+ "tokenizer_class": "LlamaTokenizer",
51
+ "unk_token": "<unk>",
52
+ "use_default_system_prompt": false
53
+ }