diff --git a/tests/test_cli.py b/tests/test_cli.py index 58ad488..e246da3 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1056,3 +1056,16 @@ def test_login_command_rename(self): self.assertNotEqual(new_tok, 'tok') self.assertEqual(new_tok, 'VeryLongBase664String==') verify(cli.utils, times=1).Poll(...) + + def test_region_flag(self): + """ + python -m unittest tests.test_cli.CLIUnitTests.test_region_flag + """ + with ArgvContext(program, '-p', 'dev', '-t', '-r'): + cli.main() + cred = cli.utils.read_config(self.credentials.name) + new_tok = cred['dev']['aws_session_token'] + self.assertNotEqual(new_tok, 'tok') + self.assertEqual(new_tok, 'VeryLongBase664String==') + self.assertEqual(cred['dev']['region'], 'ap-southeast-2') + verify(cli.utils, times=2).invoke(...) diff --git a/yawsso/cli.py b/yawsso/cli.py index 4b2195e..db41fad 100644 --- a/yawsso/cli.py +++ b/yawsso/cli.py @@ -36,6 +36,9 @@ def boot(): if args.bin: core.aws_bin = args.bin + if args.region: + core.region_flag = args.region + logger.log(TRACE, f"args: {args}") logger.log(TRACE, f"AWS_CONFIG_FILE: {core.aws_config_file}") logger.log(TRACE, f"AWS_SHARED_CREDENTIALS_FILE: {core.aws_shared_credentials_file}") @@ -60,6 +63,7 @@ def parser(): ap.add_argument("-t", "--trace", help="Trace output", action="store_true") ap.add_argument("-e", "--export-vars", dest="export_vars1", help="Print out AWS ENV vars", action="store_true") ap.add_argument("-v", "--version", help="Print version and exit", action="store_true") + ap.add_argument("-r", "--region", help="Add region to credentials file", action="store_true") sp = ap.add_subparsers(title="available commands", metavar="", dest="command") diff --git a/yawsso/core.py b/yawsso/core.py index e9ca448..d80c39f 100644 --- a/yawsso/core.py +++ b/yawsso/core.py @@ -11,7 +11,8 @@ aws_sso_cache_path = u.xu(os.getenv("AWS_SSO_CACHE_PATH", Constant.AWS_SSO_CACHE_PATH.value)) aws_config_file = u.xu(os.getenv("AWS_CONFIG_FILE", Constant.AWS_CONFIG_FILE.value)) aws_shared_credentials_file = u.xu(os.getenv("AWS_SHARED_CREDENTIALS_FILE", Constant.AWS_SHARED_CREDENTIALS_FILE.value)) -aws_default_region = os.getenv("AWS_DEFAULT_REGION", Constant.AWS_DEFAULT_REGION.value) +aws_default_region = os.getenv("AWS_DEFAULT_REGION") +region_flag = False # See https://github.com/victorskl/yawsso/issues/76 def get_aws_cli_v2_sso_cached_login(profile): @@ -34,19 +35,29 @@ def update_aws_cli_v1_credentials(profile_name, profile, credentials): logger.warning(f"No appropriate credentials found. Skip syncing profile `{profile_name}`") return - # region = profile.get("region", aws_default_region) config = u.read_config(aws_shared_credentials_file) + if config.has_section(profile_name): config.remove_section(profile_name) + config.add_section(profile_name) - # config.set(profile_name, "region", region) # See https://github.com/victorskl/yawsso/issues/76 config.set(profile_name, "aws_access_key_id", credentials["accessKeyId"]) config.set(profile_name, "aws_secret_access_key", credentials["secretAccessKey"]) config.set(profile_name, "aws_session_token", credentials["sessionToken"]) config.set(profile_name, "aws_security_token", credentials["sessionToken"]) + + # set expiration ts_expires_millisecond = credentials["expiration"] dt_utc = str(datetime.utcfromtimestamp(ts_expires_millisecond / 1000.0).isoformat() + '+0000') config.set(profile_name, "aws_session_expiration", dt_utc) + + # set region + region = profile.get("region", aws_default_region) + if region_flag and region: + # See https://github.com/victorskl/yawsso/issues/88 + config.set(profile_name, "region", region) + + # write the config out u.write_config(aws_shared_credentials_file, config) logger.debug(f"Done syncing AWS CLI v1 credentials using AWS CLI v2 SSO login session for profile `{profile_name}`")