diff --git a/.gitignore b/.gitignore index fd133da..5d4635a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ build dist *.egg-info +.idea \ No newline at end of file diff --git a/README.md b/README.md index c87466b..71074a0 100644 --- a/README.md +++ b/README.md @@ -1,114 +1,198 @@ # ssm-diff -AWS [SSM Parameter Store](https://aws.amazon.com/ec2/systems-manager/parameter-store) is a really convenient, AWS-native, KMS-enabled storage for parameters and secrets. +AWS [SSM Parameter Store](https://aws.amazon.com/ec2/systems-manager/parameter-store) provides convenient, AWS-native, +KMS-enabled storage for parameters and secrets. The API makes it easy to request a branch (i.e. subtree) of parameters +when you need to configure a machine, but AWS provides no human-friendly UI for bulk-managing a subtree. -Unfortunately, as of now, it doesn't seem to provide any human-friendly ways of batch-managing [hierarchies of parameters](http://docs.aws.amazon.com/systems-manager/latest/userguide/sysman-paramstore-working.html#sysman-paramstore-su-organize). +`ssm-diff` enables bulk-editing of the SSM Parameter Store keys by converting the path-style values in the Parameter +Store to and from YAML files, where they can be edited. For example, the values at `/Dev/DBServer/MySQL/app1` and +`/Dev/DBServer/MySQL/app2` will become: -The goal of the `ssm-diff` tool is to simplify that process by unwraping path-style -(/Dev/DBServer/MySQL/db-string13 = value) parameters into a YAML structure: ``` Dev: DBServer: MySQL: - db-string13: value + app1: <value> + app2: <value> ``` -Then, given that this local YAML representation of the SSM Parameter Store state was edited, `calculating and applying diffs` on the parameters. +While `ssm-diff` downloads the entire Parameter Store by default, CLI flags (constructor kwargs for programmatic users) +make it possible to extract and work with specific branches, exclude encrypted (i.e. secret) keys, and/or download +the encrypted version of secrets (e.g. for backup purposes). -`ssm-diff` supports complex data types as values and can operate within single or multiple prefixes. +See [`ssm_utils`](https://github.com/dkolb/ssm_utils) for a similar project implemented as a gem. + +## WARNING: MAKE A BACKUP AND ALWAYS `plan` +While this package allows you to apply operations to specific Parameter Store paths, this ability is innately dangerous. +You would not, for example, want to download a copy of a single path and then `push` that single path to the root, +erasing everything outside of that path. Parameter Store versions provide some protection from mistaken +changes, but (to the best of our knowledge) **DELETES ARE IRREVERSIBLE**. + +`ssm-diff` makes an effort to protect you against these kinds of mistakes: + +- The `SSM_NO_DECRYPT` option can be used to create a local backup of your entire Parameter Store without storing +decrypted secrets locally. +- `paths` configurations are stored in environment variables -- and configured during `__init__` for programmatic users -- +to help ensure stability between calls. +- YAML files include metadata that will attempt to prevent you from making calls in an incompatible path. This data is +stored in YAML keys like `ssm-diff:config` and **SHOULD NOT BE CHANGED OR REMOVED**. + +Despite our efforts to protect you, **USE THIS PACKAGE AT YOUR OWN RISK** and **TAKE REASONABLE SAFETY PRECAUTIONS LIKE +KEEPING A BACKUP COPY OF YOUR PARAMETER STORE**. ## Installation ``` pip install ssm-diff ``` -## Geting Started -The tool relies on native AWS SDK, thus, on a way SDK [figures out](http://boto3.readthedocs.io/en/latest/guide/configuration.html) an effective AWS configuration. You might want to configure it explicitly, setting `AWS_DEFAULT_REGION`, or `AWS_PROFILE`, before doing and manipulations on parameters +# Geting Started +This tool uses the native AWS SDK client `boto3` which provides a variety of [configuration options](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#configuration), +including environment variables and configuration files. -When `AWS_PROFILE` environment variable is set, local state file will have a name corresponding to the profile name. +## Authentication +Common authentication options include: -Before we start editing the local representation of parameters state, we have to get it from SMM: -``` -$ ssm-diff init -``` +- `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` (created in the IAM - Users - \<User\> - Security Credentials +section) +- `AWS_SESSION_TOKEN` for temporary access keys ([CLI only](https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_use-resources.html)) +- call `aws configure` to create a local configuration file +- If using a shared configuration, file, `AWS_PROFILE` determines which profile to use in a shared configuration file -will create a local `parameters.yml` (or `<AWS_PROFILE>.yml` if `AWS_PROFILE` is in use) file that stores a YAML representation of the SSM Parameter Store state. +If using an ENV-based authentication process, it may be necessary to set `AWS_DEFAULT_REGION` (e.g. `us-west-1`, `us-west-2`). -Once you accomplish editing this file, adding, modifying or deleting parameters, run: +## Working with Parameters +To initialize the local YAML file, download it from Parameter Store using `clone`: ``` -$ ssm-diff plan +$ ssm-diff clone ``` -Which will show you the diff between this local representation and an SSM Parameter Store. +The name of this file will depend on your settings (in priority order): -Finally -``` -$ ssm-diff apply -``` -will actually apply local changes to the Parameter Store. - -Operations can also be limited to a particular prefix(es): +- if `-f` is set, the name provided +- if `SSM_ROOT_PATH` is used (see below) a filename derived from this path +- if `AWS_PROFILE` is used, `<AWS_PROFILE>.yml` +- `parameters.yml` if no other condition is met +To update an existing file with changes from the Parameter Store, use `pull`: ``` -$ ssm-diff -p /dev -p /qa/ci {init,plan,apply} +$ ssm-diff pull ``` +By default, this command will preserve local changes. To overwrite local changes (keeping only added keys), use +`--force`. -NOTE: when remote state diverges for some reason, but you still want to preserve remote changes, there's a: - -``` -$ ssm-diff pull +After editing the file (e.g. removing `/test/deep/key` and changing `test/dif/key1`), you cna preview the changes by +running 'plan': ``` -command, doing just that. +$ ssm-diff plan +-/test/deep/key +~/test/diff/key1: + < old value + > new value +``` + +When you're ready to actually update the SSM Parameter Store, run `push`: +``` +$ ssm-diff push +``` + +NOTE: The default `DiffResolver` does not cache the remote state so it cannot distinguish between a local add and +remote delete. Please use caution if keys are being removed (externally) from the Parameter Store as the `pull` +command will not remove them from the local storage (even with `--force`) and the `push` command will restore them to +the Parameter Store. + +NOTE: The default `DiffResolver` does not cache the remote state so it cannot recognize concurrent changes (i.e. where +both the local and remote value has changed). Calling push will overwrite any remote changes (including any changes +made since the last `plan`). + +## Options +As discussed above (in the WARNING section). to help ensure the same configurations are preserved across commands, most +configurations are managed using environment variables. The following are available on all commands: + +- `SSM_PATHS` limits operations to specific branches identified by these paths (separated by `;` or `:`). For example, +`clone` will only copy these branches, `pull` will only apply changes to local keys within these branches, and `push` +will only apply changes to remote keys within these branches. +- `SSM_ROOT_PATH` determines the path that is used as the root of the YAML file. For example, if `SSM_ROOT_PATH` is set +to `/app1/dev/server1`, the key `/app1/dev/server1/username` and `/app1/dev/server1/password` show up in the YAML as: + ``` + username: <value> + password: <value> + ``` + As noted above, this will generate a file named `app1~dev~server1.yml` unless `-f` is used. The root path must be + an ancestor of all paths in `SSM_PATHS` or an exception will be raised. +- `SSM_NO_SECURE` excludes encrypted keys from the backup and sync process (when set to `1` or case-insenistive `true`). +This helps ensure that secrets are not accessed unnecessarily and are not decrypted on local systems. +- `SSM_NO_DECRYPT` does not decrypt `SecureString` values when they're downloaded. **NOTE: This option should only be +used to create a local backup without exposing secrets.** The AWS CLI does not provide a way to directly upload +already-encrypted values. If these values need to be restored, you will need to decrypt them using the KMS API and +upload the decrypted values. Despite the complexity of a restore, this option ensures that you have a way to backup +(and recover) your entire parameter store without downloading and storing unencrypted secrets. ## Examples Let's assume we have the following parameters set in SSM Parameter Store: ``` /qa/ci/api/db_schema = foo_ci /qa/ci/api/db_user = bar_ci -/qa/ci/api/db_password = baz_ci +/qa/ci/api/db_password = baz_ci (SecureString) /qa/uat/api/db_schema = foo_uat -/qa/uat/api/db_user = bar_uat -/qa/uat/api/db_password = baz_uat - +/qa/uat/api/db_user = bar_uat +/qa/uat/api/db_password = baz_uat (SecureString) ``` +`init` will create a `parameters.yml` file with the following contents: ``` -$ ssm-diff init -``` -will create a `parameters.yml` file with the following content: - -``` +ssm-diff:config: + ssm-diff:root: / + ssm-diff:paths: + - / + ssm-diff:no-secure: false + ssm-diff:no-decrypt: false qa: ci: api: db_schema: foo_ci db_user: bar_ci - db_password: !secure 'baz_ci' + db_password: !Secret + metadata: + aws:kms:alias: alias/aws/ssm + encrypted: false + secret: 'baz_ci' uat: api: db_schema: foo_uat db_user: bar_uat - db_password: !secure 'baz_uat' + db_password: !Secret + metadata: + aws:kms:alias: alias/aws/ssm + encrypted: true + secret: 'baz_uat' ``` -KMS-encrypted (SecureString) and String type values are distunguished by `!secure` YAML tag. +As you can see in this file: + +- The environment settings during `init` are stored in the `ssm-diff:config` metadata section. While +these are the default values, we strongly recommend that you do not edit (or remove0 this section. +- KMS-encrypted (SecureString) are decrypted and identified by the `!Secret` YAML tag. The `!Secret` tag supports +custom MKS aliases using the `aws:kms:alias` metadata key. When adding secrets that use the default KMS key, you may +use the simpler `!SecureString <decrypted value>` or the legacy `!secure <decrypted value>`. -Let's drop the `ci`-related stuff completely, and edit `uat` parameters a bit, ending up with the following `parameters.yml` file contents: +Now we delete the entire `ci` tree and edit `uat` parameters (including changing the syntax for the secret: ``` +ssm-diff:config: + ssm-diff:root: / + ssm-diff:paths: + - / + ssm-diff:no-secure: false + ssm-diff:no-decrypt: false qa: uat: api: db_schema: foo_uat db_charset: utf8mb4 db_user: bar_changed - db_password: !secure 'baz_changed' + db_password: !SecureString 'baz_changed' ``` -Running -``` -$ ssm-diff plan -``` -will give the following output: +Running `plan` will give the following output: ``` - /qa/ci/api/db_schema @@ -126,11 +210,4 @@ will give the following output: ``` -Finally -``` -$ ssm-diff apply -``` -will actually do all the necessary modifications of parameters in SSM Parameter Store itself, applying local changes - -## Known issues and limitations -- There's currently no option to use different KMS keys for `SecureString` values encryption. +Finally, `push` will run the AWS API calls needed to update the SSM Parameter Store itself to mirror the local changes. \ No newline at end of file diff --git a/setup.py b/setup.py index be38eac..19b8229 100755 --- a/setup.py +++ b/setup.py @@ -3,24 +3,23 @@ from os import path from setuptools import setup - wd = path.abspath(path.dirname(__file__)) with open(path.join(wd, 'README.md'), encoding='utf-8') as f: long_description = f.read() setup( - description = 'A tool to manage contents of AWS SSM Parameter Store', - name = 'ssm-diff', - version = '0.5', - author = 'Sergey Motovilovets', - author_email = 'motovilovets.sergey@gmail.com', + description='A tool to manage contents of AWS SSM Parameter Store', + name='ssm-diff', + version='0.5', + author='Sergey Motovilovets', + author_email='motovilovets.sergey@gmail.com', license='MIT', - url = 'https://github.com/runtheops/ssm-diff', - download_url = 'https://github.com/runtheops/ssm-diff/archive/0.5.tar.gz', + url='https://github.com/runtheops/ssm-diff', + download_url='https://github.com/runtheops/ssm-diff/archive/0.5.tar.gz', long_description=long_description, long_description_content_type='text/markdown', - keywords = ['aws', 'ssm', 'parameter-store'], - packages = ['states'], + keywords=['aws', 'ssm', 'parameter-store'], + packages=['states'], scripts=['ssm-diff'], install_requires=[ 'termcolor', diff --git a/ssm-diff b/ssm-diff index 73ebca1..c455927 100755 --- a/ssm-diff +++ b/ssm-diff @@ -1,82 +1,117 @@ #!/usr/bin/env python from __future__ import print_function -from states import * -import states.helpers as helpers + import argparse +import logging import os +import sys +from states import * -def init(args): - r, l = RemoteState(args.profile), LocalState(args.filename) - l.save(r.get(flat=False, paths=args.path)) +root = logging.getLogger() +root.setLevel(logging.INFO) +handler = logging.StreamHandler(sys.stdout) +handler.setLevel(logging.INFO) +formatter = logging.Formatter('%(name)s - %(message)s') +handler.setFormatter(formatter) +root.addHandler(handler) -def pull(args): - dictfilter = lambda x, y: dict([ (i,x[i]) for i in x if i in set(y) ]) - r, l = RemoteState(args.profile), LocalState(args.filename) - diff = helpers.FlatDictDiffer(r.get(paths=args.path), l.get(paths=args.path)) - if args.force: - ref_set = diff.changed().union(diff.removed()).union(diff.unchanged()) - target_set = diff.added() - else: - ref_set = diff.unchanged().union(diff.removed()) - target_set = diff.added().union(diff.changed()) - state = dictfilter(diff.ref, ref_set) - state.update(dictfilter(diff.target, target_set)) - l.save(helpers.unflatten(state)) +def configure_endpoints(args): + # configure() returns a DiffBase class (whose constructor may be wrapped in `partial` to pre-configure it) + diff_class = DiffBase.get_plugin(args.engine).configure(args) + return storage.ParameterStore(args.profile, diff_class, paths=args.paths, no_secure=args.no_secure, + no_decrypt=args.no_decrypt), \ + storage.YAMLFile(args.filename, root_path=args.yaml_root, paths=args.paths, no_secure=args.no_secure, + no_decrypt=args.no_decrypt) -def apply(args): - r, _, diff = plan(args) +def init(args): + """Create a local YAML file from the SSM Parameter Store (per configs in args)""" + remote, local = configure_endpoints(args) + if local.exists(): + raise ValueError('File already exists, use `pull` instead') + local.save(remote.clone()) + + +def pull(args): + """Update local YAML file with changes in the SSM Parameter Store (per configs in args)""" + remote, local = configure_endpoints(args) + if not local.exists(): + raise ValueError('File does not exist, use `init` instead') + local.save(remote.pull(local.get())) + + +def push(args): + """Apply local changes to the SSM Parameter Store""" + remote, local = configure_endpoints(args) + if not local.exists(): + raise ValueError('File does not exist. Adjust the target file using `-f` or get started using `init`.') print("\nApplying changes...") try: - r.apply(diff) + remote.push(local.get()) except Exception as e: print("Failed to apply changes to remote:", e) print("Done.") def plan(args): - r, l = RemoteState(args.profile), LocalState(args.filename) - diff = helpers.FlatDictDiffer(r.get(paths=args.path), l.get(paths=args.path)) - - if diff.differ: - diff.print_state() - else: - print("Remote state is up to date.") - - return r, l, diff + """Print a representation of the changes that would be applied to SSM Parameter Store if applied (per config in args)""" + remote, local = configure_endpoints(args) + if not local.exists(): + raise ValueError('File does not exist. Adjust the target file using `-f` or get started using `init`.') + print(DiffBase.describe_diff(remote.dry_run(local.get()))) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-f', help='local state yml file', action='store', dest='filename', default='parameters.yml') - parser.add_argument('--path', '-p', action='append', help='filter SSM path') + parser.add_argument('-f', help='local state yml file', action='store', dest='filename') + parser.add_argument('--engine', '-e', help='diff engine to use when interacting with SSM', action='store', dest='engine', default='DiffResolver') parser.add_argument('--profile', help='AWS profile name', action='store', dest='profile') subparsers = parser.add_subparsers(dest='func', help='commands') subparsers.required = True - parser_plan = subparsers.add_parser('plan', help='display changes between local and remote states') - parser_plan.set_defaults(func=plan) - - parser_init = subparsers.add_parser('init', help='create or overwrite local state snapshot') + parser_init= subparsers.add_parser('clone', help='create a local copy of the remote storage') parser_init.set_defaults(func=init) - parser_pull = subparsers.add_parser('pull', help='pull updates from remote state') + parser_pull = subparsers.add_parser('pull', help='pull changes from remote state') parser_pull.set_defaults(func=pull) parser_pull.add_argument('--force', help='overwrite local changes', action='store_true', dest='force') - parser_apply = subparsers.add_parser('apply', help='apply diff to the remote state') - parser_apply.set_defaults(func=apply) + parser_plan = subparsers.add_parser('plan', help='display changes between local and remote states') + parser_plan.set_defaults(func=plan) + + parser_apply = subparsers.add_parser('push', help='push changes to the remote storage') + parser_apply.set_defaults(func=push) args = parser.parse_args() - args.path = args.path if args.path else ['/'] - - if args.filename == 'parameters.yml': - if not args.profile: - if 'AWS_PROFILE' in os.environ: - args.filename = os.environ['AWS_PROFILE'] + '.yml' - else: - args.filename = args.profile + '.yml' + + args.no_secure = os.environ.get('SSM_NO_SECURE', 'false').lower() in ['true', '1'] + args.no_decrypt = os.environ.get('SSM_NO_DECRYPT', 'false').lower() in ['true', '1'] + args.yaml_root = os.environ.get('SSM_YAML_ROOT', '/') + args.paths = os.environ.get('SSM_PATHS', None) + if args.paths is not None: + args.paths = args.paths.split(';:') + else: + # this defaults to '/' + args.paths = args.yaml_root + + # root filename + if args.filename is not None: + filename = args.filename + elif args.yaml_root != '/': + filename = args.yaml_root.replace('/', '~') + elif args.profile: + filename = args.profile + elif 'AWS_PROFILE' in os.environ: + filename = os.environ['AWS_PROFILE'] + else: + filename = 'parameters' + + # remove extension (will be restored by storage classes) + if filename[-4:] == '.yml': + filename = filename[:-4] + args.filename = filename + args.func(args) diff --git a/states/__init__.py b/states/__init__.py index 38eb373..3129d94 100644 --- a/states/__init__.py +++ b/states/__init__.py @@ -1 +1,2 @@ -from .states import * +from .storage import YAMLFile, ParameterStore +from .engine import DiffBase diff --git a/states/engine.py b/states/engine.py new file mode 100644 index 0000000..4550df2 --- /dev/null +++ b/states/engine.py @@ -0,0 +1,158 @@ +import collections +import logging +import re +from functools import partial + +from termcolor import colored + +from .helpers import add + + +class DiffMount(type): + """Metaclass for Diff plugin system""" + # noinspection PyUnusedLocal,PyMissingConstructor + def __init__(cls, *args, **kwargs): + if not hasattr(cls, 'plugins'): + cls.plugins = dict() + else: + cls.plugins[cls.__name__] = cls + + +class DiffBase(metaclass=DiffMount): + """Superclass for diff plugins""" + def __init__(self, remote, local): + self.logger = logging.getLogger(self.__module__) + self.remote_flat, self.local_flat = self._flatten(remote), self._flatten(local) + self.remote_set, self.local_set = set(self.remote_flat.keys()), set(self.local_flat.keys()) + + # noinspection PyUnusedLocal + @classmethod + def get_plugin(cls, name): + if name in cls.plugins: + return cls.plugins[name] + + @classmethod + def configure(cls, args): + """Extract class-specific configurations from CLI args and pre-configure the __init__ method using functools.partial""" + return cls + + @classmethod + def _flatten(cls, d, current_path='', sep='/'): + """Convert a nested dict structure into a "flattened" dict i.e. {"full/path": "value", ...}""" + items = {} + for k, v in d.items(): + new = current_path + sep + k if current_path else k + if isinstance(v, collections.MutableMapping): + items.update(cls._flatten(v, new, sep=sep).items()) + else: + items[sep + new] = v + return items + + @classmethod + def _unflatten(cls, d, sep='/'): + """Converts a "flattened" dict i.e. {"full/path": "value", ...} into a nested dict structure""" + output = {} + for k, v in d.items(): + add( + obj=output, + path=k, + value=v, + sep=sep, + ) + return output + + @classmethod + def describe_diff(cls, plan): + """Return a (multi-line) string describing all differences""" + description = "" + for k, v in plan['add'].items(): + # { key: new_value } + description += colored("+", 'green') + "{} = {}".format(k, v) + '\n' + + for k in plan['delete']: + # { key: old_value } + description += colored("-", 'red') + k + '\n' + + for k, v in plan['change'].items(): + # { key: {'old': value, 'new': value} } + description += colored("~", 'yellow') + "{}:\n\t< {}\n\t> {}".format(k, v['old'], v['new']) + '\n' + + if description == "": + description = "No Changes Detected" + + return description + + @property + def plan(self): + """Returns a `dict` of operations for updating the remote storage i.e. {'add': {...}, 'change': {...}, 'delete': {...}}""" + raise NotImplementedError + + def merge(self): + """Generate a merge of the local and remote dicts, following configurations set during __init__""" + raise NotImplementedError + + +class DiffResolver(DiffBase): + """Determines diffs between two dicts, where the remote copy is considered the baseline""" + def __init__(self, remote, local, force=False): + super().__init__(remote, local) + self.intersection = self.remote_set.intersection(self.local_set) + self.force = force + + if self.added() or self.removed() or self.changed(): + self.differ = True + else: + self.differ = False + + @classmethod + def configure(cls, args): + kwargs = {} + if hasattr(args, 'force'): + kwargs['force'] = args.force + return partial(cls, **kwargs) + + def added(self): + """Returns a (flattened) dict of added leaves i.e. {"full/path": value, ...}""" + return self.local_set - self.intersection + + def removed(self): + """Returns a (flattened) dict of removed leaves i.e. {"full/path": value, ...}""" + return self.remote_set - self.intersection + + def changed(self): + """Returns a (flattened) dict of changed leaves i.e. {"full/path": value, ...}""" + return set(k for k in self.intersection if self.remote_flat[k] != self.local_flat[k]) + + def unchanged(self): + """Returns a (flattened) dict of unchanged leaves i.e. {"full/path": value, ...}""" + return set(k for k in self.intersection if self.remote_flat[k] == self.local_flat[k]) + + @property + def plan(self): + return { + 'add': { + k: self.local_flat[k] for k in self.added() + }, + 'delete': { + k: self.remote_flat[k] for k in self.removed() + }, + 'change': { + k: {'old': self.remote_flat[k], 'new': self.local_flat[k]} for k in self.changed() + } + } + + def merge(self): + dictfilter = lambda original, keep_keys: dict([(i, original[i]) for i in original if i in set(keep_keys)]) + if self.force: + # Overwrite local changes (i.e. only preserve added keys) + # NOTE: Currently the system cannot tell the difference between a remote delete and a local add + prior_set = self.changed().union(self.removed()).union(self.unchanged()) + current_set = self.added() + else: + # Preserve added keys and changed keys + # NOTE: Currently the system cannot tell the difference between a remote delete and a local add + prior_set = self.unchanged().union(self.removed()) + current_set = self.added().union(self.changed()) + state = dictfilter(original=self.remote_flat, keep_keys=prior_set) + state.update(dictfilter(original=self.local_flat, keep_keys=current_set)) + return self._unflatten(state) diff --git a/states/helpers.py b/states/helpers.py index 08d313a..a767982 100644 --- a/states/helpers.py +++ b/states/helpers.py @@ -1,84 +1,35 @@ -from termcolor import colored from copy import deepcopy -import collections -class FlatDictDiffer(object): - def __init__(self, ref, target): - self.ref, self.target = ref, target - self.ref_set, self.target_set = set(ref.keys()), set(target.keys()) - self.isect = self.ref_set.intersection(self.target_set) - - if self.added() or self.removed() or self.changed(): - self.differ = True - else: - self.differ = False - - def added(self): - return self.target_set - self.isect - - def removed(self): - return self.ref_set - self.isect - - def changed(self): - return set(k for k in self.isect if self.ref[k] != self.target[k]) - - def unchanged(self): - return set(k for k in self.isect if self.ref[k] == self.target[k]) - - def print_state(self): - for k in self.added(): - print(colored("+", 'green'), "{} = {}".format(k, self.target[k])) - - for k in self.removed(): - print(colored("-", 'red'), k) - - for k in self.changed(): - print(colored("~", 'yellow'), "{}:\n\t< {}\n\t> {}".format(k, self.ref[k], self.target[k])) - - -def flatten(d, pkey='', sep='/'): - items = [] - for k in d: - new = pkey + sep + k if pkey else k - if isinstance(d[k], collections.MutableMapping): - items.extend(flatten(d[k], new, sep=sep).items()) - else: - items.append((sep + new, d[k])) - return dict(items) - - -def add(obj, path, value): - parts = path.strip("/").split("/") +def add(obj, path, value, sep='/'): + """Add value to the `obj` dict at the specified path""" + parts = path.strip(sep).split(sep) last = len(parts) - 1 + current = obj for index, part in enumerate(parts): if index == last: - obj[part] = value + current[part] = value else: - obj = obj.setdefault(part, {}) + current = current.setdefault(part, {}) + # convenience return, object is mutated + return obj def search(state, path): - result = state + """Get value in `state` at the specified path, returning {} if the key is absent""" + if path.strip("/") == '': + return state for p in path.strip("/").split("/"): - if result.get(p): - result = result[p] - else: - result = {} - break - output = {} - add(output, path, result) - return output + if p not in state: + return {} + state = state[p] + return state -def unflatten(d): - output = {} - for k in d: - add( - obj=output, - path=k, - value=d[k]) - return output +def filter(state, path): + if path.strip("/") == '': + return state + return add({}, path, search(state, path)) def merge(a, b): diff --git a/states/states.py b/states/states.py deleted file mode 100644 index bb96897..0000000 --- a/states/states.py +++ /dev/null @@ -1,142 +0,0 @@ -from __future__ import print_function -from botocore.exceptions import ClientError, NoCredentialsError -from .helpers import flatten, merge, add, search -import sys -import os -import yaml -import boto3 -import termcolor - -def str_presenter(dumper, data): - if len(data.splitlines()) == 1 and data[-1] == '\n': - return dumper.represent_scalar( - 'tag:yaml.org,2002:str', data, style='>') - if len(data.splitlines()) > 1: - return dumper.represent_scalar( - 'tag:yaml.org,2002:str', data, style='|') - return dumper.represent_scalar( - 'tag:yaml.org,2002:str', data.strip()) - -yaml.SafeDumper.add_representer(str, str_presenter) - -class SecureTag(yaml.YAMLObject): - yaml_tag = u'!secure' - - def __init__(self, secure): - self.secure = secure - - def __repr__(self): - return self.secure - - def __str__(self): - return termcolor.colored(self.secure, 'magenta') - - def __eq__(self, other): - return self.secure == other.secure if isinstance(other, SecureTag) else False - - def __hash__(self): - return hash(self.secure) - - def __ne__(self, other): - return (not self.__eq__(other)) - - @classmethod - def from_yaml(cls, loader, node): - return SecureTag(node.value) - - @classmethod - def to_yaml(cls, dumper, data): - if len(data.secure.splitlines()) > 1: - return dumper.represent_scalar(cls.yaml_tag, data.secure, style='|') - return dumper.represent_scalar(cls.yaml_tag, data.secure) - -yaml.SafeLoader.add_constructor('!secure', SecureTag.from_yaml) -yaml.SafeDumper.add_multi_representer(SecureTag, SecureTag.to_yaml) - - -class LocalState(object): - def __init__(self, filename): - self.filename = filename - - def get(self, paths, flat=True): - try: - output = {} - with open(self.filename,'rb') as f: - l = yaml.safe_load(f.read()) - for path in paths: - if path.strip('/'): - output = merge(output, search(l, path)) - else: - return flatten(l) if flat else l - return flatten(output) if flat else output - except IOError as e: - print(e, file=sys.stderr) - if e.errno == 2: - print("Please, run init before doing plan!") - sys.exit(1) - except TypeError as e: - if 'object is not iterable' in e.args[0]: - return dict() - raise - - def save(self, state): - try: - with open(self.filename, 'wb') as f: - content = yaml.safe_dump(state, default_flow_style=False) - f.write(bytes(content.encode('utf-8'))) - except Exception as e: - print(e, file=sys.stderr) - sys.exit(1) - - -class RemoteState(object): - def __init__(self, profile): - if profile: - boto3.setup_default_session(profile_name=profile) - self.ssm = boto3.client('ssm') - - def get(self, paths=['/'], flat=True): - p = self.ssm.get_paginator('get_parameters_by_path') - output = {} - for path in paths: - try: - for page in p.paginate( - Path=path, - Recursive=True, - WithDecryption=True): - for param in page['Parameters']: - add(obj=output, - path=param['Name'], - value=self._read_param(param['Value'], param['Type'])) - except (ClientError, NoCredentialsError) as e: - print("Failed to fetch parameters from SSM!", e, file=sys.stderr) - - return flatten(output) if flat else output - - def _read_param(self, value, ssm_type='String'): - return SecureTag(value) if ssm_type == 'SecureString' else str(value) - - def apply(self, diff): - - for k in diff.added(): - ssm_type = 'String' - if isinstance(diff.target[k], list): - ssm_type = 'StringList' - if isinstance(diff.target[k], SecureTag): - ssm_type = 'SecureString' - self.ssm.put_parameter( - Name=k, - Value=repr(diff.target[k]) if type(diff.target[k]) == SecureTag else str(diff.target[k]), - Type=ssm_type) - - for k in diff.removed(): - self.ssm.delete_parameter(Name=k) - - for k in diff.changed(): - ssm_type = 'SecureString' if isinstance(diff.target[k], SecureTag) else 'String' - - self.ssm.put_parameter( - Name=k, - Value=repr(diff.target[k]) if type(diff.target[k]) == SecureTag else str(diff.target[k]), - Overwrite=True, - Type=ssm_type) diff --git a/states/storage.py b/states/storage.py new file mode 100644 index 0000000..526dedd --- /dev/null +++ b/states/storage.py @@ -0,0 +1,345 @@ +from __future__ import print_function + +import logging +import re +import sys +from copy import deepcopy + +import boto3 +import termcolor +import yaml +from botocore.exceptions import ClientError, NoCredentialsError + +from .helpers import merge, add, filter, search + + +def str_presenter(dumper, data): + if len(data.splitlines()) == 1 and data[-1] == '\n': + return dumper.represent_scalar( + 'tag:yaml.org,2002:str', data, style='>') + if len(data.splitlines()) > 1: + return dumper.represent_scalar( + 'tag:yaml.org,2002:str', data, style='|') + return dumper.represent_scalar( + 'tag:yaml.org,2002:str', data.strip()) + + +yaml.SafeDumper.add_representer(str, str_presenter) + + +class SecureTag(yaml.YAMLObject): + yaml_tag = u'!secure' + + def __init__(self, secure): + self.secure = secure + + def __repr__(self): + return self.secure + + def __str__(self): + return termcolor.colored(self.secure, 'magenta') + + def __eq__(self, other): + return self.secure == other.secure if isinstance(other, SecureTag) else False + + def __hash__(self): + return hash(self.secure) + + def __ne__(self, other): + return not self.__eq__(other) + + @classmethod + def from_yaml(cls, loader, node): + return SecureTag(node.value) + + @classmethod + def to_yaml(cls, dumper, data): + if len(data.secure.splitlines()) > 1: + return dumper.represent_scalar(cls.yaml_tag, data.secure, style='|') + return dumper.represent_scalar(cls.yaml_tag, data.secure) + + +class SecureString(yaml.YAMLObject): + yaml_tag = u'!SecureString' + + +class Secret(yaml.YAMLObject): + yaml_tag = u'!Secret' + METADATA_ENCRYPTED = 'encrypted' + + def __init__(self, secret, metadata=None, encrypted=False): + super().__init__() + self.secret = secret + self.metadata = {} if metadata is None else metadata + self.metadata[self.METADATA_ENCRYPTED] = encrypted + + def __repr__(self): + return "{}(secret={!r}, metadata={!r})".format(self.__class__.__name__, self.secret, self.metadata) + + def __eq__(self, other): + if isinstance(other, Secret): + return self.secret == other.secret and self.metadata == other.metadata + if isinstance(other, SecureTag): + return self.secret == other.secure + return False + + +yaml.SafeLoader.add_constructor('!secure', SecureTag.from_yaml) +yaml.SafeLoader.add_constructor('!SecureString', SecureTag.from_yaml) +# yaml.SafeDumper.add_multi_representer(SecureTag, SecureTag.to_yaml) +yaml.SafeLoader.add_constructor('!Secret', Secret.from_yaml) +yaml.SafeDumper.add_multi_representer(Secret, Secret.to_yaml) + + +class YAMLFile(object): + """Encodes/decodes a dictionary to/from a YAML file""" + METADATA_CONFIG = 'ssm-diff:config' + METADATA_PATHS = 'ssm-diff:paths' + METADATA_ROOT = 'ssm-diff:root' + METADATA_NO_SECURE = 'ssm-diff:no-secure' + METADATA_NO_DECRYPT = 'ssm-diff:no-decrypt' + + def __init__(self, filename, paths=('/',), root_path='/', no_secure=False, no_decrypt=False): + self.filename = '{}.yml'.format(filename) + self.root_path = root_path + self.paths = paths + self.validate_paths() + self.no_secure = no_secure + self.no_decrypt = no_decrypt + + def validate_paths(self): + length = len(self.root_path) + for path in self.paths: + if path[:length] != self.root_path: + raise ValueError('Root path {} does not contain path {}'.format(self.root_path, path)) + + def exists(self): + try: + open(self.filename, 'rb') + except FileNotFoundError: + return False + return True + + def get(self): + try: + output = {} + with open(self.filename, 'rb') as f: + local = yaml.safe_load(f.read()) + self.validate_config(local) + local = self.nest_root(local) + for path in self.paths: + if path.strip('/'): + output = merge(output, filter(local, path)) + else: + return local + return output + except TypeError as e: + if 'object is not iterable' in e.args[0]: + return dict() + raise + + def validate_config(self, local): + """YAML files may contain a special ssm:config tag that stores information about the file when it was generated. + This information can be used to ensure the file is compatible with future calls. For example, a file created + with a particular subpath (e.g. /my/deep/path) should not be used to overwrite the root path since this would + delete any keys not in the original scope. This method does that validation (with permissive defaults for + backwards compatibility).""" + config = local.pop(self.METADATA_CONFIG, {}) + + # strict requirement that the no_secure setting is equal + config_no_secure = config.get(self.METADATA_NO_SECURE, False) + if config_no_secure != self.no_secure: + raise ValueError("YAML file generated with no_secure={} but current class set to no_secure={}".format( + config_no_secure, self.no_secure, + )) + # only apply no_decrypt if we actually download secure + if not self.no_secure: + config_no_decrypt = config.get(self.METADATA_NO_DECRYPT, False) + if config_no_decrypt != self.no_decrypt: + raise ValueError("YAML file generated with no_decrypt={} but current class set to no_decrypt={}".format( + config_no_decrypt, self.no_decrypt, + )) + # strict requirement that root_path is equal + config_root = config.get(self.METADATA_ROOT, '/') + if config_root != self.root_path: + raise ValueError("YAML file generated with root_path={} but current class set to root_path={}".format( + config_root, self.root_path, + )) + # make sure all paths are subsets of file paths + config_paths = config.get(self.METADATA_PATHS, ['/']) + for path in self.paths: + for config_path in config_paths: + # if path is not found in a config path, it could look like we've deleted values + if path[:len(config_path)] == config_path: + break + else: + raise ValueError("Path {} was not included in this file when it was created.".format(path)) + + def unnest_root(self, state): + if self.root_path == '/': + return state + return search(state, self.root_path) + + def nest_root(self, state): + if self.root_path == '/': + return state + return add({}, self.root_path, state) + + def save(self, state): + state = self.unnest_root(state) + # inject state information so we can validate the file on load + # colon is not allowed in SSM keys so this namespace cannot collide with keys at any depth + state[self.METADATA_CONFIG] = { + self.METADATA_PATHS: self.paths, + self.METADATA_ROOT: self.root_path, + self.METADATA_NO_SECURE: self.no_secure + } + try: + with open(self.filename, 'wb') as f: + content = yaml.safe_dump(state, default_flow_style=False) + f.write(bytes(content.encode('utf-8'))) + except Exception as e: + print(e, file=sys.stderr) + sys.exit(1) + + +class ParameterStore(object): + """Encodes/decodes a dict to/from the SSM Parameter Store""" + invalid_characters = r'[^a-zA-Z0-9\-_\./]' + KMS_KEY = 'aws:kms:alias' + + def __init__(self, profile, diff_class, paths=('/',), no_secure=False, no_decrypt=False): + self.logger = logging.getLogger(self.__class__.__name__) + if profile: + boto3.setup_default_session(profile_name=profile) + self.ssm = boto3.client('ssm') + self.diff_class = diff_class + self.paths = paths + self.parameter_filters = [] + if no_secure: + self.parameter_filters.append({ + 'Key': 'Type', + 'Option': 'Equals', + 'Values': [ + 'String', 'StringList', + ] + }) + self.no_decrypt = no_decrypt + + def clone(self): + p = self.ssm.get_paginator('get_parameters_by_path') + output = {} + for path in self.paths: + try: + for page in p.paginate( + Path=path, + Recursive=True, + WithDecryption=not self.no_decrypt, + ParameterFilters=self.parameter_filters, + ): + for param in page['Parameters']: + add(obj=output, + path=param['Name'], + value=self._read_param(param['Value'], param['Type'], name=param['Name'])) + except (ClientError, NoCredentialsError) as e: + print("Failed to fetch parameters from SSM!", e, file=sys.stderr) + + return output + + # noinspection PyMethodMayBeStatic + def _read_param(self, value, ssm_type='String', name=None): + if ssm_type == 'SecureString': + description = self.ssm.describe_parameters( + Filters=[{ + 'Key': 'Name', + 'Values': [name] + }] + ) + value = Secret(value, { + self.KMS_KEY: description['Parameters'][0]['KeyId'], + }, encrypted=self.no_decrypt) + elif ssm_type == 'StringList': + value = value.split(',') + return value + + def pull(self, local): + diff = self.diff_class( + remote=self.clone(), + local=local, + ) + return diff.merge() + + @classmethod + def coerce_state(cls, state, path='/', sep='/'): + errors = {} + for k, v in state.items(): + if re.search(cls.invalid_characters, k) is not None: + errors[path+sep+k]: 'Invalid Key' + continue + if isinstance(v, dict): + errors.update(cls.coerce_state(v, path=path + sep + k)) + elif isinstance(v, list): + list_errors = [] + for item in v: + if not isinstance(item, str): + list_errors.append('list items must be strings: {}'.format(repr(item))) + elif re.search(r'[,]', item) is not None: + list_errors.append("StringList is comma separated so items may not contain commas: {}".format(item)) + if list_errors: + errors[path+sep+k] = list_errors + elif isinstance(v, (str, SecureTag, Secret)): + continue + elif isinstance(v, (int, float, type(None))): + state[k] = str(v) + else: + errors[path+sep+k] = 'Cannot coerce type {}'.format(type(v)) + return errors + + def dry_run(self, local): + working = deepcopy(local) + errors = self.coerce_state(working) + if errors: + raise ValueError('Errors during dry run:\n{}'.format(errors)) + plan = self.diff_class(self.clone(), working).plan + return plan + + def prepare_param(self, name, value): + kwargs = { + 'Name': name, + } + if isinstance(value, list): + kwargs['Type'] = 'StringList' + kwargs['Value'] = ','.join(value) + elif isinstance(value, Secret): + kwargs['Type'] = 'SecureString' + kwargs['Value'] = value.secret + kwargs['KeyId'] = value.metadata.get(self.KMS_KEY, None) + elif isinstance(value, SecureTag): + kwargs['Type'] = 'SecureString' + kwargs['Value'] = value.secure + else: + kwargs['Type'] = 'String' + kwargs['Value'] = value + return kwargs + + def push(self, local): + plan = self.dry_run(local) + + # plan + for k, v in plan['add'].items(): + # { key: new_value } + self.logger.info('add: {}'.format(k)) + kwargs = self.prepare_param(k, v) + self.ssm.put_parameter(**kwargs) + + for k, delta in plan['change']: + # { key: {'old': value, 'new': value} } + self.logger.info('change: {}'.format(k)) + kwargs = self.prepare_param(k, delta['new']) + kwargs['Overwrite'] = True + self.ssm.put_parameter(**kwargs) + + for k in plan['delete']: + # { key: old_value } + self.logger.info('delete: {}'.format(k)) + self.ssm.delete_parameter(Name=k) diff --git a/states/tests.py b/states/tests.py new file mode 100644 index 0000000..3d1addd --- /dev/null +++ b/states/tests.py @@ -0,0 +1,547 @@ +import random +import string +from unittest import TestCase, mock + +from . import engine, storage + + +class DiffBaseFlatten(TestCase): + """Verifies the behavior of the _flatten and _unflatten methods""" + def setUp(self) -> None: + self.obj = engine.DiffBase({}, {}) + + def test_flatten_single(self): + nested = { + "key": "value" + } + flat = { + "/key": "value", + } + self.assertEqual( + flat, + self.obj._flatten(nested) + ) + self.assertEqual( + nested, + self.obj._unflatten(flat) + ) + + def test_flatten_nested(self): + nested = { + "key1": { + "key2": "value" + } + } + flat = { + "/key1/key2": "value", + } + self.assertEqual( + flat, + self.obj._flatten(nested) + ) + self.assertEqual( + nested, + self.obj._unflatten(flat) + ) + + def test_flatten_nested_sep(self): + nested = { + "key1": { + "key2": "value" + } + } + flat = { + "\\key1\\key2": "value", + } + self.assertEqual( + flat, + self.obj._flatten(nested, sep='\\') + ) + self.assertEqual( + nested, + self.obj._unflatten(flat, sep='\\') + ) + + +class DiffResolverMerge(TestCase): + """Verifies that the `merge` method produces the expected output""" + + def test_add_remote(self): + """Remote additions should be added to local""" + remote = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d'}}, + 'x': {'y': {'z': 'x/y/z'}} + } + local = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d'}}, + } + + plan = engine.DiffResolver( + remote, + local, + ) + + self.assertEqual( + remote, + plan.merge() + ) + + def test_add_local(self): + """Local additions should be preserved so we won't see any changes to local""" + remote = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d'}}, + } + local = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d'}}, + 'x': {'y': {'z': 'x/y/z'}} + } + + diff = engine.DiffResolver( + remote, + local, + ) + + self.assertEqual( + local, + diff.merge() + ) + + def test_change_local_force(self): + """Local changes should be overwritten if force+True""" + remote = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d'}}, + } + local = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d_new'}}, + } + + args = mock.Mock(force=True) + diff = engine.DiffResolver.configure(args )( + remote, + local, + ) + + self.assertEqual( + remote, + diff.merge() + ) + + def test_change_local_no_force(self): + """Local changes should be preserved if force=False""" + remote = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d'}}, + } + local = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d_new'}}, + } + + args = mock.Mock(force=False) + diff = engine.DiffResolver.configure(args)( + remote, + local, + ) + + self.assertEqual( + local, + diff.merge() + ) + + +class DiffResolverPlan(TestCase): + + def test_add(self): + """The basic engine will mark any keys present in local but not remote as an add""" + remote = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d'}}, + } + local = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d'}}, + 'x': {'y': {'z': 'x/y/z'}} + } + + diff = engine.DiffResolver( + remote, + local, + ) + + self.assertDictEqual( + { + 'add': { + '/x/y/z': 'x/y/z', + }, + 'delete': {}, + 'change': {} + }, + diff.plan + ) + + def test_change(self): + """The basic engine will mark any keys that differ between remote and local as a change""" + remote = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d'}}, + } + local = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d_new'}}, + } + + diff = engine.DiffResolver( + remote, + local, + ) + + self.assertDictEqual( + { + 'add': {}, + 'delete': {}, + 'change': { + '/a/b/d': {'old': 'a/b/d', 'new': 'a/b/d_new'} + } + }, + diff.plan + ) + + def test_delete(self): + """The basic engine will mark any keys present in remote but not local as a delete""" + remote = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d'}}, + 'x': {'y': {'z': 'x/y/z'}} + } + local = { + 'a': {'b': {'c': 'a/b/c', + 'd': 'a/b/d'}}, + } + + diff = engine.DiffResolver( + remote, + local, + ) + + self.assertDictEqual( + { + 'add': {}, + 'delete': { + '/x/y/z': 'x/y/z', + }, + 'change': {} + }, + diff.plan + ) + + +class YAMLFileValidatePaths(TestCase): + """YAMLFile calls `validate_paths` in `__init__` to ensure the root and paths are compatible""" + def test_validate_paths_invalid(self): + with self.assertRaises(ValueError): + storage.YAMLFile(filename='unused', root_path='/one/branch', paths=['/another/branch']) + + def test_validate_paths_valid_same(self): + self.assertIsInstance( + storage.YAMLFile(filename='unused', root_path='/one/branch', paths=['/one/branch']), + storage.YAMLFile, + ) + + def test_validate_paths_valid_child(self): + self.assertIsInstance( + storage.YAMLFile(filename='unused', root_path='/one/branch', paths=['/one/branch/child']), + storage.YAMLFile, + ) + + +class YAMLFileMetadata(TestCase): + """Verifies that exceptions are thrown if the metadata in the target file is incompatible with the class configuration""" + def test_get_methods(self): + """Make sure we use the methods mocked by other tests""" + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + provider = storage.YAMLFile(filename=filename, no_secure=True) + with mock.patch('states.storage.open') as open_, mock.patch('states.storage.yaml') as yaml, \ + mock.patch.object(provider, 'validate_config'): + self.assertEqual( + provider.get(), + yaml.safe_load.return_value, + ) + open_.assert_called_once_with( + filename + '.yml', 'rb' + ) + yaml.safe_load.assert_called_once_with( + open_.return_value.__enter__.return_value.read.return_value + ) + + def test_get_invalid_no_secure(self): + """Exception should be raised if the secure metadata in the file does not match the instance""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_NO_SECURE: False + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + provider = storage.YAMLFile(filename=filename, no_secure=True) + + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + with self.assertRaises(ValueError): + provider.get() + + def test_get_valid_no_secure(self): + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_NO_SECURE: False + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + provider = storage.YAMLFile(filename=filename, no_secure=False) + + with mock.patch('states.storage.open') as open_, mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + self.assertEqual( + provider.get(), + yaml.safe_load.return_value, + ) + + def test_get_valid_no_secure_true(self): + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_NO_SECURE: True + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + provider = storage.YAMLFile(filename=filename, no_secure=True) + + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + self.assertEqual( + provider.get(), + yaml.safe_load.return_value, + ) + + def test_get_invalid_root(self): + """Exception should be raised if the root metadata in the file does not match the instance""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_ROOT: '/' + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + with mock.patch.object(storage.YAMLFile, 'validate_paths'): + provider = storage.YAMLFile(filename=filename, root_path='/another') + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + with self.assertRaises(ValueError): + provider.get() + + def test_get_valid_root(self): + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_ROOT: '/same' + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + with mock.patch.object(storage.YAMLFile, 'validate_paths'): + provider = storage.YAMLFile(filename=filename, root_path='/same') + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml, \ + mock.patch.object(provider, 'nest_root'): + yaml.safe_load.return_value = yaml_contents + provider.get() + + def test_get_invalid_paths(self): + """Exception should be raised if the paths metadata is incompatible with the instance""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_PATHS: ['/limited'] + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, paths='/') + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + with self.assertRaises(ValueError): + provider.get() + + def test_get_invalid_paths_mixed(self): + """A single invalid path should fail even in the presence of multiple matching paths""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_PATHS: ['/limited'] + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, paths=['/', '/limited']) + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + with self.assertRaises(ValueError): + provider.get() + + def test_get_invalid_paths_multiple(self): + """Multiple invalid paths should fail""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_PATHS: ['/limited'] + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, paths=['/', '/another']) + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + with self.assertRaises(ValueError): + provider.get() + + def test_get_valid_paths_same(self): + """The same path is valid""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_PATHS: ['/'] + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, paths=['/']) + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + provider.get() + + def test_get_valid_paths_child(self): + """A descendant (child) of a path is valid since it's contained in the original""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_PATHS: ['/'] + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, paths=['/child']) + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + provider.get() + + def test_get_valid_paths_child_multiple(self): + """Multiple descendant (child) of a path is valid since it's contained in the original""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_PATHS: ['/'] + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, paths=['/child', '/another_child']) + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + provider.get() + + def test_get_valid_paths_default_nested(self): + """The default path is '/' so it should be valid for anything""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, paths=['/child']) + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + provider.get() + + def test_get_valid_paths_default_root(self): + """The default path is '/' so it should be valid for anything""" + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + } + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, paths=['/']) + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + provider.get() + + +class YAMLFileRoot(TestCase): + """Verify that the `root_path` config works as expected""" + def test_unnest_path(self): + yaml_contents = { + storage.YAMLFile.METADATA_CONFIG: { + # must match root_path of object to pass checks + storage.YAMLFile.METADATA_ROOT: '/nested/path' + }, + 'key': 'value' + } + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, root_path='/nested/path', paths=['/nested/path']) + + # handle open/yaml processing + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + yaml.safe_load.return_value = yaml_contents + self.assertEqual( + { + 'nested': { + 'path': { + 'key': 'value' + } + } + }, + provider.get(), + ) + + def test_nest_path(self): + filename = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) + # make sure validate_paths isn't run + provider = storage.YAMLFile(filename=filename, root_path='/nested/path', paths=['/nested/path']) + + with mock.patch('states.storage.open'), mock.patch('states.storage.yaml') as yaml: + provider.save({ + 'nested': { + 'path': { + 'key': 'value' + } + } + }) + + yaml.safe_dump.assert_called_once_with( + { + storage.YAMLFile.METADATA_CONFIG: { + storage.YAMLFile.METADATA_ROOT: '/nested/path', + storage.YAMLFile.METADATA_PATHS: ['/nested/path'], + storage.YAMLFile.METADATA_NO_SECURE: False, + }, + 'key': 'value' + }, + # appears to replicate a default, but included in the current code + default_flow_style=False + )