Sofia Santos commited on
Commit
87d0429
Β·
1 Parent(s): f20f615

feat: improves ui

Browse files
Files changed (1) hide show
  1. tdagent/grchat.py +283 -50
tdagent/grchat.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
2
 
3
  from collections.abc import Mapping, Sequence
4
  from types import MappingProxyType
5
- from typing import TYPE_CHECKING
6
 
7
  import boto3
8
  import botocore
@@ -10,8 +10,11 @@ import botocore.exceptions
10
  import gradio as gr
11
  from langchain_aws import ChatBedrock
12
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
 
13
  from langchain_mcp_adapters.client import MultiServerMCPClient
14
  from langgraph.prebuilt import create_react_agent
 
 
15
 
16
  from tdagent.grcomponents import MutableCheckBoxGroup, MutableCheckBoxGroupEntry
17
 
@@ -48,6 +51,15 @@ GRADIO_ROLE_TO_LG_MESSAGE_TYPE = MappingProxyType(
48
  },
49
  )
50
 
 
 
 
 
 
 
 
 
 
51
 
52
  #### Shared variables ####
53
 
@@ -56,12 +68,15 @@ llm_agent: CompiledGraph | None = None
56
  #### Utility functions ####
57
 
58
 
 
59
  def create_bedrock_llm(
60
  bedrock_model_id: str,
61
  aws_access_key: str,
62
  aws_secret_key: str,
63
  aws_session_token: str,
64
  aws_region: str,
 
 
65
  ) -> tuple[ChatBedrock | None, str]:
66
  """Create a LangGraph Bedrock agent."""
67
  boto3_config = {
@@ -70,7 +85,6 @@ def create_bedrock_llm(
70
  "aws_session_token": aws_session_token if aws_session_token else None,
71
  "region_name": aws_region,
72
  }
73
-
74
  # Verify credentials
75
  try:
76
  sts = boto3.client("sts", **boto3_config)
@@ -83,7 +97,7 @@ def create_bedrock_llm(
83
  llm = ChatBedrock(
84
  model_id=bedrock_model_id,
85
  client=bedrock_client,
86
- model_kwargs={"temperature": 0.8},
87
  )
88
  except Exception as e: # noqa: BLE001
89
  return None, str(e)
@@ -91,20 +105,59 @@ def create_bedrock_llm(
91
  return llm, ""
92
 
93
 
94
- #### UI functionality ####
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
97
- async def gr_connect_to_bedrock(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  model_id: str,
99
  access_key: str,
100
  secret_key: str,
101
  session_token: str,
102
  region: str,
103
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
 
 
104
  ) -> str:
105
  """Initialize Bedrock agent."""
106
  global llm_agent # noqa: PLW0603
107
-
108
  if not access_key or not secret_key:
109
  return "❌ Please provide both Access Key ID and Secret Access Key"
110
 
@@ -114,6 +167,8 @@ async def gr_connect_to_bedrock(
114
  secret_key.strip(),
115
  session_token.strip(),
116
  region,
 
 
117
  )
118
 
119
  if llm is None:
@@ -128,7 +183,6 @@ async def gr_connect_to_bedrock(
128
  # }
129
  # )
130
  # tools = await client.get_tools()
131
-
132
  if mcp_servers:
133
  client = MultiServerMCPClient(
134
  {
@@ -142,7 +196,6 @@ async def gr_connect_to_bedrock(
142
  tools = await client.get_tools()
143
  else:
144
  tools = []
145
-
146
  llm_agent = create_react_agent(
147
  model=llm,
148
  tools=tools,
@@ -152,6 +205,73 @@ async def gr_connect_to_bedrock(
152
  return "βœ… Successfully connected to AWS Bedrock!"
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  async def gr_chat_function( # noqa: D103
156
  message: str,
157
  history: list[Mapping[str, str]],
@@ -178,49 +298,110 @@ async def gr_chat_function( # noqa: D103
178
 
179
  ## UI components ##
180
 
181
- with gr.Blocks() as gr_app:
182
- gr.Markdown("# πŸ” Secure Bedrock Chatbot")
183
-
184
- ### MCP Servers ###
185
- with gr.Accordion():
186
- mcp_list = MutableCheckBoxGroup(
187
- values=[
188
- MutableCheckBoxGroupEntry(
189
- name="TDAgent tools",
190
- value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse",
191
- ),
192
- ],
193
- label="MCP Servers",
194
- )
195
 
196
- # Credentials section (collapsible)
197
- with gr.Accordion("πŸ”‘ Bedrock Configuration", open=True):
198
- gr.Markdown(
199
- "**Note**: Credentials are only stored in memory during your session.",
200
- )
201
- with gr.Row():
202
- bedrock_model_id_textbox = gr.Textbox(
203
- label="Bedrock Model Id",
204
- value="eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  )
206
- with gr.Row():
207
  aws_access_key_textbox = gr.Textbox(
208
  label="AWS Access Key ID",
209
  type="password",
210
  placeholder="Enter your AWS Access Key ID",
 
211
  )
212
  aws_secret_key_textbox = gr.Textbox(
213
  label="AWS Secret Access Key",
214
  type="password",
215
  placeholder="Enter your AWS Secret Access Key",
 
216
  )
217
- with gr.Row():
218
- aws_session_token_textbox = gr.Textbox(
219
- label="AWS Session Token",
220
- type="password",
221
- placeholder="Enter your AWS session token",
222
- )
223
- with gr.Row():
224
  aws_region_dropdown = gr.Dropdown(
225
  label="AWS Region",
226
  choices=[
@@ -231,31 +412,83 @@ with gr.Blocks() as gr_app:
231
  "ap-southeast-1",
232
  ],
233
  value="eu-west-1",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  )
235
- connect_btn = gr.Button("πŸ”Œ Connect to Bedrock", variant="primary")
236
 
 
237
  status_textbox = gr.Textbox(label="Connection Status", interactive=False)
238
 
239
  connect_btn.click(
240
- gr_connect_to_bedrock,
241
  inputs=[
242
- bedrock_model_id_textbox,
 
 
243
  aws_access_key_textbox,
244
  aws_secret_key_textbox,
245
  aws_session_token_textbox,
246
  aws_region_dropdown,
247
- mcp_list.state,
 
 
248
  ],
249
  outputs=[status_textbox],
250
  )
251
 
252
- chat_interface = gr.ChatInterface(
253
- fn=gr_chat_function,
254
- type="messages",
255
- examples=[],
256
- title="Agent with MCP Tools",
257
- description="This is a simple agent that uses MCP tools.",
258
- )
 
259
 
260
 
261
  if __name__ == "__main__":
 
2
 
3
  from collections.abc import Mapping, Sequence
4
  from types import MappingProxyType
5
+ from typing import TYPE_CHECKING, Any
6
 
7
  import boto3
8
  import botocore
 
10
  import gradio as gr
11
  from langchain_aws import ChatBedrock
12
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
13
+ from langchain_huggingface import HuggingFaceEndpoint
14
  from langchain_mcp_adapters.client import MultiServerMCPClient
15
  from langgraph.prebuilt import create_react_agent
16
+ from openai import OpenAI
17
+ from openai.types.chat import ChatCompletion
18
 
19
  from tdagent.grcomponents import MutableCheckBoxGroup, MutableCheckBoxGroupEntry
20
 
 
51
  },
52
  )
53
 
54
+ MODEL_OPTIONS = {
55
+ "AWS Bedrock": {
56
+ "Anthropic Claude 3.5 Sonnet": "eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
57
+ # "Anthropic Claude 3.7 Sonnet": "anthropic.claude-3-7-sonnet-20250219-v1:0",
58
+ },
59
+ "HuggingFace": {
60
+ "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct",
61
+ },
62
+ }
63
 
64
  #### Shared variables ####
65
 
 
68
  #### Utility functions ####
69
 
70
 
71
+ ## Bedrock LLM creation ##
72
  def create_bedrock_llm(
73
  bedrock_model_id: str,
74
  aws_access_key: str,
75
  aws_secret_key: str,
76
  aws_session_token: str,
77
  aws_region: str,
78
+ temperature: float = 0.8,
79
+ max_tokens: int = 512,
80
  ) -> tuple[ChatBedrock | None, str]:
81
  """Create a LangGraph Bedrock agent."""
82
  boto3_config = {
 
85
  "aws_session_token": aws_session_token if aws_session_token else None,
86
  "region_name": aws_region,
87
  }
 
88
  # Verify credentials
89
  try:
90
  sts = boto3.client("sts", **boto3_config)
 
97
  llm = ChatBedrock(
98
  model_id=bedrock_model_id,
99
  client=bedrock_client,
100
+ model_kwargs={"temperature": temperature, "max_tokens": max_tokens},
101
  )
102
  except Exception as e: # noqa: BLE001
103
  return None, str(e)
 
105
  return llm, ""
106
 
107
 
108
+ ## Hugging Face LLM creation ##
109
+ def create_hf_llm(
110
+ hf_model_id: str,
111
+ huggingfacehub_api_token: str | None = None,
112
+ ) -> tuple[HuggingFaceEndpoint | None, str]:
113
+ """Create a LangGraph Hugging Face agent."""
114
+ try:
115
+ llm = HuggingFaceEndpoint(
116
+ model=hf_model_id,
117
+ huggingfacehub_api_token=huggingfacehub_api_token,
118
+ temperature=0.8,
119
+ )
120
+ except Exception as e: # noqa: BLE001
121
+ return None, str(e)
122
+
123
+ return llm, ""
124
 
125
 
126
+ ## OpenAI LLM creation ##
127
+ def create_openai_llm(
128
+ model_id: str,
129
+ token_id: str,
130
+ ) -> tuple[ChatCompletion | None, str]:
131
+ """Create a LangGraph OpenAI agent."""
132
+ try:
133
+ client = OpenAI(
134
+ base_url="https://api.studio.nebius.com/v1/",
135
+ api_key=token_id,
136
+ )
137
+ llm = client.chat.completions.create(
138
+ messages=[], # needs to be fixed
139
+ model=model_id,
140
+ max_tokens=512,
141
+ temperature=0.8,
142
+ )
143
+ except Exception as e: # noqa: BLE001
144
+ return None, str(e)
145
+ return llm, ""
146
+
147
+
148
+ #### UI functionality ####
149
+ async def gr_connect_to_bedrock( # noqa: PLR0913
150
  model_id: str,
151
  access_key: str,
152
  secret_key: str,
153
  session_token: str,
154
  region: str,
155
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
156
+ temperature: float = 0.8,
157
+ max_tokens: int = 512,
158
  ) -> str:
159
  """Initialize Bedrock agent."""
160
  global llm_agent # noqa: PLW0603
 
161
  if not access_key or not secret_key:
162
  return "❌ Please provide both Access Key ID and Secret Access Key"
163
 
 
167
  secret_key.strip(),
168
  session_token.strip(),
169
  region,
170
+ temperature=temperature,
171
+ max_tokens=max_tokens,
172
  )
173
 
174
  if llm is None:
 
183
  # }
184
  # )
185
  # tools = await client.get_tools()
 
186
  if mcp_servers:
187
  client = MultiServerMCPClient(
188
  {
 
196
  tools = await client.get_tools()
197
  else:
198
  tools = []
 
199
  llm_agent = create_react_agent(
200
  model=llm,
201
  tools=tools,
 
205
  return "βœ… Successfully connected to AWS Bedrock!"
206
 
207
 
208
+ async def gr_connect_to_hf(
209
+ model_id: str,
210
+ hf_access_token_textbox: str | None,
211
+ mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
212
+ ) -> str:
213
+ """Initialize Hugging Face agent."""
214
+ global llm_agent # noqa: PLW0603
215
+
216
+ llm, error = create_hf_llm(model_id, hf_access_token_textbox)
217
+
218
+ if llm is None:
219
+ return f"❌ Connection failed: {error}"
220
+ tools = []
221
+ if mcp_servers:
222
+ client = MultiServerMCPClient(
223
+ {
224
+ server.name.replace(" ", "-"): {
225
+ "url": server.value,
226
+ "transport": "sse",
227
+ }
228
+ for server in mcp_servers
229
+ },
230
+ )
231
+ tools = await client.get_tools()
232
+
233
+ llm_agent = create_react_agent(
234
+ model=llm,
235
+ tools=tools,
236
+ prompt=SYSTEM_MESSAGE,
237
+ )
238
+
239
+ return "βœ… Successfully connected to Hugging Face!"
240
+
241
+
242
+ async def gr_connect_to_nebius(
243
+ model_id: str,
244
+ nebius_access_token_textbox: str,
245
+ mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
246
+ ) -> str:
247
+ """Initialize Hugging Face agent."""
248
+ global llm_agent # noqa: PLW0603
249
+
250
+ llm, error = create_openai_llm(model_id, nebius_access_token_textbox)
251
+
252
+ if llm is None:
253
+ return f"❌ Connection failed: {error}"
254
+ tools = []
255
+ if mcp_servers:
256
+ client = MultiServerMCPClient(
257
+ {
258
+ server.name.replace(" ", "-"): {
259
+ "url": server.value,
260
+ "transport": "sse",
261
+ }
262
+ for server in mcp_servers
263
+ },
264
+ )
265
+ tools = await client.get_tools()
266
+
267
+ llm_agent = create_react_agent(
268
+ model=str(llm),
269
+ tools=tools,
270
+ prompt=SYSTEM_MESSAGE,
271
+ )
272
+ return "βœ… Successfully connected to nebius!"
273
+
274
+
275
  async def gr_chat_function( # noqa: D103
276
  message: str,
277
  history: list[Mapping[str, str]],
 
298
 
299
  ## UI components ##
300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
+ # Function to toggle visibility and set model IDs
303
+ def toggle_model_fields(
304
+ provider: str,
305
+ ) -> tuple[
306
+ dict[str, Any],
307
+ dict[str, Any],
308
+ dict[str, Any],
309
+ dict[str, Any],
310
+ dict[str, Any],
311
+ dict[str, Any],
312
+ ]: # ignore: F821
313
+ """Toggle visibility of model fields based on the selected provider."""
314
+ # Update model choices based on the selected provider
315
+ if provider in MODEL_OPTIONS:
316
+ model_choices = list(MODEL_OPTIONS[provider].keys())
317
+ model_pretty = gr.update(choices=model_choices, visible=True, interactive=True)
318
+ else:
319
+ model_pretty = gr.update(choices=[], visible=False)
320
+
321
+ # Visibility settings for fields specific to each provider
322
+ is_aws = provider == "AWS Bedrock"
323
+ is_hf = provider == "HuggingFace"
324
+ return (
325
+ model_pretty,
326
+ gr.update(visible=is_aws, interactive=is_aws),
327
+ gr.update(visible=is_aws, interactive=is_aws),
328
+ gr.update(visible=is_aws, interactive=is_aws),
329
+ gr.update(visible=is_aws, interactive=is_aws),
330
+ gr.update(visible=is_hf, interactive=is_hf),
331
+ )
332
+
333
+
334
+ async def update_connection_status( # noqa: PLR0913
335
+ provider: str,
336
+ pretty_model: str,
337
+ mcp_list_state: Sequence[MutableCheckBoxGroupEntry] | None,
338
+ aws_access_key_textbox: str,
339
+ aws_secret_key_textbox: str,
340
+ aws_session_token_textbox: str,
341
+ aws_region_dropdown: str,
342
+ hf_token: str,
343
+ temperature: float,
344
+ max_tokens: int,
345
+ ) -> str:
346
+ """Update the connection status based on the selected provider and model."""
347
+ if not provider or not pretty_model:
348
+ return "❌ Please select a provider and model."
349
+ model_id = MODEL_OPTIONS.get(provider, {}).get(pretty_model)
350
+ if model_id:
351
+ if provider == "AWS Bedrock":
352
+ connection = await gr_connect_to_bedrock(
353
+ model_id,
354
+ aws_access_key_textbox,
355
+ aws_secret_key_textbox,
356
+ aws_session_token_textbox,
357
+ aws_region_dropdown,
358
+ mcp_list_state,
359
+ temperature,
360
+ max_tokens,
361
+ )
362
+ elif provider == "HuggingFace":
363
+ connection = await gr_connect_to_hf(model_id, hf_token, mcp_list_state)
364
+ elif provider == "Nebius":
365
+ connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
366
+ else:
367
+ return "❌ Invalid provider"
368
+ return connection if connection else "❌ Invalid provider"
369
+
370
+
371
+ with gr.Blocks(
372
+ theme=gr.themes.Origin(primary_hue="teal", spacing_size="sm", font="sans-serif"),
373
+ title="TDAgent",
374
+ ) as gr_app, gr.Row():
375
+ with gr.Column(scale=1):
376
+ with gr.Accordion("πŸ”Œ MCP Servers", open=False):
377
+ mcp_list = MutableCheckBoxGroup(
378
+ values=[
379
+ MutableCheckBoxGroupEntry(
380
+ name="TDAgent tools",
381
+ value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse",
382
+ ),
383
+ ],
384
+ label="MCP Servers",
385
+ )
386
+
387
+ with gr.Accordion("βš™οΈ Provider Configuration", open=True):
388
+ model_provider = gr.Dropdown(
389
+ choices=list(MODEL_OPTIONS.keys()),
390
+ value=None,
391
+ label="Select Model Provider",
392
  )
 
393
  aws_access_key_textbox = gr.Textbox(
394
  label="AWS Access Key ID",
395
  type="password",
396
  placeholder="Enter your AWS Access Key ID",
397
+ visible=False,
398
  )
399
  aws_secret_key_textbox = gr.Textbox(
400
  label="AWS Secret Access Key",
401
  type="password",
402
  placeholder="Enter your AWS Secret Access Key",
403
+ visible=False,
404
  )
 
 
 
 
 
 
 
405
  aws_region_dropdown = gr.Dropdown(
406
  label="AWS Region",
407
  choices=[
 
412
  "ap-southeast-1",
413
  ],
414
  value="eu-west-1",
415
+ visible=False,
416
+ )
417
+ aws_session_token_textbox = gr.Textbox(
418
+ label="AWS Session Token",
419
+ type="password",
420
+ placeholder="Enter your AWS session token",
421
+ visible=False,
422
+ )
423
+ hf_token = gr.Textbox(
424
+ label="HuggingFace Token",
425
+ type="password",
426
+ placeholder="Enter your Hugging Face Access Token",
427
+ visible=False,
428
+ )
429
+
430
+ with gr.Accordion("🧠 Model Configuration", open=True):
431
+ model_display_id = gr.Dropdown(
432
+ label="Select Model ID",
433
+ choices=[],
434
+ visible=False,
435
+ )
436
+ model_provider.change(
437
+ toggle_model_fields,
438
+ inputs=[model_provider],
439
+ outputs=[
440
+ model_display_id,
441
+ aws_access_key_textbox,
442
+ aws_secret_key_textbox,
443
+ aws_session_token_textbox,
444
+ aws_region_dropdown,
445
+ hf_token,
446
+ ],
447
+ )
448
+ # Initialize the temperature and max tokens based on model specifications
449
+ temperature = gr.Slider(
450
+ label="Temperature",
451
+ minimum=0.0,
452
+ maximum=1.0,
453
+ value=0.8,
454
+ step=0.1,
455
+ )
456
+ max_tokens = gr.Slider(
457
+ label="Max Tokens",
458
+ minimum=64,
459
+ maximum=4096,
460
+ value=512,
461
+ step=64,
462
  )
 
463
 
464
+ connect_btn = gr.Button("πŸ”Œ Connect to Model", variant="primary")
465
  status_textbox = gr.Textbox(label="Connection Status", interactive=False)
466
 
467
  connect_btn.click(
468
+ update_connection_status,
469
  inputs=[
470
+ model_provider,
471
+ model_display_id,
472
+ mcp_list.state,
473
  aws_access_key_textbox,
474
  aws_secret_key_textbox,
475
  aws_session_token_textbox,
476
  aws_region_dropdown,
477
+ hf_token,
478
+ temperature,
479
+ max_tokens,
480
  ],
481
  outputs=[status_textbox],
482
  )
483
 
484
+ with gr.Column(scale=2):
485
+ chat_interface = gr.ChatInterface(
486
+ fn=gr_chat_function,
487
+ type="messages",
488
+ examples=[], # Add examples if needed
489
+ title="πŸ‘©β€πŸ’» TDAgent",
490
+ description="This is a simple agent that uses MCP tools.",
491
+ )
492
 
493
 
494
  if __name__ == "__main__":