diff --git a/src/pyella/either.py b/src/pyella/either.py index 142a62a..4ca5eb9 100644 --- a/src/pyella/either.py +++ b/src/pyella/either.py @@ -130,7 +130,7 @@ def pure( value: TB_co, # type: ignore [misc] # covariant arg ok, b/c function is pure ) -> Either[TA_co, TB_co]: # pylint: disable=invalid-name """ - Alias for :py:func:`pure(self) ` + Alias for :py:func:`pure(value) ` """ return pure(value) diff --git a/src/pyella/reader.py b/src/pyella/reader.py new file mode 100644 index 0000000..f36415c --- /dev/null +++ b/src/pyella/reader.py @@ -0,0 +1,160 @@ +# Copyright (c) Ely Deckers. +# +# This source code is licensed under the MPL-2.0 license found in the +# LICENSE file in the root directory of this source tree. + +""" +Reader - Contains the :py:class:`Reader[TE, TA] ` type and related +functions. It was strongly inspired by the Haskell ``Control.Monad.Reader`` module. + +More information on the Haskell ``Control.Monad.Reader`` type can be found here: +https://hackage.haskell.org/package/mtl-2.3.1/docs/Control-Monad-Reader.html +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Generic, TypeVar + +from pyella.shared import _const + +TE_co = TypeVar("TE_co", covariant=True) # pylint: disable=invalid-name +TA_co = TypeVar("TA_co", covariant=True) # pylint: disable=invalid-name +TC_co = TypeVar("TC_co", covariant=True) # pylint: disable=invalid-name + + +@dataclass(frozen=True) +class Reader(Generic[TE_co, TA_co]): # pylint: disable=too-few-public-methods + """ + Represents a computation, which can read values from a shared environment, pass values from function to function, + and execute sub-computations in a modified environment. + """ + + run_reader: Callable[[TE_co], TA_co] + + def bind( + self, apply: Callable[[TA_co], Reader[TE_co, TC_co]] + ) -> Reader[TE_co, TC_co]: + """ + Alias for :py:func:`bind(self, apply) ` + """ + return bind(self, apply) + + def chain(self, em1: Reader[TE_co, TC_co]) -> Reader[TE_co, TC_co]: + """ + Alias for :py:func:`chain(self, em1) ` + """ + return chain(self, em1) + + def discard( + self, apply: Callable[[TA_co], Reader[TE_co, TA_co]] + ) -> Reader[TE_co, TA_co]: + """ + Alias for :py:func:`discard(self, apply) ` + """ + return discard(self, apply) + + def fmap(self, apply: Callable[[TA_co], TC_co]) -> Reader[TE_co, TC_co]: + """ + Alias for :py:func:`fmap(self, apply) ` + """ + return fmap(self, apply) + + @staticmethod + def pure( + run_reader: Callable[[TE_co], TA_co], # type: ignore [misc] # covariant arg ok, b/c function is pure + ) -> Reader[TE_co, TA_co]: # pylint: disable=invalid-name + """ + Alias for :py:func:`pure(run_reader) ` + """ + return pure(run_reader) + + def __eq__(self, __o: object) -> bool: + if not isinstance(__o, Reader): + return False + + return __o.run_reader == self.run_reader + + +def ask() -> Reader[TE_co, TE_co]: # type: ignore [misc] # covariant arg ok, b/c function is pure + """ + Retrieves the monad environment. + + .. note:: Haskell: `ask `__ + """ + return Reader(lambda v: v) + + +def bind( + em0: Reader[TE_co, TA_co], apply: Callable[[TA_co], Reader[TE_co, TC_co]] +) -> Reader[TE_co, TC_co]: + """ + Passes the inherited environment to both subcomputations, first the original `Reader [TE, TA]` and + then the result of the function `TA -> Reader [TE, TC]` + + .. note:: Haskell: `>>= `__ + """ + + def _bind(environment: TE_co) -> TC_co: # type: ignore [misc] # covariant arg ok, b/c function is pure + environment_1 = em0.run_reader(environment) + + return apply(environment_1).run_reader(environment) + + return Reader(_bind) + + +def chain( + em0: Reader[TC_co, TE_co], # type: ignore [misc] # covariant arg ok, b/c function is pure + em1: Reader[TC_co, TA_co], +) -> Reader[TC_co, TA_co]: + """ + Discard the current value of a :py:class:`Reader[TE,TA] ` and replace it with the given :py:class:`Reader[TC, TB] ` + + .. note:: Haskell: `>> `__ + """ + return bind(em0, lambda _: em1) + + +def discard( + em0: Reader[TC_co, TE_co], apply: Callable[[TE_co], Reader[TC_co, TA_co]] +) -> Reader[TC_co, TE_co]: + """ + Apply the given function to the value of a :py:class:`Reader[TE,TA] ` and discard the result + """ + return em0.bind(apply).chain(em0) + + +def replace( + self, value: TC_co # type: ignore [misc] # covariant arg ok, b/c function is pure +) -> Reader[TE_co, TC_co]: + """ + Replace the value of an :py:class:`Reader` with a new value + + :return: a :py:class:`Reader[TE,TA] ` with provided value + + .. note:: Haskell: `<$ `__ + """ + return fmap(self, _const(value)) + + +def fmap( + em0: Reader[TE_co, TA_co], # type: ignore [misc] # covariant arg ok, b/c function is pure + apply: Callable[[TA_co], TC_co], +) -> Reader[TE_co, TC_co]: # type: ignore [misc] # covariant arg ok, b/c function is pure + """ + Map a function over the value of a :py:class:`Reader[TE,TA] ` + + .. note:: Haskell: `fmap `__ + """ + return bind(em0, lambda m0: pure(lambda _: apply(m0))) + + +def pure( + run_reader: Callable[[TE_co], TA_co], # type: ignore [misc] # covariant arg ok, b/c function is pure +) -> Reader[TE_co, TA_co]: + """ + Create a :py:class:`Reader[TE,TA] ` from a function + + .. note:: Haskell: `pure `__ + """ + return Reader(run_reader) diff --git a/src/tests/test_reader.py b/src/tests/test_reader.py new file mode 100644 index 0000000..a8bf468 --- /dev/null +++ b/src/tests/test_reader.py @@ -0,0 +1,82 @@ +# Copyright (c) Ely Deckers. +# +# This source code is licensed under the MPL-2.0 license found in the +# LICENSE file in the root directory of this source tree. + + +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +import unittest +from dataclasses import dataclass +from random import randint + +from pyella.reader import Reader, ask, bind, pure +from pyella.shared import _const + + +@dataclass(frozen=True) +class Environment: + some_value: int + + +class TestReader(unittest.TestCase): + def test_reader_objects_compare_correctly(self): + # arrange + def times2(some_value: int): + return some_value * 2 + + def times3(some_value: int): + return some_value * 3 + + r0_2 = Reader(times2) + r1_2 = Reader(times2) + + r2_3 = Reader(times3) + + # act + # assert + self.assertEqual( + r0_2, + r1_2, + "The given Reader-objects contain the same function, so they should be considered Equal", + ) + self.assertNotEqual( + r0_2, + r2_3, + # pylint: disable=line-too-long + "The given Reader-objects contain different functions, so they should be considered Not Equal", + ) + + def test_reader_reads_and_applies_environment_data_correctly(self): + # arrange + + random_value = randint(0, 100) # nosec # B311 random in test is safe + + environment = Environment(random_value) + + def times2(env: Environment) -> int: + return env.some_value * 2 + + def multiply_env_by_2() -> Reader[Environment, int]: + return bind(ask(), lambda env: pure(_const(times2(env)))) + + def env_multiplication_as_tuple() -> Reader[Environment, tuple[str, int]]: + return multiply_env_by_2().fmap(lambda x: (str(x), x)) + + # act + result_str, result_int = env_multiplication_as_tuple().run_reader(environment) + + # assert + expected_int = 2 * random_value + + self.assertEqual( + str(expected_int), + result_str, + "The resulting string position of the tuple doesn't match the expected integer", + ) + self.assertEqual( + expected_int, + result_int, + "The resulting integer position of the tuple doesn't match the expected integer", + )