diff --git a/tests/test_base.py b/tests/test_base.py index d6c3725..95e1a71 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- +import sys import unittest import trafaret as t +import trafaret.utils as tu from collections import Mapping as AbcMapping from trafaret import extract_error, ignore, DataError from trafaret.extras import KeysSubset @@ -264,6 +266,26 @@ def test_enum(self): res = extract_error(trafaret, 2) self.assertEqual(res, "value doesn't match any variant") + @unittest.skipIf(sys.version_info < (3, 4), + "not supported in this veresion" + ) + def test_enum_py3(self): + import enum + + class Colors(enum.Enum): + red = 0 + green = 1 + blue = 2 + + trafaret = t.Enum.from_enum(Colors) + self.assertEqual(repr(trafaret), "") + res = trafaret.check('red') + res = trafaret.check('green') + res = extract_error(trafaret, 'unknown') + self.assertEqual(res, "value doesn't match any variant") + + with self.assertRaises(TypeError): + trafaret = t.Enum.from_enum('not a enum instance') class TestFloat(unittest.TestCase): diff --git a/trafaret/__init__.py b/trafaret/__init__.py index f0d820e..5e901e0 100644 --- a/trafaret/__init__.py +++ b/trafaret/__init__.py @@ -15,11 +15,14 @@ __VERSION__ = (0, 10, 2) +enum = None + # Python3 support if py3: import urllib.parse as urlparse str_types = (str, bytes) unicode = str + else: try: from future_builtins import map @@ -1354,6 +1357,15 @@ def check_value(self, value): def __repr__(self): return "" % (", ".join(map(repr, self.variants))) + @classmethod + def from_enum(cls, variant): + import enum + + if not isinstance(variant, enum.EnumMeta): + raise TypeError("Expect 'enum.EnumMeta' got %r" % type(variant)) + + return cls(*(x.name for x in variant)) + class Callable(Trafaret): """