Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import re | |
| import math | |
| from typing import Optional | |
| from openai import OpenAI, OpenAIError | |
| from tools.registry import TOOL_REGISTRY | |
| from typing import Any, Callable, Dict, Optional, Type, Union | |
| from tools.base import Tool | |
| # Instantiate stateless tools once and store in a dict | |
| STATIC_TOOL_INSTANCES = {name: cls() for name, cls in TOOL_REGISTRY.items() if name != "history_taking"} | |
| """ | |
| skills_registry.py | |
| ------------------ | |
| Registry and helper functions for tool instantiation, syndrome matching, and embedding lookup. | |
| """ | |
| #match for KB | |
| # 1) Initialize the OpenAI client | |
| openai_api_key = os.getenv("OPENAI_API_KEY") | |
| if not openai_api_key: | |
| raise RuntimeError("OPENAI_API_KEY environment variable is not set. Please set it before running the application.") | |
| _client = OpenAI(api_key=openai_api_key) | |
| # 2) Load your precomputed syndrome embeddings | |
| # Make sure syndrome_embeddings.json is in your working dir | |
| with open("syndrome_embeddings.json", "r") as f: | |
| SYNDROME_EMBS = json.load(f) # { syndrome_key: [float, ...], … } | |
| def _cosine_sim(a: list[float], b: list[float]) -> float: | |
| """ | |
| Compute the cosine similarity between two vectors. | |
| Args: | |
| a (list[float]): First vector. | |
| b (list[float]): Second vector. | |
| Returns: | |
| float: Cosine similarity between a and b. | |
| """ | |
| dot = sum(x*y for x, y in zip(a, b)) | |
| norm_a = math.sqrt(sum(x*x for x in a)) | |
| norm_b = math.sqrt(sum(y*y for y in b)) | |
| return dot / (norm_a * norm_b) if norm_a and norm_b else 0.0 | |
| def _match_syndrome(user_query: str, | |
| model: str = "text-embedding-ada-002", | |
| threshold: float = 0.7 | |
| ) -> Optional[str]: | |
| # TODO: implement syndrome matching logic or remove if unused | |
| pass | |
| # This function is now implemented at the end of the file (see below) | |
| pass | |
| """ | |
| Semantically match the user_query to the best syndrome_key | |
| via cosine similarity against precomputed embeddings. | |
| Args: | |
| user_query (str): The user's query string. | |
| model (str, optional): Embedding model to use. Defaults to "text-embedding-ada-002". | |
| threshold (float, optional): Minimum similarity threshold. Defaults to 0.7. | |
| Returns: | |
| Optional[str]: The best-matching syndrome key, or None if no match meets the threshold. | |
| """ | |
| # normalize | |
| q = user_query.lower() | |
| q = re.sub(r"[^a-z0-9\s]", " ", q) | |
| q = re.sub(r"\s+", " ", q).strip() | |
| # embed the query | |
| try: | |
| resp = _client.embeddings.create(model=model, input=[q]) | |
| q_emb = resp.data[0].embedding | |
| except OpenAIError as e: | |
| # if embedding fails, fall back to no match | |
| print(f"[Embedding error] {e}") | |
| return None | |
| # find best cosine similarity | |
| best_key, best_score = None, -1.0 | |
| for key, emb in SYNDROME_EMBS.items(): | |
| score = _cosine_sim(q_emb, emb) | |
| if score > best_score: | |
| best_key, best_score = key, score | |
| return best_key if best_score >= threshold else None | |
| # Unified tool registry: uses TOOL_REGISTRY for class references and STATIC_TOOL_INSTANCES for stateless tools | |
| from tools.history_taking import HistoryTakingTool | |
| tool_registry: Dict[str, Dict[str, Any]] = {} | |
| for name, cls in TOOL_REGISTRY.items(): | |
| if name == "history_taking": | |
| tool_registry[name] = {"fn": HistoryTakingTool} | |
| else: | |
| instance = STATIC_TOOL_INSTANCES[name] | |
| tool_registry[name] = {"fn": instance, "args_schema": instance.args_schema} | |
| def get_tool_by_name( | |
| name: str, | |
| context: Dict[str, Any] | |
| ) -> Optional[Tool]: | |
| """ | |
| Retrieve a tool instance by name, optionally using context for dynamic instantiation. | |
| Args: | |
| name (str): The tool key, e.g. "history_taking". | |
| context (Dict[str, Any]): Must include "user_query" for dynamic tools. | |
| Returns: | |
| Optional[Tool]: The tool instance, or None if not found or not instantiable. | |
| """ | |
| entry = tool_registry[name] | |
| fn = entry["fn"] | |
| if name == "history_taking": | |
| syndrome_key = _match_syndrome(context["user_query"]) | |
| if not syndrome_key: | |
| return None | |
| return fn(syndrome_key) # instantiate with dynamic key | |
| else: | |
| # static tools: fn is already an instance | |
| return fn |