diff --git a/demo_agent/agents/dynamic_prompting.py b/demo_agent/agents/dynamic_prompting.py index 552aae3f..44d5f064 100644 --- a/demo_agent/agents/dynamic_prompting.py +++ b/demo_agent/agents/dynamic_prompting.py @@ -8,16 +8,17 @@ from textwrap import dedent from typing import Literal from warnings import warn + from browsergym.core.action.base import AbstractActionSet from browsergym.core.action.highlevel import HighLevelActionSet +from browsergym.core.action.python import PythonActionSet +from utils.llm_utils import ParseError from utils.llm_utils import ( count_tokens, image_to_jpg_base64_url, parse_html_tags_raise, ) -from browsergym.core.action.python import PythonActionSet -from utils.llm_utils import ParseError @dataclass @@ -36,7 +37,7 @@ class Flags: use_concrete_example: bool = True use_abstract_example: bool = False multi_actions: bool = False - action_space: Literal["python", "bid", "coord", "bid+coord"] = "bid" + action_space: Literal["python", "bid", "coord", "bid+coord", "bid+nav", "coord+nav", "bid+coord+nav"] = "bid" is_strict: bool = False # This flag will be automatically disabled `if not chat_model_args.has_vision()` use_screenshot: bool = True @@ -65,45 +66,6 @@ def from_dict(self, flags_dict): return Flags(**flags_dict) -BASIC_FLAGS = Flags( - use_html=True, - use_ax_tree=False, - drop_ax_tree_first=True, - use_thinking=False, - use_error_logs=True, - use_past_error_logs=False, - use_history=False, - use_action_history=False, - use_memory=False, - use_diff=False, - html_type="pruned_html", - use_concrete_example=False, - use_abstract_example=True, - multi_actions=False, - action_space="bid", - use_screenshot=True, -) - -ALL_TRUE_FLAGS = Flags( - use_html=True, - use_ax_tree=True, - drop_ax_tree_first=True, - use_thinking=True, - use_error_logs=True, - use_past_error_logs=True, - use_history=True, - use_action_history=True, - use_memory=True, - use_diff=True, - html_type="pruned_html", - use_concrete_example=True, - use_abstract_example=True, - multi_actions=True, - action_space="bid+coord", - use_screenshot=True, -) - - class PromptElement: """Base class for all prompt elements. Prompt elements can be hidden. diff --git a/demo_agent/run_demo.py b/demo_agent/run_demo.py index 2369faba..dadc647a 100644 --- a/demo_agent/run_demo.py +++ b/demo_agent/run_demo.py @@ -8,12 +8,6 @@ def parse_args(): parser = argparse.ArgumentParser(description="Run experiment with hyperparameters.") - parser.add_argument( - "--start_url", - type=str, - default="https://www.google.com", - help="Starting URL for the task.", - ) parser.add_argument( "--model_name", type=str, @@ -24,29 +18,25 @@ def parse_args(): "--task_name", type=str, default="openended", - help="Task name for the experiment. If 'openended', you need to specify a 'start_url'", + help="Name of the Browsergym task to run. If 'openended', you need to specify a 'start_url'", ) parser.add_argument( - "--slow_mo", type=int, default=500, help="Slow motion delay for the experiment." + "--start_url", + type=str, + default="https://www.google.com", + help="Starting URL (only for the openended task).", ) parser.add_argument( - "--enable_debug", - default="True", - help="Enable debug mode for the experiment. If False, it will be difficult to debug because of the nested try statements.", + "--slow_mo", type=int, default=500, help="Slow motion delay for the playwright actions." ) parser.add_argument( - "--headless", type=str2bool, default=False, help="Run in headless mode for the experiment." + "--headless", type=str2bool, default=False, help="Run the experiment in headless mode (hides the browser windows)." ) - - # BrowserGym Flags parser.add_argument( "--demo_mode", type=str2bool, default=True, - help="add visual effects when the agents performs actions.", - ) - parser.add_argument( - "--enable_chat", type=str2bool, default=True, help="Enable chat in the agent." + help="Add visual effects when the agents performs actions.", ) parser.add_argument( "--use_html", type=str2bool, default=True, help="Use HTML in the agent's observation space." @@ -70,7 +60,7 @@ def parse_args(): "--action_space", type=str, default="bid", - choices=["python", "bid", "coord", "bid+coord"], + choices=["python", "bid", "coord", "bid+coord", "bid+nav", "coord+nav", "bid+coord+nav"], help="", ) parser.add_argument( @@ -79,8 +69,6 @@ def parse_args(): default=True, help="Use history in the agent's observation space.", ) - - # Agent Flags parser.add_argument( "--use_thinking", type=str2bool, @@ -94,7 +82,16 @@ def parse_args(): def main(): args = parse_args() - RESULTS_DIR = Path("./results") + task_kwargs={ + "viewport": {"width": 1500, "height": 1280}, + "slow_mo": args.slow_mo, + } + + if args.task_name == "openended": + task_kwargs.update({ + "start_url": args.start_url, + "wait_for_user_message": True, + }) exp_args = ExpArgs( agent_args=GenericAgentArgs( @@ -107,7 +104,7 @@ def main(): flags=Flags( use_html=args.use_html, use_ax_tree=args.use_ax_tree, - use_thinking=False, # "Enable the agent with a memory (scratchpad)." + use_thinking=args.use_thinking, # "Enable the agent with a memory (scratchpad)." use_error_logs=True, # "Prompt the agent with the error logs." use_memory=False, # "Enables the agent with a memory (scratchpad)." use_history=args.use_history, @@ -118,24 +115,18 @@ def main(): use_abstract_example=True, # "Prompt the agent with an abstract example." use_concrete_example=True, # "Prompt the agent with a concrete example." use_screenshot=args.use_screenshot, - enable_chat=args.enable_chat, + enable_chat=True, demo_mode=args.demo_mode, ), ), max_steps=100, # "Maximum steps for the experiment." task_seed=None, task_name=args.task_name, - task_kwargs={ - "start_url": args.start_url, - "wait_for_user_message": True, - "viewport": {"width": 1500, "height": 1280}, - "slow_mo": args.slow_mo, - }, - enable_debug=args.enable_debug, + task_kwargs=task_kwargs, headless=args.headless, ) - exp_args.prepare(RESULTS_DIR / "live_agent_tests") + exp_args.prepare(Path("./results")) exp_args.run() diff --git a/demo_agent/utils/exp_utils.py b/demo_agent/utils/exp_utils.py index 24e7f785..ac1b6520 100644 --- a/demo_agent/utils/exp_utils.py +++ b/demo_agent/utils/exp_utils.py @@ -9,24 +9,25 @@ import time import traceback import uuid +import pandas as pd + from datetime import datetime from pathlib import Path from warnings import warn from langchain_community.callbacks import get_openai_callback from langchain_community.callbacks.openai_info import OpenAICallbackHandler from contexttimer import Timer +from PIL import Image +from tqdm import tqdm +import gymnasium as gym import browsergym.miniwob # important, registers "browsergym/miniwob.*" gym environment import browsergym.workarena # important, registers "browsergym/workarena.*" gym environment import browsergym.webarena # important, registers "browsergym/webarena.*" gym environment -import gymnasium as gym -from PIL import Image -import pandas as pd -from tqdm import tqdm +from browsergym.core.chat import Chat from agents import AgentArgs from agents.base import Agent -import os from utils.llm_utils import count_messages_token, count_tokens @@ -84,7 +85,6 @@ class ExpArgs: max_steps: int = 10 headless: bool = True sleep_at_each_step: float = None - enable_debug: bool = True order: int = None # use to keep the original order the experiments were meant to be lancuhed. def prepare(self, savedir_base): @@ -141,6 +141,8 @@ def run(self): if action is None: break + send_chat_info(env.unwrapped.chat, action, step_info.agent_info) + step_info = StepInfo(step=step_info.step + 1) episode_info.append(step_info) step_info.perform_step(env, action, obs_processor=agent.preprocess_obs) @@ -153,7 +155,7 @@ def run(self): stack_trace = traceback.format_exc() warn(err_msg) - if _is_debugging() and self.enable_debug: + if _is_debugging(): raise finally: @@ -164,6 +166,13 @@ def run(self): logging.error(f"Error while closing the environment: {e}") +def send_chat_info(chat: Chat, action: str, agent_info: dict): + info = {"think": agent_info.get("think", None), "action": action} + msg = "\n\n".join([f"{key}:\n{val}" for key, val in info.items() if val is not None]) + logging.info(msg) + chat.add_message(role="info", msg=msg) + + @dataclasses.dataclass class StepStats: """Collects statistics about a step."""