File size: 6,234 Bytes
1061bb6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from typing import Any, Dict, List
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
MAX_TOKENS_IN_BATCH = 4_000 # Hard limit to prevent OOMs
DEFAULT_MAX_NEW_TOKENS = 10 # By default limit the output to 10 tokens
class EndpointHandler:
"""
This class is used to handle the inference with pre and post process for
text2text models. See
https://huggingface.co/docs/inference-endpoints/guides/custom_handler for
more details.
"""
def __init__(self, path: str = ""):
try:
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(path, device_map="auto")
except:
import accelerate
print(f"ACCELERATE VERSION: {accelerate.__version__}")
raise
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
This method is called when the endpoint is called.
Arguments
---------
data (Dict[str, Any]):
Must contains the input data under `input` key and any
parameters for the inference under `parameters`.
Returns
-------
output (List[Dict[str, Any]]):
A list, length equal to the number of outputted characters,
where each item is a dictionary containing `generated_text` (i.e
the character), `perplexity` and `first_token_probs`.
"""
input_texts = data["inputs"]
generate_kwargs = data.get("parameters", {})
# This is not technically a generate_kwarg, but needs to live under parameters
check_first_tokens = generate_kwargs.pop("check_first_tokens", None)
max_new_tokens = (
generate_kwargs.pop("max_new_tokens", None) or DEFAULT_MAX_NEW_TOKENS
)
# Tokenizing input texts
inputs = self.tokenizer(
input_texts, return_tensors="pt", padding=True, truncation=True,
)["input_ids"]
# Make sure not to OOM if too many inputs
assert inputs.dim() == 2, f"Inputs have dimension {inputs.dim()} != 2"
total_tokens = inputs.shape[0] * (inputs.shape[1] + max_new_tokens - 1)
assert (
total_tokens <= MAX_TOKENS_IN_BATCH
), f"Passed {total_tokens} (shape: {inputs.shape}, max_new_tokens: {max_new_tokens}), which is greater than limit of {MAX_TOKENS_IN_BATCH}"
# Run inference on GPU
inputs = inputs.to("cuda:0")
with torch.no_grad():
outputs = self.model.generate(
inputs,
output_scores=True,
return_dict_in_generate=True,
max_new_tokens=max_new_tokens,
**generate_kwargs,
)
inputs = inputs.to("cpu")
scores = [s.to("cpu") for s in outputs.scores]
del outputs
# process outputs
to_return: Dict[str, Any] = {
"generated_text": self._output_text_from_scores(scores),
"perplexity": [float(p) for p in self._perplexity(scores)],
}
if check_first_tokens:
to_return["first_token_probs"] = self._get_first_token_probs(
check_first_tokens, scores
)
# Reformat output to conform to HF Pipeline format
return [
{key: to_return[key][ndx] for key in to_return.keys()}
for ndx in range(len(to_return["generated_text"]))
]
def _output_text_from_scores(self, scores: List[torch.Tensor]) -> List[str]:
"""
Returns the decoded text from the scores.
TODO (ENG-20823): Use the returned sequences so we pay attention to
things like bad_words, force_words etc.
"""
# Always return list format
batch_token_ids = [
[score[ndx].argmax() for score in scores]
for ndx in range(scores[0].shape[0])
]
# Fix for new tokens being generated after EOS
new_batch_token_ids = []
for token_ids in batch_token_ids:
try:
new_token_ids = token_ids[
: token_ids.index(self.tokenizer.eos_token_id)
]
except ValueError:
new_token_ids = token_ids[:-1]
new_batch_token_ids.append(new_token_ids)
return self.tokenizer.batch_decode(new_batch_token_ids)
def _perplexity(self, scores: List[torch.Tensor]) -> List[float]:
"""
Returns the perplexity (model confidence) of the outputted text.
e^( sum(ln(p(word))) / N)
TODO (ENG-20823): don't include the trailing pad tokens in perplexity
"""
return torch.exp(
torch.stack(
[score.softmax(axis=1).log().max(axis=1)[0] for score in scores]
).sum(axis=0)
/ len(scores)
).tolist()
def _get_first_token_probs(
self, tokens: List[str], scores: List[torch.Tensor]
) -> List[Dict[str, float]]:
"""
Return the softmaxed probabilities of the specific tokens for each
output
"""
first_token_probs = []
softmaxed_scores = scores[0].softmax(axis=1)
# Finding the correct token IDs
# TODO (ENG-20824): Support multi-token words
token_ids = {}
for token in tokens:
encoded_token: List[int] = self.tokenizer.encode(token)
if len(encoded_token) > 2:
# This means the tokenizer broke the token up into multiple parts
token_ids[token] = -1
else:
token_ids[token] = encoded_token[0]
# Now finding the scores for each token in the list
for seq_ndx in range(scores[0].shape[0]):
curr_token_probs: Dict[str, float] = {}
for token in tokens:
if token_ids[token] == -1:
curr_token_probs[token] = 0
else:
curr_token_probs[token] = float(
softmaxed_scores[seq_ndx, token_ids[token]]
)
first_token_probs.append(curr_token_probs)
return first_token_probs
|