Send attention_mask to device (#9)
Browse files- Send attention_mask to device (76fb033a3c0c4b1d6764fcf69699ba60f6ad942e)
Co-authored-by: Michael Verrilli <[email protected]>
- instruct_pipeline.py +1 -1
instruct_pipeline.py
CHANGED
|
@@ -131,7 +131,7 @@ class InstructionTextGenerationPipeline(Pipeline):
|
|
| 131 |
|
| 132 |
generated_sequence = self.model.generate(
|
| 133 |
input_ids=input_ids.to(self.model.device),
|
| 134 |
-
attention_mask=attention_mask,
|
| 135 |
pad_token_id=self.tokenizer.pad_token_id,
|
| 136 |
**generate_kwargs,
|
| 137 |
)
|
|
|
|
| 131 |
|
| 132 |
generated_sequence = self.model.generate(
|
| 133 |
input_ids=input_ids.to(self.model.device),
|
| 134 |
+
attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None,
|
| 135 |
pad_token_id=self.tokenizer.pad_token_id,
|
| 136 |
**generate_kwargs,
|
| 137 |
)
|