Update modeling_prismatic.py to match public GitHub repo
Browse filesCheck input_ids before adding the special empty token ('') rather
than adding it unconditionally.
- modeling_prismatic.py +12 -11
modeling_prismatic.py
CHANGED
|
@@ -504,14 +504,15 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
| 504 |
self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
|
| 505 |
|
| 506 |
def predict_action(
|
| 507 |
-
self, input_ids: Optional[torch.LongTensor] = None, unnorm_key: Optional[str] = None, **kwargs
|
| 508 |
) -> np.ndarray:
|
| 509 |
"""Thin wrapper around .generate() that decodes predicted actions and unnormalizes them."""
|
| 510 |
-
#
|
| 511 |
-
#
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
|
|
|
| 515 |
|
| 516 |
# Run VLA inference
|
| 517 |
generated_ids = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key), **kwargs)
|
|
@@ -535,7 +536,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
| 535 |
return actions
|
| 536 |
|
| 537 |
@staticmethod
|
| 538 |
-
def _check_unnorm_key(norm_stats, unnorm_key):
|
| 539 |
if unnorm_key is None:
|
| 540 |
assert len(norm_stats) == 1, (
|
| 541 |
f"Your model was trained on more than one dataset, "
|
|
@@ -550,12 +551,12 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
| 550 |
)
|
| 551 |
return unnorm_key
|
| 552 |
|
| 553 |
-
def get_action_dim(self, unnorm_key=None):
|
| 554 |
-
"""
|
| 555 |
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
|
| 556 |
return len(self.norm_stats[unnorm_key]["action"]["q01"])
|
| 557 |
|
| 558 |
-
def get_action_stats(self, unnorm_key=None):
|
| 559 |
-
"""
|
| 560 |
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
|
| 561 |
return self.norm_stats[unnorm_key]["action"]
|
|
|
|
| 504 |
self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
|
| 505 |
|
| 506 |
def predict_action(
|
| 507 |
+
self, input_ids: Optional[torch.LongTensor] = None, unnorm_key: Optional[str] = None, **kwargs: str
|
| 508 |
) -> np.ndarray:
|
| 509 |
"""Thin wrapper around .generate() that decodes predicted actions and unnormalizes them."""
|
| 510 |
+
# If the special empty token ('') does not already appear after the colon (':') token in the prompt
|
| 511 |
+
# (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
|
| 512 |
+
if not torch.all(input_ids[:, -1] == 29871):
|
| 513 |
+
input_ids = torch.cat(
|
| 514 |
+
(input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
|
| 515 |
+
)
|
| 516 |
|
| 517 |
# Run VLA inference
|
| 518 |
generated_ids = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key), **kwargs)
|
|
|
|
| 536 |
return actions
|
| 537 |
|
| 538 |
@staticmethod
|
| 539 |
+
def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
|
| 540 |
if unnorm_key is None:
|
| 541 |
assert len(norm_stats) == 1, (
|
| 542 |
f"Your model was trained on more than one dataset, "
|
|
|
|
| 551 |
)
|
| 552 |
return unnorm_key
|
| 553 |
|
| 554 |
+
def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
|
| 555 |
+
"""Get the dimensionality of the policy's action space."""
|
| 556 |
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
|
| 557 |
return len(self.norm_stats[unnorm_key]["action"]["q01"])
|
| 558 |
|
| 559 |
+
def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
|
| 560 |
+
"""Get all the logged statistics for the given dataset."""
|
| 561 |
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
|
| 562 |
return self.norm_stats[unnorm_key]["action"]
|