|
|
from typing import Any, Dict, List |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
|
|
MAX_TOKENS_IN_BATCH = 4_000 |
|
|
DEFAULT_MAX_NEW_TOKENS = 10 |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
This class is used to handle the inference with pre and post process for |
|
|
text2text models. See |
|
|
https://huggingface.co/docs/inference-endpoints/guides/custom_handler for |
|
|
more details. |
|
|
""" |
|
|
|
|
|
def __init__(self, path: str = ""): |
|
|
try: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(path, device_map="auto") |
|
|
except: |
|
|
import accelerate |
|
|
|
|
|
print(f"ACCELERATE VERSION: {accelerate.__version__}") |
|
|
raise |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
This method is called when the endpoint is called. |
|
|
|
|
|
Arguments |
|
|
--------- |
|
|
data (Dict[str, Any]): |
|
|
Must contains the input data under `input` key and any |
|
|
parameters for the inference under `parameters`. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
output (List[Dict[str, Any]]): |
|
|
A list, length equal to the number of outputted characters, |
|
|
where each item is a dictionary containing `generated_text` (i.e |
|
|
the character), `perplexity` and `first_token_probs`. |
|
|
""" |
|
|
input_texts = data["inputs"] |
|
|
generate_kwargs = data.get("parameters", {}) |
|
|
|
|
|
check_first_tokens = generate_kwargs.pop("check_first_tokens", None) |
|
|
max_new_tokens = ( |
|
|
generate_kwargs.pop("max_new_tokens", None) or DEFAULT_MAX_NEW_TOKENS |
|
|
) |
|
|
|
|
|
|
|
|
inputs = self.tokenizer( |
|
|
input_texts, return_tensors="pt", padding=True, truncation=True, |
|
|
)["input_ids"] |
|
|
|
|
|
|
|
|
assert inputs.dim() == 2, f"Inputs have dimension {inputs.dim()} != 2" |
|
|
total_tokens = inputs.shape[0] * (inputs.shape[1] + max_new_tokens - 1) |
|
|
assert ( |
|
|
total_tokens <= MAX_TOKENS_IN_BATCH |
|
|
), f"Passed {total_tokens} (shape: {inputs.shape}, max_new_tokens: {max_new_tokens}), which is greater than limit of {MAX_TOKENS_IN_BATCH}" |
|
|
|
|
|
|
|
|
inputs = inputs.to("cuda:0") |
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
inputs, |
|
|
output_scores=True, |
|
|
return_dict_in_generate=True, |
|
|
max_new_tokens=max_new_tokens, |
|
|
**generate_kwargs, |
|
|
) |
|
|
inputs = inputs.to("cpu") |
|
|
scores = [s.to("cpu") for s in outputs.scores] |
|
|
del outputs |
|
|
|
|
|
|
|
|
to_return: Dict[str, Any] = { |
|
|
"generated_text": self._output_text_from_scores(scores), |
|
|
"perplexity": [float(p) for p in self._perplexity(scores)], |
|
|
} |
|
|
if check_first_tokens: |
|
|
to_return["first_token_probs"] = self._get_first_token_probs( |
|
|
check_first_tokens, scores |
|
|
) |
|
|
|
|
|
|
|
|
return [ |
|
|
{key: to_return[key][ndx] for key in to_return.keys()} |
|
|
for ndx in range(len(to_return["generated_text"])) |
|
|
] |
|
|
|
|
|
def _output_text_from_scores(self, scores: List[torch.Tensor]) -> List[str]: |
|
|
""" |
|
|
Returns the decoded text from the scores. |
|
|
TODO (ENG-20823): Use the returned sequences so we pay attention to |
|
|
things like bad_words, force_words etc. |
|
|
""" |
|
|
|
|
|
batch_token_ids = [ |
|
|
[score[ndx].argmax() for score in scores] |
|
|
for ndx in range(scores[0].shape[0]) |
|
|
] |
|
|
|
|
|
new_batch_token_ids = [] |
|
|
for token_ids in batch_token_ids: |
|
|
try: |
|
|
new_token_ids = token_ids[ |
|
|
: token_ids.index(self.tokenizer.eos_token_id) |
|
|
] |
|
|
except ValueError: |
|
|
new_token_ids = token_ids[:-1] |
|
|
|
|
|
new_batch_token_ids.append(new_token_ids) |
|
|
return self.tokenizer.batch_decode(new_batch_token_ids) |
|
|
|
|
|
def _perplexity(self, scores: List[torch.Tensor]) -> List[float]: |
|
|
""" |
|
|
Returns the perplexity (model confidence) of the outputted text. |
|
|
e^( sum(ln(p(word))) / N) |
|
|
|
|
|
TODO (ENG-20823): don't include the trailing pad tokens in perplexity |
|
|
""" |
|
|
|
|
|
return torch.exp( |
|
|
torch.stack( |
|
|
[score.softmax(axis=1).log().max(axis=1)[0] for score in scores] |
|
|
).sum(axis=0) |
|
|
/ len(scores) |
|
|
).tolist() |
|
|
|
|
|
def _get_first_token_probs( |
|
|
self, tokens: List[str], scores: List[torch.Tensor] |
|
|
) -> List[Dict[str, float]]: |
|
|
""" |
|
|
Return the softmaxed probabilities of the specific tokens for each |
|
|
output |
|
|
""" |
|
|
first_token_probs = [] |
|
|
softmaxed_scores = scores[0].softmax(axis=1) |
|
|
|
|
|
|
|
|
|
|
|
token_ids = {} |
|
|
for token in tokens: |
|
|
encoded_token: List[int] = self.tokenizer.encode(token) |
|
|
if len(encoded_token) > 2: |
|
|
|
|
|
token_ids[token] = -1 |
|
|
else: |
|
|
token_ids[token] = encoded_token[0] |
|
|
|
|
|
|
|
|
for seq_ndx in range(scores[0].shape[0]): |
|
|
curr_token_probs: Dict[str, float] = {} |
|
|
|
|
|
for token in tokens: |
|
|
if token_ids[token] == -1: |
|
|
curr_token_probs[token] = 0 |
|
|
else: |
|
|
curr_token_probs[token] = float( |
|
|
softmaxed_scores[seq_ndx, token_ids[token]] |
|
|
) |
|
|
|
|
|
first_token_probs.append(curr_token_probs) |
|
|
|
|
|
return first_token_probs |
|
|
|