Skip to content

Commit

Permalink
Add MDP policy gradient (#14)
Browse files Browse the repository at this point in the history
Co-authored-by: Fabian Konstantinidis <[email protected]>
Co-authored-by: Maximilian Naumann <[email protected]>
Co-authored-by: Naumann Maximilian (CR/AIR4.2) <[email protected]>
  • Loading branch information
4 people authored Feb 5, 2024
1 parent 5733399 commit 4a509da
Show file tree
Hide file tree
Showing 7 changed files with 536 additions and 23 deletions.
30 changes: 15 additions & 15 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,42 +12,42 @@ jobs:
steps:
- uses: actions/checkout@v2

- name: set up python
- name: Set up python
uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: 3.8

- name: set up env
- name: Set up env
run: python -m pip install -e .[docs,dev]

- name: run black
- name: Run black
run: black --check .

- name: run isort
- name: Run isort
run: isort .

- name: run pylint for mdp folder
- name: Run pylint for mdp folder
run: pylint src/behavior_generation_lecture_python/mdp --errors-only

- name: run mypy for mdp folder
- name: Run mypy for mdp folder
run: mypy src/behavior_generation_lecture_python/mdp

- name: test
- name: Test
run: |
export DISPLAY=:99
Xvfb :99 &
pytest
- name: check coverage
- name: Check coverage
run: |
export DISPLAY=:99
Xvfb :99 &
pytest --cov=src --cov-fail-under=85
- name: copy notebooks to docs folder
- name: Copy notebooks to docs folder
run: cp -r notebooks/* docs/notebooks

- name: build docs
- name: Build docs
run: mkdocs build

deploy-pages:
Expand All @@ -58,15 +58,15 @@ jobs:
steps:
- uses: actions/checkout@v2

- name: set up python
- name: Set up python
uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: 3.8

- name: set up env
- name: Set up env
run: python -m pip install -e .[docs]

- name: copy notebooks to docs folder
- name: Copy notebooks to docs folder
run: cp -r notebooks/* docs/notebooks

- run: mkdocs gh-deploy --force
Expand Down
2 changes: 0 additions & 2 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ plugins:
rendering:
show_if_no_docstring: true
show_signature_annotations: true
watch:
- src
- gen-files:
scripts:
- docs/gen_ref_pages.py
Expand Down
250 changes: 250 additions & 0 deletions notebooks/mdp_policy_gradient.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from behavior_generation_lecture_python.mdp.policy import CategorialPolicy\n",
"from behavior_generation_lecture_python.utils.grid_plotting import (\n",
" make_plot_policy_step_function,\n",
")\n",
"from behavior_generation_lecture_python.mdp.mdp import (\n",
" GridMDP,\n",
" policy_gradient,\n",
" derive_deterministic_policy,\n",
" GRID_MDP_DICT,\n",
" HIGHWAY_MDP_DICT,\n",
" LC_RIGHT_ACTION,\n",
" STAY_IN_LANE_ACTION,\n",
")\n",
"\n",
"HIGHWAY_MDP_DICT[\"restrict_actions_to_available_states\"] = False"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## TOY EXAMPLE"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"grid_mdp = GridMDP(**GRID_MDP_DICT)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"policy = CategorialPolicy(\n",
" sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],\n",
" actions=list(grid_mdp.actions),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_checkpoints = policy_gradient(\n",
" mdp=grid_mdp,\n",
" policy=policy,\n",
" iterations=100,\n",
" return_history=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"policy_array = [\n",
" derive_deterministic_policy(mdp=grid_mdp, policy=model)\n",
" for model in model_checkpoints\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_policy_step_grid_map = make_plot_policy_step_function(\n",
" columns=4, rows=3, policy_over_time=policy_array\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mkdocs_flag = True\n",
"if mkdocs_flag:\n",
" import ipywidgets\n",
" from IPython.display import display\n",
"\n",
" iteration_slider = ipywidgets.IntSlider(\n",
" min=0, max=len(model_checkpoints) - 1, step=1, value=0\n",
" )\n",
" w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)\n",
" display(w)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_policy_step_grid_map(100)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## HIGHWAY EXAMPLE"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if False:\n",
" # we will change this to true later on, to see the effect\n",
" HIGHWAY_MDP_DICT[\"transition_probabilities_per_action\"][LC_RIGHT_ACTION] = [\n",
" (0.4, LC_RIGHT_ACTION),\n",
" (0.6, STAY_IN_LANE_ACTION),\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"policy = CategorialPolicy(\n",
" sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],\n",
" actions=list(highway_mdp.actions),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_checkpoints = policy_gradient(\n",
" mdp=highway_mdp,\n",
" policy=policy,\n",
" iterations=200,\n",
" return_history=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"policy_array = [\n",
" derive_deterministic_policy(mdp=highway_mdp, policy=model)\n",
" for model in model_checkpoints\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_policy_step_grid_map = make_plot_policy_step_function(\n",
" columns=10, rows=4, policy_over_time=policy_array\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if mkdocs_flag:\n",
" import ipywidgets\n",
" from IPython.display import display\n",
"\n",
" iteration_slider = ipywidgets.IntSlider(\n",
" min=0, max=len(model_checkpoints) - 1, step=1, value=0\n",
" )\n",
" w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)\n",
" display(w)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_policy_step_grid_map(200)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.11.5 ('kit_vorlesung_tutorial')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
},
"vscode": {
"interpreter": {
"hash": "c55583abd569aed2a1a6538892df4383b19c955ebf68dd4bc0814f5cb22bab0c"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "behavior_generation_lecture_python"
version = "0.0.2"
description = "Python code for the respective lecture at KIT"
readme = "README.md"
requires-python = ">=3.7"
requires-python = ">=3.8"
license = {file = "LICENSE"}
authors = [
{name = "Organizers of the lecture 'Verhaltensgenerierung für Fahrzeuge' at KIT" }
Expand All @@ -17,7 +17,8 @@ dependencies = [
"matplotlib>=2.2.4",
"scipy",
"jupyter",
"python-statemachine"
"python-statemachine",
"torch"
]

[project.optional-dependencies]
Expand Down
Loading

0 comments on commit 4a509da

Please sign in to comment.