Update tokenization_dart.py
Browse files- tokenization_dart.py +1 -23
tokenization_dart.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
| 1 |
import logging
|
| 2 |
-
import
|
| 3 |
-
from typing import Dict, List
|
| 4 |
-
from pydantic.dataclasses import dataclass
|
| 5 |
|
| 6 |
from transformers import PreTrainedTokenizerFast
|
| 7 |
from tokenizers.decoders import Decoder
|
|
@@ -57,26 +55,6 @@ PROMPT_TEMPLATE = (
|
|
| 57 |
# fmt: on
|
| 58 |
|
| 59 |
|
| 60 |
-
@dataclass
|
| 61 |
-
class Category:
|
| 62 |
-
name: str
|
| 63 |
-
bos_token_id: int
|
| 64 |
-
eos_token_id: int
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
@dataclass
|
| 68 |
-
class TagCategoryConfig:
|
| 69 |
-
categories: Dict[str, Category]
|
| 70 |
-
category_to_token_ids: Dict[str, List[int]]
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def load_tag_category_config(config_json: str):
|
| 74 |
-
with open(config_json, "rb") as file:
|
| 75 |
-
config: TagCategoryConfig = TagCategoryConfig(**json.loads(file.read()))
|
| 76 |
-
|
| 77 |
-
return config
|
| 78 |
-
|
| 79 |
-
|
| 80 |
class DartDecoder:
|
| 81 |
def __init__(self, special_tokens: List[str]):
|
| 82 |
self.special_tokens = list(special_tokens)
|
|
|
|
| 1 |
import logging
|
| 2 |
+
from typing import List
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from transformers import PreTrainedTokenizerFast
|
| 5 |
from tokenizers.decoders import Decoder
|
|
|
|
| 55 |
# fmt: on
|
| 56 |
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
class DartDecoder:
|
| 59 |
def __init__(self, special_tokens: List[str]):
|
| 60 |
self.special_tokens = list(special_tokens)
|