Skip to content

Commit

Permalink
Add mypy and make it happy
Browse files Browse the repository at this point in the history
  • Loading branch information
m-naumann committed Jan 16, 2024
1 parent a2dd321 commit 0f2ce12
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 22 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ jobs:
- name: run pylint for mdp folder
run: pylint src/behavior_generation_lecture_python/mdp --errors-only

- name: run mypy for mdp folder
run: mypy src/behavior_generation_lecture_python/mdp

- name: test
run: pytest

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ dev = [
"black[jupyter]==22.3.0",
"pytest",
"pytest-cov>=3.0.0",
"pylint"
"pylint",
"mypy"
]
docs = [
"mkdocs-material",
Expand Down
51 changes: 30 additions & 21 deletions src/behavior_generation_lecture_python/mdp/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
for action in self.actions:
if (state, action) not in transition_probabilities:
continue
total_prob = 0
total_prob = 0.0
for prob, next_state in transition_probabilities[(state, action)]:
assert (
next_state in self.states
Expand Down Expand Up @@ -151,14 +151,17 @@ def sample_next_state(self, state, action) -> Any:
return prob_per_transition[choice][1]


GridState = Tuple[int, int]


class GridMDP(MDP):
def __init__(
self,
grid: List[List[Union[float, None]]],
initial_state: Tuple[int, int],
terminal_states: Set[Tuple[int, int]],
initial_state: GridState,
terminal_states: Set[GridState],
transition_probabilities_per_action: Dict[
Tuple[int, int], List[Tuple[float, Tuple[int, int]]]
GridState, List[Tuple[float, GridState]]
],
restrict_actions_to_available_states: Optional[bool] = False,
) -> None:
Expand Down Expand Up @@ -186,7 +189,9 @@ def __init__(
for y in range(rows):
if grid[y][x] is not None:
states.add((x, y))
reward[(x, y)] = grid[y][x]
reward_xy = grid[y][x]
assert reward_xy is not None
reward[(x, y)] = reward_xy

transition_probabilities = {}
for state in states:
Expand Down Expand Up @@ -260,8 +265,11 @@ def _next_state_deterministic(
return state


StateValueTable = Dict[Any, float]


def expected_utility_of_action(
mdp: MDP, state: Any, action: Any, utility_of_states: Dict[Any, float]
mdp: MDP, state: Any, action: Any, utility_of_states: StateValueTable
) -> float:
"""Compute the expected utility of taking an action in a state.
Expand All @@ -283,7 +291,7 @@ def expected_utility_of_action(
)


def derive_policy(mdp: MDP, utility_of_states: Dict[Any, float]) -> Dict[Any, Any]:
def derive_policy(mdp: MDP, utility_of_states: StateValueTable) -> Dict[Any, Any]:
"""Compute the best policy for an MDP given the utility of the states.
Args:
Expand All @@ -310,7 +318,7 @@ def value_iteration(
epsilon: float,
max_iterations: int,
return_history: Optional[bool] = False,
) -> Union[Dict[Any, float], List[Dict[Any, float]]]:
) -> Union[StateValueTable, List[StateValueTable]]:
"""Derive a utility estimate by means of value iteration.
Args:
Expand All @@ -326,11 +334,11 @@ def value_iteration(
The final utility estimate, if return_history is false. The
history of utility estimates as list, if return_history is true.
"""
utility = {state: 0 for state in mdp.get_states()}
utility = {state: 0.0 for state in mdp.get_states()}
utility_history = [utility.copy()]
for _ in range(max_iterations):
utility_old = utility.copy()
max_delta = 0
max_delta = 0.0
for state in mdp.get_states():
utility[state] = max(
expected_utility_of_action(
Expand All @@ -348,8 +356,11 @@ def value_iteration(
raise RuntimeError(f"Did not converge in {max_iterations} iterations")


QTable = Dict[Tuple[Any, Any], float]


def best_action_from_q_table(
*, state: Any, available_actions: Set[Any], q_table: Dict[Tuple[Any, Any], float]
*, state: Any, available_actions: Set[Any], q_table: QTable
) -> Any:
"""Derive the best action from a Q table.
Expand All @@ -361,9 +372,9 @@ def best_action_from_q_table(
Returns:
The best action according to the Q table.
"""
available_actions = list(available_actions)
values = np.array([q_table[(state, action)] for action in available_actions])
action = available_actions[np.argmax(values)]
available_actions_list = list(available_actions)
values = np.array([q_table[(state, action)] for action in available_actions_list])
action = available_actions_list[np.argmax(values)]
return action


Expand All @@ -376,15 +387,13 @@ def random_action(available_actions: Set[Any]) -> Any:
Returns:
A random action.
"""
available_actions = list(available_actions)
num_actions = len(available_actions)
available_actions_list = list(available_actions)
num_actions = len(available_actions_list)
choice = np.random.choice(num_actions)
return available_actions[choice]
return available_actions_list[choice]


def greedy_value_estimate_for_state(
*, q_table: Dict[Tuple[Any, Any], float], state: Any
) -> float:
def greedy_value_estimate_for_state(*, q_table: QTable, state: Any) -> float:
"""Compute the greedy (best possible) value estimate for a state from the Q table.
Args:
Expand All @@ -407,7 +416,7 @@ def q_learning(
epsilon: float,
iterations: int,
return_history: Optional[bool] = False,
) -> Dict[Tuple[Any, Any], float]:
) -> Union[QTable, List[QTable]]:
"""Derive a value estimate for state-action pairs by means of Q learning.
Args:
Expand Down

0 comments on commit 0f2ce12

Please sign in to comment.