Upload folder using huggingface_hub
Browse files- config.json +1 -1
- esm_nv.py +101 -49
config.json
CHANGED
|
@@ -35,7 +35,7 @@
|
|
| 35 |
"position_embedding_type": "rotary",
|
| 36 |
"qkv_weight_interleaved": true,
|
| 37 |
"token_dropout": true,
|
| 38 |
-
"transformers_version": "4.
|
| 39 |
"use_cache": true,
|
| 40 |
"vocab_list": null,
|
| 41 |
"vocab_size": 33
|
|
|
|
| 35 |
"position_embedding_type": "rotary",
|
| 36 |
"qkv_weight_interleaved": true,
|
| 37 |
"token_dropout": true,
|
| 38 |
+
"transformers_version": "4.57.0",
|
| 39 |
"use_cache": true,
|
| 40 |
"vocab_list": null,
|
| 41 |
"vocab_size": 33
|
esm_nv.py
CHANGED
|
@@ -39,7 +39,7 @@ from transformers.modeling_outputs import (
|
|
| 39 |
)
|
| 40 |
from transformers.modeling_utils import PreTrainedModel
|
| 41 |
from transformers.models.esm.configuration_esm import EsmConfig
|
| 42 |
-
from transformers.models.esm.modeling_esm import
|
| 43 |
from transformers.utils import logging
|
| 44 |
|
| 45 |
|
|
@@ -153,16 +153,6 @@ class NVEsmEncoder(nn.Module):
|
|
| 153 |
if config.position_embedding_type == "rotary":
|
| 154 |
self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
|
| 155 |
|
| 156 |
-
# Keep on CPU, pin for faster non_blocking H2D; don't persist in state_dict.
|
| 157 |
-
if config.attn_input_format == "bshd":
|
| 158 |
-
self.register_buffer(
|
| 159 |
-
"te_rope_emb",
|
| 160 |
-
self.rotary_embeddings(max_seq_len=config.max_position_embeddings).cpu().pin_memory(),
|
| 161 |
-
persistent=False,
|
| 162 |
-
)
|
| 163 |
-
else:
|
| 164 |
-
self.te_rope_emb = None
|
| 165 |
-
|
| 166 |
def forward(
|
| 167 |
self,
|
| 168 |
hidden_states: torch.Tensor,
|
|
@@ -195,6 +185,7 @@ class NVEsmEncoder(nn.Module):
|
|
| 195 |
"THD expects embeddings shaped [1, total_tokens, hidden_size]."
|
| 196 |
)
|
| 197 |
hidden_states = hidden_states.squeeze(0)
|
|
|
|
| 198 |
|
| 199 |
elif self.config.attn_input_format == "bshd":
|
| 200 |
if any(x is not None for x in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]):
|
|
@@ -202,28 +193,14 @@ class NVEsmEncoder(nn.Module):
|
|
| 202 |
"cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k are not allowed when using BSHD inputs."
|
| 203 |
)
|
| 204 |
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
if self.config.
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
raise RuntimeError(
|
| 214 |
-
f"ROPE length {te_rope_emb.size(0)} < input seq length {seq_len}. "
|
| 215 |
-
f"Increase max_position_embeddings."
|
| 216 |
-
)
|
| 217 |
-
te_rope_emb = te_rope_emb[:seq_len]
|
| 218 |
-
|
| 219 |
-
elif self.config.attn_input_format == "thd":
|
| 220 |
-
assert cu_seq_lens_q is not None
|
| 221 |
-
te_rope_emb = self.rotary_embeddings(max_seq_len=cu_seq_lens_q[-1]).to(
|
| 222 |
-
device=hidden_states.device, dtype=hidden_states.dtype, non_blocking=True
|
| 223 |
-
)
|
| 224 |
-
|
| 225 |
-
else:
|
| 226 |
-
raise ValueError(f"Unsupported attention input format: {self.config.attn_input_format}")
|
| 227 |
|
| 228 |
for layer_module in self.layers:
|
| 229 |
if output_hidden_states:
|
|
@@ -305,15 +282,10 @@ class NVEsmModel(NVEsmPreTrainedModel):
|
|
| 305 |
super().__init__(config)
|
| 306 |
self.config = config
|
| 307 |
|
| 308 |
-
# Create EsmEmbeddings with temporarily modified config to use padded vocab size
|
| 309 |
-
# This ensures the word embeddings layer uses the padded vocabulary size for FP8 support
|
| 310 |
-
original_vocab_size = config.vocab_size
|
| 311 |
-
config.vocab_size = config.padded_vocab_size
|
| 312 |
# Ensure pad_token_id is set properly, defaulting to 0 if not specified
|
| 313 |
if not hasattr(config, "pad_token_id") or config.pad_token_id is None:
|
| 314 |
config.pad_token_id = 0
|
| 315 |
-
self.embeddings =
|
| 316 |
-
config.vocab_size = original_vocab_size # Restore original vocab_size
|
| 317 |
self.encoder = NVEsmEncoder(config)
|
| 318 |
self.pooler = EsmPooler(config) if add_pooling_layer else None
|
| 319 |
|
|
@@ -337,7 +309,6 @@ class NVEsmModel(NVEsmPreTrainedModel):
|
|
| 337 |
input_ids: Optional[torch.Tensor] = None,
|
| 338 |
attention_mask: Optional[torch.Tensor] = None,
|
| 339 |
position_ids: Optional[torch.Tensor] = None,
|
| 340 |
-
head_mask: Optional[torch.Tensor] = None,
|
| 341 |
inputs_embeds: Optional[torch.Tensor] = None,
|
| 342 |
output_hidden_states: Optional[bool] = None,
|
| 343 |
cu_seq_lens_q: torch.IntTensor | None = None,
|
|
@@ -351,7 +322,6 @@ class NVEsmModel(NVEsmPreTrainedModel):
|
|
| 351 |
input_ids (torch.Tensor): The input ids.
|
| 352 |
attention_mask (torch.Tensor): The attention mask.
|
| 353 |
position_ids (torch.Tensor): The position ids.
|
| 354 |
-
head_mask (torch.Tensor): The head mask.
|
| 355 |
inputs_embeds (torch.Tensor): The input embeddings.
|
| 356 |
output_hidden_states (bool): Whether to output the hidden states.
|
| 357 |
cu_seq_lens_q (torch.IntTensor): The cumulative sequence lengths for the query state, if using THD inputs.
|
|
@@ -389,18 +359,14 @@ class NVEsmModel(NVEsmPreTrainedModel):
|
|
| 389 |
# TE expects a boolean attention mask, where 1s are masked and 0s are not masked
|
| 390 |
extended_attention_mask = extended_attention_mask < -1
|
| 391 |
|
| 392 |
-
# Prepare head mask if needed
|
| 393 |
-
# 1.0 in head_mask indicate we keep the head
|
| 394 |
-
# attention_probs has shape bsz x n_heads x N x N
|
| 395 |
-
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 396 |
-
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 397 |
-
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 398 |
-
|
| 399 |
embedding_output = self.embeddings(
|
| 400 |
input_ids=input_ids,
|
| 401 |
-
position_ids=position_ids,
|
| 402 |
attention_mask=attention_mask,
|
| 403 |
inputs_embeds=inputs_embeds,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
)
|
| 405 |
encoder_outputs = self.encoder(
|
| 406 |
embedding_output,
|
|
@@ -547,3 +513,89 @@ class NVEsmLMHead(nn.Module):
|
|
| 547 |
x = torch.nn.functional.gelu(x)
|
| 548 |
x = self.decoder(x)
|
| 549 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
)
|
| 40 |
from transformers.modeling_utils import PreTrainedModel
|
| 41 |
from transformers.models.esm.configuration_esm import EsmConfig
|
| 42 |
+
from transformers.models.esm.modeling_esm import EsmPooler
|
| 43 |
from transformers.utils import logging
|
| 44 |
|
| 45 |
|
|
|
|
| 153 |
if config.position_embedding_type == "rotary":
|
| 154 |
self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
def forward(
|
| 157 |
self,
|
| 158 |
hidden_states: torch.Tensor,
|
|
|
|
| 185 |
"THD expects embeddings shaped [1, total_tokens, hidden_size]."
|
| 186 |
)
|
| 187 |
hidden_states = hidden_states.squeeze(0)
|
| 188 |
+
attention_mask = None
|
| 189 |
|
| 190 |
elif self.config.attn_input_format == "bshd":
|
| 191 |
if any(x is not None for x in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]):
|
|
|
|
| 193 |
"cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k are not allowed when using BSHD inputs."
|
| 194 |
)
|
| 195 |
|
| 196 |
+
# Ensure that rotary embeddings are computed with at a higher precision outside the torch autocast context.
|
| 197 |
+
with torch.autocast(device_type="cuda", enabled=False):
|
| 198 |
+
if self.config.position_embedding_type == "rotary":
|
| 199 |
+
if self.config.attn_input_format == "bshd":
|
| 200 |
+
te_rope_emb = self.rotary_embeddings(max_seq_len=hidden_states.shape[1])
|
| 201 |
+
elif self.config.attn_input_format == "thd":
|
| 202 |
+
te_rope_emb = self.rotary_embeddings(max_seq_len=cu_seq_lens_q[-1])
|
| 203 |
+
te_rope_emb = te_rope_emb.to(hidden_states.device, dtype=hidden_states.dtype, non_blocking=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
for layer_module in self.layers:
|
| 206 |
if output_hidden_states:
|
|
|
|
| 282 |
super().__init__(config)
|
| 283 |
self.config = config
|
| 284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
# Ensure pad_token_id is set properly, defaulting to 0 if not specified
|
| 286 |
if not hasattr(config, "pad_token_id") or config.pad_token_id is None:
|
| 287 |
config.pad_token_id = 0
|
| 288 |
+
self.embeddings = NVEsmEmbeddings(config)
|
|
|
|
| 289 |
self.encoder = NVEsmEncoder(config)
|
| 290 |
self.pooler = EsmPooler(config) if add_pooling_layer else None
|
| 291 |
|
|
|
|
| 309 |
input_ids: Optional[torch.Tensor] = None,
|
| 310 |
attention_mask: Optional[torch.Tensor] = None,
|
| 311 |
position_ids: Optional[torch.Tensor] = None,
|
|
|
|
| 312 |
inputs_embeds: Optional[torch.Tensor] = None,
|
| 313 |
output_hidden_states: Optional[bool] = None,
|
| 314 |
cu_seq_lens_q: torch.IntTensor | None = None,
|
|
|
|
| 322 |
input_ids (torch.Tensor): The input ids.
|
| 323 |
attention_mask (torch.Tensor): The attention mask.
|
| 324 |
position_ids (torch.Tensor): The position ids.
|
|
|
|
| 325 |
inputs_embeds (torch.Tensor): The input embeddings.
|
| 326 |
output_hidden_states (bool): Whether to output the hidden states.
|
| 327 |
cu_seq_lens_q (torch.IntTensor): The cumulative sequence lengths for the query state, if using THD inputs.
|
|
|
|
| 359 |
# TE expects a boolean attention mask, where 1s are masked and 0s are not masked
|
| 360 |
extended_attention_mask = extended_attention_mask < -1
|
| 361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
embedding_output = self.embeddings(
|
| 363 |
input_ids=input_ids,
|
|
|
|
| 364 |
attention_mask=attention_mask,
|
| 365 |
inputs_embeds=inputs_embeds,
|
| 366 |
+
cu_seq_lens_q=cu_seq_lens_q,
|
| 367 |
+
cu_seq_lens_k=cu_seq_lens_k,
|
| 368 |
+
max_length_q=max_length_q,
|
| 369 |
+
max_length_k=max_length_k,
|
| 370 |
)
|
| 371 |
encoder_outputs = self.encoder(
|
| 372 |
embedding_output,
|
|
|
|
| 513 |
x = torch.nn.functional.gelu(x)
|
| 514 |
x = self.decoder(x)
|
| 515 |
return x
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
class NVEsmEmbeddings(nn.Module):
|
| 519 |
+
"""Modified version of EsmEmbeddings to support THD inputs."""
|
| 520 |
+
|
| 521 |
+
def __init__(self, config):
|
| 522 |
+
"""Initialize a NVEsmEmbeddings."""
|
| 523 |
+
super().__init__()
|
| 524 |
+
self.word_embeddings = nn.Embedding(
|
| 525 |
+
config.padded_vocab_size, config.hidden_size, padding_idx=config.pad_token_id
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
self.layer_norm = (
|
| 529 |
+
nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.emb_layer_norm_before else None
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
if config.position_embedding_type != "rotary":
|
| 533 |
+
raise ValueError(
|
| 534 |
+
"The TE-accelerated ESM-2 model only supports rotary position embeddings, received "
|
| 535 |
+
f"{config.position_embedding_type}"
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
self.padding_idx = config.pad_token_id
|
| 539 |
+
self.token_dropout = config.token_dropout
|
| 540 |
+
self.mask_token_id = config.mask_token_id
|
| 541 |
+
|
| 542 |
+
def forward(
|
| 543 |
+
self,
|
| 544 |
+
input_ids=None,
|
| 545 |
+
attention_mask=None,
|
| 546 |
+
inputs_embeds=None,
|
| 547 |
+
cu_seq_lens_q: torch.IntTensor | None = None,
|
| 548 |
+
cu_seq_lens_k: torch.IntTensor | None = None,
|
| 549 |
+
max_length_q: int | None = None,
|
| 550 |
+
max_length_k: int | None = None,
|
| 551 |
+
):
|
| 552 |
+
"""Forward pass of the NVEsmEmbeddings."""
|
| 553 |
+
if inputs_embeds is None:
|
| 554 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 555 |
+
|
| 556 |
+
# Note that if we want to support ESM-1 (not 1b!) in future then we need to support an
|
| 557 |
+
# embedding_scale factor here.
|
| 558 |
+
embeddings = inputs_embeds
|
| 559 |
+
|
| 560 |
+
if all(x is not None for x in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]):
|
| 561 |
+
using_thd = True
|
| 562 |
+
attention_mask = None
|
| 563 |
+
else:
|
| 564 |
+
using_thd = False
|
| 565 |
+
|
| 566 |
+
# Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
|
| 567 |
+
# flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
|
| 568 |
+
# masked tokens are treated as if they were selected for input dropout and zeroed out.
|
| 569 |
+
# This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
|
| 570 |
+
# a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample).
|
| 571 |
+
# This is analogous to the way that dropout layers scale down outputs during evaluation when not
|
| 572 |
+
# actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
|
| 573 |
+
if self.token_dropout and input_ids is not None:
|
| 574 |
+
embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
|
| 575 |
+
mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
|
| 576 |
+
|
| 577 |
+
if not using_thd:
|
| 578 |
+
# BSHD token dropout correction
|
| 579 |
+
src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1]
|
| 580 |
+
n_masked_per_seq = (input_ids == self.mask_token_id).sum(-1).float()
|
| 581 |
+
mask_ratio_observed = n_masked_per_seq / src_lengths
|
| 582 |
+
scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed)
|
| 583 |
+
embeddings = (embeddings * scale_factor[:, None, None]).to(embeddings.dtype)
|
| 584 |
+
|
| 585 |
+
else:
|
| 586 |
+
src_lengths = torch.diff(cu_seq_lens_q)
|
| 587 |
+
# We need to find the number of masked tokens in each sequence in the padded batch.
|
| 588 |
+
is_masked = (input_ids == self.mask_token_id).squeeze(0)
|
| 589 |
+
n_masked_per_seq = torch.nested.nested_tensor_from_jagged(is_masked, offsets=cu_seq_lens_q).sum(1)
|
| 590 |
+
mask_ratio_observed = n_masked_per_seq.float() / src_lengths
|
| 591 |
+
scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed)
|
| 592 |
+
reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0)
|
| 593 |
+
embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype)
|
| 594 |
+
|
| 595 |
+
if self.layer_norm is not None:
|
| 596 |
+
embeddings = self.layer_norm(embeddings)
|
| 597 |
+
|
| 598 |
+
if attention_mask is not None:
|
| 599 |
+
embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
|
| 600 |
+
|
| 601 |
+
return embeddings
|