Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| import multiprocessing | |
| import logging | |
| from PIL import Image | |
| import io | |
| import base64 | |
| import numpy as np | |
| import gymnasium as gym | |
| import os | |
| from agent.checklist import generate_checklist | |
| from agent.reward import get_ar_reward | |
| from browser_agent import BrowserAgent | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel('INFO') | |
| templates_dir = Path(__file__).parent / "templates" | |
| CSS_RM_CARDS: str = (templates_dir / "rm_cards.css").read_text() | |
| CSS_TRAJECTORY: str = (templates_dir / "trajectory.css").read_text() | |
| CARD_HTML_TEMPLATE: str = (templates_dir / "card.html").read_text() | |
| RM_BASE_URL = os.environ['RM_BASE_URL'] | |
| RM_MODEL_NAME = os.environ['RM_MODEL_NAME'] | |
| def return_state(state, screenshot=None): | |
| return state, None, None, screenshot, None | |
| def run_agent(instruction: str, model_name: str = "gpt-4o", start_url: str = "about:blank", | |
| use_html: bool = False, use_axtree: bool = True, use_screenshot: bool = False, max_steps: int = 20): | |
| logger.info(f"Starting agent with instruction: {instruction}") | |
| logger.info(f"Configuration: model={model_name}, start_url={start_url}") | |
| trajectory = [] | |
| trajectory_str = '' | |
| agent = BrowserAgent( | |
| model_name=model_name, | |
| use_html=use_html, | |
| use_axtree=use_axtree, | |
| use_screenshot=use_screenshot | |
| ) | |
| # Initialize BrowserGym environment | |
| logger.info("Initializing BrowserGym environment") | |
| yield return_state("## Initializing BrowserGym environment...", None) | |
| env = gym.make( | |
| "browsergym/openended", | |
| task_kwargs={ | |
| "start_url": start_url, | |
| "goal": instruction, | |
| }, | |
| wait_for_user_message=True | |
| ) | |
| obs, info = env.reset() | |
| logger.info("Environment initialized") | |
| # Send user instruction to the environment | |
| logger.info("Sending user instruction to environment") | |
| obs, reward, terminated, truncated, info = env.step({ | |
| "type": "send_msg_to_user", | |
| "message": instruction | |
| }) | |
| processed_obs = agent.obs_preprocessor(obs) | |
| logger.info(f"Obs: {processed_obs.keys()}") | |
| logger.info(f"axtree_txt: {processed_obs['axtree_txt']}") | |
| yield return_state("## Generating checklist...", obs['som_screenshot']) | |
| checklist = generate_checklist(intent=instruction, start_url=start_url, text_observation=processed_obs['axtree_txt']) | |
| # yield initial state | |
| current_screenshot = obs['som_screenshot'].copy() | |
| yield "## Rollout actions from policy...", checklist, [], current_screenshot, trajectory.copy() | |
| try: | |
| step_count = 0 | |
| while step_count < max_steps: | |
| logger.info(f"Step {step_count}: Getting next action") | |
| # Get next action from agent | |
| candidates, _ = agent.get_action(processed_obs) | |
| yield return_state(f"## Rewarding actions...", current_screenshot) | |
| total_rewards, total_thoughts = get_ar_reward( | |
| dataset=[ | |
| { | |
| 'text_observation': processed_obs['axtree_txt'], | |
| 'intent': instruction, | |
| 'trajectory': trajectory_str, | |
| 'current_url': processed_obs['open_pages_urls'][processed_obs['active_page_index'][0]], | |
| 'checklist': checklist, | |
| 'thought': cand['thought'], | |
| 'action': cand['action'], | |
| } for cand in candidates | |
| ], | |
| base_url=RM_BASE_URL, | |
| model_name=RM_MODEL_NAME, | |
| ) | |
| # process rewards | |
| diff_reward = abs(max(total_rewards) - total_rewards[0]) # reward difference between actions with the highest reward and the most frequent. | |
| if diff_reward <= 0.01: | |
| logger.info(f"diff_reward: {diff_reward} -> most frequent action") | |
| max_index = 0 # most frequent action | |
| else: | |
| logger.info(f"diff_reward: {diff_reward} -> highest reward") | |
| max_index = total_rewards.index(max(total_rewards)) # highest reward | |
| # sort by reward | |
| sorted_indices = sorted(list(enumerate(total_rewards)), key=lambda x: (-1 if x[0] == max_index else 0, -x[1])) | |
| new_order = [idx for idx, _ in sorted_indices] | |
| candidates = [candidates[idx] for idx in new_order] | |
| total_rewards = [total_rewards[idx] for idx in new_order] | |
| total_thoughts = [total_thoughts[idx] for idx in new_order] | |
| best_cand = candidates[0] | |
| agent.action_history.append(best_cand['response']) | |
| action = best_cand['action'] | |
| # processing action | |
| step_info = { | |
| 'thought': best_cand['thought'], | |
| 'action': action | |
| } | |
| current_cards = [{'thought': cand['thought'], 'action': cand['action'], 'feedback': feedback, 'reward': round(reward, 2)} for idx, (cand, reward, feedback) in enumerate(zip(candidates, total_rewards, total_thoughts))] | |
| trajectory_str += f'THOUGHT {step_count+1}: {step_info["thought"]}\nACTION {step_count+1}: {step_info["action"]}\n\n' | |
| # Execute action | |
| logger.info(f"Step {step_count}: Executing action: {action}") | |
| yield f"## Executing action: {action}", checklist, current_cards, current_screenshot, trajectory.copy() | |
| if action.startswith('send_msg_to_user'): | |
| terminated = True | |
| truncated = False | |
| else: | |
| obs, reward, terminated, truncated, info = env.step(action) | |
| trajectory.append((processed_obs['som_screenshot'], [{'action': cand['action'], 'reward': round(reward, 2)} for cand, reward in zip(candidates, total_rewards)])) | |
| processed_obs = agent.obs_preprocessor(obs) | |
| current_screenshot = processed_obs['som_screenshot'].copy() | |
| while '\n\n' in step_info['thought']: | |
| step_info['thought'] = step_info['thought'].replace('\n\n', '\n') | |
| # trajectory에 numpy array 직접 저장 | |
| logger.info(f"Step {step_count}: Saved screenshot and updated trajectory") | |
| step_count += 1 | |
| # yield by each step | |
| yield "## Rollout actions from policy...", checklist, current_cards, current_screenshot, trajectory.copy() | |
| if terminated or truncated: | |
| logger.info(f"Episode ended: terminated={terminated}, truncated={truncated}") | |
| yield return_state("## Episode ended", current_screenshot) | |
| break | |
| finally: | |
| logger.info("Finished") | |
| def run_agent_worker(instruction, model_name, start_url, use_html, use_axtree, use_screenshot, max_steps, return_queue): | |
| """Worker function that runs the agent in a separate process and puts results in a queue.""" | |
| try: | |
| for result in run_agent(instruction, model_name, start_url, use_html, use_axtree, use_screenshot, max_steps): | |
| return_queue.put(result) | |
| except Exception as e: | |
| logger.error(f"Error in agent worker process: {e}") | |
| return_queue.put(("Error occurred in agent process", [], None, [])) | |
| import traceback | |
| traceback.print_exc() | |
| finally: | |
| # Signal that the process is done | |
| return_queue.put(None) | |
| def run_agent_wrapper(instruction, model_name="gpt-4o", start_url="about:blank", | |
| use_html=False, use_axtree=True, use_screenshot=False, max_steps=20): | |
| """Wrapper function that runs the agent in a separate process and yields results.""" | |
| return_queue = multiprocessing.Queue() | |
| # Start the agent in a separate process | |
| p = multiprocessing.Process( | |
| target=run_agent_worker, | |
| args=(instruction, model_name, start_url, use_html, use_axtree, use_screenshot, max_steps, return_queue) | |
| ) | |
| p.daemon = True # Ensure process terminates when parent terminates | |
| p.start() | |
| # Get results from the queue and yield them | |
| while True: | |
| result = return_queue.get() | |
| if result is None: # End signal | |
| break | |
| yield result | |
| # Clean up | |
| if p.is_alive(): | |
| p.terminate() | |
| p.join() | |
| def process_run(instruction, model_name, start_url): | |
| # Use the wrapper function instead of directly calling run_agent | |
| trajectory_generator = run_agent_wrapper( | |
| instruction, | |
| model_name, | |
| start_url, | |
| use_html=False, | |
| use_axtree=True, | |
| use_screenshot=False | |
| ) | |
| all_trajectory = [] | |
| last_checklist_view, last_trajectory_html = None, None | |
| for state, checklist_view, rm_cards, screenshot, trajectory in trajectory_generator: | |
| if checklist_view is None: | |
| yield state, screenshot, last_checklist_view, None, last_trajectory_html | |
| continue | |
| # Create HTML for reward model cards | |
| rm_cards_html = f""" | |
| <style> | |
| {CSS_RM_CARDS} | |
| </style> | |
| <div class="rm-cards-container"> | |
| """ | |
| for idx, card in enumerate(rm_cards): | |
| rm_cards_html += CARD_HTML_TEMPLATE.format( | |
| additional_class='top-candidate' if idx == 0 else '', | |
| k=idx+1, | |
| suffix='(best)' if idx == 0 else '', | |
| thought=card['thought'], | |
| action=card['action'], | |
| reward=card['reward'], | |
| feedback=card['feedback'] | |
| ) | |
| rm_cards_html += "</div>" | |
| all_trajectory = trajectory | |
| # Create HTML for trajectory display | |
| trajectory_html = f""" | |
| <style> | |
| {CSS_TRAJECTORY} | |
| </style> | |
| <div class="trajectory-container"> | |
| """ | |
| for idx, (after_img, cands) in enumerate(all_trajectory): | |
| # Convert image to base64 if needed | |
| img = all_trajectory[idx][0] | |
| if isinstance(img, np.ndarray): | |
| img = Image.fromarray(img) | |
| if isinstance(img, Image.Image): | |
| buffer = io.BytesIO() | |
| img.save(buffer, format="JPEG") | |
| img_str = base64.b64encode(buffer.getvalue()).decode() | |
| img_src = f"data:image/jpeg;base64,{img_str}" | |
| else: | |
| img_src = img | |
| trajectory_html += f""" | |
| <div class="step-container"> | |
| <div class="step-header">Step {idx + 1}</div> | |
| <div class="step-content"> | |
| <div class="step-image"> | |
| <img src="{img_src}" alt="Browser state"> | |
| </div> | |
| <div class="step-info"> | |
| <div class="box-title">Action Candidates:</div> | |
| <div class="action-candidates"> | |
| """ | |
| # Display all candidates for this step | |
| for i, cand in enumerate(cands): | |
| action = cand['action'] | |
| reward = cand['reward'] | |
| trajectory_html += f""" | |
| <div class="candidate-box{' selected' if i == 0 else ''}"> | |
| <div class="box-title"> | |
| Action {i+1}{' (Selected)' if i == 0 else ''} | |
| <span class="reward-text">Reward: {reward}</span> | |
| </div> | |
| <pre>{action}</pre> | |
| </div> | |
| """ | |
| trajectory_html += """ | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| trajectory_html += "</div>" | |
| last_checklist_view, last_trajectory_html = checklist_view, trajectory_html | |
| yield state, screenshot, last_checklist_view, rm_cards_html, last_trajectory_html | |
| yield state, screenshot, last_checklist_view, rm_cards_html, last_trajectory_html | |