Skip to content

Latest commit

 

History

History
86 lines (61 loc) · 3.84 KB

README.md

File metadata and controls

86 lines (61 loc) · 3.84 KB

hungry-geese-demo

Project Title

  • Imitation Learning and Hyperparameter Search with Wandb

Overview

  • This repository contains an imitation learning implementation to train an agent for the game "Hungry Geese" on Kaggle. This game requires a policy to navigate a game board with the aim of consuming food and avoiding collisions with the opponent's geese. We use a custom CNN model, named GeeseNet originally designed by yuricat and kyazuki, to capture spatial information from the game board.

Prerequisites

  • Wandb account (for tracking experiments and hyperparameter search)

Repository Contents

  • training.py - The main script to train the model. It includes functions for training and validation of the model, hyper-parameter parsing, and creating a DataLoader for training and validation sets.

  • get_data.py - Script to download and extract the training dataset.

  • data_processing.py - Contains functions to process the downloaded dataset and return train and validation sets.

  • model.py - Defines the GeeseNet model and a function to create a submission file based on the trained model.

  • utils.py - Contains helper functions like creating folders, getting path lists, etc.

  • visualization.py - Contains a function to create a GIF from a submission for visualization purposes.

  • optuna_config_hungry_geese.yaml - Sample configuration file for optimization by Launch with Optuna.

Getting Started

You can start by cloning the repository:

git clone https://github.com/nejumi/hungry-geese-demo

Next, install the required Python packages. It's recommended to create a new Python environment, and once you activate the environment, you can install the packages using:

cd hungry-geese-demo
pip install -r requirements.txt

Training

You can train the model by running the training.py script:

cd src
python training.py

You can provide the following optional arguments:

  • --layers: number of layers in the GeeseNet model.
  • --filters: number of filters in each convolutional layer in the GeeseNet model.
  • --batch_size: the batch size used for training.
  • --data_folder: the path of the folder where the training data is located.
  • --val_size: the size of the validation set.
  • --n_epochs: the number of epochs to train each chunk.
  • --chunk_size: the number of samples in each chunk of the training data.
  • --chunk_num: the number of chunks to be used for training.
  • --project: project name for wandb.
  • --entity: entity name for wandb.

For example:

python training.py --layers 16 --filters 16 --batch_size 4096

Hyperparameter Optimization by Launch with Optuna

WandB Sweeps can used to optimize hyperparameters. This can also be done more scalably on Launch. Please refer to the documentation for details on the settings. For example:

wandb launch-sweep optuna_config_hungry_geese.yaml -q "your-queue-name" -p your-project-name -e your-entity-name

Please change the yaml file, including the name of the job, according to your requirements.

Results

The training script uses the Weights & Biases (wandb) platform to track the model's performance. After each epoch, the script logs the loss, accuracy, and win rate of the model. It also logs a GIF of a self-match episode for the agent. You can visualize these results on the wandb platform. optuna_multi

License

This project is licensed under the MIT License. See the LICENSE file for details.

Acknowledgments

  • Kaggle for providing the Hungry Geese environment.
  • yuricat and kyazuki for the GeeseNet and related implementations.