Skip to content

Commit

Permalink
demo agent fixes and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
gasse committed Mar 13, 2024
1 parent d3136b5 commit 032c387
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 81 deletions.
46 changes: 4 additions & 42 deletions demo_agent/agents/dynamic_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
from textwrap import dedent
from typing import Literal
from warnings import warn

from browsergym.core.action.base import AbstractActionSet
from browsergym.core.action.highlevel import HighLevelActionSet
from browsergym.core.action.python import PythonActionSet

from utils.llm_utils import ParseError
from utils.llm_utils import (
count_tokens,
image_to_jpg_base64_url,
parse_html_tags_raise,
)
from browsergym.core.action.python import PythonActionSet
from utils.llm_utils import ParseError


@dataclass
Expand All @@ -36,7 +37,7 @@ class Flags:
use_concrete_example: bool = True
use_abstract_example: bool = False
multi_actions: bool = False
action_space: Literal["python", "bid", "coord", "bid+coord"] = "bid"
action_space: Literal["python", "bid", "coord", "bid+coord", "bid+nav", "coord+nav", "bid+coord+nav"] = "bid"
is_strict: bool = False
# This flag will be automatically disabled `if not chat_model_args.has_vision()`
use_screenshot: bool = True
Expand Down Expand Up @@ -65,45 +66,6 @@ def from_dict(self, flags_dict):
return Flags(**flags_dict)


BASIC_FLAGS = Flags(
use_html=True,
use_ax_tree=False,
drop_ax_tree_first=True,
use_thinking=False,
use_error_logs=True,
use_past_error_logs=False,
use_history=False,
use_action_history=False,
use_memory=False,
use_diff=False,
html_type="pruned_html",
use_concrete_example=False,
use_abstract_example=True,
multi_actions=False,
action_space="bid",
use_screenshot=True,
)

ALL_TRUE_FLAGS = Flags(
use_html=True,
use_ax_tree=True,
drop_ax_tree_first=True,
use_thinking=True,
use_error_logs=True,
use_past_error_logs=True,
use_history=True,
use_action_history=True,
use_memory=True,
use_diff=True,
html_type="pruned_html",
use_concrete_example=True,
use_abstract_example=True,
multi_actions=True,
action_space="bid+coord",
use_screenshot=True,
)


class PromptElement:
"""Base class for all prompt elements. Prompt elements can be hidden.
Expand Down
55 changes: 23 additions & 32 deletions demo_agent/run_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@

def parse_args():
parser = argparse.ArgumentParser(description="Run experiment with hyperparameters.")
parser.add_argument(
"--start_url",
type=str,
default="https://www.google.com",
help="Starting URL for the task.",
)
parser.add_argument(
"--model_name",
type=str,
Expand All @@ -24,29 +18,25 @@ def parse_args():
"--task_name",
type=str,
default="openended",
help="Task name for the experiment. If 'openended', you need to specify a 'start_url'",
help="Name of the Browsergym task to run. If 'openended', you need to specify a 'start_url'",
)
parser.add_argument(
"--slow_mo", type=int, default=500, help="Slow motion delay for the experiment."
"--start_url",
type=str,
default="https://www.google.com",
help="Starting URL (only for the openended task).",
)
parser.add_argument(
"--enable_debug",
default="True",
help="Enable debug mode for the experiment. If False, it will be difficult to debug because of the nested try statements.",
"--slow_mo", type=int, default=500, help="Slow motion delay for the playwright actions."
)
parser.add_argument(
"--headless", type=str2bool, default=False, help="Run in headless mode for the experiment."
"--headless", type=str2bool, default=False, help="Run the experiment in headless mode (hides the browser windows)."
)

# BrowserGym Flags
parser.add_argument(
"--demo_mode",
type=str2bool,
default=True,
help="add visual effects when the agents performs actions.",
)
parser.add_argument(
"--enable_chat", type=str2bool, default=True, help="Enable chat in the agent."
help="Add visual effects when the agents performs actions.",
)
parser.add_argument(
"--use_html", type=str2bool, default=True, help="Use HTML in the agent's observation space."
Expand All @@ -70,7 +60,7 @@ def parse_args():
"--action_space",
type=str,
default="bid",
choices=["python", "bid", "coord", "bid+coord"],
choices=["python", "bid", "coord", "bid+coord", "bid+nav", "coord+nav", "bid+coord+nav"],
help="",
)
parser.add_argument(
Expand All @@ -79,8 +69,6 @@ def parse_args():
default=True,
help="Use history in the agent's observation space.",
)

# Agent Flags
parser.add_argument(
"--use_thinking",
type=str2bool,
Expand All @@ -94,7 +82,16 @@ def parse_args():
def main():
args = parse_args()

RESULTS_DIR = Path("./results")
task_kwargs={
"viewport": {"width": 1500, "height": 1280},
"slow_mo": args.slow_mo,
}

if args.task_name == "openended":
task_kwargs.update({
"start_url": args.start_url,
"wait_for_user_message": True,
})

exp_args = ExpArgs(
agent_args=GenericAgentArgs(
Expand All @@ -107,7 +104,7 @@ def main():
flags=Flags(
use_html=args.use_html,
use_ax_tree=args.use_ax_tree,
use_thinking=False, # "Enable the agent with a memory (scratchpad)."
use_thinking=args.use_thinking, # "Enable the agent with a memory (scratchpad)."
use_error_logs=True, # "Prompt the agent with the error logs."
use_memory=False, # "Enables the agent with a memory (scratchpad)."
use_history=args.use_history,
Expand All @@ -118,24 +115,18 @@ def main():
use_abstract_example=True, # "Prompt the agent with an abstract example."
use_concrete_example=True, # "Prompt the agent with a concrete example."
use_screenshot=args.use_screenshot,
enable_chat=args.enable_chat,
enable_chat=True,
demo_mode=args.demo_mode,
),
),
max_steps=100, # "Maximum steps for the experiment."
task_seed=None,
task_name=args.task_name,
task_kwargs={
"start_url": args.start_url,
"wait_for_user_message": True,
"viewport": {"width": 1500, "height": 1280},
"slow_mo": args.slow_mo,
},
enable_debug=args.enable_debug,
task_kwargs=task_kwargs,
headless=args.headless,
)

exp_args.prepare(RESULTS_DIR / "live_agent_tests")
exp_args.prepare(Path("./results"))
exp_args.run()


Expand Down
23 changes: 16 additions & 7 deletions demo_agent/utils/exp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,25 @@
import time
import traceback
import uuid
import pandas as pd

from datetime import datetime
from pathlib import Path
from warnings import warn
from langchain_community.callbacks import get_openai_callback
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
from contexttimer import Timer
from PIL import Image
from tqdm import tqdm

import gymnasium as gym
import browsergym.miniwob # important, registers "browsergym/miniwob.*" gym environment
import browsergym.workarena # important, registers "browsergym/workarena.*" gym environment
import browsergym.webarena # important, registers "browsergym/webarena.*" gym environment
import gymnasium as gym
from PIL import Image
import pandas as pd
from tqdm import tqdm
from browsergym.core.chat import Chat

from agents import AgentArgs
from agents.base import Agent
import os

from utils.llm_utils import count_messages_token, count_tokens

Expand Down Expand Up @@ -84,7 +85,6 @@ class ExpArgs:
max_steps: int = 10
headless: bool = True
sleep_at_each_step: float = None
enable_debug: bool = True
order: int = None # use to keep the original order the experiments were meant to be lancuhed.

def prepare(self, savedir_base):
Expand Down Expand Up @@ -141,6 +141,8 @@ def run(self):
if action is None:
break

send_chat_info(env.unwrapped.chat, action, step_info.agent_info)

step_info = StepInfo(step=step_info.step + 1)
episode_info.append(step_info)
step_info.perform_step(env, action, obs_processor=agent.preprocess_obs)
Expand All @@ -153,7 +155,7 @@ def run(self):
stack_trace = traceback.format_exc()

warn(err_msg)
if _is_debugging() and self.enable_debug:
if _is_debugging():
raise

finally:
Expand All @@ -164,6 +166,13 @@ def run(self):
logging.error(f"Error while closing the environment: {e}")


def send_chat_info(chat: Chat, action: str, agent_info: dict):
info = {"think": agent_info.get("think", None), "action": action}
msg = "\n\n".join([f"{key}:\n{val}" for key, val in info.items() if val is not None])
logging.info(msg)
chat.add_message(role="info", msg=msg)


@dataclasses.dataclass
class StepStats:
"""Collects statistics about a step."""
Expand Down

0 comments on commit 032c387

Please sign in to comment.