Spaces:
Paused
Paused
| import json | |
| import requests | |
| def check_server_health(cloud_gateway_api: str): | |
| """ | |
| Use the appropriate API endpoint to check the server health. | |
| Args: | |
| cloud_gateway_api: API endpoint to probe. | |
| Returns: | |
| True if server is active, false otherwise. | |
| """ | |
| try: | |
| response = requests.get(cloud_gateway_api + "/health") | |
| if response.status_code == 200: | |
| return True | |
| except requests.ConnectionError: | |
| print("Failed to establish connection to the server.") | |
| return False | |
| def request_generation(message: str, | |
| system_prompt: str, | |
| cloud_gateway_api: str, | |
| max_new_tokens: int = 1024, | |
| temperature: float = 0.6, | |
| top_p: float = 0.9, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.2, ): | |
| """ | |
| Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize | |
| token-by-token generation from LLM. | |
| Args: | |
| message: prompt from the user. | |
| system_prompt: system prompt to append. | |
| cloud_gateway_api (str): API endpoint to send the request. | |
| max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt. | |
| temperature: the value used to module the next token probabilities. | |
| top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p | |
| or higher are kept for generation. | |
| top_k: the number of highest probability vocabulary tokens to keep for top-k-filtering. | |
| repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty. | |
| Returns: | |
| """ | |
| payload = { | |
| "model": "meta-llama/Meta-Llama-3-8B-Instruct", | |
| "messages": [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": message} | |
| ], | |
| "max_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "repetition_penalty": repetition_penalty, | |
| "top_k": top_k, | |
| "stream": True # Enable streaming | |
| } | |
| with requests.post(cloud_gateway_api + "/v1/chat/completions", json=payload, stream=True) as response: | |
| for chunk in response.iter_lines(): | |
| if chunk: | |
| # Convert the chunk from bytes to a string and then parse it as json | |
| chunk_str = chunk.decode('utf-8') | |
| # Remove the `data: ` prefix from the chunk if it exists | |
| if chunk_str.startswith("data: "): | |
| chunk_str = chunk_str[len("data: "):] | |
| # Skip empty chunks | |
| if chunk_str.strip() == "[DONE]": | |
| break | |
| # Parse the chunk into a JSON object | |
| try: | |
| chunk_json = json.loads(chunk_str) | |
| # Extract the "content" field from the choices | |
| content = chunk_json["choices"][0]["delta"].get("content", "") | |
| # Print the generated content as it's streamed | |
| if content: | |
| yield content | |
| except json.JSONDecodeError: | |
| # Handle any potential errors in decoding | |
| continue | |