-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add - task implementation - evaluation - logic to write predictions to file to support hidden test test - readme/.toml Tests - Test for evaluation on dev set - Toy implementation task
- Loading branch information
Showing
16 changed files
with
720 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# AssistantBench <> BrowserGym | ||
|
||
This package provides an implementation for using the [AssistantBench](https://assistantbench.github.io/) benchmark in BrowserGym. | ||
|
||
Because AssistantBench includes open-ended tasks, setup is extremely easy and simply requires installing the package. | ||
|
||
Please note that AssistantBench has a hidden test set, so test set predictions will need to be uploaded to the official [leaderboard](https://huggingface.co/spaces/AssistantBench/leaderboard). | ||
|
||
## Setting up | ||
|
||
- Install the package (this is still a wip) | ||
``` | ||
pip install browsergym-assistantbench | ||
``` | ||
|
||
- Run inference, e.g., run the following commands for demo on a simple toy task | ||
``` | ||
cd demo-agent | ||
python run_demo.py --task_name ab.imp.0 | ||
``` | ||
|
||
- Test set predictions will be saved to `browsergym/assistantbench/predictions/test.jsonl`. To evaluate on the official test set, update these predictions to the official [leaderboard](https://huggingface.co/spaces/AssistantBench/leaderboard). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from browsergym.core.registration import register_task | ||
from browsergym.assistantbench.src import task | ||
|
||
ALL_AB_TASK_IDS = [] | ||
|
||
# register a toy easy task for testing implemenation | ||
gym_id = f"ab.imp.0" | ||
register_task( | ||
gym_id, | ||
task.AssistantBenchTask, | ||
task_kwargs={"task_id": f"imp.0", | ||
"output_file_path": "browsergym/assistantbench/predictions/imp.jsonl"}, | ||
) | ||
ALL_AB_TASK_IDS.append(gym_id) | ||
|
||
# register the AssistantBench dev set | ||
for task_id in range(33): | ||
gym_id = f"ab.{task_id}" | ||
register_task( | ||
gym_id, | ||
task.AssistantBenchTask, | ||
task_kwargs={"task_id": task_id}, | ||
) | ||
ALL_AB_TASK_IDS.append(gym_id) | ||
|
||
# register the AssistantBench test set | ||
for task_id in range(181): | ||
gym_id = f"ab.test.{task_id}" | ||
register_task( | ||
gym_id, | ||
task.AssistantBenchTask, | ||
task_kwargs={"task_id": f'test.{task_id}', | ||
"output_file_path": "browsergym/assistantbench/predictions/test.jsonl"}, | ||
) | ||
ALL_AB_TASK_IDS.append(gym_id) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
[build-system] | ||
requires = ["hatchling", "hatch-requirements-txt"] | ||
build-backend = "hatchling.build" | ||
|
||
[project] | ||
name = "browsergym-miniwob" | ||
description = "AssistantBench - BrowserGym" | ||
authors = [ | ||
{name = "Ori Yoran"}, | ||
{name = "Maxime Gasse"}, | ||
|
||
] | ||
readme = "README.md" | ||
requires-python = ">3.7" | ||
license = {text = "Apache-2.0"} | ||
classifiers = [ | ||
"Development Status :: 2 - Pre-Alpha", | ||
"Programming Language :: Python :: 3", | ||
"Operating System :: OS Independent", | ||
"Intended Audience :: Science/Research", | ||
"Topic :: Scientific/Engineering :: Artificial Intelligence", | ||
"License :: OSI Approved :: Apache Software License", | ||
] | ||
dynamic = ["dependencies", "version"] | ||
|
||
[project.urls] | ||
homepage = "https://github.com/ServiceNow/BrowserGym" | ||
|
||
[tool.hatch.version] | ||
path = "../core/src/browsergym/core/__init__.py" | ||
|
||
[tool.hatch.metadata.hooks.requirements_txt] | ||
files = ["requirements.txt"] | ||
|
||
[tool.hatch.build.targets.wheel] | ||
packages = ["src/browsergym"] |
Empty file.
Empty file.
68 changes: 68 additions & 0 deletions
68
browsergym/assistantbench/src/evaluation/evaluate_utils/evaluate_dicts.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from typing import Dict, List | ||
import numpy as np | ||
|
||
from browsergym.assistantbench.src.evaluation.evaluate_utils.utils import _align_bags | ||
|
||
|
||
def calculate_f1_score(precision, recall): | ||
if precision + recall == 0: | ||
return 0 # Handle the case to avoid division by zero | ||
return 2 * (precision * recall) / (precision + recall) | ||
|
||
|
||
def calc_recall(pred: Dict, gold: Dict, use_gold_for_eval: bool): | ||
from browsergym.assistantbench.src.evaluation.evaluate_utils.evaluate_factory import get_evaluator_from_gold_answer | ||
|
||
recall = [] | ||
for gold_key, gold_value in gold.items(): | ||
pred_value = pred.get(gold_key) | ||
gold_value = fix_number(gold_value) | ||
pred_value = fix_number(pred_value) | ||
if gold_key not in pred: | ||
recall.append(0) | ||
else: | ||
evaluator = ( | ||
get_evaluator_from_gold_answer(type(gold_value)) | ||
if use_gold_for_eval | ||
else get_evaluator_from_gold_answer(type(pred_value)) | ||
) | ||
if type(pred_value) != type(gold_value): | ||
recall.append(0) | ||
continue | ||
recall.append(evaluator(pred_value, gold_value)) | ||
avg_recall = np.average(recall) | ||
return avg_recall | ||
|
||
|
||
def fix_number(number): | ||
|
||
if type(number) == str: | ||
copy_ans = number | ||
copy_ans = ' '.join(' '.join(' '.join(copy_ans.split('$')).split('%')).split('sqft')).strip() | ||
copy_ans = copy_ans.strip() | ||
copy_ans = copy_ans.replace(',', '.') | ||
try: | ||
return float(copy_ans) | ||
except: | ||
return number | ||
elif type(number) == int: | ||
return float(number) | ||
else: | ||
return number | ||
|
||
def evaluate_pair_of_dicts(pred: Dict, gold: Dict): | ||
recall = calc_recall(pred, gold, True) | ||
precision = calc_recall(gold, pred, False) | ||
f1 = calculate_f1_score(precision, recall) | ||
return f1 | ||
|
||
|
||
def evaluate_dicts(pred: List[Dict], gold: List[Dict]): | ||
if not ( | ||
type(pred) == dict | ||
or len(pred) == 0 | ||
or (type(pred) == list and type(pred[0]) == dict) | ||
): | ||
return 0 | ||
max_alignment_scores = _align_bags(pred, gold, evaluate_pair_of_dicts) | ||
return np.average(max_alignment_scores) |
28 changes: 28 additions & 0 deletions
28
browsergym/assistantbench/src/evaluation/evaluate_utils/evaluate_factory.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from typing import Union, Dict | ||
|
||
from browsergym.assistantbench.src.evaluation.evaluate_utils.evaluate_dicts import evaluate_dicts | ||
from browsergym.assistantbench.src.evaluation.evaluate_utils.evaluate_numbers import evaluate_numbers | ||
from browsergym.assistantbench.src.evaluation.evaluate_utils.evaluate_strings import evaluate_strings | ||
|
||
EvaluatorFactory = { | ||
"string": evaluate_strings, | ||
"number": evaluate_numbers, | ||
"json": evaluate_dicts, | ||
"string list": evaluate_strings, | ||
} | ||
|
||
EvaluatorFactoryFromType = { | ||
str: evaluate_strings, | ||
int: evaluate_numbers, | ||
float: evaluate_numbers, | ||
bool: evaluate_strings, | ||
list: evaluate_strings | ||
} | ||
|
||
|
||
def get_evaluator(evaluator: str): | ||
return EvaluatorFactory[evaluator] | ||
|
||
|
||
def get_evaluator_from_gold_answer(gold_answer: Union[str, int, float]): | ||
return EvaluatorFactoryFromType[gold_answer] |
33 changes: 33 additions & 0 deletions
33
browsergym/assistantbench/src/evaluation/evaluate_utils/evaluate_numbers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from typing import Union | ||
import numpy as np | ||
|
||
|
||
# Renamed calc_z function to distance_function_log | ||
def distance_function_log(pred: float, gold: float): | ||
if pred == gold == 0: | ||
return 1 | ||
if pred == 0: | ||
pred = 1e-4 | ||
if gold == 0: | ||
gold = 1e-4 | ||
if pred > gold: | ||
return max(0, 1 - np.log(pred / gold)) | ||
else: | ||
return max(0, 1 - np.log(gold / pred)) | ||
|
||
|
||
def evaluate_numbers(pred: Union[float, str], gold: float): | ||
res = None | ||
if type(pred) != float and type(pred) != int: | ||
try: | ||
pred = float(pred) | ||
except ValueError: | ||
res = 0 | ||
if type(gold) != float and type(gold) != int: | ||
try: | ||
gold = float(gold) | ||
except ValueError: | ||
res = 0 | ||
if res is None: | ||
res = distance_function_log(pred, gold) | ||
return res |
Oops, something went wrong.