Update custom_st.py
Browse files- custom_st.py +5 -2
custom_st.py
CHANGED
|
@@ -51,8 +51,8 @@ class Transformer(nn.Module):
|
|
| 51 |
if config_args is None:
|
| 52 |
config_args = {}
|
| 53 |
|
|
|
|
| 54 |
self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
|
| 55 |
-
self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir, **model_args)
|
| 56 |
|
| 57 |
self._lora_adaptations = self.config.lora_adaptations
|
| 58 |
if (
|
|
@@ -65,7 +65,10 @@ class Transformer(nn.Module):
|
|
| 65 |
self._adaptation_map = {
|
| 66 |
name: idx for idx, name in enumerate(self._lora_adaptations)
|
| 67 |
}
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
if max_seq_length is not None and "model_max_length" not in tokenizer_args:
|
| 71 |
tokenizer_args["model_max_length"] = max_seq_length
|
|
|
|
| 51 |
if config_args is None:
|
| 52 |
config_args = {}
|
| 53 |
|
| 54 |
+
|
| 55 |
self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
|
|
|
|
| 56 |
|
| 57 |
self._lora_adaptations = self.config.lora_adaptations
|
| 58 |
if (
|
|
|
|
| 65 |
self._adaptation_map = {
|
| 66 |
name: idx for idx, name in enumerate(self._lora_adaptations)
|
| 67 |
}
|
| 68 |
+
|
| 69 |
+
self.default_task = model_args.pop('default_task', None)
|
| 70 |
+
|
| 71 |
+
self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir, **model_args)
|
| 72 |
|
| 73 |
if max_seq_length is not None and "model_max_length" not in tokenizer_args:
|
| 74 |
tokenizer_args["model_max_length"] = max_seq_length
|