From 5e8b5c4d7e4304ba6858b2c87e83389b7a490a24 Mon Sep 17 00:00:00 2001
From: Zach Kurtz <zkurtz@protonmail.com>
Date: Thu, 28 Nov 2024 22:25:52 -0500
Subject: [PATCH] use runtime module protocol checks

---
 dummio/__init__.py                   |  1 -
 dummio/constants.py                  | 20 +------------------
 tests/test_assert_module_protocol.py | 30 +++++++++++++++++++++++++---
 3 files changed, 28 insertions(+), 23 deletions(-)

diff --git a/dummio/__init__.py b/dummio/__init__.py
index 1e2db21..35c97a2 100644
--- a/dummio/__init__.py
+++ b/dummio/__init__.py
@@ -5,7 +5,6 @@
 
 from dummio import json as json
 from dummio import text as text
-from dummio.constants import ModuleProtocol as ModuleProtocol
 
 try:
     from dummio import yaml as yaml
diff --git a/dummio/constants.py b/dummio/constants.py
index 2ebfd74..1323831 100644
--- a/dummio/constants.py
+++ b/dummio/constants.py
@@ -1,28 +1,10 @@
 """Constants for dummio."""
 
-import typing
 from pathlib import Path
-from typing import Any, Protocol, TypeAlias
+from typing import Any, TypeAlias
 
 PathType: TypeAlias = str | Path
 AnyDict: TypeAlias = dict[Any, Any]
 
-# pyright expect a type var to be used at least twice within a single method. It's having
-# trouble respecting how it's use *accross* methods of a class.
-T = typing.TypeVar("T")  # pyright: ignore
-
 DEFAULT_ENCODING = "utf-8"
 DEFAULT_WRITE_MODE = "w"
-
-
-@typing.runtime_checkable
-class ModuleProtocol(Protocol):
-    """Protocol for dummio's IO modules."""
-
-    def save(self, data: T, *, filepath: PathType) -> None:  # pyright: ignore[reportInvalidTypeVarUse]
-        """Declares the signature of an IO module save method."""
-        ...
-
-    def load(self, filepath: PathType) -> T:  # pyright: ignore[reportInvalidTypeVarUse]
-        """Declares the signature of an IO module load method."""
-        ...
diff --git a/tests/test_assert_module_protocol.py b/tests/test_assert_module_protocol.py
index c54c6e7..4a102b8 100644
--- a/tests/test_assert_module_protocol.py
+++ b/tests/test_assert_module_protocol.py
@@ -1,8 +1,9 @@
-"""Assert that every IO module implements the ModuleProtocol."""
+"""Assert that every IO module implements save and load in a consistent way."""
 
 import importlib
+from typing import Callable, get_type_hints
 
-from dummio import ModuleProtocol
+from dummio.constants import PathType
 
 IO_MODULES = [
     "dummio.json",
@@ -17,4 +18,27 @@
 def test_assert_module_protocol() -> None:
     for module_name in IO_MODULES:
         module = importlib.import_module(module_name)
-        assert isinstance(module, ModuleProtocol)
+        assert hasattr(module, "save")
+        assert hasattr(module, "load")
+
+        # make the following assertions about the save attribute:
+        # - it is a function
+        # - the first argument is named "data"
+        # - all subsequent arguments are keyword-only
+        # - the second argument is "filepath" of type dummio.constants.PathType
+        assert isinstance(module.save, Callable)
+        signature = get_type_hints(module.save)
+        first_two_args = list(signature.keys())[:2]
+        assert first_two_args == ["data", "filepath"]
+        assert signature["filepath"] == PathType
+
+        # make the following assertions about the load attribute:
+        # - it is a function
+        # - the first argument is named "filepath", of type dummio.constants.PathType
+        # - the return type is the same as the "data" argument of the save function
+        assert isinstance(module.load, Callable)
+        signature = get_type_hints(module.load)
+        first_arg = list(signature.keys())[0]
+        assert first_arg == "filepath"
+        assert signature["filepath"] == PathType
+        assert signature["return"] == get_type_hints(module.save)["data"]