pstjohn commited on
Commit
cf74082
·
verified ·
1 Parent(s): e5c4b7d

Upload folder using huggingface_hub

Browse files
config.json CHANGED
@@ -29,6 +29,7 @@
29
  "num_attention_heads": 40,
30
  "num_hidden_layers": 36,
31
  "pad_token_id": 1,
 
32
  "position_embedding_type": "rotary",
33
  "qkv_weight_interleaved": true,
34
  "token_dropout": true,
 
29
  "num_attention_heads": 40,
30
  "num_hidden_layers": 36,
31
  "pad_token_id": 1,
32
+ "padded_vocab_size": 64,
33
  "position_embedding_type": "rotary",
34
  "qkv_weight_interleaved": true,
35
  "token_dropout": true,
esm_nv.py CHANGED
@@ -23,7 +23,7 @@
23
  Adapted from `modeling_esm.py` in huggingface/transformers.
24
  """
25
 
26
- from typing import Optional, Tuple, Union
27
 
28
  # TODO: put import guard around transformer_engine here, with an informative error message around
29
  # installation and the nvidia docker container.
@@ -35,7 +35,6 @@ from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding
35
  from transformers.modeling_outputs import (
36
  BaseModelOutput,
37
  BaseModelOutputWithPooling,
38
- BaseModelOutputWithPoolingAndCrossAttentions,
39
  MaskedLMOutput,
40
  )
41
  from transformers.modeling_utils import PreTrainedModel
@@ -56,10 +55,11 @@ class NVEsmConfig(EsmConfig):
56
  self,
57
  qkv_weight_interleaved: bool = True,
58
  encoder_activation: str = "gelu",
59
- attn_input_format: str = "bshd",
60
  fuse_qkv_params: bool = True,
61
  micro_batch_size: Optional[int] = None,
62
  max_seq_length: Optional[int] = None,
 
63
  **kwargs,
64
  ):
65
  """Initialize the NVEsmConfig with additional TE-related config options.
@@ -87,6 +87,8 @@ class NVEsmConfig(EsmConfig):
87
  max_seq_length: The maximum sequence length to use for the attention. This is needed for
88
  JIT Warmup, a technique where jit fused functions are warmed up before training to
89
  ensure same kernels are used for forward propogation and activation recompute phase.
 
 
90
  **kwargs: Additional config options to pass to EsmConfig.
91
  """
92
  super().__init__(**kwargs)
@@ -98,6 +100,15 @@ class NVEsmConfig(EsmConfig):
98
  self.micro_batch_size = micro_batch_size
99
  self.max_seq_length = max_seq_length
100
 
 
 
 
 
 
 
 
 
 
101
 
102
  class NVEsmEncoder(nn.Module):
103
  """NVEsmEncoder is a TransformerEngine-optimized ESM encoder."""
@@ -138,15 +149,26 @@ class NVEsmEncoder(nn.Module):
138
  self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
139
  if config.position_embedding_type == "rotary":
140
  self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
141
- self.te_rope_emb = self.rotary_embeddings(max_seq_len=config.max_position_embeddings)
142
- else:
143
- self.te_rope_emb = None
 
 
 
 
 
 
 
144
 
145
  def forward(
146
  self,
147
  hidden_states: torch.Tensor,
148
  attention_mask: Optional[torch.Tensor] = None,
149
  output_hidden_states: bool = False,
 
 
 
 
150
  ):
151
  """Forward pass of the NVEsmEncoder.
152
 
@@ -154,14 +176,51 @@ class NVEsmEncoder(nn.Module):
154
  hidden_states (torch.Tensor): The hidden states.
155
  attention_mask (torch.Tensor): The attention mask.
156
  output_hidden_states (bool): Whether to output the hidden states.
 
 
 
 
157
  """
158
- all_hidden_states = () if output_hidden_states else None
159
 
160
- if self.te_rope_emb is not None:
161
- te_rope_emb = self.te_rope_emb.to(hidden_states.device, non_blocking=True)
162
- te_rope_emb = te_rope_emb[: hidden_states.shape[1]]
163
- else:
164
- te_rope_emb = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  for layer_module in self.layers:
167
  if output_hidden_states:
@@ -171,6 +230,10 @@ class NVEsmEncoder(nn.Module):
171
  hidden_states,
172
  attention_mask,
173
  rotary_pos_emb=te_rope_emb,
 
 
 
 
174
  )
175
 
176
  hidden_states = self.emb_layer_norm_after(hidden_states)
@@ -180,7 +243,7 @@ class NVEsmEncoder(nn.Module):
180
 
181
  return BaseModelOutput(
182
  last_hidden_state=hidden_states,
183
- hidden_states=all_hidden_states,
184
  )
185
 
186
 
@@ -239,7 +302,15 @@ class NVEsmModel(NVEsmPreTrainedModel):
239
  super().__init__(config)
240
  self.config = config
241
 
 
 
 
 
 
 
 
242
  self.embeddings = EsmEmbeddings(config)
 
243
  self.encoder = NVEsmEncoder(config)
244
  self.pooler = EsmPooler(config) if add_pooling_layer else None
245
 
@@ -266,7 +337,11 @@ class NVEsmModel(NVEsmPreTrainedModel):
266
  head_mask: Optional[torch.Tensor] = None,
267
  inputs_embeds: Optional[torch.Tensor] = None,
268
  output_hidden_states: Optional[bool] = None,
269
- ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
 
 
 
 
270
  """Forward pass of the NVEsmModel.
271
 
272
  Args:
@@ -276,25 +351,14 @@ class NVEsmModel(NVEsmPreTrainedModel):
276
  head_mask (torch.Tensor): The head mask.
277
  inputs_embeds (torch.Tensor): The input embeddings.
278
  output_hidden_states (bool): Whether to output the hidden states.
 
 
 
 
279
 
280
  Returns:
281
  BaseModelOutputWithPooling: The output of the model.
282
  """
283
- r"""
284
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
285
- Sequence of hidden-states at the output of the last layer of the encoder. Used in the
286
- cross-attention if the model is configured as a decoder.
287
- encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
288
- Mask to avoid performing attention on the padding token indices of the encoder input.
289
- This mask is used in the cross-attention if the model is configured as a decoder. Mask
290
- values selected in `[0, 1]`:
291
-
292
- - 1 for tokens that are **not masked**,
293
- - 0 for tokens that are **masked**.
294
-
295
- Note that this mask is inverted when it is passed to TransformerEngine, which expects a
296
- boolean mask where 1s are masked and 0s are not masked.
297
- """
298
  output_hidden_states = (
299
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
300
  )
@@ -339,6 +403,10 @@ class NVEsmModel(NVEsmPreTrainedModel):
339
  embedding_output,
340
  attention_mask=extended_attention_mask,
341
  output_hidden_states=output_hidden_states,
 
 
 
 
342
  )
343
  sequence_output = encoder_outputs[0]
344
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
@@ -391,7 +459,11 @@ class NVEsmForMaskedLM(NVEsmPreTrainedModel):
391
  inputs_embeds: Optional[torch.FloatTensor] = None,
392
  labels: Optional[torch.LongTensor] = None,
393
  output_hidden_states: Optional[bool] = None,
394
- ) -> Union[Tuple, MaskedLMOutput]:
 
 
 
 
395
  """Forward pass of the NVEsmForMaskedLM.
396
 
397
  Args:
@@ -401,34 +473,39 @@ class NVEsmForMaskedLM(NVEsmPreTrainedModel):
401
  inputs_embeds (torch.FloatTensor): The input embeddings.
402
  labels (torch.LongTensor): The labels.
403
  output_hidden_states (bool): Whether to output the hidden states.
 
 
 
 
404
 
405
  Returns:
406
  MaskedLMOutput: The output of the model.
407
  """
408
- r"""
409
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
410
- Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
411
- config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
412
- loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
413
- kwargs (`Dict[str, any]`, *optional*, defaults to `{}`):
414
- Used to hide legacy arguments that have been deprecated.
415
- """
416
  outputs = self.esm(
417
  input_ids,
418
  attention_mask=attention_mask,
419
  position_ids=position_ids,
420
  inputs_embeds=inputs_embeds,
421
  output_hidden_states=output_hidden_states,
 
 
 
 
422
  )
423
  sequence_output = outputs[0]
424
  prediction_scores = self.lm_head(sequence_output)
425
 
 
 
 
 
426
  masked_lm_loss = None
427
  if labels is not None:
428
  loss_fct = CrossEntropyLoss()
429
-
430
- labels = labels.to(prediction_scores.device)
431
- masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
 
432
 
433
  return MaskedLMOutput(
434
  loss=masked_lm_loss,
@@ -436,18 +513,6 @@ class NVEsmForMaskedLM(NVEsmPreTrainedModel):
436
  hidden_states=outputs.hidden_states,
437
  )
438
 
439
- def predict_contacts(self, tokens: torch.Tensor, attention_mask: torch.Tensor):
440
- """Predict the contacts of the model.
441
-
442
- Args:
443
- tokens (torch.Tensor): The tokens.
444
- attention_mask (torch.Tensor): The attention mask.
445
-
446
- Returns:
447
- torch.Tensor: The predicted contacts.
448
- """
449
- return self.esm.predict_contacts(tokens, attention_mask=attention_mask)
450
-
451
 
452
  class NVEsmLMHead(nn.Module):
453
  """ESM Head for masked language modeling using TransformerEngine."""
@@ -463,7 +528,7 @@ class NVEsmLMHead(nn.Module):
463
 
464
  self.decoder = transformer_engine.pytorch.LayerNormLinear(
465
  config.hidden_size,
466
- config.vocab_size,
467
  bias=True,
468
  eps=config.layer_norm_eps,
469
  )
 
23
  Adapted from `modeling_esm.py` in huggingface/transformers.
24
  """
25
 
26
+ from typing import Literal, Optional
27
 
28
  # TODO: put import guard around transformer_engine here, with an informative error message around
29
  # installation and the nvidia docker container.
 
35
  from transformers.modeling_outputs import (
36
  BaseModelOutput,
37
  BaseModelOutputWithPooling,
 
38
  MaskedLMOutput,
39
  )
40
  from transformers.modeling_utils import PreTrainedModel
 
55
  self,
56
  qkv_weight_interleaved: bool = True,
57
  encoder_activation: str = "gelu",
58
+ attn_input_format: Literal["bshd", "thd"] = "bshd",
59
  fuse_qkv_params: bool = True,
60
  micro_batch_size: Optional[int] = None,
61
  max_seq_length: Optional[int] = None,
62
+ padded_vocab_size: Optional[int] = 64,
63
  **kwargs,
64
  ):
65
  """Initialize the NVEsmConfig with additional TE-related config options.
 
87
  max_seq_length: The maximum sequence length to use for the attention. This is needed for
88
  JIT Warmup, a technique where jit fused functions are warmed up before training to
89
  ensure same kernels are used for forward propogation and activation recompute phase.
90
+ padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults
91
+ to vocab_size. Must be greater than or equal to vocab_size.
92
  **kwargs: Additional config options to pass to EsmConfig.
93
  """
94
  super().__init__(**kwargs)
 
100
  self.micro_batch_size = micro_batch_size
101
  self.max_seq_length = max_seq_length
102
 
103
+ # Set padded_vocab_size with default fallback to vocab_size
104
+ self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size
105
+
106
+ # Ensure padded_vocab_size is at least as large as vocab_size
107
+ if self.padded_vocab_size is not None and self.vocab_size is not None:
108
+ assert self.padded_vocab_size >= self.vocab_size, (
109
+ f"padded_vocab_size ({self.padded_vocab_size}) must be greater than or equal to vocab_size ({self.vocab_size})"
110
+ )
111
+
112
 
113
  class NVEsmEncoder(nn.Module):
114
  """NVEsmEncoder is a TransformerEngine-optimized ESM encoder."""
 
149
  self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
150
  if config.position_embedding_type == "rotary":
151
  self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
152
+
153
+ # Keep on CPU, pin for faster non_blocking H2D; don't persist in state_dict.
154
+ if config.attn_input_format == "bshd":
155
+ self.register_buffer(
156
+ "te_rope_emb",
157
+ self.rotary_embeddings(max_seq_len=config.max_position_embeddings).cpu().pin_memory(),
158
+ persistent=False,
159
+ )
160
+ else:
161
+ self.te_rope_emb = None
162
 
163
  def forward(
164
  self,
165
  hidden_states: torch.Tensor,
166
  attention_mask: Optional[torch.Tensor] = None,
167
  output_hidden_states: bool = False,
168
+ cu_seq_lens_q: torch.IntTensor | None = None,
169
+ cu_seq_lens_k: torch.IntTensor | None = None,
170
+ max_length_q: int | None = None,
171
+ max_length_k: int | None = None,
172
  ):
173
  """Forward pass of the NVEsmEncoder.
174
 
 
176
  hidden_states (torch.Tensor): The hidden states.
177
  attention_mask (torch.Tensor): The attention mask.
178
  output_hidden_states (bool): Whether to output the hidden states.
179
+ cu_seq_lens_q (torch.IntTensor): The cumulative sequence lengths for the query state, if using THD inputs.
180
+ cu_seq_lens_k (torch.IntTensor): The cumulative sequence lengths for the key state, if using THD inputs.
181
+ max_length_q (int): The maximum length for the query state, if using THD inputs.
182
+ max_length_k (int): The maximum length for the key state, if using THD inputs.
183
  """
184
+ all_hidden_states: tuple[torch.Tensor, ...] = ()
185
 
186
+ if self.config.attn_input_format == "thd":
187
+ if any(x is None for x in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]):
188
+ raise ValueError(
189
+ "cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k must be provided when using THD inputs."
190
+ )
191
+ assert hidden_states.dim() == 3 and hidden_states.size(0) == 1, (
192
+ "THD expects embeddings shaped [1, total_tokens, hidden_size]."
193
+ )
194
+ hidden_states = hidden_states.squeeze(0)
195
+
196
+ elif self.config.attn_input_format == "bshd":
197
+ if any(x is not None for x in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]):
198
+ raise ValueError(
199
+ "cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k are not allowed when using BSHD inputs."
200
+ )
201
+
202
+ te_rope_emb = None
203
+ if self.config.position_embedding_type == "rotary":
204
+ if self.config.attn_input_format == "bshd":
205
+ te_rope_emb = self.te_rope_emb.to(
206
+ device=hidden_states.device, dtype=hidden_states.dtype, non_blocking=True
207
+ )
208
+ seq_len = hidden_states.shape[1]
209
+ if te_rope_emb.size(0) < seq_len:
210
+ raise RuntimeError(
211
+ f"ROPE length {te_rope_emb.size(0)} < input seq length {seq_len}. "
212
+ f"Increase max_position_embeddings."
213
+ )
214
+ te_rope_emb = te_rope_emb[:seq_len]
215
+
216
+ elif self.config.attn_input_format == "thd":
217
+ assert cu_seq_lens_q is not None
218
+ te_rope_emb = self.rotary_embeddings(max_seq_len=cu_seq_lens_q[-1]).to(
219
+ device=hidden_states.device, dtype=hidden_states.dtype, non_blocking=True
220
+ )
221
+
222
+ else:
223
+ raise ValueError(f"Unsupported attention input format: {self.config.attn_input_format}")
224
 
225
  for layer_module in self.layers:
226
  if output_hidden_states:
 
230
  hidden_states,
231
  attention_mask,
232
  rotary_pos_emb=te_rope_emb,
233
+ cu_seqlens_q=cu_seq_lens_q,
234
+ cu_seqlens_kv=cu_seq_lens_k,
235
+ max_seqlen_q=max_length_q,
236
+ max_seqlen_kv=max_length_k,
237
  )
238
 
239
  hidden_states = self.emb_layer_norm_after(hidden_states)
 
243
 
244
  return BaseModelOutput(
245
  last_hidden_state=hidden_states,
246
+ hidden_states=all_hidden_states if all_hidden_states else None,
247
  )
248
 
249
 
 
302
  super().__init__(config)
303
  self.config = config
304
 
305
+ # Create EsmEmbeddings with temporarily modified config to use padded vocab size
306
+ # This ensures the word embeddings layer uses the padded vocabulary size for FP8 support
307
+ original_vocab_size = config.vocab_size
308
+ config.vocab_size = config.padded_vocab_size
309
+ # Ensure pad_token_id is set properly, defaulting to 0 if not specified
310
+ if not hasattr(config, "pad_token_id") or config.pad_token_id is None:
311
+ config.pad_token_id = 0
312
  self.embeddings = EsmEmbeddings(config)
313
+ config.vocab_size = original_vocab_size # Restore original vocab_size
314
  self.encoder = NVEsmEncoder(config)
315
  self.pooler = EsmPooler(config) if add_pooling_layer else None
316
 
 
337
  head_mask: Optional[torch.Tensor] = None,
338
  inputs_embeds: Optional[torch.Tensor] = None,
339
  output_hidden_states: Optional[bool] = None,
340
+ cu_seq_lens_q: torch.IntTensor | None = None,
341
+ cu_seq_lens_k: torch.IntTensor | None = None,
342
+ max_length_q: int | None = None,
343
+ max_length_k: int | None = None,
344
+ ) -> BaseModelOutputWithPooling:
345
  """Forward pass of the NVEsmModel.
346
 
347
  Args:
 
351
  head_mask (torch.Tensor): The head mask.
352
  inputs_embeds (torch.Tensor): The input embeddings.
353
  output_hidden_states (bool): Whether to output the hidden states.
354
+ cu_seq_lens_q (torch.IntTensor): The cumulative sequence lengths for the query state, if using THD inputs.
355
+ cu_seq_lens_k (torch.IntTensor): The cumulative sequence lengths for the key state, if using THD inputs.
356
+ max_length_q (int): The maximum length for the query state, if using THD inputs.
357
+ max_length_k (int): The maximum length for the key state, if using THD inputs.
358
 
359
  Returns:
360
  BaseModelOutputWithPooling: The output of the model.
361
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  output_hidden_states = (
363
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
364
  )
 
403
  embedding_output,
404
  attention_mask=extended_attention_mask,
405
  output_hidden_states=output_hidden_states,
406
+ cu_seq_lens_q=cu_seq_lens_q,
407
+ cu_seq_lens_k=cu_seq_lens_k,
408
+ max_length_q=max_length_q,
409
+ max_length_k=max_length_k,
410
  )
411
  sequence_output = encoder_outputs[0]
412
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
 
459
  inputs_embeds: Optional[torch.FloatTensor] = None,
460
  labels: Optional[torch.LongTensor] = None,
461
  output_hidden_states: Optional[bool] = None,
462
+ cu_seq_lens_q: torch.IntTensor | None = None,
463
+ cu_seq_lens_k: torch.IntTensor | None = None,
464
+ max_length_q: int | None = None,
465
+ max_length_k: int | None = None,
466
+ ) -> MaskedLMOutput:
467
  """Forward pass of the NVEsmForMaskedLM.
468
 
469
  Args:
 
473
  inputs_embeds (torch.FloatTensor): The input embeddings.
474
  labels (torch.LongTensor): The labels.
475
  output_hidden_states (bool): Whether to output the hidden states.
476
+ cu_seq_lens_q (torch.IntTensor): The cumulative sequence lengths for the query state, if using THD inputs.
477
+ cu_seq_lens_k (torch.IntTensor): The cumulative sequence lengths for the key state, if using THD inputs.
478
+ max_length_q (int): The maximum length for the query state, if using THD inputs.
479
+ max_length_k (int): The maximum length for the key state, if using THD inputs.
480
 
481
  Returns:
482
  MaskedLMOutput: The output of the model.
483
  """
 
 
 
 
 
 
 
 
484
  outputs = self.esm(
485
  input_ids,
486
  attention_mask=attention_mask,
487
  position_ids=position_ids,
488
  inputs_embeds=inputs_embeds,
489
  output_hidden_states=output_hidden_states,
490
+ cu_seq_lens_q=cu_seq_lens_q,
491
+ cu_seq_lens_k=cu_seq_lens_k,
492
+ max_length_q=max_length_q,
493
+ max_length_k=max_length_k,
494
  )
495
  sequence_output = outputs[0]
496
  prediction_scores = self.lm_head(sequence_output)
497
 
498
+ # Truncate logits back to original vocab_size if padding was used
499
+ if self.config.padded_vocab_size != self.config.vocab_size:
500
+ prediction_scores = prediction_scores[..., : self.config.vocab_size]
501
+
502
  masked_lm_loss = None
503
  if labels is not None:
504
  loss_fct = CrossEntropyLoss()
505
+ masked_lm_loss = loss_fct(
506
+ prediction_scores.view(-1, self.config.vocab_size),
507
+ labels.to(prediction_scores.device).view(-1),
508
+ )
509
 
510
  return MaskedLMOutput(
511
  loss=masked_lm_loss,
 
513
  hidden_states=outputs.hidden_states,
514
  )
515
 
 
 
 
 
 
 
 
 
 
 
 
 
516
 
517
  class NVEsmLMHead(nn.Module):
518
  """ESM Head for masked language modeling using TransformerEngine."""
 
528
 
529
  self.decoder = transformer_engine.pytorch.LayerNormLinear(
530
  config.hidden_size,
531
+ config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size,
532
  bias=True,
533
  eps=config.layer_norm_eps,
534
  )
model-00001-of-00003.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:556b6aa9e921bf1caf2826232e2f1dd7bf40e62b3f07275fc37880c97e3c2745
3
- size 4930807200
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e8931ca9ef4b515e3ad09daca46bec50a98bc580e2a8a87dcd2a73a8ebbc0a6
3
+ size 4931124640
model-00003-of-00003.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:af3f10ee45ecb126c2b94a0c06688d63a3f3f2249cee1ea4e5e967ba9e70c850
3
- size 1494863401
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd83d7ac3fe52dde4fc414511ee0babb770422a19b26a6f83b6f4ab39a6e0114
3
+ size 1494863525
model.safetensors.index.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "metadata": {
3
- "total_parameters": 2839004193,
4
- "total_size": 11356016905
5
  },
6
  "weight_map": {
7
  "esm.embeddings.word_embeddings.weight": "model-00001-of-00003.safetensors",
 
1
  {
2
  "metadata": {
3
+ "total_parameters": 2839083584,
4
+ "total_size": 11356334469
5
  },
6
  "weight_map": {
7
  "esm.embeddings.word_embeddings.weight": "model-00001-of-00003.safetensors",