pstjohn commited on
Commit
2fcafda
·
verified ·
1 Parent(s): 83343c0

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. esm_nv.py +101 -49
config.json CHANGED
@@ -35,7 +35,7 @@
35
  "position_embedding_type": "rotary",
36
  "qkv_weight_interleaved": true,
37
  "token_dropout": true,
38
- "transformers_version": "4.56.2",
39
  "use_cache": true,
40
  "vocab_list": null,
41
  "vocab_size": 33
 
35
  "position_embedding_type": "rotary",
36
  "qkv_weight_interleaved": true,
37
  "token_dropout": true,
38
+ "transformers_version": "4.57.0",
39
  "use_cache": true,
40
  "vocab_list": null,
41
  "vocab_size": 33
esm_nv.py CHANGED
@@ -39,7 +39,7 @@ from transformers.modeling_outputs import (
39
  )
40
  from transformers.modeling_utils import PreTrainedModel
41
  from transformers.models.esm.configuration_esm import EsmConfig
42
- from transformers.models.esm.modeling_esm import EsmEmbeddings, EsmPooler
43
  from transformers.utils import logging
44
 
45
 
@@ -153,16 +153,6 @@ class NVEsmEncoder(nn.Module):
153
  if config.position_embedding_type == "rotary":
154
  self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
155
 
156
- # Keep on CPU, pin for faster non_blocking H2D; don't persist in state_dict.
157
- if config.attn_input_format == "bshd":
158
- self.register_buffer(
159
- "te_rope_emb",
160
- self.rotary_embeddings(max_seq_len=config.max_position_embeddings).cpu().pin_memory(),
161
- persistent=False,
162
- )
163
- else:
164
- self.te_rope_emb = None
165
-
166
  def forward(
167
  self,
168
  hidden_states: torch.Tensor,
@@ -195,6 +185,7 @@ class NVEsmEncoder(nn.Module):
195
  "THD expects embeddings shaped [1, total_tokens, hidden_size]."
196
  )
197
  hidden_states = hidden_states.squeeze(0)
 
198
 
199
  elif self.config.attn_input_format == "bshd":
200
  if any(x is not None for x in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]):
@@ -202,28 +193,14 @@ class NVEsmEncoder(nn.Module):
202
  "cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k are not allowed when using BSHD inputs."
203
  )
204
 
205
- te_rope_emb = None
206
- if self.config.position_embedding_type == "rotary":
207
- if self.config.attn_input_format == "bshd":
208
- te_rope_emb = self.te_rope_emb.to(
209
- device=hidden_states.device, dtype=hidden_states.dtype, non_blocking=True
210
- )
211
- seq_len = hidden_states.shape[1]
212
- if te_rope_emb.size(0) < seq_len:
213
- raise RuntimeError(
214
- f"ROPE length {te_rope_emb.size(0)} < input seq length {seq_len}. "
215
- f"Increase max_position_embeddings."
216
- )
217
- te_rope_emb = te_rope_emb[:seq_len]
218
-
219
- elif self.config.attn_input_format == "thd":
220
- assert cu_seq_lens_q is not None
221
- te_rope_emb = self.rotary_embeddings(max_seq_len=cu_seq_lens_q[-1]).to(
222
- device=hidden_states.device, dtype=hidden_states.dtype, non_blocking=True
223
- )
224
-
225
- else:
226
- raise ValueError(f"Unsupported attention input format: {self.config.attn_input_format}")
227
 
228
  for layer_module in self.layers:
229
  if output_hidden_states:
@@ -305,15 +282,10 @@ class NVEsmModel(NVEsmPreTrainedModel):
305
  super().__init__(config)
306
  self.config = config
307
 
308
- # Create EsmEmbeddings with temporarily modified config to use padded vocab size
309
- # This ensures the word embeddings layer uses the padded vocabulary size for FP8 support
310
- original_vocab_size = config.vocab_size
311
- config.vocab_size = config.padded_vocab_size
312
  # Ensure pad_token_id is set properly, defaulting to 0 if not specified
313
  if not hasattr(config, "pad_token_id") or config.pad_token_id is None:
314
  config.pad_token_id = 0
315
- self.embeddings = EsmEmbeddings(config)
316
- config.vocab_size = original_vocab_size # Restore original vocab_size
317
  self.encoder = NVEsmEncoder(config)
318
  self.pooler = EsmPooler(config) if add_pooling_layer else None
319
 
@@ -337,7 +309,6 @@ class NVEsmModel(NVEsmPreTrainedModel):
337
  input_ids: Optional[torch.Tensor] = None,
338
  attention_mask: Optional[torch.Tensor] = None,
339
  position_ids: Optional[torch.Tensor] = None,
340
- head_mask: Optional[torch.Tensor] = None,
341
  inputs_embeds: Optional[torch.Tensor] = None,
342
  output_hidden_states: Optional[bool] = None,
343
  cu_seq_lens_q: torch.IntTensor | None = None,
@@ -351,7 +322,6 @@ class NVEsmModel(NVEsmPreTrainedModel):
351
  input_ids (torch.Tensor): The input ids.
352
  attention_mask (torch.Tensor): The attention mask.
353
  position_ids (torch.Tensor): The position ids.
354
- head_mask (torch.Tensor): The head mask.
355
  inputs_embeds (torch.Tensor): The input embeddings.
356
  output_hidden_states (bool): Whether to output the hidden states.
357
  cu_seq_lens_q (torch.IntTensor): The cumulative sequence lengths for the query state, if using THD inputs.
@@ -389,18 +359,14 @@ class NVEsmModel(NVEsmPreTrainedModel):
389
  # TE expects a boolean attention mask, where 1s are masked and 0s are not masked
390
  extended_attention_mask = extended_attention_mask < -1
391
 
392
- # Prepare head mask if needed
393
- # 1.0 in head_mask indicate we keep the head
394
- # attention_probs has shape bsz x n_heads x N x N
395
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
396
- # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
397
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
398
-
399
  embedding_output = self.embeddings(
400
  input_ids=input_ids,
401
- position_ids=position_ids,
402
  attention_mask=attention_mask,
403
  inputs_embeds=inputs_embeds,
 
 
 
 
404
  )
405
  encoder_outputs = self.encoder(
406
  embedding_output,
@@ -547,3 +513,89 @@ class NVEsmLMHead(nn.Module):
547
  x = torch.nn.functional.gelu(x)
548
  x = self.decoder(x)
549
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  )
40
  from transformers.modeling_utils import PreTrainedModel
41
  from transformers.models.esm.configuration_esm import EsmConfig
42
+ from transformers.models.esm.modeling_esm import EsmPooler
43
  from transformers.utils import logging
44
 
45
 
 
153
  if config.position_embedding_type == "rotary":
154
  self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
155
 
 
 
 
 
 
 
 
 
 
 
156
  def forward(
157
  self,
158
  hidden_states: torch.Tensor,
 
185
  "THD expects embeddings shaped [1, total_tokens, hidden_size]."
186
  )
187
  hidden_states = hidden_states.squeeze(0)
188
+ attention_mask = None
189
 
190
  elif self.config.attn_input_format == "bshd":
191
  if any(x is not None for x in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]):
 
193
  "cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k are not allowed when using BSHD inputs."
194
  )
195
 
196
+ # Ensure that rotary embeddings are computed with at a higher precision outside the torch autocast context.
197
+ with torch.autocast(device_type="cuda", enabled=False):
198
+ if self.config.position_embedding_type == "rotary":
199
+ if self.config.attn_input_format == "bshd":
200
+ te_rope_emb = self.rotary_embeddings(max_seq_len=hidden_states.shape[1])
201
+ elif self.config.attn_input_format == "thd":
202
+ te_rope_emb = self.rotary_embeddings(max_seq_len=cu_seq_lens_q[-1])
203
+ te_rope_emb = te_rope_emb.to(hidden_states.device, dtype=hidden_states.dtype, non_blocking=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  for layer_module in self.layers:
206
  if output_hidden_states:
 
282
  super().__init__(config)
283
  self.config = config
284
 
 
 
 
 
285
  # Ensure pad_token_id is set properly, defaulting to 0 if not specified
286
  if not hasattr(config, "pad_token_id") or config.pad_token_id is None:
287
  config.pad_token_id = 0
288
+ self.embeddings = NVEsmEmbeddings(config)
 
289
  self.encoder = NVEsmEncoder(config)
290
  self.pooler = EsmPooler(config) if add_pooling_layer else None
291
 
 
309
  input_ids: Optional[torch.Tensor] = None,
310
  attention_mask: Optional[torch.Tensor] = None,
311
  position_ids: Optional[torch.Tensor] = None,
 
312
  inputs_embeds: Optional[torch.Tensor] = None,
313
  output_hidden_states: Optional[bool] = None,
314
  cu_seq_lens_q: torch.IntTensor | None = None,
 
322
  input_ids (torch.Tensor): The input ids.
323
  attention_mask (torch.Tensor): The attention mask.
324
  position_ids (torch.Tensor): The position ids.
 
325
  inputs_embeds (torch.Tensor): The input embeddings.
326
  output_hidden_states (bool): Whether to output the hidden states.
327
  cu_seq_lens_q (torch.IntTensor): The cumulative sequence lengths for the query state, if using THD inputs.
 
359
  # TE expects a boolean attention mask, where 1s are masked and 0s are not masked
360
  extended_attention_mask = extended_attention_mask < -1
361
 
 
 
 
 
 
 
 
362
  embedding_output = self.embeddings(
363
  input_ids=input_ids,
 
364
  attention_mask=attention_mask,
365
  inputs_embeds=inputs_embeds,
366
+ cu_seq_lens_q=cu_seq_lens_q,
367
+ cu_seq_lens_k=cu_seq_lens_k,
368
+ max_length_q=max_length_q,
369
+ max_length_k=max_length_k,
370
  )
371
  encoder_outputs = self.encoder(
372
  embedding_output,
 
513
  x = torch.nn.functional.gelu(x)
514
  x = self.decoder(x)
515
  return x
516
+
517
+
518
+ class NVEsmEmbeddings(nn.Module):
519
+ """Modified version of EsmEmbeddings to support THD inputs."""
520
+
521
+ def __init__(self, config):
522
+ """Initialize a NVEsmEmbeddings."""
523
+ super().__init__()
524
+ self.word_embeddings = nn.Embedding(
525
+ config.padded_vocab_size, config.hidden_size, padding_idx=config.pad_token_id
526
+ )
527
+
528
+ self.layer_norm = (
529
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.emb_layer_norm_before else None
530
+ )
531
+
532
+ if config.position_embedding_type != "rotary":
533
+ raise ValueError(
534
+ "The TE-accelerated ESM-2 model only supports rotary position embeddings, received "
535
+ f"{config.position_embedding_type}"
536
+ )
537
+
538
+ self.padding_idx = config.pad_token_id
539
+ self.token_dropout = config.token_dropout
540
+ self.mask_token_id = config.mask_token_id
541
+
542
+ def forward(
543
+ self,
544
+ input_ids=None,
545
+ attention_mask=None,
546
+ inputs_embeds=None,
547
+ cu_seq_lens_q: torch.IntTensor | None = None,
548
+ cu_seq_lens_k: torch.IntTensor | None = None,
549
+ max_length_q: int | None = None,
550
+ max_length_k: int | None = None,
551
+ ):
552
+ """Forward pass of the NVEsmEmbeddings."""
553
+ if inputs_embeds is None:
554
+ inputs_embeds = self.word_embeddings(input_ids)
555
+
556
+ # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an
557
+ # embedding_scale factor here.
558
+ embeddings = inputs_embeds
559
+
560
+ if all(x is not None for x in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]):
561
+ using_thd = True
562
+ attention_mask = None
563
+ else:
564
+ using_thd = False
565
+
566
+ # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
567
+ # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
568
+ # masked tokens are treated as if they were selected for input dropout and zeroed out.
569
+ # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
570
+ # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample).
571
+ # This is analogous to the way that dropout layers scale down outputs during evaluation when not
572
+ # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
573
+ if self.token_dropout and input_ids is not None:
574
+ embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
575
+ mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
576
+
577
+ if not using_thd:
578
+ # BSHD token dropout correction
579
+ src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1]
580
+ n_masked_per_seq = (input_ids == self.mask_token_id).sum(-1).float()
581
+ mask_ratio_observed = n_masked_per_seq / src_lengths
582
+ scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed)
583
+ embeddings = (embeddings * scale_factor[:, None, None]).to(embeddings.dtype)
584
+
585
+ else:
586
+ src_lengths = torch.diff(cu_seq_lens_q)
587
+ # We need to find the number of masked tokens in each sequence in the padded batch.
588
+ is_masked = (input_ids == self.mask_token_id).squeeze(0)
589
+ n_masked_per_seq = torch.nested.nested_tensor_from_jagged(is_masked, offsets=cu_seq_lens_q).sum(1)
590
+ mask_ratio_observed = n_masked_per_seq.float() / src_lengths
591
+ scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed)
592
+ reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0)
593
+ embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype)
594
+
595
+ if self.layer_norm is not None:
596
+ embeddings = self.layer_norm(embeddings)
597
+
598
+ if attention_mask is not None:
599
+ embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
600
+
601
+ return embeddings