from fastapi import FastAPI, Header, HTTPException from neo4j import GraphDatabase import os import json from datetime import datetime import psycopg2 from psycopg2.extras import RealDictCursor app = FastAPI() driver = GraphDatabase.driver( os.getenv("NEO4J_BOLT_URL"), auth=("neo4j", os.getenv("NEO4J_AUTH").split("/")[1]) ) VALID_API_KEYS = os.getenv("MCP_API_KEYS", "").split(",") POSTGRES_CONN = os.getenv("POSTGRES_CONNECTION") @app.get("/health") def health(): return {"ok": True, "timestamp": datetime.now().isoformat()} @app.post("/mcp") async def execute_tool(request: dict, x_api_key: str = Header(None)): # Verify API key if x_api_key not in VALID_API_KEYS: raise HTTPException(status_code=401, detail="Invalid API key") tool = request.get("tool") params = request.get("params", {}) if tool == "get_schema": # Return node labels and relationships with driver.session() as session: result = session.run("CALL db.labels() YIELD label RETURN collect(label) as labels") return {"labels": result.single()["labels"]} elif tool == "query_graph": # Execute parameterized query query = params.get("query") query_params = params.get("parameters", {}) with driver.session() as session: result = session.run(query, query_params) return {"data": [dict(record) for record in result]} elif tool == "write_graph": # Structured write operation action = params.get("action") if action == "create_node": label = params.get("label") properties = params.get("properties", {}) with driver.session() as session: result = session.run(f"CREATE (n:{label} $props) RETURN n", {"props": properties}) record = result.single() if record: node = record["n"] return {"created": dict(node) if hasattr(node, 'items') else {"id": str(node.id), "labels": list(node.labels), "properties": dict(node)}} return {"created": {}} elif tool == "get_next_instruction": # Get next pending instruction with driver.session() as session: result = session.run(""" MATCH (i:Instruction {status: 'pending'}) RETURN i ORDER BY i.sequence LIMIT 1 """) record = result.single() return {"instruction": dict(record["i"]) if record else None} elif tool == "query_postgres": query = params.get("query") try: conn = psycopg2.connect(POSTGRES_CONN) with conn.cursor(cursor_factory=RealDictCursor) as cur: cur.execute(query) if cur.description: # SELECT query results = cur.fetchall() return {"data": results, "row_count": len(results)} else: # INSERT/UPDATE/DELETE conn.commit() return {"affected_rows": cur.rowcount} except Exception as e: return {"error": str(e)} finally: if 'conn' in locals(): conn.close() elif tool == "discover_postgres_schema": try: conn = psycopg2.connect(POSTGRES_CONN) with conn.cursor(cursor_factory=RealDictCursor) as cur: # Get all tables cur.execute(""" SELECT table_name, table_schema FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE' """) tables = cur.fetchall() schema_info = {} for table in tables: table_name = table['table_name'] # Get columns for each table cur.execute(""" SELECT column_name, data_type, is_nullable, column_default, character_maximum_length FROM information_schema.columns WHERE table_schema = 'public' AND table_name = %s ORDER BY ordinal_position """, (table_name,)) schema_info[table_name] = cur.fetchall() return {"schema": schema_info} except Exception as e: return {"error": str(e)} finally: if 'conn' in locals(): conn.close() return {"error": "Unknown tool"}