Trouter-Library commited on
Commit
dccc9c1
·
verified ·
1 Parent(s): ed61a86

Create inference/server.py

Browse files
Files changed (1) hide show
  1. inference/server.py +405 -0
inference/server.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Helion-2.5-Rnd Inference Server
4
+ High-performance inference server with vLLM backend
5
+ """
6
+
7
+ import argparse
8
+ import asyncio
9
+ import json
10
+ import logging
11
+ import os
12
+ import time
13
+ from typing import AsyncGenerator, Dict, List, Optional, Union
14
+
15
+ import torch
16
+ import uvicorn
17
+ from fastapi import FastAPI, HTTPException, Request
18
+ from fastapi.middleware.cors import CORSMiddleware
19
+ from fastapi.responses import JSONResponse, StreamingResponse
20
+ from pydantic import BaseModel, Field
21
+ from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
22
+ from vllm.utils import random_uuid
23
+
24
+ # Configure logging
25
+ logging.basicConfig(
26
+ level=logging.INFO,
27
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
28
+ )
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class ChatMessage(BaseModel):
33
+ """Chat message format"""
34
+ role: str = Field(..., description="Role: system, user, or assistant")
35
+ content: str = Field(..., description="Message content")
36
+
37
+
38
+ class ChatCompletionRequest(BaseModel):
39
+ """Chat completion request format"""
40
+ model: str = Field(default="DeepXR/Helion-2.5-Rnd")
41
+ messages: List[ChatMessage]
42
+ temperature: float = Field(default=0.7, ge=0.0, le=2.0)
43
+ top_p: float = Field(default=0.9, ge=0.0, le=1.0)
44
+ top_k: int = Field(default=50, ge=0)
45
+ max_tokens: int = Field(default=4096, ge=1)
46
+ stream: bool = Field(default=False)
47
+ stop: Optional[List[str]] = None
48
+ presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
49
+ frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
50
+ repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0)
51
+ n: int = Field(default=1, ge=1, le=10)
52
+ logprobs: Optional[int] = None
53
+ echo: bool = Field(default=False)
54
+
55
+
56
+ class CompletionRequest(BaseModel):
57
+ """Text completion request format"""
58
+ model: str = Field(default="DeepXR/Helion-2.5-Rnd")
59
+ prompt: Union[str, List[str]]
60
+ temperature: float = Field(default=0.7, ge=0.0, le=2.0)
61
+ top_p: float = Field(default=0.9, ge=0.0, le=1.0)
62
+ max_tokens: int = Field(default=4096, ge=1)
63
+ stream: bool = Field(default=False)
64
+ stop: Optional[List[str]] = None
65
+ n: int = Field(default=1, ge=1, le=10)
66
+
67
+
68
+ class HelionInferenceServer:
69
+ """Main inference server class"""
70
+
71
+ def __init__(
72
+ self,
73
+ model_path: str,
74
+ tensor_parallel_size: int = 2,
75
+ max_model_len: int = 131072,
76
+ gpu_memory_utilization: float = 0.95,
77
+ dtype: str = "bfloat16"
78
+ ):
79
+ self.model_path = model_path
80
+ self.model_name = "DeepXR/Helion-2.5-Rnd"
81
+
82
+ # Initialize vLLM engine
83
+ engine_args = AsyncEngineArgs(
84
+ model=model_path,
85
+ tensor_parallel_size=tensor_parallel_size,
86
+ max_model_len=max_model_len,
87
+ gpu_memory_utilization=gpu_memory_utilization,
88
+ dtype=dtype,
89
+ trust_remote_code=True,
90
+ enforce_eager=False,
91
+ disable_log_stats=False,
92
+ )
93
+
94
+ logger.info(f"Initializing Helion-2.5-Rnd from {model_path}")
95
+ self.engine = AsyncLLMEngine.from_engine_args(engine_args)
96
+ logger.info("Engine initialized successfully")
97
+
98
+ # Statistics
99
+ self.request_count = 0
100
+ self.start_time = time.time()
101
+
102
+ def format_chat_prompt(self, messages: List[ChatMessage]) -> str:
103
+ """Format chat messages into prompt"""
104
+ formatted = ""
105
+ for msg in messages:
106
+ formatted += f"<|im_start|>{msg.role}\n{msg.content}<|im_end|>\n"
107
+ formatted += "<|im_start|>assistant\n"
108
+ return formatted
109
+
110
+ async def generate(
111
+ self,
112
+ prompt: str,
113
+ sampling_params: SamplingParams,
114
+ request_id: str
115
+ ) -> AsyncGenerator[str, None]:
116
+ """Generate text streaming"""
117
+ results_generator = self.engine.generate(
118
+ prompt,
119
+ sampling_params,
120
+ request_id
121
+ )
122
+
123
+ async for request_output in results_generator:
124
+ text = request_output.outputs[0].text
125
+ yield text
126
+
127
+ async def chat_completion(
128
+ self,
129
+ request: ChatCompletionRequest
130
+ ) -> Union[Dict, AsyncGenerator]:
131
+ """Handle chat completion request"""
132
+ request_id = f"helion-{random_uuid()}"
133
+ self.request_count += 1
134
+
135
+ # Format prompt
136
+ prompt = self.format_chat_prompt(request.messages)
137
+
138
+ # Create sampling parameters
139
+ sampling_params = SamplingParams(
140
+ temperature=request.temperature,
141
+ top_p=request.top_p,
142
+ top_k=request.top_k,
143
+ max_tokens=request.max_tokens,
144
+ stop=request.stop or ["<|im_end|>", "<|endoftext|>"],
145
+ presence_penalty=request.presence_penalty,
146
+ frequency_penalty=request.frequency_penalty,
147
+ repetition_penalty=request.repetition_penalty,
148
+ n=request.n,
149
+ logprobs=request.logprobs,
150
+ )
151
+
152
+ if request.stream:
153
+ return self._stream_chat_completion(
154
+ prompt,
155
+ sampling_params,
156
+ request_id,
157
+ request.model
158
+ )
159
+ else:
160
+ return await self._complete_chat_completion(
161
+ prompt,
162
+ sampling_params,
163
+ request_id,
164
+ request.model
165
+ )
166
+
167
+ async def _complete_chat_completion(
168
+ self,
169
+ prompt: str,
170
+ sampling_params: SamplingParams,
171
+ request_id: str,
172
+ model: str
173
+ ) -> Dict:
174
+ """Non-streaming chat completion"""
175
+ final_output = None
176
+ async for request_output in self.engine.generate(
177
+ prompt, sampling_params, request_id
178
+ ):
179
+ final_output = request_output
180
+
181
+ if final_output is None:
182
+ raise HTTPException(status_code=500, detail="Generation failed")
183
+
184
+ choice = {
185
+ "index": 0,
186
+ "message": {
187
+ "role": "assistant",
188
+ "content": final_output.outputs[0].text
189
+ },
190
+ "finish_reason": final_output.outputs[0].finish_reason
191
+ }
192
+
193
+ return {
194
+ "id": request_id,
195
+ "object": "chat.completion",
196
+ "created": int(time.time()),
197
+ "model": model,
198
+ "choices": [choice],
199
+ "usage": {
200
+ "prompt_tokens": len(final_output.prompt_token_ids),
201
+ "completion_tokens": len(final_output.outputs[0].token_ids),
202
+ "total_tokens": len(final_output.prompt_token_ids) + len(final_output.outputs[0].token_ids)
203
+ }
204
+ }
205
+
206
+ async def _stream_chat_completion(
207
+ self,
208
+ prompt: str,
209
+ sampling_params: SamplingParams,
210
+ request_id: str,
211
+ model: str
212
+ ) -> AsyncGenerator:
213
+ """Streaming chat completion"""
214
+ async def generate():
215
+ previous_text = ""
216
+ async for request_output in self.engine.generate(
217
+ prompt, sampling_params, request_id
218
+ ):
219
+ text = request_output.outputs[0].text
220
+ delta = text[len(previous_text):]
221
+ previous_text = text
222
+
223
+ chunk = {
224
+ "id": request_id,
225
+ "object": "chat.completion.chunk",
226
+ "created": int(time.time()),
227
+ "model": model,
228
+ "choices": [{
229
+ "index": 0,
230
+ "delta": {"content": delta},
231
+ "finish_reason": None
232
+ }]
233
+ }
234
+ yield f"data: {json.dumps(chunk)}\n\n"
235
+
236
+ # Final chunk
237
+ final_chunk = {
238
+ "id": request_id,
239
+ "object": "chat.completion.chunk",
240
+ "created": int(time.time()),
241
+ "model": model,
242
+ "choices": [{
243
+ "index": 0,
244
+ "delta": {},
245
+ "finish_reason": "stop"
246
+ }]
247
+ }
248
+ yield f"data: {json.dumps(final_chunk)}\n\n"
249
+ yield "data: [DONE]\n\n"
250
+
251
+ return generate()
252
+
253
+
254
+ # Initialize FastAPI app
255
+ app = FastAPI(
256
+ title="Helion-2.5-Rnd Inference API",
257
+ description="Advanced language model inference server",
258
+ version="2.5.0-rnd"
259
+ )
260
+
261
+ # Add CORS middleware
262
+ app.add_middleware(
263
+ CORSMiddleware,
264
+ allow_origins=["*"],
265
+ allow_credentials=True,
266
+ allow_methods=["*"],
267
+ allow_headers=["*"],
268
+ )
269
+
270
+ # Global server instance
271
+ server: Optional[HelionInferenceServer] = None
272
+
273
+
274
+ @app.on_event("startup")
275
+ async def startup_event():
276
+ """Initialize the model on startup"""
277
+ global server
278
+
279
+ model_path = os.getenv("MODEL_PATH", "/models/helion")
280
+ tensor_parallel = int(os.getenv("TENSOR_PARALLEL_SIZE", "2"))
281
+ max_len = int(os.getenv("MAX_MODEL_LEN", "131072"))
282
+ gpu_util = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.95"))
283
+
284
+ server = HelionInferenceServer(
285
+ model_path=model_path,
286
+ tensor_parallel_size=tensor_parallel,
287
+ max_model_len=max_len,
288
+ gpu_memory_utilization=gpu_util
289
+ )
290
+ logger.info("Helion-2.5-Rnd server started successfully")
291
+
292
+
293
+ @app.get("/")
294
+ async def root():
295
+ """Root endpoint"""
296
+ return {
297
+ "model": "DeepXR/Helion-2.5-Rnd",
298
+ "version": "2.5.0-rnd",
299
+ "status": "ready",
300
+ "type": "research"
301
+ }
302
+
303
+
304
+ @app.get("/health")
305
+ async def health():
306
+ """Health check endpoint"""
307
+ if server is None:
308
+ raise HTTPException(status_code=503, detail="Server not initialized")
309
+
310
+ return {
311
+ "status": "healthy",
312
+ "model": server.model_name,
313
+ "requests_served": server.request_count,
314
+ "uptime_seconds": int(time.time() - server.start_time)
315
+ }
316
+
317
+
318
+ @app.get("/v1/models")
319
+ async def list_models():
320
+ """List available models"""
321
+ return {
322
+ "object": "list",
323
+ "data": [{
324
+ "id": "DeepXR/Helion-2.5-Rnd",
325
+ "object": "model",
326
+ "created": int(time.time()),
327
+ "owned_by": "DeepXR"
328
+ }]
329
+ }
330
+
331
+
332
+ @app.post("/v1/chat/completions")
333
+ async def chat_completions(request: ChatCompletionRequest):
334
+ """Chat completion endpoint"""
335
+ if server is None:
336
+ raise HTTPException(status_code=503, detail="Server not initialized")
337
+
338
+ try:
339
+ result = await server.chat_completion(request)
340
+
341
+ if request.stream:
342
+ return StreamingResponse(
343
+ result,
344
+ media_type="text/event-stream"
345
+ )
346
+ else:
347
+ return JSONResponse(content=result)
348
+
349
+ except Exception as e:
350
+ logger.error(f"Error in chat completion: {str(e)}")
351
+ raise HTTPException(status_code=500, detail=str(e))
352
+
353
+
354
+ @app.post("/v1/completions")
355
+ async def completions(request: CompletionRequest):
356
+ """Text completion endpoint"""
357
+ if server is None:
358
+ raise HTTPException(status_code=503, detail="Server not initialized")
359
+
360
+ # Convert to chat format
361
+ messages = [ChatMessage(role="user", content=request.prompt)]
362
+ chat_request = ChatCompletionRequest(
363
+ model=request.model,
364
+ messages=messages,
365
+ temperature=request.temperature,
366
+ top_p=request.top_p,
367
+ max_tokens=request.max_tokens,
368
+ stream=request.stream,
369
+ stop=request.stop,
370
+ n=request.n
371
+ )
372
+
373
+ return await chat_completions(chat_request)
374
+
375
+
376
+ def main():
377
+ """Main entry point"""
378
+ parser = argparse.ArgumentParser(description="Helion-2.5-Rnd Inference Server")
379
+ parser.add_argument("--model", type=str, default="/models/helion")
380
+ parser.add_argument("--host", type=str, default="0.0.0.0")
381
+ parser.add_argument("--port", type=int, default=8000)
382
+ parser.add_argument("--tensor-parallel-size", type=int, default=2)
383
+ parser.add_argument("--max-model-len", type=int, default=131072)
384
+ parser.add_argument("--gpu-memory-utilization", type=float, default=0.95)
385
+
386
+ args = parser.parse_args()
387
+
388
+ # Set environment variables
389
+ os.environ["MODEL_PATH"] = args.model
390
+ os.environ["TENSOR_PARALLEL_SIZE"] = str(args.tensor_parallel_size)
391
+ os.environ["MAX_MODEL_LEN"] = str(args.max_model_len)
392
+ os.environ["GPU_MEMORY_UTILIZATION"] = str(args.gpu_memory_utilization)
393
+
394
+ # Run server
395
+ uvicorn.run(
396
+ app,
397
+ host=args.host,
398
+ port=args.port,
399
+ log_level="info",
400
+ access_log=True
401
+ )
402
+
403
+
404
+ if __name__ == "__main__":
405
+ main()