Spaces:
Running
Running
| import json | |
| import aiohttp | |
| class TextGenerator: | |
| def __init__(self, host_url): | |
| self.host_url = host_url.rstrip("/") + "/generate" | |
| self.host_url_stream = host_url.rstrip("/") + "/generate_stream" | |
| async def generate_text_async(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8): | |
| payload = { | |
| 'inputs': prompt, | |
| 'parameters': { | |
| 'max_new_tokens': max_new_tokens, | |
| 'do_sample': do_sample, | |
| 'temperature': temperature, | |
| } | |
| } | |
| headers = { | |
| 'Content-Type': 'application/json' | |
| } | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post(self.host_url, data=json.dumps(payload), headers=headers) as response: | |
| if response.status == 200: | |
| data = await response.json() | |
| text = data["generated_text"] | |
| return text | |
| else: | |
| # Handle error responses here | |
| return None | |
| def generate_text(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8): | |
| import requests | |
| payload = { | |
| 'inputs': prompt, | |
| 'parameters': { | |
| 'max_new_tokens': max_new_tokens, | |
| 'do_sample': do_sample, | |
| 'temperature': temperature, | |
| } | |
| } | |
| headers = { | |
| 'Content-Type': 'application/json' | |
| } | |
| response = requests.post(self.host_url, data=json.dumps(payload), headers=headers).json() | |
| text = response["generated_text"] | |
| return text | |
| def generate_text_stream(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8, stop=[], best_of=1): | |
| import requests | |
| payload = { | |
| 'inputs': prompt, | |
| 'parameters': { | |
| 'max_new_tokens': max_new_tokens, | |
| 'do_sample': do_sample, | |
| 'temperature': temperature, | |
| 'stop': stop, | |
| 'best_of': best_of, | |
| } | |
| } | |
| headers = { | |
| 'Content-Type': 'application/json', | |
| 'Cache-Control': 'no-cache', | |
| 'Connection': 'keep-alive' | |
| } | |
| response = requests.post(self.host_url_stream, data=json.dumps(payload), headers=headers, stream=True) | |
| for line in response.iter_lines(): | |
| if line: | |
| print(line) | |
| json_data = line.decode('utf-8') | |
| if json_data.startswith('data:'): | |
| print(json_data) | |
| json_data = json_data[5:] | |
| token_data = json.loads(json_data) | |
| token = token_data['token']['text'] | |
| if not token_data['token']['special']: | |
| yield token | |
| class SummarizerGenerator: | |
| def __init__(self, api): | |
| self.api = api | |
| def generate_summary_stream(self, text): | |
| import requests | |
| payload = {"text": text} | |
| headers = { | |
| 'Content-Type': 'application/json', | |
| 'Cache-Control': 'no-cache', | |
| 'Connection': 'keep-alive' | |
| } | |
| response = requests.post(self.api, data=json.dumps(payload), headers=headers, stream=True) | |
| for line in response.iter_lines(): | |
| if line: | |
| print(line) | |
| data = line.decode('utf-8').removesuffix('<|eot_id|>') | |
| if data.startswith("•"): | |
| data = data.replace("•", "-") | |
| if data.startswith("-"): | |
| data = "\n\n" + data | |
| yield data | |