|
|
from mcp import ClientSession |
|
|
from mcp.client.streamable_http import streamablehttp_client |
|
|
|
|
|
from langgraph.checkpoint.memory import InMemorySaver |
|
|
from langgraph.prebuilt import create_react_agent |
|
|
from langchain_mcp_adapters.tools import load_mcp_tools |
|
|
from langchain_openai import ChatOpenAI |
|
|
|
|
|
from environs import Env |
|
|
from loguru import logger |
|
|
|
|
|
env = Env() |
|
|
env.read_env() |
|
|
|
|
|
AGENT_DEBUG = env.bool("AGENT_DEBUG", True) |
|
|
LLM_PROVIDER = env.str("LLM_PROVIDER", "OPENAI") |
|
|
LLM_API_KEY = env.str("LLM_API_KEY") |
|
|
LLM_MODEL = env.str("LLM_MODEL") |
|
|
MCP_SERVER_URL = env.str("MCP_SERVER_URL") |
|
|
|
|
|
|
|
|
def get_llm(): |
|
|
if LLM_PROVIDER == "OPENAI": |
|
|
return ChatOpenAI(model=LLM_MODEL, api_key=LLM_API_KEY) |
|
|
elif LLM_PROVIDER == "GEMINI": |
|
|
raise NotImplementedError("Gemini is not supported yet") |
|
|
else: |
|
|
raise ValueError(f"Unsupported LLM provider: {LLM_PROVIDER}") |
|
|
|
|
|
|
|
|
class MCPClient: |
|
|
_instance = None |
|
|
|
|
|
def __new__(cls): |
|
|
logger.info("Creating MCP client instance...") |
|
|
if cls._instance is None: |
|
|
cls._instance = super(MCPClient, cls).__new__(cls) |
|
|
cls._instance._initialized = False |
|
|
return cls._instance |
|
|
|
|
|
def __init__(self): |
|
|
logger.info("Initializing MCP client instance...") |
|
|
if self._initialized: |
|
|
return |
|
|
|
|
|
try: |
|
|
self.llm = get_llm() |
|
|
self.tools = None |
|
|
self.agent = None |
|
|
self.instruction = "" |
|
|
self.api_description = "" |
|
|
self.system_message = "" |
|
|
self._initialized = True |
|
|
self.checkpointer = InMemorySaver() |
|
|
self.load_instruction() |
|
|
self.load_api_description() |
|
|
logger.info("MCP client instance initialized successfully") |
|
|
except Exception as e: |
|
|
logger.error(f"Error initializing MCP client: {str(e)}") |
|
|
raise |
|
|
|
|
|
def load_instruction(self) -> str: |
|
|
logger.info("Loading instruction...") |
|
|
try: |
|
|
instruction_path = "resources/instruction.txt" |
|
|
logger.info(f"Looking for instruction file at: {instruction_path}") |
|
|
with open(instruction_path, "r") as file: |
|
|
self.instruction = file.read() |
|
|
logger.info("Instruction loaded successfully") |
|
|
except FileNotFoundError: |
|
|
logger.error(f"Instruction file not found at: {instruction_path}") |
|
|
raise FileNotFoundError(f"Instruction file not found at: {instruction_path}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading instruction: {str(e)}") |
|
|
raise |
|
|
|
|
|
def load_api_description(self) -> str: |
|
|
logger.info("Loading api description...") |
|
|
try: |
|
|
api_path = "resources/dpc_restapi_summary.txt" |
|
|
logger.info(f"Looking for API description file at: {api_path}") |
|
|
with open(api_path, "r") as file: |
|
|
self.api_description = file.read() |
|
|
logger.info("API description loaded successfully") |
|
|
except FileNotFoundError: |
|
|
logger.error(f"API description file not found at: {api_path}") |
|
|
raise FileNotFoundError(f"API description file not found at: {api_path}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading API description: {str(e)}") |
|
|
raise |
|
|
|
|
|
async def initialize(self): |
|
|
logger.info("Initializing MCP client...") |
|
|
try: |
|
|
logger.info("Print Environment Variables: \n") |
|
|
logger.info(f"LLM_PROVIDER: {LLM_PROVIDER}") |
|
|
logger.info(f"LLM_API_KEY: {LLM_API_KEY[:12]}...") |
|
|
logger.info(f"LLM_MODEL: {LLM_MODEL}") |
|
|
logger.info(f"MCP_SERVER_URL: {MCP_SERVER_URL}") |
|
|
|
|
|
self.system_message = "".join([self.instruction, "\n\n", self.api_description]) |
|
|
logger.info("MCP client initialization completed successfully") |
|
|
except Exception as e: |
|
|
logger.error(f"Error during MCP client initialization: {str(e)}") |
|
|
raise |
|
|
|
|
|
async def clear_memory(self): |
|
|
"""Clear the checkpointer memory for the current conversation.""" |
|
|
logger.info("Clearing checkpointer memory...") |
|
|
await self.checkpointer.adelete_thread(thread_id="conversation_123") |
|
|
logger.info("Checkpointer memory cleared successfully") |
|
|
|
|
|
async def invoke(self, input_messages): |
|
|
logger.info(f"Invoking agent with input: {input_messages}") |
|
|
async with streamablehttp_client(MCP_SERVER_URL) as (read, write, _): |
|
|
async with ClientSession(read, write) as session: |
|
|
await session.initialize() |
|
|
logger.info("Loading tools...") |
|
|
self.tools = await load_mcp_tools(session) |
|
|
logger.info("Creating agent...") |
|
|
self.agent = create_react_agent( |
|
|
model=self.llm, |
|
|
tools=self.tools, |
|
|
prompt=self.system_message, |
|
|
checkpointer=self.checkpointer, |
|
|
debug=AGENT_DEBUG, |
|
|
) |
|
|
logger.info("Invoking agent...") |
|
|
config = {"configurable": {"thread_id": "conversation_123"}} |
|
|
result = await self.agent.ainvoke( |
|
|
input={"messages": input_messages}, config=config |
|
|
) |
|
|
logger.info(f"Agent result: {result}") |
|
|
logger.info("========================================================") |
|
|
last_message = result["messages"][-1] |
|
|
logger.info(f"Last message: {last_message.content}") |
|
|
return last_message.content |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|