fix: make sure data and adapter on same device (#11)
Browse files- fix: make sure data and adapter on same device (08577bc2e88cb6d2e7ffa9fb2c45ba7c16c02836)
- custom_st.py +1 -2
custom_st.py
CHANGED
|
@@ -55,7 +55,6 @@ class Transformer(nn.Module):
|
|
| 55 |
|
| 56 |
config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
|
| 57 |
self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, **model_args)
|
| 58 |
-
self.device = next(self.auto_model.parameters()).device
|
| 59 |
|
| 60 |
self._lora_adaptations = config.lora_adaptations
|
| 61 |
if (
|
|
@@ -111,7 +110,7 @@ class Transformer(nn.Module):
|
|
| 111 |
num_examples = len(features['input_ids'])
|
| 112 |
|
| 113 |
adapter_mask = torch.full(
|
| 114 |
-
(num_examples,), task_id, dtype=torch.int32, device=
|
| 115 |
)
|
| 116 |
|
| 117 |
lora_arguments = (
|
|
|
|
| 55 |
|
| 56 |
config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
|
| 57 |
self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, **model_args)
|
|
|
|
| 58 |
|
| 59 |
self._lora_adaptations = config.lora_adaptations
|
| 60 |
if (
|
|
|
|
| 110 |
num_examples = len(features['input_ids'])
|
| 111 |
|
| 112 |
adapter_mask = torch.full(
|
| 113 |
+
(num_examples,), task_id, dtype=torch.int32, device=features['input_ids'].device
|
| 114 |
)
|
| 115 |
|
| 116 |
lora_arguments = (
|