Skip to content

Commit

Permalink
rc2
Browse files Browse the repository at this point in the history
  • Loading branch information
gasse committed Mar 13, 2024
1 parent bf181b7 commit b10d327
Show file tree
Hide file tree
Showing 15 changed files with 110 additions and 31 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
browsergym-core==0.1.0rc1
browsergym-core==0.1.0rc2
english-words>=2.0.1
numpy>=1.14
requests>=2.31
Expand Down
2 changes: 1 addition & 1 deletion src/browsergym/workarena/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.0rc1"
__version__ = "0.1.0rc2"

from browsergym.core.registration import register_task

Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"Gisela Kosicki",
"Kyle Lindauer",
"Mildred Gallegas",
"Bob Dylan",
"Son Marschke",
"Veronica Radman",
"Kennith Peto",
"Terrell Rodda",
Expand Down Expand Up @@ -259,7 +259,7 @@
"Ron Kettering",
"Incident Manager",
"Andrew Jackson",
"Alexandre Drouin",
"Lacy Hyten",
"Kay Ganguli",
"Rosalyn Daulton",
"Lashonda Derouen",
Expand Down Expand Up @@ -304,7 +304,7 @@
"Essie Vaill",
"Marc Wanger",
"Kelli Varrato",
"bob too",
"Nadia Wilshire",
"David Miller",
"Marcie Shulz",
"Cathryn Nicolaus",
Expand Down Expand Up @@ -443,7 +443,7 @@
"Kurtis Asberry",
"Rene Dummermuth",
"Rosemarie Fifield",
"Allan Schwantd",
"Heath Vanalphen",
"Jasmin Gum",
"Gayla Geimer",
"Emilia Oxley",
Expand Down
16 changes: 15 additions & 1 deletion src/browsergym/workarena/tasks/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,21 @@ def setup(self, seed: int, page: Page) -> tuple[str, dict]:
self.task_fields = config["task_fields"]
self.fields = config["fields"]

self.created_sysids = [] # Used to track an
self.created_sysids = []

# generate the goal
goal = (
f"Create a new {self.table_label} with "
+ " and ".join(
[
f'a value of "{self.template_record[f]}"' + f' for field "{self.fields[f]}"'
for f in self.task_fields
]
)
+ "."
)
info = {}
return goal, info

def _run_init_scripts(self, page: Page) -> None:
self._add_init_scripts_to_context_and_reload(
Expand Down
7 changes: 4 additions & 3 deletions src/browsergym/workarena/tasks/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,11 @@ def setup(self, seed: int, page: Page) -> tuple[str, dict]:
config = self.random.choice(self.all_configs)
self.sort_fields = config["sort_fields"]
self.sort_dirs = config["sort_dirs"]
self.goal = config["goal"]
goal = config["goal"]

def get_goal(self) -> str:
return self.goal
info = {}

return goal, info

def _generate_random_config(self, seed: int, page: Page):
self.pre_setup(seed, page)
Expand Down
54 changes: 41 additions & 13 deletions src/browsergym/workarena/tasks/scripts/validate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging
import multiprocessing

from browsergym.workarena.config import (
Expand Down Expand Up @@ -109,45 +110,72 @@
}


@retry(stop=stop_after_attempt(10), reraise=True)
def validate_task(task_config, task_class):
@retry(stop=stop_after_attempt(3), reraise=True)
def validate_task(task_config, task_class, page=None):
"""Validates a task with a given configuration"""
with sync_playwright() as p:
browser = p.chromium.launch()
context = browser.new_context()
page = context.new_page()
num_attempts = 4
tries = 0
browser = None
p = None
while tries < num_attempts:
if page is None and p is None:
p = sync_playwright().start()
browser = p.chromium.launch(slow_mo=1000)
context = browser.new_context()
page = context.new_page()
task = task_class(fixed_config=task_config)
task.setup(page=page)
task.setup(page=page, seed=1)
chat_messages = []
task.cheat(page=page, chat_messages=chat_messages)
page.wait_for_timeout(2000)
task_successful = task.validate(page, chat_messages)[1]
task.teardown()
tries += 1
if task_successful:
break
else:
logging.warning(
f"Task {task_class.__name__} was not successful ({tries} / {num_attempts})"
)

if browser is not None:
browser.close()
if p is not None:
p.stop()

return task_successful, task_config
return task_successful, task_config


def validate_configs(task_class, config_path) -> list[dict]:
def validate_configs(
task_class, config_path, num_tasks: int = None, save_failed_tasks: bool = True, page=None
) -> list[dict]:
"""Validate that the configs are working. Saves failing configs to json so they can be tested."""
with open(config_path, "r") as f:
all_configs = json.load(f)

if num_tasks is not None:
all_configs = all_configs[:num_tasks]

failed_tasks = []
with tqdm(
total=len(all_configs), desc=f"Validating {task_class.__name__} configs", ncols=150
) as pbar:
for task_config in all_configs:
try:
success, task_config = validate_task(task_config, task_class)
success, task_config = validate_task(task_config, task_class, page)
print(f"success: {success}")
if not success:
failed_tasks.append(task_config)
except Exception as e:
failed_tasks.append(task_config)
print(f"Exception")
pbar.update(1)
# Save failed tasks to a JSON file
with open(f"failed_{task_class.__name__}.json", "w") as f:
json.dump(failed_tasks, f)
if save_failed_tasks:
# Save failed tasks to a JSON file
with open(f"failed_{task_class.__name__}.json", "w") as f:
json.dump(failed_tasks, f)

return failed_tasks


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_random_config_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
)
@pytest.mark.parametrize("task_entrypoint", RANDOMLY_CONFIGURALBE_TASKS)
@pytest.mark.parametrize("random_seed", range(3))
@pytest.mark.slow
@pytest.mark.skip(reason="Slows CI tests")
def test_cheat_from_random_config(task_entrypoint, random_seed: int, page: Page):
task = task_entrypoint()
task._generate_random_config(seed=random_seed, page=page)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_task_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
before_sleep=lambda _: logging.info("Retrying due to a TimeoutError..."),
)
@pytest.mark.parametrize("task_entrypoint", ALL_WORKARENA_TASKS)
@pytest.mark.parametrize("random_seed", range(3))
@pytest.mark.parametrize("random_seed", range(1))
@pytest.mark.slow
def test_cheat(task_entrypoint, random_seed: int, page: Page):
task = task_entrypoint()
Expand Down
35 changes: 35 additions & 0 deletions tests/test_validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
Tests that are not specific to any particular kind of task.
"""

import pytest
import json
import logging
import random

# bugfix: use same playwright instance in browsergym and pytest
from utils import setup_playwright
from playwright.sync_api import Page, TimeoutError
from tenacity import retry, stop_after_attempt, retry_if_exception_type
from browsergym.workarena.config import ORDER_APPLE_WATCH_TASK_CONFIG_PATH

from browsergym.workarena.tasks.service_catalog import OrderAppleWatchTask
from browsergym.workarena.tasks.scripts.validate import validate_configs


@retry(
stop=stop_after_attempt(2),
retry=retry_if_exception_type(TimeoutError),
reraise=True,
before_sleep=lambda _: logging.info("Retrying due to a TimeoutError..."),
)
def test_validate_configs(page: Page):
failed_tasks = validate_configs(
OrderAppleWatchTask,
ORDER_APPLE_WATCH_TASK_CONFIG_PATH,
num_tasks=2,
save_failed_tasks=False,
page=page,
)
assert len(failed_tasks) == 0

0 comments on commit b10d327

Please sign in to comment.