diff --git a/rlpytorch/model_loader.py b/rlpytorch/model_loader.py index 560c1537..618a744c 100644 --- a/rlpytorch/model_loader.py +++ b/rlpytorch/model_loader.py @@ -96,6 +96,9 @@ def load_model(self, params): return model +def str2bool(v): + return v.lower() in ("yes", "true", "t", "1") + def load_env(envs, num_models=None, overrides=dict(), defaults=dict(), **kwargs): ''' Load envs. envs will be specified as environment variables, more specifically, ``game``, ``model_file`` and ``model`` are required. @@ -132,5 +135,6 @@ def load_env(envs, num_models=None, overrides=dict(), defaults=dict(), **kwargs) env.update(kwargs) parser = argparse.ArgumentParser() + parser.register('type', 'bool', str2bool) all_args = ArgsProvider.Load(parser, env, global_defaults=defaults, global_overrides=overrides) return env, all_args