Skip to content

Commit

Permalink
feat: add Reader monad
Browse files Browse the repository at this point in the history
  • Loading branch information
edeckers committed Mar 16, 2024
1 parent 4edbdc1 commit 8712c53
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/pyella/either.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def pure(
value: TB_co, # type: ignore [misc] # covariant arg ok, b/c function is pure
) -> Either[TA_co, TB_co]: # pylint: disable=invalid-name
"""
Alias for :py:func:`pure(self) <pure>`
Alias for :py:func:`pure(value) <pure>`
"""
return pure(value)

Expand Down
160 changes: 160 additions & 0 deletions src/pyella/reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) Ely Deckers.
#
# This source code is licensed under the MPL-2.0 license found in the
# LICENSE file in the root directory of this source tree.

"""
Reader - Contains the :py:class:`Reader[TE, TA] <Reader>` type and related
functions. It was strongly inspired by the Haskell ``Control.Monad.Reader`` module.
More information on the Haskell ``Control.Monad.Reader`` type can be found here:
https://hackage.haskell.org/package/mtl-2.3.1/docs/Control-Monad-Reader.html
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Generic, TypeVar

from pyella.shared import _const

TE_co = TypeVar("TE_co", covariant=True) # pylint: disable=invalid-name
TA_co = TypeVar("TA_co", covariant=True) # pylint: disable=invalid-name
TC_co = TypeVar("TC_co", covariant=True) # pylint: disable=invalid-name


@dataclass(frozen=True)
class Reader(Generic[TE_co, TA_co]): # pylint: disable=too-few-public-methods
"""
Represents a computation, which can read values from a shared environment, pass values from function to function,
and execute sub-computations in a modified environment.
"""

run_reader: Callable[[TE_co], TA_co]

def bind(
self, apply: Callable[[TA_co], Reader[TE_co, TC_co]]
) -> Reader[TE_co, TC_co]:
"""
Alias for :py:func:`bind(self, apply) <bind>`
"""
return bind(self, apply)

def chain(self, em1: Reader[TE_co, TC_co]) -> Reader[TE_co, TC_co]:
"""
Alias for :py:func:`chain(self, em1) <chain>`
"""
return chain(self, em1)

def discard(
self, apply: Callable[[TA_co], Reader[TE_co, TA_co]]
) -> Reader[TE_co, TA_co]:
"""
Alias for :py:func:`discard(self, apply) <discard>`
"""
return discard(self, apply)

def fmap(self, apply: Callable[[TA_co], TC_co]) -> Reader[TE_co, TC_co]:
"""
Alias for :py:func:`fmap(self, apply) <fmap>`
"""
return fmap(self, apply)

@staticmethod
def pure(
run_reader: Callable[[TE_co], TA_co], # type: ignore [misc] # covariant arg ok, b/c function is pure
) -> Reader[TE_co, TA_co]: # pylint: disable=invalid-name
"""
Alias for :py:func:`pure(run_reader) <pure>`
"""
return pure(run_reader)

def __eq__(self, __o: object) -> bool:
if not isinstance(__o, Reader):
return False

return __o.run_reader == self.run_reader


def ask() -> Reader[TE_co, TE_co]: # type: ignore [misc] # covariant arg ok, b/c function is pure
"""
Retrieves the monad environment.
.. note:: Haskell: `ask <https://hackage.haskell.org/package/mtl-2.3.1/docs/Control-Monad-Reader.html#v:ask>`__
"""
return Reader(lambda v: v)


def bind(
em0: Reader[TE_co, TA_co], apply: Callable[[TA_co], Reader[TE_co, TC_co]]
) -> Reader[TE_co, TC_co]:
"""
Passes the inherited environment to both subcomputations, first the original `Reader [TE, TA]` and
then the result of the function `TA -> Reader [TE, TC]`
.. note:: Haskell: `>>= <https://hackage.haskell.org/package/mtl-2.3.1/docs/Control-Monad-Reader.html#v:-62--62--61->`__
"""

def _bind(environment: TE_co) -> TC_co: # type: ignore [misc] # covariant arg ok, b/c function is pure
environment_1 = em0.run_reader(environment)

return apply(environment_1).run_reader(environment)

return Reader(_bind)


def chain(
em0: Reader[TC_co, TE_co], # type: ignore [misc] # covariant arg ok, b/c function is pure
em1: Reader[TC_co, TA_co],
) -> Reader[TC_co, TA_co]:
"""
Discard the current value of a :py:class:`Reader[TE,TA] <Reader>` and replace it with the given :py:class:`Reader[TC, TB] <Reader>`
.. note:: Haskell: `>> <https://hackage.haskell.org/package/mtl-2.3.1/docs/Control-Monad-Reader.html#v:-62--62->`__
"""
return bind(em0, lambda _: em1)


def discard(
em0: Reader[TC_co, TE_co], apply: Callable[[TE_co], Reader[TC_co, TA_co]]
) -> Reader[TC_co, TE_co]:
"""
Apply the given function to the value of a :py:class:`Reader[TE,TA] <Reader>` and discard the result
"""
return em0.bind(apply).chain(em0)


def replace(
self, value: TC_co # type: ignore [misc] # covariant arg ok, b/c function is pure
) -> Reader[TE_co, TC_co]:
"""
Replace the value of an :py:class:`Reader` with a new value
:return: a :py:class:`Reader[TE,TA] <Reader>` with provided value
.. note:: Haskell: `<$ <https://hackage.haskell.org/package/mtl-2.3.1/docs/Control-Monad-Reader.html#v:-60--36->`__
"""
return fmap(self, _const(value))


def fmap(
em0: Reader[TE_co, TA_co], # type: ignore [misc] # covariant arg ok, b/c function is pure
apply: Callable[[TA_co], TC_co],
) -> Reader[TE_co, TC_co]: # type: ignore [misc] # covariant arg ok, b/c function is pure
"""
Map a function over the value of a :py:class:`Reader[TE,TA] <Reader>`
.. note:: Haskell: `fmap <https://hackage.haskell.org/package/mtl-2.3.1/docs/Control-Monad-Reader.html>`__
"""
return bind(em0, lambda m0: pure(lambda _: apply(m0)))


def pure(
run_reader: Callable[[TE_co], TA_co], # type: ignore [misc] # covariant arg ok, b/c function is pure
) -> Reader[TE_co, TA_co]:
"""
Create a :py:class:`Reader[TE,TA] <Reader>` from a function
.. note:: Haskell: `pure <https://hackage.haskell.org/package/mtl-2.3.1/docs/Control-Monad-Reader.html#v:pure>`__
"""
return Reader(run_reader)
82 changes: 82 additions & 0 deletions src/tests/test_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Ely Deckers.
#
# This source code is licensed under the MPL-2.0 license found in the
# LICENSE file in the root directory of this source tree.


# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
import unittest
from dataclasses import dataclass
from random import randint

from pyella.reader import Reader, ask, bind, pure
from pyella.shared import _const


@dataclass(frozen=True)
class Environment:
some_value: int


class TestReader(unittest.TestCase):
def test_reader_objects_compare_correctly(self):
# arrange
def times2(some_value: int):
return some_value * 2

def times3(some_value: int):
return some_value * 3

r0_2 = Reader(times2)
r1_2 = Reader(times2)

r2_3 = Reader(times3)

# act
# assert
self.assertEqual(
r0_2,
r1_2,
"The given Reader-objects contain the same function, so they should be considered Equal",
)
self.assertNotEqual(
r0_2,
r2_3,
# pylint: disable=line-too-long
"The given Reader-objects contain different functions, so they should be considered Not Equal",
)

def test_reader_reads_and_applies_environment_data_correctly(self):
# arrange

random_value = randint(0, 100) # nosec # B311 random in test is safe

environment = Environment(random_value)

def times2(env: Environment) -> int:
return env.some_value * 2

def multiply_env_by_2() -> Reader[Environment, int]:
return bind(ask(), lambda env: pure(_const(times2(env))))

def env_multiplication_as_tuple() -> Reader[Environment, tuple[str, int]]:
return multiply_env_by_2().fmap(lambda x: (str(x), x))

# act
result_str, result_int = env_multiplication_as_tuple().run_reader(environment)

# assert
expected_int = 2 * random_value

self.assertEqual(
str(expected_int),
result_str,
"The resulting string position of the tuple doesn't match the expected integer",
)
self.assertEqual(
expected_int,
result_int,
"The resulting integer position of the tuple doesn't match the expected integer",
)

0 comments on commit 8712c53

Please sign in to comment.