From b33989eb95588eeb054c72eff07c6a16e0dacc51 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Thu, 12 Oct 2017 14:14:17 -0500 Subject: [PATCH] nullable handling through if statements --- arrowed/compiler.py | 24 +++++++++++++++++++----- arrowed/version.py | 2 +- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/arrowed/compiler.py b/arrowed/compiler.py index 0fa6077..7f100dd 100644 --- a/arrowed/compiler.py +++ b/arrowed/compiler.py @@ -376,10 +376,13 @@ def __init__(self, possibilities, parameter): self.nullable = False def setnullable(self, value=True): - out = ArrowedType(self.possibilities, self.parameter) - out.isparameter = self.isparameter - out.nullable = value - return out + if self is nullable or self is untracked: + return self + else: + out = ArrowedType(self.possibilities, self.parameter) + out.isparameter = self.isparameter + out.nullable = value + return out def generate(self, handler): out = None @@ -654,7 +657,14 @@ def do_Attribute(node, symtable, externalfcns, env, sym, sourcefile, recurse): def handler(schema): if isinstance(schema, Record): if node.attr in schema.contents: - return retyped(node.value, ArrowedType(schema.contents[node.attr], node.value.atype.parameter)) + if node.value.atype.nullable: # FIXME: or schema.nullable + value = toexpr("REFUSENONE(VALUE)", + REFUSENONE = toname(newrefusenone(env, sym, "record" if schema.name is None else schema.name, node.value.lineno, sourcefile)), + VALUE = node.value) + else: + value = node.value + return retyped(value, ArrowedType(schema.contents[node.attr], node.value.atype.parameter)) + elif schema.name is None: raise AttributeError("attribute {0} not found in record with structure:\n\n{1}\n\nat line {lineno} of {sourcefile}".format( repr(node.attr), schema.format(" "), lineno=node.lineno, sourcefile=sourcefile)) @@ -814,6 +824,10 @@ def handler(schema): # IfExp ("test", "body", "orelse") # If ("test", "body", "orelse") +def do_If(node, symtable, externalfcns, env, sym, sourcefile, recurse): + body = recurse(node.body) + orelse = recurse(node.orelse) + return rebuilt(node, recurse(node.test), body, orelse) # ImportFrom ("module", "names", "level") diff --git a/arrowed/version.py b/arrowed/version.py index b7fed3f..9861e60 100644 --- a/arrowed/version.py +++ b/arrowed/version.py @@ -16,7 +16,7 @@ import re -__version__ = "0.0.11" +__version__ = "0.0.12" version = __version__ version_info = tuple(re.split(r"[-\.]", __version__))