| """Callback handlers used in the app.""" | |
| from typing import Any, Dict, List | |
| from langchain.callbacks.base import AsyncCallbackHandler | |
| from schemas import ChatResponse | |
| class StreamingLLMCallbackHandler(AsyncCallbackHandler): | |
| """Callback handler for streaming LLM responses.""" | |
| def __init__(self, websocket): | |
| self.websocket = websocket | |
| async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
| resp = ChatResponse(sender="bot", message=token, type="stream") | |
| await self.websocket.send_json(resp.dict()) | |
| class QuestionGenCallbackHandler(AsyncCallbackHandler): | |
| """Callback handler for question generation.""" | |
| def __init__(self, websocket): | |
| self.websocket = websocket | |
| async def on_llm_start( | |
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |
| ) -> None: | |
| """Run when LLM starts running.""" | |
| resp = ChatResponse( | |
| sender="bot", message="Synthesizing question...", type="info" | |
| ) | |
| await self.websocket.send_json(resp.dict()) |