diff --git a/.binder/environment.yml b/.binder/environment.yml index c5ea8e4..215a802 100644 --- a/.binder/environment.yml +++ b/.binder/environment.yml @@ -1,4 +1,4 @@ channels: - conda-forge dependencies: -- pyiron_base =0.8.3 +- python diff --git a/.ci_support/environment-tests.yml b/.ci_support/environment-tests.yml new file mode 100644 index 0000000..b157631 --- /dev/null +++ b/.ci_support/environment-tests.yml @@ -0,0 +1,4 @@ +channels: + - conda-forge +dependencies: + - cloudpickle =3.0.0 \ No newline at end of file diff --git a/.ci_support/environment.yml b/.ci_support/environment.yml index deed6b5..4a251b3 100644 --- a/.ci_support/environment.yml +++ b/.ci_support/environment.yml @@ -1,5 +1,4 @@ channels: - conda-forge dependencies: - - pyiron_base =0.8.3 - + - python diff --git a/.github/workflows/daily.yml b/.github/workflows/daily.yml index 10cdbfb..ddc05eb 100644 --- a/.github/workflows/daily.yml +++ b/.github/workflows/daily.yml @@ -9,4 +9,6 @@ on: jobs: codeql: uses: pyiron/actions/.github/workflows/tests-and-coverage.yml@actions-3.1.0 - secrets: inherit \ No newline at end of file + secrets: inherit + with: + tests-env-files: .ci_support/environment.yml .ci_support/environment-tests.yml \ No newline at end of file diff --git a/.github/workflows/push-pull.yml b/.github/workflows/push-pull.yml index e4c5d41..f26c046 100644 --- a/.github/workflows/push-pull.yml +++ b/.github/workflows/push-pull.yml @@ -10,4 +10,7 @@ on: jobs: pyiron: uses: pyiron/actions/.github/workflows/push-pull.yml@actions-3.1.0 - secrets: inherit \ No newline at end of file + secrets: inherit + with: + tests-env-files: .ci_support/environment.yml .ci_support/environment-tests.yml + do-benchmark-tests: false \ No newline at end of file diff --git a/docs/environment.yml b/docs/environment.yml index 33efb13..0d13587 100644 --- a/docs/environment.yml +++ b/docs/environment.yml @@ -7,4 +7,4 @@ dependencies: - sphinx-gallery - sphinx-rtd-theme - versioneer -- pyiron_base =0.8.3 +- python diff --git a/notebooks/example.ipynb b/notebooks/example.ipynb index 02892d4..337c021 100644 --- a/notebooks/example.ipynb +++ b/notebooks/example.ipynb @@ -5,23 +5,45 @@ "id": "c44a93e4-1ce4-4a26-a82a-3d2bff41988b", "metadata": {}, "source": [ - "# Demo notebook\n", - "\n", - "In the standard setup, this notebook gets included in both the docs and the tests.\n", + "# Demos" + ] + }, + { + "cell_type": "markdown", + "id": "51352718-c84e-4515-b033-f0cd80150269", + "metadata": {}, + "source": [ + "## `DotDict`" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "0cef82c5-686f-4753-b8aa-ee8125f17380", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "this is dot accessible\n" + ] + } + ], + "source": [ + "from pyiron_snippets.dotdict import DotDict\n", "\n", - "In the notebook gallery in the docs, most notebooks will use their last image as their thumbnail; this notebook specifies usage of the pyiron logo in `docs/conf.py` under the `nbsphinx_thumbnails` dictionary." + "dd = DotDict({\"foo\": \"this is dot accessible\"})\n", + "print(dd.foo)" ] }, { "cell_type": "code", "execution_count": null, - "id": "52a8cfcc", + "id": "203b5240-0ce9-4c71-86a0-6048b45f7b0f", "metadata": {}, "outputs": [], - "source": [ - "import pyiron_snippets\n", - "print(pyiron_snippets.__version__)" - ] + "source": [] } ], "metadata": { @@ -40,7 +62,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/pyiron_snippets/__init__.py b/pyiron_snippets/__init__.py index ecd3379..4d52a61 100644 --- a/pyiron_snippets/__init__.py +++ b/pyiron_snippets/__init__.py @@ -1,3 +1,3 @@ - from . import _version -__version__ = _version.get_versions()['version'] + +__version__ = _version.get_versions()["version"] diff --git a/pyiron_snippets/_version.py b/pyiron_snippets/_version.py index 5bb3058..079914f 100644 --- a/pyiron_snippets/_version.py +++ b/pyiron_snippets/_version.py @@ -1,4 +1,3 @@ - # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -68,12 +67,14 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate @@ -100,10 +101,14 @@ def run_command( try: dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) + process = subprocess.Popen( + [command] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + **popen_kwargs, + ) break except OSError as e: if e.errno == errno.ENOENT: @@ -141,15 +146,21 @@ def versions_from_parentdir( for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -212,7 +223,7 @@ def git_versions_from_keywords( # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -221,7 +232,7 @@ def git_versions_from_keywords( # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} + tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -229,32 +240,36 @@ def git_versions_from_keywords( for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] # Filter out refs that exactly match prefix or that don't start # with a number once the prefix is stripped (mostly a concern # when prefix is '') - if not re.match(r'\d', r): + if not re.match(r"\d", r): continue if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") def git_pieces_from_vcs( - tag_prefix: str, - root: str, - verbose: bool, - runner: Callable = run_command + tag_prefix: str, root: str, verbose: bool, runner: Callable = run_command ) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. @@ -273,8 +288,7 @@ def git_pieces_from_vcs( env.pop("GIT_DIR", None) runner = functools.partial(runner, env=env) - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=not verbose) + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -282,10 +296,19 @@ def git_pieces_from_vcs( # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) + describe_out, rc = runner( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + f"{tag_prefix}[[:digit:]]*", + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -300,8 +323,7 @@ def git_pieces_from_vcs( pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) # --abbrev-ref was added in git-1.6.3 if rc != 0 or branch_name is None: raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") @@ -341,17 +363,16 @@ def git_pieces_from_vcs( dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -360,10 +381,12 @@ def git_pieces_from_vcs( if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -412,8 +435,7 @@ def render_pep440(pieces: Dict[str, Any]) -> str: rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -442,8 +464,7 @@ def render_pep440_branch(pieces: Dict[str, Any]) -> str: rendered = "0" if pieces["branch"] != "master": rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -604,11 +625,13 @@ def render_git_describe_long(pieces: Dict[str, Any]) -> str: def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -632,9 +655,13 @@ def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } def get_versions() -> Dict[str, Any]: @@ -648,8 +675,7 @@ def get_versions() -> Dict[str, Any]: verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass @@ -658,13 +684,16 @@ def get_versions() -> Dict[str, Any]: # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): + for _ in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None, + } try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -678,6 +707,10 @@ def get_versions() -> Dict[str, Any]: except NotThisMethod: pass - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } diff --git a/pyiron_snippets/colors.py b/pyiron_snippets/colors.py new file mode 100644 index 0000000..19311e6 --- /dev/null +++ b/pyiron_snippets/colors.py @@ -0,0 +1,23 @@ +""" +Simple stuff for colors when you don't want matplotlib/seaborn in your dependency stack. +""" + + +class SeabornColors: + """ + Hex codes for the ten `seaborn.color_palette()` colors (plus pure white and black), + recreated to avoid adding an entire dependency. + """ + + blue = "#1f77b4" + orange = "#ff7f0e" + green = "#2ca02c" + red = "#d62728" + purple = "#9467bd" + brown = "#8c564b" + pink = "#e377c2" + gray = "#7f7f7f" + olive = "#bcbd22" + cyan = "#17becf" + white = "#ffffff" + black = "#000000" diff --git a/pyiron_snippets/dotdict.py b/pyiron_snippets/dotdict.py new file mode 100644 index 0000000..c7df1d7 --- /dev/null +++ b/pyiron_snippets/dotdict.py @@ -0,0 +1,25 @@ +class DotDict(dict): + def __getattr__(self, item): + try: + return self.__getitem__(item) + except KeyError: + raise AttributeError( + f"{self.__class__.__name__} object has no attribute '{item}'" + ) + + def __setattr__(self, key, value): + self[key] = value + + def __dir__(self): + return set(super().__dir__() + list(self.keys())) + + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + for k, v in state.items(): + self.__dict__[k] = v + + def to_list(self): + """A list of values (order not guaranteed)""" + return list(self.values()) diff --git a/pyiron_snippets/factory.py b/pyiron_snippets/factory.py new file mode 100644 index 0000000..98d5caf --- /dev/null +++ b/pyiron_snippets/factory.py @@ -0,0 +1,437 @@ +""" +Tools for making dynamically generated classes unique, and their instances pickleable. + +Provides two main user-facing tools: :func:`classfactory`, which should be used +_exclusively_ as a decorator (this restriction pertains to namespace requirements for +re-importing), and `ClassFactory`, which can be used to instantiate a new factory from +some existing factory function. + +In both cases, the decorated function/input argument should be a pickleable function +taking only positional arguments, and returning a tuple suitable for use in dynamic +class creation via :func:`builtins.type` -- i.e. taking a class name, a tuple of base +classes, a dictionary of class attributes, and a dictionary of values to be expanded +into kwargs for `__subclass_init__`. + +The resulting factory produces classes that are (a) pickleable, and (b) the same object +as any previously built class with the same name. (Note: avoiding class degeneracy with +respect to class name is the responsibility of the person writing the factory function.) + +These classes are then themselves pickleable, and produce instances which are in turn +pickleable (so long as any data they've been fed as inputs or attributes is pickleable, +i.e. here the only pickle-barrier we resolve is that of having come from a dynamically +generated class). + +Since users need to build their own class factories returning classes with sensible +names, we also provide a helper function :func:`sanitize_callable_name`, which makes +sure a string is compliant with use as a class name. This is run internally on user- +provided names, and failure for the user name and sanitized name to match will give a +clear error message. + +Constructed classes can, in turn be used as bases in further class factories. +""" + +from __future__ import annotations + +from abc import ABC, ABCMeta +from functools import wraps +from importlib import import_module +from inspect import signature, Parameter +import pickle +from re import sub +from typing import ClassVar + + +class _SingleInstance(ABCMeta): + """Simple singleton pattern.""" + + _instance = None + + def __call__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super(_SingleInstance, cls).__call__(*args, **kwargs) + return cls._instance + + +class _FactoryTown(metaclass=_SingleInstance): + """ + Makes sure two factories created around the same factory function are the same + factory object. + """ + + factories = {} + + @classmethod + def clear(cls): + """ + Remove factories. + + Can be useful if you're + """ + cls.factories = {} + + @staticmethod + def _factory_address(factory_function: callable) -> str: + return f"{factory_function.__module__}.{factory_function.__qualname__}" + + def get_factory(self, factory_function: callable[..., type]) -> _ClassFactory: + + self._verify_function_only_takes_positional_args(factory_function) + + address = self._factory_address(factory_function) + + try: + return self.factories[address] + except KeyError: + factory = self._build_factory(factory_function) + self.factories[address] = factory + return factory + + @staticmethod + def _build_factory(factory_function): + """ + Subclass :class:`_ClassFactory` and make an instance. + """ + new_factory_class = type( + sanitize_callable_name( + f"{factory_function.__module__}{factory_function.__qualname__}" + f"{factory_function.__name__.title()}" + f"{_ClassFactory.__name__}" + ).replace("_", ""), + (_ClassFactory,), + {}, + factory_function=factory_function, + ) + return wraps(factory_function)(new_factory_class()) + + @staticmethod + def _verify_function_only_takes_positional_args(factory_function: callable): + parameters = signature(factory_function).parameters.values() + if any( + p.kind not in [Parameter.POSITIONAL_ONLY, Parameter.VAR_POSITIONAL] + for p in parameters + ): + raise InvalidFactorySignature( + f"{_ClassFactory.__name__} can only be subclassed using factory " + f"functions that take exclusively positional arguments, but " + f"{factory_function.__name__} has the parameters {parameters}" + ) + + +_FACTORY_TOWN = _FactoryTown() + + +class InvalidFactorySignature(ValueError): + """When the factory function's arguments are not purely positional""" + + pass + + +class InvalidClassNameError(ValueError): + """When a string isn't a good class name""" + + pass + + +class _ClassFactory(metaclass=_SingleInstance): + """ + For making dynamically created classes the same class. + """ + + _decorated_as_classfactory: ClassVar[bool] = False + + def __init_subclass__(cls, /, factory_function, **kwargs): + super().__init_subclass__(**kwargs) + cls.factory_function = staticmethod(factory_function) + cls.class_registry = {} + + def __call__(self, *args) -> type[_FactoryMade]: + name, bases, class_dict, sc_init_kwargs = self.factory_function(*args) + self._verify_name_is_legal(name) + try: + return self.class_registry[name] + except KeyError: + factory_made = self._build_class( + name, + bases, + class_dict, + sc_init_kwargs, + args, + ) + self.class_registry[name] = factory_made + return factory_made + + @classmethod + def clear(cls, *class_names, skip_missing=True): + """ + Remove constructed class(es). + + Can be useful if you've updated the constructor and want to remove old + instances. + + Args: + *class_names (str): The names of classes to remove. Removes all of them + when empty. + skip_missing (bool): Whether to pass over key errors when a name is + requested that is not currently in the class registry. (Default is + True, let missing names pass silently.) + """ + if len(class_names) == 0: + cls.class_registry = {} + else: + for name in class_names: + try: + cls.class_registry.pop(name) + except KeyError as e: + if skip_missing: + continue + else: + raise KeyError(f"Could not find class {name}") + + def _build_class( + self, name, bases, class_dict, sc_init_kwargs, class_factory_args + ) -> type[_FactoryMade]: + + if "__module__" not in class_dict.keys(): + class_dict["__module__"] = self.factory_function.__module__ + if "__qualname__" not in class_dict.keys(): + class_dict["__qualname__"] = f"{self.__qualname__}.{name}" + sc_init_kwargs["class_factory"] = self + sc_init_kwargs["class_factory_args"] = class_factory_args + + if not any(_FactoryMade in base.mro() for base in bases): + bases = (_FactoryMade, *bases) + + return type(name, bases, class_dict, **sc_init_kwargs) + + @staticmethod + def _verify_name_is_legal(name): + sanitized_name = sanitize_callable_name(name) + if name != sanitized_name: + raise InvalidClassNameError( + f"The class name {name} failed to match with its sanitized version" + f"({sanitized_name}), please supply a valid class name." + ) + + def __reduce__(self): + if ( + self._decorated_as_classfactory + and "" not in self.factory_function.__qualname__ + ): + return ( + _import_object, + (self.factory_function.__module__, self.factory_function.__qualname__), + ) + else: + return (_FACTORY_TOWN.get_factory, (self.factory_function,)) + + +def _import_object(module_name, qualname): + module = import_module(module_name) + obj = module + for name in qualname.split("."): + obj = getattr(obj, name) + return obj + + +class _FactoryMade(ABC): + """ + A mix-in to make class-factory-produced classes pickleable. + + If the factory is used as a decorator for another function, it will conflict with + this function (i.e. the owned function will be the true function, and will mismatch + with imports from that location, which will return the post-decorator factory made + class). This can be resolved by setting the + :attr:`_class_returns_from_decorated_function` attribute to be the decorated + function in the decorator definition. + """ + + _class_returns_from_decorated_function: ClassVar[callable | None] = None + + def __init_subclass__(cls, /, class_factory, class_factory_args, **kwargs): + super().__init_subclass__(**kwargs) + cls._class_factory = class_factory + cls._class_factory_args = class_factory_args + cls._factory_town = _FACTORY_TOWN + + def __reduce__(self): + if ( + self._class_returns_from_decorated_function is not None + and "" + not in self._class_returns_from_decorated_function.__qualname__ + ): + # When we create a class by decorating some other function, this class + # conflicts with its own factory_function attribute in the namespace, so we + # rely on directly re-importing the factory + return ( + _instantiate_from_decorated, + ( + self._class_returns_from_decorated_function.__module__, + self._class_returns_from_decorated_function.__qualname__, + self.__getnewargs_ex__(), + ), + self.__getstate__(), + ) + else: + return ( + _instantiate_from_factory, + ( + self._class_factory, + self._class_factory_args, + self.__getnewargs_ex__(), + ), + self.__getstate__(), + ) + + def __getnewargs_ex__(self): + # Child classes can override this as needed + return (), {} + + def __getstate__(self): + # Python <3.11 compatibility + try: + return super().__getstate__() + except AttributeError: + return dict(self.__dict__) + + def __setstate__(self, state): + # Python <3.11 compatibility + try: + super().__setstate__(state) + except AttributeError: + self.__dict__.update(**state) + + +def _instantiate_from_factory(factory, factory_args, newargs_ex): + """ + Recover the dynamic class, then invoke its `__new__` to avoid instantiation (and + the possibility of positional args in `__init__`). + """ + cls = factory(*factory_args) + return cls.__new__(cls, *newargs_ex[0], **newargs_ex[1]) + + +def _instantiate_from_decorated(module, qualname, newargs_ex): + """ + In case the class comes from a decorated function, we need to import it directly. + """ + cls = _import_object(module, qualname) + return cls.__new__(cls, *newargs_ex[0], **newargs_ex[1]) + + +def classfactory( + factory_function: callable[..., tuple[str, tuple[type, ...], dict, dict]] +) -> _ClassFactory: + """ + A decorator for building dynamic class factories whose classes are unique and whose + terminal instances can be pickled. + + Under the hood, classes created by factories get dependence on + :class:`_FactoryMade` mixed in. This class leverages :meth:`__reduce__` and + :meth:`__init_subclass__` and uses up the class namespace :attr:`_class_factory` + and :attr:`_class_factory_args` to hold data (using up corresponding public + variable names in the :meth:`__init_subclass__` kwargs), so any interference with + these fields may cause unexpected side effects. For un-pickling, the dynamic class + gets recreated then its :meth:`__new__` is called using `__newargs_ex__`; a default + implementation returning no arguments is provided on :class:`_FactoryMade` but can + be overridden. + + Args: + factory_function (callable[..., tuple[str, tuple[type, ...], dict, dict]]): + A function returning arguments that would be passed to `builtins.type` to + dynamically generate a class. The function must accept exclusively + positional arguments + + Returns: + (type[_ClassFactory]): A new callable that returns unique classes whose + instances can be pickled. + + Notes: + If the :param:`factory_function` itself, or any data stored on instances of + its resulting class(es) cannot be pickled, then the instances will not be able + to be pickled. Here we only remove the trouble associated with pickling + dynamically created classes. + + If the `__init_subclass__` kwargs are exploited, remember that these are + subject to all the same "gotchas" as their regular non-factory use; namely, all + child classes must specify _all_ parent class kwargs in order to avoid them + getting overwritten by the parent class defaults! + + Dynamically generated classes can, in turn, be used as base classes for further + `@classfactory` decorated factory functions. + + Warnings: + Use _exclusively_ as a decorator. For an inline constructor for an existing + callable, use :class:`ClassFactory` instead. + + Examples: + >>> import pickle + >>> + >>> from pyiron_snippets.factory import classfactory + >>> + >>> class HasN(ABC): + ... '''Some class I want to make dynamically subclass.''' + ... def __init_subclass__(cls, /, n=0, s="foo", **kwargs): + ... super(HasN, cls).__init_subclass__(**kwargs) + ... cls.n = n + ... cls.s = s + ... + ... def __init__(self, x, y=0): + ... self.x = x + ... self.y = y + >>> + >>> @classfactory + ... def has_n_factory(n, s="wrapped_function", /): + ... return ( + ... f"{HasN.__name__}{n}{s}", # New class name + ... (HasN,), # Base class(es) + ... {}, # Class attributes dictionary + ... {"n": n, "s": s} + ... # dict of `builtins.type` kwargs (passed to `__init_subclass__`) + ... ) + >>> + >>> Has2 = has_n_factory(2, "my_dynamic_class") + >>> HasToo = has_n_factory(2, "my_dynamic_class") + >>> HasToo is Has2 + True + + >>> foo = Has2(42, y=-1) + >>> print(foo.n, foo.s, foo.x, foo.y) + 2 my_dynamic_class 42 -1 + + >>> reloaded = pickle.loads(pickle.dumps(foo)) # doctest: +SKIP + >>> print(reloaded.n, reloaded.s, reloaded.x, reloaded.y) # doctest: +SKIP + 2 my_dynamic_class 42 -1 # doctest: +SKIP + + """ + factory = _FACTORY_TOWN.get_factory(factory_function) + factory._decorated_as_classfactory = True + return factory + + +class ClassFactory: + """ + A constructor for new class factories. + + Use on existing class factory callables, _not_ as a decorator. + + Cf. the :func:`classfactory` decorator for more info. + """ + + def __new__(cls, factory_function): + return _FACTORY_TOWN.get_factory(factory_function) + + +def sanitize_callable_name(name: str): + """ + A helper class for sanitizing a string so it's appropriate as a class/function name. + """ + # Replace non-alphanumeric characters except underscores + sanitized_name = sub(r"\W+", "_", name) + # Ensure the name starts with a letter or underscore + if ( + len(sanitized_name) > 0 + and not sanitized_name[0].isalpha() + and sanitized_name[0] != "_" + ): + sanitized_name = "_" + sanitized_name + return sanitized_name diff --git a/pyiron_snippets/files.py b/pyiron_snippets/files.py new file mode 100644 index 0000000..d65302f --- /dev/null +++ b/pyiron_snippets/files.py @@ -0,0 +1,216 @@ +from __future__ import annotations +from pathlib import Path +import shutil + + +def delete_files_and_directories_recursively(path): + if not path.exists(): + return + for item in path.rglob("*"): + if item.is_file(): + item.unlink() + else: + delete_files_and_directories_recursively(item) + path.rmdir() + + +def categorize_folder_items(folder_path): + types = [ + "dir", + "file", + "mount", + "symlink", + "block_device", + "char_device", + "fifo", + "socket", + ] + results = {t: [] for t in types} + + for item in folder_path.iterdir(): + for tt in types: + try: + if getattr(item, f"is_{tt}")(): + results[tt].append(str(item)) + except NotImplementedError: + pass + return results + + +def _resolve_directory_and_path( + file_name: str, + directory: DirectoryObject | str | None = None, + default_directory: str = ".", +): + """ + Internal routine to separate the file name and the directory in case + file name is given in absolute path etc. + """ + path = Path(file_name) + file_name = path.name + if path.is_absolute(): + if directory is not None: + raise ValueError( + "You cannot set `directory` when `file_name` is an absolute path" + ) + # If absolute path, take that of new_file_name regardless of the + # name of directory + directory = str(path.parent) + else: + if directory is None: + # If directory is not given, take default directory + directory = default_directory + else: + # If the directory is given, use it as the main path and append + # additional path if given in new_file_name + if isinstance(directory, DirectoryObject): + directory = directory.path + directory = directory / path.parent + if not isinstance(directory, DirectoryObject): + directory = DirectoryObject(directory) + return file_name, directory + + +class DirectoryObject: + def __init__(self, directory: str | Path | DirectoryObject): + if isinstance(directory, str): + self.path = Path(directory) + elif isinstance(directory, Path): + self.path = directory + elif isinstance(directory, DirectoryObject): + self.path = directory.path + self.create() + + def create(self): + self.path.mkdir(parents=True, exist_ok=True) + + def delete(self, only_if_empty: bool = False): + if self.is_empty() or not only_if_empty: + delete_files_and_directories_recursively(self.path) + + def list_content(self): + return categorize_folder_items(self.path) + + def __len__(self): + return sum([len(cc) for cc in self.list_content().values()]) + + def __repr__(self): + return f"DirectoryObject(directory='{self.path}')\n{self.list_content()}" + + def get_path(self, file_name): + return self.path / file_name + + def file_exists(self, file_name): + return self.get_path(file_name).is_file() + + def write(self, file_name, content, mode="w"): + with self.get_path(file_name).open(mode=mode) as f: + f.write(content) + + def create_subdirectory(self, path): + return DirectoryObject(self.path / path) + + def create_file(self, file_name): + return FileObject(file_name, self) + + def is_empty(self) -> bool: + return len(self) == 0 + + def remove_files(self, *files: str): + for file in files: + path = self.get_path(file) + if path.is_file(): + path.unlink() + + +class NoDestinationError(ValueError): + """A custom error for when neither a new file name nor new location are provided""" + + +class FileObject: + def __init__(self, file_name: str, directory: DirectoryObject = None): + self._file_name, self.directory = _resolve_directory_and_path( + file_name=file_name, directory=directory, default_directory="." + ) + + @property + def file_name(self): + return self._file_name + + @property + def path(self): + return self.directory.path / Path(self._file_name) + + def write(self, content, mode="x"): + self.directory.write(file_name=self.file_name, content=content, mode=mode) + + def read(self, mode="r"): + with open(self.path, mode=mode) as f: + return f.read() + + def is_file(self): + return self.directory.file_exists(self.file_name) + + def delete(self): + self.path.unlink() + + def __str__(self): + return str(self.path.absolute()) + + def _resolve_directory_and_path( + self, + file_name: str, + directory: DirectoryObject | str | None = None, + default_directory: str = ".", + ): + """ + Internal routine to separate the file name and the directory in case + file name is given in absolute path etc. + """ + path = Path(file_name) + file_name = path.name + if path.is_absolute(): + # If absolute path, take that of new_file_name regardless of the + # name of directory + directory = str(path.parent) + else: + if directory is None: + # If directory is not given, take default directory + directory = default_directory + else: + # If the directory is given, use it as the main path and append + # additional path if given in new_file_name + if isinstance(directory, DirectoryObject): + directory = directory.path + directory = directory / path.parent + if not isinstance(directory, DirectoryObject): + directory = DirectoryObject(directory) + return file_name, directory + + def copy( + self, + new_file_name: str | None = None, + directory: DirectoryObject | str | None = None, + ): + """ + Copy an existing file to a new location. + Args: + new_file_name (str): New file name. You can also set + an absolute path (in which case `directory` will be ignored) + directory (DirectoryObject): Directory. If None, the same + directory is used + Returns: + (FileObject): file object of the new file + """ + if new_file_name is None: + if directory is None: + raise NoDestinationError( + "Either new file name or directory must be specified" + ) + new_file_name = self.file_name + file_name, directory = self._resolve_directory_and_path( + new_file_name, directory, default_directory=self.directory.path + ) + new_file = FileObject(file_name, directory.path) + shutil.copy(str(self.path), str(new_file.path)) + return new_file diff --git a/pyiron_snippets/has_post.py b/pyiron_snippets/has_post.py new file mode 100644 index 0000000..948263b --- /dev/null +++ b/pyiron_snippets/has_post.py @@ -0,0 +1,22 @@ +from abc import ABCMeta + + +class HasPost(type): + """ + A metaclass for adding a `__post__` method which has a compatible signature with + `__init__` (and indeed receives all its input), but is guaranteed to be called + only _after_ `__init__` is totally finished. + + Based on @jsbueno's reply in [this discussion](https://discuss.python.org/t/add-a-post-method-equivalent-to-the-new-method-but-called-after-init/5449/11) + """ + + def __call__(cls, *args, **kwargs): + instance = super().__call__(*args, **kwargs) + if post := getattr(cls, "__post__", False): + post(instance, *args, **kwargs) + return instance + + +class AbstractHasPost(HasPost, ABCMeta): + # Just for resolving metaclass conflic for ABC classes that have post + pass diff --git a/pyiron_snippets/logger.py b/pyiron_snippets/logger.py new file mode 100644 index 0000000..b80d480 --- /dev/null +++ b/pyiron_snippets/logger.py @@ -0,0 +1,73 @@ +# coding: utf-8 +# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department +# Distributed under the terms of "New BSD License", see the LICENSE file. + +import logging +from types import MethodType + +__author__ = "Joerg Neugebauer" +__copyright__ = ( + "Copyright 2021, Max-Planck-Institut für Eisenforschung GmbH - " + "Computational Materials Design (CM) Department" +) +__version__ = "1.0" +__maintainer__ = "Jan Janssen" +__email__ = "janssen@mpie.de" +__status__ = "production" +__date__ = "Sep 1, 2017" + + +""" +Set the logging level for pyiron +""" + + +def set_logging_level(self, level, channel=None): + """ + Set level for logger + + Args: + level (str): 'DEBUG, INFO, WARN' + channel (int): 0: file_log, 1: stream, None: both + """ + + if channel: + self.handlers[channel].setLevel(level) + else: + self.handlers[0].setLevel(level) + self.handlers[1].setLevel(level) + + +def setup_logger(): + """ + Setup logger - logs are written to pyiron.log + + Returns: + logging.getLogger: Logger + """ + logger = logging.getLogger("pyiron_log") + logger.setLevel(logging.DEBUG) + + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + ch = logging.StreamHandler() + ch.setLevel(logging.WARN) + ch.setFormatter(formatter) + logger.addHandler(ch) + + try: + fh = logging.FileHandler("pyiron.log") + except PermissionError: + pass + else: + fh.setLevel(logging.INFO) + fh.setFormatter(formatter) + logger.addHandler(fh) + + return logger + + +logger = setup_logger() +logger.set_logging_level = MethodType(set_logging_level, logger) diff --git a/pyiron_snippets/singleton.py b/pyiron_snippets/singleton.py new file mode 100644 index 0000000..5b8b85d --- /dev/null +++ b/pyiron_snippets/singleton.py @@ -0,0 +1,36 @@ +# coding: utf-8 +# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department +# Distributed under the terms of "New BSD License", see the LICENSE file. + +""" +Utility functions used in pyiron. +In order to be accessible from anywhere in pyiron, they *must* remain free of any imports from pyiron! +""" +from abc import ABCMeta + +__author__ = "Joerg Neugebauer, Jan Janssen" +__copyright__ = ( + "Copyright 2020, Max-Planck-Institut für Eisenforschung GmbH - " + "Computational Materials Design (CM) Department" +) +__version__ = "1.0" +__maintainer__ = "Jan Janssen" +__email__ = "janssen@mpie.de" +__status__ = "production" +__date__ = "Sep 1, 2017" + + +class Singleton(ABCMeta): + """ + Implemented with suggestions from + + http://stackoverflow.com/questions/6760685/creating-a-singleton-in-python + + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] diff --git a/pyproject.toml b/pyproject.toml index 6e964f3..51f029b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,5 @@ [build-system] requires = [ - "pyiron_base", "setuptools", "versioneer[toml]==0.29", ] @@ -13,7 +12,7 @@ readme = "docs/README.md" keywords = [ "pyiron",] requires-python = ">=3.9, <3.13" classifiers = [ - "Development Status :: 3 - Alpha", + "Development Status :: 4 - Beta", "Topic :: Scientific/Engineering", "License :: OSI Approved :: BSD License", "Intended Audience :: Science/Research", @@ -24,7 +23,6 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] dependencies = [ - "pyiron_base==0.8.3" ] dynamic = [ "version",] authors = [ @@ -52,3 +50,8 @@ include = [ "pyiron_snippets*",] [tool.setuptools.dynamic.version] attr = "pyiron_snippets.__version__" + +[project.optional-dependencies] +tests = [ + "cloudpickle==3.0.0", +] \ No newline at end of file diff --git a/tests/benchmark/__init__.py b/tests/benchmark/__init__.py deleted file mode 100644 index 102f047..0000000 --- a/tests/benchmark/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Timed tests to make sure critical components stay sufficiently efficient. -""" \ No newline at end of file diff --git a/tests/benchmark/test_benchmark.py b/tests/benchmark/test_benchmark.py deleted file mode 100644 index 9c25ae2..0000000 --- a/tests/benchmark/test_benchmark.py +++ /dev/null @@ -1,6 +0,0 @@ -import unittest - - -class TestNothing(unittest.TestCase): - def test_nothing(self): - self.assertTrue(True) diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py deleted file mode 100644 index 9c25ae2..0000000 --- a/tests/integration/test_integration.py +++ /dev/null @@ -1,6 +0,0 @@ -import unittest - - -class TestNothing(unittest.TestCase): - def test_nothing(self): - self.assertTrue(True) diff --git a/tests/unit/test_dotdict.py b/tests/unit/test_dotdict.py new file mode 100644 index 0000000..ccb17c7 --- /dev/null +++ b/tests/unit/test_dotdict.py @@ -0,0 +1,28 @@ +import unittest + +from pyiron_snippets.dotdict import DotDict + + +class TestDotDict(unittest.TestCase): + def test_dot_dict(self): + dd = DotDict({'foo': 42}) + + self.assertEqual(dd['foo'], dd.foo, msg="Dot access should be equivalent.") + dd.bar = "towel" + self.assertEqual("towel", dd["bar"], msg="Dot assignment should be equivalent.") + + self.assertListEqual(dd.to_list(), [42, "towel"]) + + with self.assertRaises( + KeyError, msg="Failed item access should raise key error" + ): + dd["missing"] + + with self.assertRaises( + AttributeError, msg="Failed attribute access should raise attribute error" + ): + dd.missing + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/test_factory.py b/tests/unit/test_factory.py new file mode 100644 index 0000000..8405c95 --- /dev/null +++ b/tests/unit/test_factory.py @@ -0,0 +1,546 @@ +from __future__ import annotations + +from abc import ABC +import pickle +from typing import ClassVar +import unittest + +import cloudpickle + +from pyiron_snippets.factory import ( + _ClassFactory, + _FactoryMade, + ClassFactory, + classfactory, + InvalidClassNameError, + InvalidFactorySignature, + sanitize_callable_name +) + + +class HasN(ABC): + def __init_subclass__(cls, /, n=0, s="foo", **kwargs): + super().__init_subclass__(**kwargs) + cls.n = n + cls.s = s + + def __init__(self, x, *args, y=0, **kwargs): + super().__init__(*args, **kwargs) + self.x = x + self.y = y + + +@classfactory +def has_n_factory(n, s="wrapped_function", /): + return ( + f"{HasN.__name__}{n}{s}", + (HasN,), + {}, + {"n": n, "s": s} + ) + + +def undecorated_function(n, s="undecorated_function", /): + return ( + f"{HasN.__name__}{n}{s}", + (HasN,), + {}, + {"n": n, "s": s} + ) + + +def takes_kwargs(n, /, s="undecorated_function"): + return ( + f"{HasN.__name__}{n}{s}", + (HasN,), + {}, + {"n": n, "s": s} + ) + + +class FactoryOwner: + @staticmethod + @classfactory + def has_n_factory(n, s="decorated_method", /): + return ( + f"{HasN.__name__}{n}{s}", + (HasN,), + {}, + {"n": n, "s": s} + ) + + +Has2 = has_n_factory(2, "factory_made") # For testing repeated inheritance + + +class HasM(ABC): + def __init_subclass__(cls, /, m=0, **kwargs): + super(HasM, cls).__init_subclass__(**kwargs) + cls.m = m + + def __init__(self, z, *args, **kwargs): + super().__init__(*args, **kwargs) + self.z = z + + +@classfactory +def has_n2_m_factory(m, /): + return ( + f"HasN2M{m}", + (Has2, HasM), + {}, + {"m": m, "n": Has2.n, "s": Has2.s} + ) + + +@classfactory +def has_m_n2_factory(m, /): + return ( + f"HasM{m}N2", + (HasM, Has2,), + {}, + {"m": m} + ) + + +class AddsNandX(ABC): + fnc: ClassVar[callable] + n: ClassVar[int] + + def __init__(self, x): + self.x = x + + def add_to_function(self, *args, **kwargs): + return self.fnc(*args, **kwargs) + self.n + self.x + + +@classfactory +def adder_factory(fnc, n, /): + return ( + f"{AddsNandX.__name__}{fnc.__name__}", + (AddsNandX,), + { + "fnc": staticmethod(fnc), + "n": n, + "_class_returns_from_decorated_function": fnc + }, + {}, + ) + + +def add_to_this_decorator(n): + def wrapped(fnc): + factory_made = adder_factory(fnc, n) + factory_made._class_returns_from_decorated_function = fnc + return factory_made + return wrapped + + +@add_to_this_decorator(5) +def adds_5_plus_x(y: int): + return y + + +class TestClassfactory(unittest.TestCase): + + def test_factory_initialization(self): + self.assertTrue( + issubclass(has_n_factory.__class__, _ClassFactory), + msg="Creation by decorator should yield a subclass" + ) + self.assertTrue( + issubclass(ClassFactory(undecorated_function).__class__, _ClassFactory), + msg="Creation by public instantiator should yield a subclass" + ) + + factory = has_n_factory(2, "foo") + self.assertTrue( + issubclass(factory, HasN), + msg=f"Resulting class should inherit from the base" + ) + self.assertEqual(2, factory.n, msg="Factory args should get interpreted") + self.assertEqual("foo", factory.s, msg="Factory kwargs should get interpreted") + + def test_factory_uniqueness(self): + f1 = classfactory(undecorated_function) + f2 = classfactory(undecorated_function) + + self.assertIs( + f1, + f2, + msg="Repeatedly packaging the same function should give the exact same " + "factory" + ) + self.assertIsNot( + f1, + has_n_factory, + msg="Factory degeneracy is based on the actual wrapped function, we don't " + "do any parsing for identical behaviour inside those functions." + ) + + def test_factory_pickle(self): + with self.subTest("By decoration"): + reloaded = pickle.loads(pickle.dumps(has_n_factory)) + self.assertIs(has_n_factory, reloaded) + + with self.subTest("From instantiation"): + my_factory = ClassFactory(undecorated_function) + reloaded = pickle.loads(pickle.dumps(my_factory)) + self.assertIs(my_factory, reloaded) + + with self.subTest("From qualname by decoration"): + my_factory = FactoryOwner().has_n_factory + reloaded = pickle.loads(pickle.dumps(my_factory)) + self.assertIs(my_factory, reloaded) + + def test_class_creation(self): + n2 = has_n_factory(2, "something") + self.assertEqual( + 2, + n2.n, + msg="Factory args should be getting parsed" + ) + self.assertEqual( + "something", + n2.s, + msg="Factory kwargs should be getting parsed" + ) + self.assertTrue( + issubclass(n2, HasN), + msg="" + ) + self.assertTrue( + issubclass(n2, HasN), + msg="Resulting classes should inherit from the requested base(s)" + ) + + with self.assertRaises( + InvalidClassNameError, + msg="Invalid class names should raise an error" + ): + has_n_factory( + 2, + "our factory function uses this as part of the class name, but spaces" + "are not allowed!" + ) + + def test_class_uniqueness(self): + n2 = has_n_factory(2) + + self.assertIs( + n2, + has_n_factory(2), + msg="Repeatedly creating the same class should give the exact same class" + ) + self.assertIsNot( + n2, + has_n_factory(2, "something_else"), + msg="Sanity check" + ) + + def test_bad_factory_function(self): + with self.assertRaises( + InvalidFactorySignature, + msg="For compliance with __reduce__, we can only use factory functions " + "that strictly take positional arguments" + ): + ClassFactory(takes_kwargs) + + def test_instance_creation(self): + foo = has_n_factory(2, "used")(42, y=43) + self.assertEqual( + 2, foo.n, msg="Class attributes should be inherited" + ) + self.assertEqual( + "used", foo.s, msg="Class attributes should be inherited" + ) + self.assertEqual( + 42, foo.x, msg="Initialized args should be captured" + ) + self.assertEqual( + 43, foo.y, msg="Initialized kwargs should be captured" + ) + self.assertIsInstance( + foo, + HasN, + msg="Instances should inherit from the requested base(s)" + ) + self.assertIsInstance( + foo, + _FactoryMade, + msg="Instances should get :class:`_FactoryMade` mixed in." + ) + + def test_instance_pickle(self): + foo = has_n_factory(2, "used")(42, y=43) + reloaded = pickle.loads(pickle.dumps(foo)) + self.assertEqual( + foo.n, reloaded.n, msg="Class attributes should be reloaded" + ) + self.assertEqual( + foo.s, reloaded.s, msg="Class attributes should be reloaded" + ) + self.assertEqual( + foo.x, reloaded.x, msg="Initialized args should be reloaded" + ) + self.assertEqual( + foo.y, reloaded.y, msg="Initialized kwargs should be reloaded" + ) + self.assertIsInstance( + reloaded, + HasN, + msg="Instances should inherit from the requested base(s)" + ) + self.assertIsInstance( + reloaded, + _FactoryMade, + msg="Instances should get :class:`_FactoryMade` mixed in." + ) + + def test_decorated_method(self): + msg = "It should be possible to have class factories as methods on a class" + foo = FactoryOwner().has_n_factory(2)(42, y=43) + reloaded = pickle.loads(pickle.dumps(foo)) + self.assertEqual(foo.n, reloaded.n, msg=msg) + self.assertEqual(foo.s, reloaded.s, msg=msg) + self.assertEqual(foo.x, reloaded.x, msg=msg) + self.assertEqual(foo.y, reloaded.y, msg=msg) + + def test_factory_inside_a_function(self): + @classfactory + def internal_factory(n, s="unimportable_scope", /): + return ( + f"{HasN.__name__}{n}{s}", + (HasN,), + {}, + {"n": n, "s": s} + ) + + foo = internal_factory(2)(1, y=0) + self.assertEqual(2, foo.n, msg="Nothing should stop the factory from working") + self.assertEqual( + "unimportable_scope", + foo.s, + msg="Nothing should stop the factory from working" + ) + self.assertEqual(1, foo.x, msg="Nothing should stop the factory from working") + self.assertEqual(0, foo.y, msg="Nothing should stop the factory from working") + with self.assertRaises( + AttributeError, + msg="`internal_factory` is defined only locally inside the scope of " + "another function, so we don't expect it to be pickleable whether it's " + "a class factory or not!" + ): + pickle.dumps(foo) + + reloaded = cloudpickle.loads(cloudpickle.dumps(foo)) + self.assertTupleEqual( + (foo.n, foo.s, foo.x, foo.y), + (reloaded.n, reloaded.s, reloaded.x, reloaded.y), + msg="Cloudpickle is powerful enough to overcome this limitation." + ) + + # And again with a factory from the instance constructor + def internally_undecorated(n, s="undecorated_unimportable", /): + return ( + f"{HasN.__name__}{n}{s}", + (HasN,), + {}, + {"n": n, "s": s} + ) + factory_instance = ClassFactory(internally_undecorated) + bar = factory_instance(2)(1, y=0) + self.assertTupleEqual( + (2, "undecorated_unimportable", 1, 0), + (bar.n, bar.s, bar.x, bar.y), + msg="Sanity check" + ) + + with self.assertRaises( + AttributeError, + msg="The relevant factory function is only in " + ): + pickle.dumps(bar) + + reloaded = cloudpickle.loads(cloudpickle.dumps(bar)) + self.assertTupleEqual( + (bar.n, bar.s, bar.x, bar.y), + (reloaded.n, reloaded.s, reloaded.x, reloaded.y), + msg="Cloudpickle is powerful enough to overcome this limitation." + ) + + + def test_repeated_inheritance(self): + n2m3 = has_n2_m_factory(3)(5, 6) + m3n2 = has_m_n2_factory(3)(5, 6) + + self.assertListEqual( + [3, 2, "factory_made"], + [n2m3.m, n2m3.n, n2m3.s], + msg="Sanity check on class property inheritance" + ) + self.assertListEqual( + [3, 0, "foo"], # n and s defaults from HasN! + [m3n2.m, m3n2.n, m3n2.s], + msg="When exploiting __init_subclass__, each subclass must take care to " + "specify _all_ parent class __init_subclass__ kwargs, or they will " + "revert to the default behaviour. This is totally normal python " + "behaviour, and here we just verify that we're vulnerable to the same " + "'gotcha' as the rest of the language." + ) + self.assertListEqual( + [5, 6], + [n2m3.x, n2m3.z], + msg="Sanity check on instance inheritance" + ) + self.assertListEqual( + [m3n2.z, m3n2.x], + [n2m3.x, n2m3.z], + msg="Inheritance order should impact arg order, also completely as usual " + "for python classes" + ) + reloaded_nm = pickle.loads(pickle.dumps(n2m3)) + self.assertListEqual( + [n2m3.m, n2m3.n, n2m3.s, n2m3.z, n2m3.x, n2m3.y], + [ + reloaded_nm.m, + reloaded_nm.n, + reloaded_nm.s, + reloaded_nm.z, + reloaded_nm.x, + reloaded_nm.y + ], + msg="Pickling behaviour should not care that one of the parents was itself " + "a factory made class." + ) + + reloaded_mn = pickle.loads(pickle.dumps(m3n2)) + self.assertListEqual( + [m3n2.m, m3n2.n, m3n2.s, m3n2.z, m3n2.x, m3n2.y], + [ + reloaded_mn.m, + reloaded_mn.n, + reloaded_mn.s, + reloaded_mn.z, + reloaded_mn.x, + reloaded_nm.y + ], + msg="Pickling behaviour should not care about the order of bases." + ) + + def test_clearing_town(self): + + self.assertGreater(len(Has2._factory_town.factories), 0, msg="Sanity check") + + Has2._factory_town.clear() + self.assertEqual( + len(Has2._factory_town.factories), + 0, + msg="Town should get cleared" + ) + + ClassFactory(undecorated_function) + self.assertEqual( + len(Has2._factory_town.factories), + 1, + msg="Has2 exists in memory and the factory town has forgotten about it, " + "but it still knows about the factory town and can see the newly " + "created one." + ) + + def test_clearing_class_register(self): + self.assertGreater( + len(has_n_factory.class_registry), + 0, + msg="Sanity. We expect to have created at least one class up in the header." + ) + has_n_factory.clear() + self.assertEqual( + len(has_n_factory.class_registry), + 0, + msg="Clear should remove all instances" + ) + n_new = 3 + for i in range(n_new): + has_n_factory(i) + self.assertEqual( + len(has_n_factory.class_registry), + n_new, + msg="Should see the new constructed classes" + ) + + def test_other_decorators(self): + """ + In case the factory-produced class itself comes from a decorator, we need to + check that name conflicts between the class and decorated function are handled. + """ + a5 = adds_5_plus_x(2) + self.assertIsInstance(a5, AddsNandX) + self.assertIsInstance(a5, _FactoryMade) + self.assertEqual(5, a5.n) + self.assertEqual(2, a5.x) + self.assertEqual( + 1 + 5 + 2, # y + n=5 + x=2 + a5.add_to_function(1), + msg="Should execute the function as part of call" + ) + + reloaded = pickle.loads(pickle.dumps(a5)) + self.assertEqual(a5.n, reloaded.n) + self.assertIs(a5.fnc, reloaded.fnc) + self.assertEqual(a5.x, reloaded.x) + + def test_other_decorators_inside_locals(self): + @add_to_this_decorator(6) + def adds_6_plus_x(y: int): + return y + + a6 = adds_6_plus_x(42) + self.assertEqual( + 1 + 42 + 6, + a6.add_to_function(1), + msg="Nothing stops us from creating and running these" + ) + with self.assertRaises( + AttributeError, + msg="We can't find the function defined to import and recreate" + "the factory" + ): + pickle.dumps(a6) + + reloaded = cloudpickle.loads(cloudpickle.dumps(a6)) + self.assertTupleEqual( + (a6.n, a6.x), + (reloaded.n, reloaded.x), + msg="Cloudpickle is powerful enough to overcome this limitation." + ) + + +class TestSanitization(unittest.TestCase): + + def test_simple_string(self): + self.assertEqual(sanitize_callable_name("SimpleString"), "SimpleString") + + def test_string_with_spaces(self): + self.assertEqual( + sanitize_callable_name("String with spaces"), "String_with_spaces" + ) + + def test_string_with_special_characters(self): + self.assertEqual(sanitize_callable_name("a!@#$%b^&*()c"), "a_b_c") + + def test_string_with_numbers_at_start(self): + self.assertEqual(sanitize_callable_name("123Class"), "_123Class") + + def test_empty_string(self): + self.assertEqual(sanitize_callable_name(""), "") + + def test_string_with_only_special_characters(self): + self.assertEqual(sanitize_callable_name("!@#$%"), "_") + + def test_string_with_only_numbers(self): + self.assertEqual(sanitize_callable_name("123456"), "_123456") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/test_files.py b/tests/unit/test_files.py new file mode 100644 index 0000000..16548b1 --- /dev/null +++ b/tests/unit/test_files.py @@ -0,0 +1,154 @@ +import unittest +from pyiron_snippets.files import DirectoryObject, FileObject +from pathlib import Path +import platform + + +class TestFiles(unittest.TestCase): + def setUp(self): + self.directory = DirectoryObject("test") + + def tearDown(self): + self.directory.delete() + + def test_directory_instantiation(self): + directory = DirectoryObject(Path("test")) + self.assertEqual(directory.path, self.directory.path) + directory = DirectoryObject(self.directory) + self.assertEqual(directory.path, self.directory.path) + + def test_file_instantiation(self): + self.assertEqual( + FileObject("test.txt", self.directory).path, + FileObject("test.txt", "test").path, + msg="DirectoryObject and str must give the same object" + ) + self.assertEqual( + FileObject("test/test.txt").path, + FileObject("test.txt", "test").path, + msg="File path not same as directory path" + ) + + if platform.system() == "Windows": + self.assertRaises(ValueError, FileObject, "C:\\test.txt", "test") + else: + self.assertRaises(ValueError, FileObject, "/test.txt", "test") + + + def test_directory_exists(self): + self.assertTrue(Path("test").exists() and Path("test").is_dir()) + + def test_write(self): + self.directory.write(file_name="test.txt", content="something") + self.assertTrue(self.directory.file_exists("test.txt")) + self.assertTrue( + "test/test.txt" in [ + ff.replace("\\", "/") + for ff in self.directory.list_content()['file'] + ] + ) + self.assertEqual(len(self.directory), 1) + + def test_create_subdirectory(self): + self.directory.create_subdirectory("another_test") + self.assertTrue(Path("test/another_test").exists()) + + def test_path(self): + f = FileObject("test.txt", self.directory) + self.assertEqual(str(f.path).replace("\\", "/"), "test/test.txt") + + def test_read_and_write(self): + f = FileObject("test.txt", self.directory) + f.write("something") + self.assertEqual(f.read(), "something") + + def test_is_file(self): + f = FileObject("test.txt", self.directory) + self.assertFalse(f.is_file()) + f.write("something") + self.assertTrue(f.is_file()) + f.delete() + self.assertFalse(f.is_file()) + + def test_is_empty(self): + self.assertTrue(self.directory.is_empty()) + self.directory.write(file_name="test.txt", content="something") + self.assertFalse(self.directory.is_empty()) + + def test_delete(self): + self.assertTrue( + Path("test").exists() and Path("test").is_dir(), + msg="Sanity check on initial state" + ) + self.directory.write(file_name="test.txt", content="something") + self.directory.delete(only_if_empty=True) + self.assertFalse( + self.directory.is_empty(), + msg="Flag argument on delete should have prevented removal" + ) + self.directory.delete() + self.assertFalse( + Path("test").exists(), + msg="Delete should remove the entire directory" + ) + self.directory = DirectoryObject("test") # Rebuild it so the tearDown works + + def test_remove(self): + self.directory.write(file_name="test1.txt", content="something") + self.directory.write(file_name="test2.txt", content="something") + self.directory.write(file_name="test3.txt", content="something") + self.assertEqual( + 3, + len(self.directory), + msg="Sanity check on initial state" + ) + self.directory.remove_files("test1.txt", "test2.txt") + self.assertEqual( + 1, + len(self.directory), + msg="Should be able to remove multiple files at once", + ) + self.directory.remove_files("not even there", "nor this") + self.assertEqual( + 1, + len(self.directory), + msg="Removing non-existent things should have no effect", + ) + self.directory.remove_files("test3.txt") + self.assertEqual( + 0, + len(self.directory), + msg="Should be able to remove just one file", + ) + + def test_copy(self): + f = FileObject("test_copy.txt", self.directory) + f.write("sam wrote this wondrful thing") + new_file_1 = f.copy("another_test") + self.assertEqual(new_file_1.read(), "sam wrote this wondrful thing") + new_file_2 = f.copy("another_test", ".") + with open("another_test", "r") as file: + txt = file.read() + self.assertEqual(txt, "sam wrote this wondrful thing") + new_file_2.delete() # needed because current directory + new_file_3 = f.copy(str(f.path.parent / "another_test"), ".") + self.assertEqual(new_file_1.path.absolute(), new_file_3.path.absolute()) + new_file_4 = f.copy(directory=".") + with open("test_copy.txt", "r") as file: + txt = file.read() + self.assertEqual(txt, "sam wrote this wondrful thing") + new_file_4.delete() # needed because current directory + with self.assertRaises(ValueError): + f.copy() + + + def test_str(self): + f = FileObject("test_copy.txt", self.directory) + if platform.system() == "Windows": + txt = f"my file: {self.directory.path.absolute()}\\test_copy.txt" + else: + txt = f"my file: {self.directory.path.absolute()}/test_copy.txt" + self.assertEqual(f"my file: {f}", txt) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/test_has_post.py b/tests/unit/test_has_post.py new file mode 100644 index 0000000..8aa5e55 --- /dev/null +++ b/tests/unit/test_has_post.py @@ -0,0 +1,41 @@ +import unittest + +import pyiron_snippets.has_post + + +class TestHasPost(unittest.TestCase): + def test_has_post_metaclass(self): + class Foo(metaclass=pyiron_snippets.has_post.HasPost): + def __init__(self, x=0): + self.x = x + self.y = x + self.z = x + self.x += 1 + + @property + def data(self): + return self.x, self.y, self.z + + class Bar(Foo): + def __init__(self, x=0, extra=1): + super().__init__(x) + + def __post__(self, *args, extra=1, **kwargs): + self.z = self.x + extra + + self.assertTupleEqual( + (1, 0, 0), + Foo().data, + msg="It should be fine to have this metaclass but not define post" + ) + + self.assertTupleEqual( + (1, 0, 2), + Bar().data, + msg="Metaclass should be inherited, able to use input, and happen _after_ " + "__init__" + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/test_logger.py b/tests/unit/test_logger.py new file mode 100644 index 0000000..8f6864e --- /dev/null +++ b/tests/unit/test_logger.py @@ -0,0 +1,40 @@ +import unittest +from pyiron_snippets.logger import logger +import os +import shutil + + +class TestLogger(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.logger_file = os.path.join(os.getcwd(), 'pyiron.log') + cls.backup_file = os.path.join(os.getcwd(), 'pyiron.log.test_logger_backup') + shutil.copy(cls.logger_file, cls.backup_file) + + @classmethod + def tearDownClass(cls) -> None: + shutil.move(cls.backup_file, cls.logger_file) + + def test_logger(self): + logsize = os.path.getsize(self.logger_file) + logger.warning("Here is a warning") + self.assertGreater(os.path.getsize(self.logger_file), logsize) + + def test_set_logging_level(self): + logger.set_logging_level(10) + self.assertEqual(10, logger.getEffectiveLevel(), "Overall logger level should match input") + self.assertEqual(10, logger.handlers[0].level, "Stream level should match input") + self.assertEqual(10, logger.handlers[0].level, "File level should match input") + + logger.set_logging_level(20, channel=1) + self.assertEqual(10, logger.getEffectiveLevel(), "Overall logger level should not have changed") + self.assertEqual(10, logger.handlers[0].level, "Stream level should not have changed") + self.assertEqual(20, logger.handlers[1].level, "File level should match input") + + logger.set_logging_level("WARNING", channel=0) + self.assertEqual(30, logger.handlers[0].level, "Should be able to set by string") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/test_singleton.py b/tests/unit/test_singleton.py new file mode 100644 index 0000000..b9e5bcb --- /dev/null +++ b/tests/unit/test_singleton.py @@ -0,0 +1,19 @@ +import unittest +from pyiron_snippets.singleton import Singleton + + +class TestSingleton(unittest.TestCase): + def test_uniqueness(self): + class Foo(metaclass=Singleton): + def __init__(self): + self.x = 1 + + f1 = Foo() + f2 = Foo() + self.assertIs(f1, f2) + f2.x = 2 + self.assertEqual(2, f1.x) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/test_tests.py b/tests/unit/test_tests.py deleted file mode 100644 index 71876c7..0000000 --- a/tests/unit/test_tests.py +++ /dev/null @@ -1,9 +0,0 @@ -import unittest -import pyiron_snippets - - -class TestVersion(unittest.TestCase): - def test_version(self): - version = pyiron_snippets.__version__ - print(version) - self.assertTrue(version.startswith('0'))