Skip to content

Commit

Permalink
trafaret/Enum: support Python >=3.4 enum module
Browse files Browse the repository at this point in the history
  • Loading branch information
spumer committed Sep 28, 2017
1 parent b38366d commit d88091a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
import sys
import unittest
import trafaret as t
import trafaret.utils as tu
from collections import Mapping as AbcMapping
from trafaret import extract_error, ignore, DataError
from trafaret.extras import KeysSubset
Expand Down Expand Up @@ -264,6 +266,26 @@ def test_enum(self):
res = extract_error(trafaret, 2)
self.assertEqual(res, "value doesn't match any variant")

@unittest.skipIf(sys.version_info < (3, 4),
"not supported in this veresion"
)
def test_enum_py3(self):
import enum

class Colors(enum.Enum):
red = 0
green = 1
blue = 2

trafaret = t.Enum.from_enum(Colors)
self.assertEqual(repr(trafaret), "<Enum('red', 'green', 'blue')>")
res = trafaret.check('red')
res = trafaret.check('green')
res = extract_error(trafaret, 'unknown')
self.assertEqual(res, "value doesn't match any variant")

with self.assertRaises(TypeError):
trafaret = t.Enum.from_enum('not a enum instance')


class TestFloat(unittest.TestCase):
Expand Down
12 changes: 12 additions & 0 deletions trafaret/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
__VERSION__ = (0, 10, 2)


enum = None

# Python3 support
if py3:
import urllib.parse as urlparse
str_types = (str, bytes)
unicode = str

else:
try:
from future_builtins import map
Expand Down Expand Up @@ -1354,6 +1357,15 @@ def check_value(self, value):
def __repr__(self):
return "<Enum(%s)>" % (", ".join(map(repr, self.variants)))

@classmethod
def from_enum(cls, variant):
import enum

if not isinstance(variant, enum.EnumMeta):
raise TypeError("Expect 'enum.EnumMeta' got %r" % type(variant))

return cls(*(x.name for x in variant))


class Callable(Trafaret):
"""
Expand Down

0 comments on commit d88091a

Please sign in to comment.