-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
23 changed files
with
2,352 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Can't Do That Anymore | ||
|
||
We evaluate how well models can adapt to new rules of an environment, by applying novel reasoning to a task rather than following biases seen during their training. We task models to play a variant of chess and evaluate whether they can avoid making moves that are ordinarily legal, but are illegal in our variant which has slightly different rules. In our variant of chess, bishops move as knights do. | ||
|
||
## Usage | ||
|
||
Run with: | ||
|
||
``` | ||
oaieval <solver> cant_do_that_anymore | ||
``` | ||
|
||
We suggest using `generation/direct/gpt-3.5-turbo` or `generation/direct/gpt-4-turbo-preview` as default choices for `<solver>` | ||
|
||
For more examples of running this eval, see `scripts/run_experiments.sh` | ||
|
||
## Dataset | ||
|
||
For each model we evaluate, we construct a dataset where every sample contains a board position and the next move that was played, which is legal for the board position under the normal rules of chess, but illegal under the rules of our variant (i.e. the next move is a bishop moving diagonally). We call these types of moves *special moves*. We additionally filter to only include special moves that the model would have predicted under temperature=0 with the normal rules. We can use this to evaluate if models will change their predictions when given the variant rules, despite normally strongly predicting the move under the normal rules. | ||
|
||
Each model's dataset is automatically found and loaded upon running the eval. If a dataset doesn't exist for a particular solver, one will automatically be constructed for it. | ||
|
||
## Evaluation Process | ||
|
||
Samples from the dataset are evaluated one-by-one. Each sample contains a board position and the special move (next move). We prompt models to predict the next best move given the board position, separately under both the normal rules of chess and our variant's rules. We then measure whether the model predicted the special move from the sample under both rule settings. If the model was perfectly following the given rules, we'd expect it to never predict the special move under the variant's rules. | ||
|
||
To see how we prompt models under each rule setting, see `defaults.py`. | ||
|
||
## Metrics | ||
|
||
The below are the key metrics of this eval: | ||
|
||
| Metric | Interpretation | | ||
| --- | --- | | ||
| `variant_impact_factor` | The relative decrease in special move predictions when under the variant's rules, relative to the special move predictions under the normal rules. Lower is better, perfect score is -1. | ||
| `delta` | The absolute decrease in predicting the special move when under the variant's rules, relative to the models predictions under the normal rules. Lower is better. | ||
| `predicted_move_proportion` | The proportion of examples where the model predicted the special move under the normal rules. | ||
| `predicted_move_in_variant_proportion` | The proportion of examples where the model predicted the special move under the variant's rules. | ||
| `avg_num_previous_moves` | Average number of previous moves leading up to the board positions across all samples. | ||
| `std_num_previous_moves` | Standard deviation of the number of previous moves leading up to the board positions across all samples. | ||
|
||
## Variants | ||
|
||
| Variant | Notes | | ||
| --- | --- | | ||
| Default: `cant_do_that_anymore.all` | Default setting. Each dataset has 1000 samples. | | ||
| `cant_do_that_anymore.all_small` | A smaller version of the default setting. Each dataset has 100 samples. | | ||
| `cant_do_that_anymore.all_diagonal` | In this variant, we measure the proportion of samples (board positions) where the model will attempt to move a bishop diagonally. | | ||
|
||
## Custom Solvers | ||
|
||
We use two custom solvers for the base models we evaluate: `chess/generation/direct/gpt-3.5-turbo-instruct` and `chess/generation/direct/gpt-4-base`. These only generate up to four tokens, which prevents the base models from simulating the entire game. | ||
|
||
## Token Usage Estimates | ||
|
||
Below is a rough estimate of the total number of tokens used by the default variant: | ||
|
||
| Solver | Input Tokens | Output Tokens | Total Tokens | | ||
| --- | --- | --- | --- | | ||
| generation/direct/gpt-3.5-turbo | 375,000 | 10,000 | 385,000 | | ||
| generation/direct/gpt-4-turbo-preview | 375,000 | 10,000 | 385,000 | | ||
|
||
## Version History | ||
|
||
- v0: Initial version released | ||
|
||
## Contribution statement | ||
|
||
Eval design, implementation, and results evaluation was primarily conducted by Oliver Jaffe with contributions from Giulio Starace, under the guidance of (alphabetically by last-name) Steven Adler, James Aung, and Chan Jun Shern who scoped and managed the broader research project, including input on evaluation design, results analysis, and interpretation. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
import copy | ||
from typing import Callable, Dict, Sequence | ||
|
||
from evals.elsuite.cant_do_that_anymore.chess.notation import NotationParser | ||
from evals.elsuite.cant_do_that_anymore.chess.pieces import Piece | ||
from evals.elsuite.cant_do_that_anymore.chess.utils import ( | ||
Move, | ||
get_other_player_id, | ||
get_path_between_coords, | ||
parse_piece, | ||
) | ||
|
||
|
||
class Board: | ||
""" | ||
Represents one board position. Is instantiated several times | ||
by the BoardController to simulate future boards after playing | ||
some moves. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
board_state: Sequence[Sequence[str]], | ||
piece_id_to_instance: Dict[int, Piece], | ||
piece_str_to_id: Dict[str, int], | ||
piece_id_to_str: Dict[int, str], | ||
): | ||
self.board_state = board_state | ||
self.piece_id_to_instance = piece_id_to_instance | ||
self.piece_str_to_id = piece_str_to_id | ||
self.piece_id_to_str = piece_id_to_str | ||
|
||
def __str__(self) -> str: | ||
str_board = [["" for _ in range(8)] for _ in range(8)] | ||
|
||
for row_idx in range(len(self.board_state)): | ||
row = self.board_state[row_idx] | ||
for col_idx in range(len(row)): | ||
piece_color, piece_id = parse_piece(self.board_state, row_idx, col_idx) | ||
|
||
if piece_color != "E": | ||
white_piece = piece_color == "W" | ||
s = ( | ||
self.piece_id_to_instance[piece_id].white_render | ||
if white_piece | ||
else self.piece_id_to_instance[piece_id].black_render | ||
) | ||
else: | ||
s = "\u25A1" | ||
str_board[row_idx][col_idx] = s | ||
|
||
# Add letters on bottom | ||
str_board += [["-"] * 8] | ||
str_board += [["a", "b", "c", "d", "e", "f", "g", "h"]] | ||
|
||
# Add numbers on side | ||
str_board = [["|"] + row for row in str_board] | ||
numbers = list(range(8, 0, -1)) + [" ", " "] | ||
str_board = [[str(numbers[idx])] + row for (idx, row) in enumerate(str_board)] | ||
|
||
# Render as string | ||
str_board = "\n".join([" ".join(row) for row in str_board]) | ||
return str_board | ||
|
||
def _update_board(self, move: Move): | ||
""" | ||
Updates board_state according to given move. This move must have previously been checked | ||
to be legal. Edge cases for moves that: | ||
1) Take pieces at other positions where this piece isn't moving (en passant) | ||
2) Move two pieces (castling) | ||
3) Change the id of the piece (promotion) | ||
""" | ||
start_coord, target_coord = move.start_coord, move.target_coord | ||
piece_color, piece_id = parse_piece(self.board_state, start_coord[0], start_coord[1]) | ||
target_piece_color, target_piece_id = parse_piece( | ||
self.board_state, target_coord[0], target_coord[1] | ||
) | ||
|
||
# En passant | ||
if piece_id == 0 and target_piece_color == "E": | ||
dy = target_coord[1] - start_coord[1] | ||
target_en_passant_piece = [start_coord[0], start_coord[1] + dy] | ||
self.board_state[target_en_passant_piece[0]][target_en_passant_piece[1]] = "E" | ||
|
||
# Castling | ||
if move.castling: | ||
path = get_path_between_coords(start_coord, target_coord) | ||
rook_tile = path[0] | ||
self.board_state[rook_tile[0]][rook_tile[1]] = f"{piece_color}3" | ||
|
||
kingside = target_coord[1] <= 4 | ||
old_rook_tile = [start_coord[0], 0] if kingside else [start_coord[0], 7] | ||
self.board_state[old_rook_tile[0]][old_rook_tile[1]] = "E" | ||
|
||
# Move piece | ||
self.board_state[start_coord[0]][start_coord[1]] = "E" | ||
self.board_state[target_coord[0]][target_coord[1]] = f"{piece_color}{piece_id}" | ||
|
||
# Promotion | ||
if move.promotion is not None: | ||
self.board_state[target_coord[0]][target_coord[1]] = f"{piece_color}{move.promotion}" | ||
|
||
def _get_player_moves(self, player_id: str, previous_moves: Sequence[Move]) -> Sequence[Move]: | ||
""" | ||
Returns all possible moves by pieces for a player. Doesn't filter out moves that | ||
result in the king being placed under check | ||
""" | ||
moves = [] | ||
for row_idx in range(len(self.board_state)): | ||
row = self.board_state[row_idx] | ||
for col_idx in range(len(row)): | ||
piece_color, piece_id = parse_piece(self.board_state, row_idx, col_idx) | ||
if piece_color != player_id: | ||
continue | ||
|
||
piece = self.piece_id_to_instance[piece_id] | ||
possible_piece_moves = piece.get_piece_moves( | ||
self.board_state, player_id, [row_idx, col_idx], previous_moves | ||
) | ||
moves += possible_piece_moves | ||
|
||
return moves | ||
|
||
def _is_king_in_check(self, player_id: str) -> bool: | ||
other_player_id = get_other_player_id(player_id) | ||
|
||
other_player_moves = self._get_player_moves(other_player_id, []) | ||
king_capturing_moves = self._filter_for_king_capturing_moves(other_player_moves, player_id) | ||
return len(king_capturing_moves) != 0 | ||
|
||
def _filter_for_king_capturing_moves( | ||
self, moves: Sequence[Move], king_color: str | ||
) -> Sequence[Move]: | ||
king_capturing_moves = [] | ||
for move in moves: | ||
piece_color, piece_id = parse_piece( | ||
self.board_state, move.target_coord[0], move.target_coord[1] | ||
) | ||
if piece_color == king_color and piece_id == 5: | ||
king_capturing_moves.append(move) | ||
|
||
return king_capturing_moves | ||
|
||
|
||
class BoardController: | ||
""" | ||
Manages a single game of chess. Contains logic to find all legal | ||
moves for a particular player and update the internal board according | ||
to a given move. Maintains one Board obj to represent the true state of play | ||
""" | ||
|
||
def __init__( | ||
self, | ||
board_init: Callable[..., Sequence[Sequence[str]]], | ||
piece_id_to_instance: Dict[int, Piece], | ||
piece_str_to_id: Dict[str, int], | ||
piece_id_to_str: Dict[int, str], | ||
notation_parser: NotationParser, | ||
): | ||
self.board = Board(board_init(), piece_id_to_instance, piece_str_to_id, piece_id_to_str) | ||
self.notation_parser = notation_parser | ||
|
||
self.previous_moves = [] | ||
|
||
def __str__(self) -> str: | ||
return self.board.__str__() | ||
|
||
def update_board(self, move: str): | ||
""" | ||
Parses move, updates the internal board state, then stores the move | ||
since knowing previous moves is necessary for En Passant and castling | ||
""" | ||
move = self.notation_parser._str_to_move(move, self.board.board_state) | ||
self.board._update_board(move) | ||
self.previous_moves.append(move) | ||
|
||
def get_player_legal_moves(self, player_id: str) -> Sequence[str]: | ||
""" | ||
Gets all legal moves for a player with the given player_id, returned in | ||
the notation this object was initialised with | ||
""" | ||
legal_moves = self.board._get_player_moves(player_id, self.previous_moves) | ||
legal_moves = self._filter_to_prevent_pinning(legal_moves, player_id) | ||
|
||
legal_moves = [ | ||
self.notation_parser._move_to_str(i, self.board.board_state) for i in legal_moves | ||
] | ||
return legal_moves | ||
|
||
def _filter_to_prevent_pinning(self, moves: Sequence[Move], player_id: str) -> Sequence[Move]: | ||
""" | ||
Filter out moves that would result in the king being pinned, or the king moving over a pinned | ||
position when castling | ||
""" | ||
|
||
def _is_valid_castling(move: Move) -> bool: | ||
if self.board._is_king_in_check(player_id): | ||
return False | ||
|
||
# Check that the king won't move over an attacked position | ||
dy = (move.target_coord[1] - move.start_coord[1]) / abs( | ||
move.target_coord[1] - move.start_coord[1] | ||
) | ||
king_path = get_path_between_coords( | ||
move.start_coord, [move.target_coord[0], move.target_coord[1] + dy] | ||
) | ||
|
||
not_pinned_along_path = [] | ||
for coord in king_path: | ||
simulated_board = copy.deepcopy(self.board) | ||
simulated_board._update_board( | ||
Move(move.start_coord, coord, promotion=None, castling=False) | ||
) | ||
pinned = simulated_board._is_king_in_check(player_id) | ||
not_pinned_along_path.append(not pinned) | ||
|
||
if all(not_pinned_along_path): | ||
return True | ||
|
||
return False | ||
|
||
filtered_moves = [] | ||
for move in moves: | ||
if move.castling and _is_valid_castling(move): | ||
filtered_moves.append(move) | ||
elif not move.castling: | ||
simulated_board = copy.deepcopy(self.board) | ||
simulated_board._update_board(move) | ||
if not simulated_board._is_king_in_check(player_id): | ||
filtered_moves.append(move) | ||
|
||
return filtered_moves | ||
|
||
def _is_checkmate(self, player_id: str) -> bool: | ||
legal_moves = self.get_player_legal_moves(player_id) | ||
if len(legal_moves) == 0 and self.board._is_king_in_check(player_id): | ||
return True | ||
return False | ||
|
||
def _is_stalemate(self, player_id: str) -> bool: | ||
legal_moves = self.get_player_legal_moves(player_id) | ||
if len(legal_moves) == 0 and not self.board._is_king_in_check(player_id): | ||
return True | ||
return False |
Oops, something went wrong.