Upload FlaxTransformerLMForCausalLM
Browse files
modeling_transformerlm_flax.py
CHANGED
|
@@ -426,6 +426,7 @@ class FlaxTransformerLMPreTrainedModel(FlaxPreTrainedModel):
|
|
| 426 |
mutable=mutable,
|
| 427 |
)
|
| 428 |
lm_logits = output.logits
|
|
|
|
| 429 |
if input_ids.shape[1] > 1:
|
| 430 |
lm_logits = lm_logits[:, 1:, :] # Ignore leading zeros in prompts
|
| 431 |
|
|
|
|
| 426 |
mutable=mutable,
|
| 427 |
)
|
| 428 |
lm_logits = output.logits
|
| 429 |
+
|
| 430 |
if input_ids.shape[1] > 1:
|
| 431 |
lm_logits = lm_logits[:, 1:, :] # Ignore leading zeros in prompts
|
| 432 |
|