add handler.py for HF Dedicated Inference
#14
by
jmbrito
- opened
- handler.py +102 -0
- requirements.txt +2 -0
handler.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import NougatProcessor, VisionEncoderDecoderModel, StoppingCriteria, StoppingCriteriaList
|
| 2 |
+
import torch.cuda
|
| 3 |
+
import io
|
| 4 |
+
import base64
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from typing import Dict, Any
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
|
| 9 |
+
class RunningVarTorch:
|
| 10 |
+
def __init__(self, L=15, norm=False):
|
| 11 |
+
self.values = None
|
| 12 |
+
self.L = L
|
| 13 |
+
self.norm = norm
|
| 14 |
+
|
| 15 |
+
def push(self, x: torch.Tensor):
|
| 16 |
+
assert x.dim() == 1
|
| 17 |
+
if self.values is None:
|
| 18 |
+
self.values = x[:, None]
|
| 19 |
+
elif self.values.shape[1] < self.L:
|
| 20 |
+
self.values = torch.cat((self.values, x[:, None]), 1)
|
| 21 |
+
else:
|
| 22 |
+
self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
|
| 23 |
+
|
| 24 |
+
def variance(self):
|
| 25 |
+
if self.values is None:
|
| 26 |
+
return
|
| 27 |
+
if self.norm:
|
| 28 |
+
return torch.var(self.values, 1) / self.values.shape[1]
|
| 29 |
+
else:
|
| 30 |
+
return torch.var(self.values, 1)
|
| 31 |
+
|
| 32 |
+
class StoppingCriteriaScores(StoppingCriteria):
|
| 33 |
+
def __init__(self, threshold: float = 0.015, window_size: int = 200):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.threshold = threshold
|
| 36 |
+
self.vars = RunningVarTorch(norm=True)
|
| 37 |
+
self.varvars = RunningVarTorch(L=window_size)
|
| 38 |
+
self.stop_inds = defaultdict(int)
|
| 39 |
+
self.stopped = defaultdict(bool)
|
| 40 |
+
self.size = 0
|
| 41 |
+
self.window_size = window_size
|
| 42 |
+
|
| 43 |
+
@torch.no_grad()
|
| 44 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
| 45 |
+
last_scores = scores[-1]
|
| 46 |
+
self.vars.push(last_scores.max(1)[0].float().cpu())
|
| 47 |
+
self.varvars.push(self.vars.variance())
|
| 48 |
+
self.size += 1
|
| 49 |
+
if self.size < self.window_size:
|
| 50 |
+
return False
|
| 51 |
+
|
| 52 |
+
varvar = self.varvars.variance()
|
| 53 |
+
for b in range(len(last_scores)):
|
| 54 |
+
if varvar[b] < self.threshold:
|
| 55 |
+
if self.stop_inds[b] > 0 and not self.stopped[b]:
|
| 56 |
+
self.stopped[b] = self.stop_inds[b] >= self.size
|
| 57 |
+
else:
|
| 58 |
+
self.stop_inds[b] = int(
|
| 59 |
+
min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095)
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
self.stop_inds[b] = 0
|
| 63 |
+
self.stopped[b] = False
|
| 64 |
+
return all(self.stopped.values()) and len(self.stopped) > 0
|
| 65 |
+
|
| 66 |
+
class EndpointHandler():
|
| 67 |
+
def __init__(self, path="facebook/nougat-base"):
|
| 68 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 69 |
+
self.processor = NougatProcessor.from_pretrained(path)
|
| 70 |
+
self.model = VisionEncoderDecoderModel.from_pretrained(path)
|
| 71 |
+
self.model = self.model.to(self.device)
|
| 72 |
+
|
| 73 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 74 |
+
"""
|
| 75 |
+
Args:
|
| 76 |
+
data (Dict): The payload with the text prompt
|
| 77 |
+
and generation parameters.
|
| 78 |
+
"""
|
| 79 |
+
# Get inputs
|
| 80 |
+
input = data.pop("inputs", None)
|
| 81 |
+
parameters = data.pop("parameters", None)
|
| 82 |
+
fix_markdown = data.pop("fix_markdown", None)
|
| 83 |
+
if input is None:
|
| 84 |
+
raise ValueError("Missing image.")
|
| 85 |
+
# autoregressively generate tokens, with custom stopping criteria (as defined by the Nougat authors)
|
| 86 |
+
binary_data = base64.b64decode(input)
|
| 87 |
+
|
| 88 |
+
image = Image.open(io.BytesIO(binary_data))
|
| 89 |
+
pixel_values = self.processor(images= image, return_tensors="pt").pixel_values
|
| 90 |
+
outputs = self.model.generate(
|
| 91 |
+
pixel_values=pixel_values.to(self.model.device),
|
| 92 |
+
min_length=1,
|
| 93 |
+
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
|
| 94 |
+
return_dict_in_generate=True,
|
| 95 |
+
output_scores=True,
|
| 96 |
+
stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()]),
|
| 97 |
+
**parameters,
|
| 98 |
+
)
|
| 99 |
+
generated = self.processor.batch_decode(outputs[0], skip_special_tokens=True)[0]
|
| 100 |
+
prediction = self.processor.post_process_generation(generated, fix_markdown=fix_markdown)
|
| 101 |
+
|
| 102 |
+
return {"generated_text": prediction}
|
requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python-Levenshtein
|
| 2 |
+
nltk
|