diff --git a/.pylintrc b/.pylintrc index d913d9d0c..89f503418 100644 --- a/.pylintrc +++ b/.pylintrc @@ -3,7 +3,7 @@ variable-rgx=[a-z0-9_]{1,30}$ good-names=log -disable=invalid-name,unused-argument,too-few-public-methods,no-self-use,missing-docstring,logging-format-interpolation,too-many-instance-attributes,duplicate-code,too-many-public-methods,too-many-arguments,protected-access,too-many-lines +disable=invalid-name,unused-argument,too-few-public-methods,no-self-use,missing-docstring,logging-format-interpolation,too-many-instance-attributes,duplicate-code,too-many-public-methods,too-many-arguments,protected-access,too-many-lines,unspecified-encoding,consider-using-dict-items,consider-using-with [FORMAT] max-line-length=119 diff --git a/pysparkling/sql/ast/README.md b/pysparkling/sql/ast/README.md new file mode 100644 index 000000000..87745e9c7 --- /dev/null +++ b/pysparkling/sql/ast/README.md @@ -0,0 +1,26 @@ +# Python Abstract Syntax Tree for Spark SQL + +This folder uses ANTLR4 to convert a SQL statement in an Abstract Syntax Tree in Python. + +This AST is then transformed in the corresponding pysparkling abtrasaction. + +## Example + + +## Recreate generated files + +First, download the ANTLR complete JAR from [the ANTLR site][antlr]. + +[antlr]:http://www.antlr.org/ + +Next, install the required dev ANTLR4 Python 3 runtime package: + +``` +pip install antlr4-python3-runtime +``` + +Then, run ANTLR to compile the SQL grammar and generate Python code. + +``` +java -Xmx500M -cp ":$CLASSPATH" org.antlr.v4.Tool -Dlanguage=Python3 SqlBase.g4 +``` diff --git a/pysparkling/sql/ast/__init__.py b/pysparkling/sql/ast/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pysparkling/sql/ast/ast_to_python.py b/pysparkling/sql/ast/ast_to_python.py new file mode 100644 index 000000000..5a395bf1f --- /dev/null +++ b/pysparkling/sql/ast/ast_to_python.py @@ -0,0 +1,637 @@ +import ast +import logging + +from sqlparser import string_to_ast +from sqlparser.internalparser import SqlParsingError + +from ...sql import functions +from ..column import Column, parse +from ..expressions.expressions import expression_registry +from ..expressions.literals import Literal +from ..expressions.mappers import Concat, CreateStruct +from ..expressions.operators import ( + Add, Alias, And, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, Cast, Divide, Equal, GreaterThan, + GreaterThanOrEqual, Invert, LessThan, LessThanOrEqual, Minus, Mod, Negate, Or, Time, UnaryPositive +) +from ..types import DoubleType, parsed_string_to_type, StringType, StructField, StructType + + +class UnsupportedStatement(SqlParsingError): + pass + + +def check_children(expected, children): + if len(children) != expected: + raise SqlParsingError( + "Expecting {0} children, got {1}: {2}".format( + expected, len(children), children + ) + ) + + +def unwrap(*children): + check_children(1, children) + return convert_tree(children[0]) + + +def never_found(*children): + logging.warning("We should never have encounter this node.") + return unwrap(*children) + + +def unsupported(*children): + if children: + parent_context = children[0].parentCtx.__class__.__name__ + else: + parent_context = 'unknown' + raise UnsupportedStatement(parent_context) + + +def empty(*children): + check_children(0, children) + + +def first_child_only(*children): + return convert_tree(children[0]) + + +def child_and_eof(*children): + check_children(2, children) + return convert_tree(children[0]) + + +def convert_tree(tree): + tree_type = tree.__class__.__name__ + if not hasattr(tree, "children"): + return get_leaf_value(tree) + try: + converter = CONVERTERS[tree_type] + except UnsupportedStatement: + raise SqlParsingError("Unsupported statement {0}".format(tree_type)) from None + children = tree.children or () + for c in children: + if c.__class__.__name__ == 'ErrorNodeImpl': + raise SqlParsingError(f'Unable to parse data type, unexpected {c.symbol}') + return converter(*children) + + +def call_function(*children): + raw_function_name = convert_tree(children[0]) + function_name = next( + (name for name in functions.__all__ if name.lower() == raw_function_name.lower()), + None + ) + function_expression = expression_registry.get(function_name.lower()) + params = [convert_tree(c) for c in children[2:-1]] + + complex_function = any( + not isinstance(param, Column) and param == ')' + for param in params + ) + if not complex_function: + last_argument_position = None + # filter_clause = None + # over_clause = None + # set_clause = None + else: + # pylint: disable=fixme + # todo: Handle complex functions + last_argument_position = params.index(")") + # filter_clause = ... + # over_clause = ... + # set_clause = ... + + # parameters are comma separated + function_arguments = params[0:last_argument_position:2] + return function_expression(*function_arguments) + + +def binary_operation(*children): + check_children(3, children) + left, operator, right = children + cls = binary_operations[convert_tree(operator).upper()] + return cls( + convert_tree(left), + convert_tree(right) + ) + + +def cast_context(*children): + """ + Children are: + CAST '(' expression AS dataType ')' + + """ + check_children(6, children) + expression = convert_tree(children[2]) + data_type = convert_tree(children[4]) + return parse(expression).cast(data_type) + + +def detect_data_type(*children): + data_type = convert_tree(children[0]) + params = [convert_tree(c) for c in children[2:-1:2]] + return parsed_string_to_type(data_type, params) + + +def unary_operation(*children): + check_children(2, children) + operator, value = children + cls = unary_operations[convert_tree(operator).upper()] + return cls( + convert_tree(value) + ) + + +def parenthesis_context(*children): + check_children(3, children) + return convert_tree(children[1]) + + +def convert_boolean(*children): + check_children(1, children) + value = convert_tree(children[0]) + if value.lower() == "false": + return False + if value.lower() == "true": + return True + raise SqlParsingError("Expecting boolean value, got {0}".format(children)) + + +def convert_to_literal(*children): + check_children(1, children) + value = convert_tree(children[0]) + return Literal(value) + + +def convert_column(*children): + check_children(1, children) + value = convert_tree(children[0]) + return Column(value) + + +def convert_field(*children): + name = convert_tree(children[0]) + data_type = convert_tree(children[1]) + return StructField(name, data_type) + + +def convert_table_schema(*children): + check_children(2, children) + return StructType(fields=list(convert_tree(children[0]))) + + +def convert_to_null(*children): + return None + + +def get_leaf_value(*children): + check_children(1, children) + value = children[0] + if value.__class__.__name__ != "TerminalNodeImpl": + raise SqlParsingError("Expecting TerminalNodeImpl, got {0}".format(type(value).__name__)) + if not hasattr(value, "symbol"): + raise SqlParsingError("Got leaf value but without symbol") + return value.symbol.text + + +def remove_delimiter(*children): + delimited_value = get_leaf_value(*children) + return delimited_value[1:-1] + + +def explicit_list(*children): + return tuple( + convert_tree(c) + for c in children[1:-1:2] + ) + + +def convert_to_complex_col_type(*children): + name = convert_tree(children[0]) + data_type = convert_tree(children[2]) + if len(children) > 3: + params = [convert_tree(c) for c in children[3:]] + if len(params) >= 2 and isinstance(params[0], str) and isinstance(params[1], str): + nullable = (params[0].lower(), params[1].lower()) != ('not', 'null') + raw_metadata = params[2:] + else: + nullable = True + raw_metadata = params + metadata = {k.lower(): v for k, v in raw_metadata} or None + else: + nullable = True + metadata = None + return name, data_type, nullable, metadata + + +def implicit_list(*children): + return tuple( + convert_tree(c) + for c in children[::2] + ) + + +def concat_to_value(*children): + return ast.literal_eval("".join(convert_tree(c) for c in children)) + + +def concat_keywords(*children): + return " ".join(convert_tree(c) for c in children) + + +def concat_strings(*children): + return "".join(convert_tree(c) for c in children) + + +def build_struct(*children): + return CreateStruct(*(convert_tree(c) for c in children[2:-1:2])) + + +def potential_alias(*children): + if len(children) == 1: + return convert_tree(children[0]) + if len(children) in (2, 3): + return Alias( + convert_tree(children[0]), + convert_tree(children[-1]) + ) + raise SqlParsingError("Expecting 1, 2 or 3 children, got {0}".format(len(children))) + + +def get_quoted_identifier(*children): + value = get_leaf_value(*children) + return value[1:-1] + + +def check_identifier(*children): + check_children(2, children) + if children[0].children is None: + raise SqlParsingError("Expected identifier not found") from children[0].exception + identifier = convert_tree(children[0]) + if children[1].children: + extra = convert_tree(children[1]) + raise SqlParsingError( + "Possibly unquoted identifier {0}{1} detected. " + "Please consider quoting it with back-quotes as `{0}{1}`".format(identifier, extra) + ) + return identifier + + +def convert_comment(*children): + check_children(2, children) + return "COMMENT", convert_tree(children[1])[1:-1] + + +CONVERTERS = { + 'AddTableColumnsContext': unsupported, + 'AddTablePartitionContext': unsupported, + 'AggregationClauseContext': unsupported, + 'AliasedQueryContext': unsupported, + 'AliasedRelationContext': unsupported, + 'AlterColumnActionContext': unsupported, + 'AlterTableAlterColumnContext': unsupported, + 'AlterViewQueryContext': unsupported, + 'AnalyzeContext': unsupported, + 'AnsiNonReservedContext': get_leaf_value, + 'ApplyTransformContext': unsupported, + 'ArithmeticBinaryContext': binary_operation, + 'ArithmeticOperatorContext': get_leaf_value, + 'ArithmeticUnaryContext': unary_operation, + 'AssignmentContext': unsupported, + 'AssignmentListContext': implicit_list, + 'BigDecimalLiteralContext': concat_to_value, + 'BigIntLiteralContext': concat_to_value, + 'BooleanExpressionContext': never_found, + 'BooleanLiteralContext': unwrap, + 'BooleanValueContext': convert_boolean, + 'BucketSpecContext': unsupported, + 'CacheTableContext': unsupported, + 'CastContext': cast_context, + 'ClearCacheContext': unsupported, + 'ColPositionContext': unsupported, + 'ColTypeContext': convert_field, + 'ColTypeListContext': implicit_list, + 'ColumnReferenceContext': convert_column, + 'CommentNamespaceContext': unsupported, + 'CommentSpecContext': convert_comment, + 'CommentTableContext': unsupported, + 'ComparisonContext': binary_operation, + 'ComparisonOperatorContext': get_leaf_value, + 'ComplexColTypeContext': convert_to_complex_col_type, + 'ComplexColTypeListContext': implicit_list, + 'ComplexDataTypeContext': detect_data_type, + 'ConstantContext': never_found, + 'ConstantDefaultContext': convert_to_literal, + 'ConstantListContext': explicit_list, + 'CreateFileFormatContext': unsupported, + 'CreateFunctionContext': unsupported, + 'CreateHiveTableContext': unsupported, + 'CreateNamespaceContext': unsupported, + 'CreateTableClausesContext': unsupported, + 'CreateTableContext': unsupported, + 'CreateTableHeaderContext': unsupported, + 'CreateTableLikeContext': unsupported, + 'CreateTempViewUsingContext': unsupported, + 'CreateViewContext': unsupported, + 'CtesContext': unsupported, + 'CurrentDatetimeContext': unsupported, + 'DataTypeContext': empty, + 'DecimalLiteralContext': concat_to_value, + 'DeleteFromTableContext': unsupported, + 'DereferenceContext': unsupported, + 'DescribeColNameContext': unsupported, + 'DescribeFuncNameContext': unwrap, + 'DescribeFunctionContext': unsupported, + 'DescribeNamespaceContext': unsupported, + 'DescribeQueryContext': unsupported, + 'DescribeRelationContext': unsupported, + 'DmlStatementContext': unsupported, + 'DmlStatementNoWithContext': never_found, + 'DoubleLiteralContext': concat_to_value, + 'DropFunctionContext': unsupported, + 'DropNamespaceContext': unsupported, + 'DropTableColumnsContext': unsupported, + 'DropTableContext': unsupported, + 'DropTablePartitionsContext': unsupported, + 'DropViewContext': unsupported, + 'ErrorCapturingIdentifierContext': check_identifier, + 'ErrorCapturingIdentifierExtraContext': never_found, + 'ErrorCapturingMultiUnitsIntervalContext': unsupported, + 'ErrorCapturingUnitToUnitIntervalContext': unsupported, + 'ErrorIdentContext': concat_strings, + 'ExistsContext': unsupported, + 'ExplainContext': unsupported, + 'ExponentLiteralContext': concat_to_value, + 'ExpressionContext': unwrap, + 'ExtractContext': unsupported, + 'FailNativeCommandContext': unsupported, + 'FileFormatContext': never_found, + 'FirstContext': unsupported, + 'FrameBoundContext': unsupported, + 'FromClauseContext': unsupported, + 'FromStatementBodyContext': unsupported, + 'FromStatementContext': unsupported, + 'FromStmtContext': unwrap, + 'FunctionCallContext': call_function, + 'FunctionIdentifierContext': unsupported, + 'FunctionNameContext': unwrap, + 'FunctionTableContext': unsupported, + 'GenericFileFormatContext': unwrap, + 'GroupingSetContext': unsupported, + 'HavingClauseContext': unsupported, + 'HintContext': unsupported, + 'HintStatementContext': unsupported, + 'HiveChangeColumnContext': unsupported, + 'HiveReplaceColumnsContext': unsupported, + 'IdentifierCommentContext': unsupported, + 'IdentifierCommentListContext': explicit_list, + 'IdentifierContext': unwrap, + 'IdentifierListContext': explicit_list, + 'IdentifierSeqContext': unsupported, + 'IdentityTransformContext': unwrap, + 'InlineTableContext': unsupported, + 'InlineTableDefault1Context': unwrap, + 'InlineTableDefault2Context': unwrap, + 'InsertIntoContext': never_found, + 'InsertIntoTableContext': unsupported, + 'InsertOverwriteDirContext': unsupported, + 'InsertOverwriteHiveDirContext': unsupported, + 'InsertOverwriteTableContext': unsupported, + 'IntegerLiteralContext': concat_to_value, + 'IntervalContext': unsupported, + 'IntervalLiteralContext': unwrap, + 'IntervalUnitContext': unwrap, + 'IntervalValueContext': unsupported, + 'JoinCriteriaContext': unsupported, + 'JoinRelationContext': unsupported, + 'JoinTypeContext': concat_keywords, + 'LambdaContext': unsupported, + 'LastContext': unsupported, + 'LateralViewContext': unsupported, + 'LegacyDecimalLiteralContext': concat_to_value, + 'LoadDataContext': unsupported, + 'LocationSpecContext': unsupported, + 'LogicalBinaryContext': binary_operation, + 'LogicalNotContext': unary_operation, + 'ManageResourceContext': unsupported, + 'MatchedActionContext': unsupported, + 'MatchedClauseContext': unsupported, + 'MergeIntoTableContext': unsupported, + 'MultiInsertQueryBodyContext': unsupported, + 'MultiInsertQueryContext': unsupported, + 'MultipartIdentifierContext': unsupported, + 'MultipartIdentifierListContext': implicit_list, + 'MultiUnitsIntervalContext': unsupported, + 'NamedExpressionContext': potential_alias, + 'NamedExpressionSeqContext': unsupported, + 'NamedQueryContext': unsupported, + 'NamedWindowContext': unsupported, + 'NamespaceContext': get_leaf_value, + 'NestedConstantListContext': explicit_list, + 'NonReservedContext': get_leaf_value, + 'NotMatchedActionContext': unsupported, + 'NotMatchedClauseContext': unsupported, + 'NullLiteralContext': convert_to_null, + 'NumberContext': never_found, + 'NumericLiteralContext': unwrap, + 'OrderedIdentifierContext': unsupported, + 'OrderedIdentifierListContext': explicit_list, + 'OverlayContext': unsupported, + 'ParenthesizedExpressionContext': parenthesis_context, + 'PartitionSpecContext': unsupported, + 'PartitionSpecLocationContext': unsupported, + 'PartitionValContext': unsupported, + 'PivotClauseContext': unsupported, + 'PivotColumnContext': unsupported, + 'PivotValueContext': unsupported, + 'PositionContext': unsupported, + 'PredicateContext': unsupported, + 'PredicatedContext': unwrap, + 'PredicateOperatorContext': get_leaf_value, + 'PrimaryExpressionContext': never_found, + 'PrimitiveDataTypeContext': detect_data_type, + 'QualifiedColTypeWithPositionContext': unsupported, + 'QualifiedColTypeWithPositionListContext': implicit_list, + 'QualifiedNameContext': concat_strings, + 'QualifiedNameListContext': implicit_list, + 'QueryContext': unsupported, + 'QueryOrganizationContext': unsupported, + 'QueryPrimaryContext': never_found, + 'QueryPrimaryDefaultContext': unwrap, + 'QuerySpecificationContext': unsupported, + 'QueryTermContext': never_found, + 'QueryTermDefaultContext': unwrap, + 'QuotedIdentifierAlternativeContext': unwrap, + 'QuotedIdentifierContext': get_quoted_identifier, + 'RealIdentContext': empty, + 'RecoverPartitionsContext': unsupported, + 'RefreshResourceContext': unsupported, + 'RefreshTableContext': unsupported, + 'RegularQuerySpecificationContext': unsupported, + 'RelationContext': unsupported, + 'RelationPrimaryContext': never_found, + 'RenameTableColumnContext': unsupported, + 'RenameTableContext': unsupported, + 'RenameTablePartitionContext': unsupported, + 'RepairTableContext': unsupported, + 'ReplaceTableContext': unsupported, + 'ReplaceTableHeaderContext': unsupported, + 'ResetConfigurationContext': unwrap, + 'ResourceContext': unsupported, + 'RowConstructorContext': unsupported, + 'RowFormatContext': never_found, + 'RowFormatDelimitedContext': unsupported, + 'RowFormatSerdeContext': unsupported, + 'SampleByBucketContext': unsupported, + 'SampleByBytesContext': unsupported, + 'SampleByPercentileContext': unsupported, + 'SampleByRowsContext': unsupported, + 'SampleContext': unsupported, + 'SampleMethodContext': never_found, + 'SearchedCaseContext': unsupported, + 'SelectClauseContext': unsupported, + 'SetClauseContext': unsupported, + 'SetConfigurationContext': unsupported, + 'SetNamespaceLocationContext': unsupported, + 'SetNamespacePropertiesContext': unsupported, + 'SetOperationContext': unsupported, + 'SetQuantifierContext': get_leaf_value, + 'SetTableLocationContext': unsupported, + 'SetTablePropertiesContext': unsupported, + 'SetTableSerDeContext': unsupported, + 'ShowColumnsContext': unsupported, + 'ShowCreateTableContext': unsupported, + 'ShowCurrentNamespaceContext': unsupported, + 'ShowFunctionsContext': unsupported, + 'ShowNamespacesContext': unsupported, + 'ShowPartitionsContext': unsupported, + 'ShowTableContext': unsupported, + 'ShowTablesContext': unsupported, + 'ShowTblPropertiesContext': unsupported, + 'ShowViewsContext': unsupported, + 'SimpleCaseContext': unsupported, + 'SingleDataTypeContext': child_and_eof, + 'SingleExpressionContext': child_and_eof, + 'SingleFunctionIdentifierContext': child_and_eof, + 'SingleInsertQueryContext': unsupported, + 'SingleMultipartIdentifierContext': child_and_eof, + 'SingleStatementContext': first_child_only, + 'SingleTableIdentifierContext': child_and_eof, + 'SingleTableSchemaContext': convert_table_schema, + 'SkewSpecContext': unsupported, + 'SmallIntLiteralContext': concat_to_value, + 'SortItemContext': unsupported, + 'SqlBaseParser': unsupported, + 'StarContext': unsupported, + 'StatementContext': never_found, + 'StatementDefaultContext': unwrap, + 'StorageHandlerContext': unsupported, + 'StrictIdentifierContext': never_found, + 'StrictNonReservedContext': get_leaf_value, + 'StringLiteralContext': remove_delimiter, + 'StructContext': build_struct, + 'SubqueryContext': parenthesis_context, + 'SubqueryExpressionContext': parenthesis_context, + 'SubscriptContext': unsupported, + 'SubstringContext': unsupported, + 'TableAliasContext': unsupported, + 'TableContext': unsupported, + 'TableFileFormatContext': unsupported, + 'TableIdentifierContext': unsupported, + 'TableNameContext': unsupported, + 'TablePropertyContext': unsupported, + 'TablePropertyKeyContext': unsupported, + 'TablePropertyListContext': explicit_list, + 'TablePropertyValueContext': unwrap, + 'TableProviderContext': unsupported, + 'onContext': unwrap, + 'TerminalNodeImpl': get_leaf_value, + 'TinyIntLiteralContext': concat_to_value, + 'TransformArgumentContext': unwrap, + 'TransformClauseContext': unsupported, + 'TransformContext': never_found, + 'TransformListContext': explicit_list, + 'TransformQuerySpecificationContext': unsupported, + 'TrimContext': unsupported, + 'TruncateTableContext': unsupported, + 'TypeConstructorContext': unsupported, + 'UncacheTableContext': unsupported, + 'UnitToUnitIntervalContext': unsupported, + 'UnquotedIdentifierContext': unwrap, + 'UnsetTablePropertiesContext': unsupported, + 'UnsupportedHiveNativeCommandsContext': unsupported, + 'UpdateTableContext': unsupported, + 'UseContext': unsupported, + 'ValueExpressionContext': never_found, + 'ValueExpressionDefaultContext': unwrap, + 'WhenClauseContext': unsupported, + 'WhereClauseContext': unsupported, + 'WindowClauseContext': unsupported, + 'WindowDefContext': unsupported, + 'WindowFrameContext': unsupported, + 'WindowRefContext': unsupported, + 'WindowSpecContext': never_found, +} + +binary_operations = { + "=": Equal, + "==": Equal, + "<>": lambda *args: Invert(Equal(*args)), + "!=": lambda *args: Invert(Equal(*args)), + "<": LessThan, + "<=": LessThanOrEqual, + "!>": LessThanOrEqual, + ">": GreaterThan, + ">=": GreaterThanOrEqual, + "!<": GreaterThanOrEqual, + "+": Add, + "-": Minus, + '*': Time, + '/': lambda a, b: Divide(Cast(a, DoubleType), Cast(b, DoubleType)), + '%': Mod, + 'DIV': lambda a, b: Divide(Cast(a, DoubleType), Cast(b, DoubleType)), + '&': BitwiseAnd, + '|': BitwiseOr, + '||': lambda a, b: Concat([Cast(a, StringType), Cast(b, StringType)]), + '^': BitwiseXor, + 'AND': And, + 'OR': Or, +} + +unary_operations = { + "+": UnaryPositive, + "-": Negate, + "~": BitwiseNot, + 'NOT': Invert +} + + +def parse_sql(string, rule, debug=False): + tree = string_to_ast(string, rule, debug=debug) + return convert_tree(tree) + + +def parse_data_type(string, debug=False): + return parse_sql(string, "singleDataType", debug) + + +def parse_schema(string, debug=False): + return parse_sql(string, "singleTableSchema", debug) + + +def parse_expression(string, debug=False): + return parse_sql(string, "singleExpression", debug) + + +def parse_ddl_string(string, debug=False): + try: + # DDL format, "fieldname datatype, fieldname datatype". + return parse_schema(string, debug) + except SqlParsingError: + try: + # For backwards compatibility, "integer", "struct" and etc. + return parse_data_type(string, debug) + except SqlParsingError: + # For backwards compatibility, "fieldname: datatype, fieldname: datatype" case. + return parse_data_type(f"struct<{string.strip()}>") diff --git a/pysparkling/sql/ast/tests/__init__.py b/pysparkling/sql/ast/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pysparkling/sql/ast/tests/test_functions.py b/pysparkling/sql/ast/tests/test_functions.py new file mode 100644 index 000000000..f8792d158 --- /dev/null +++ b/pysparkling/sql/ast/tests/test_functions.py @@ -0,0 +1,43 @@ +from unittest import TestCase + +from parameterized import parameterized +from parameterized.parameterized import default_name_func + +from pysparkling import Row +from pysparkling.sql.ast.ast_to_python import parse_expression +from pysparkling.sql.types import IntegerType, StructField, StructType + +ROW = Row(a=1, b=2, c=3) +SCHEMA = StructType([ + StructField("a", IntegerType()), + StructField("b", IntegerType()), +]) + + +def format_test_name(func, num, p): + base_name = default_name_func(func, num, p) + if len(p.args) > 1 and isinstance(p.args[1], tuple) and isinstance(p.args[1][0], str): + return base_name + "_" + parameterized.to_safe_name(p.args[1][0]) + return base_name + + +class TestFunctions(TestCase): + SCENARIOS = { + 'Least(-1,0,1)': ('least', 'least(-1, 0, 1)', -1), + 'GREATEST(-1,0,1)': ('greatest', 'greatest(-1, 0, 1)', 1), + 'shiftRight ( 42, 1 )': ('shiftright', 'shiftright(42, 1)', 21), + 'ShiftLeft ( 42, 1 )': ('shiftleft', 'shiftleft(42, 1)', 84), + "concat_ws('/', a, b )": ('concat_ws', 'concat_ws(/, a, b)', "1/2"), + 'instr(a, a)': ('instr', 'instr(a, a)', 1), # rely on columns + 'instr(a, b)': ('instr', 'instr(a, b)', 0), # rely on columns + "instr('abc', 'c')": ('instr', 'instr(abc, c)', 3), # rely on lit + } + + @parameterized.expand(SCENARIOS.items(), name_func=format_test_name) + def test_functions(self, string, expected): + operator, expected_parsed, expected_result = expected + actual_parsed = parse_expression(string, True) + self.assertEqual(expected_parsed, str(actual_parsed)) + self.assertEqual(operator, actual_parsed.pretty_name) + actual_result = actual_parsed.eval(ROW, SCHEMA) + self.assertEqual(expected_result, actual_result) diff --git a/pysparkling/sql/ast/tests/test_operations.py b/pysparkling/sql/ast/tests/test_operations.py new file mode 100644 index 000000000..697e5931b --- /dev/null +++ b/pysparkling/sql/ast/tests/test_operations.py @@ -0,0 +1,66 @@ +import logging +from unittest import TestCase + +from parameterized import parameterized +from parameterized.parameterized import default_name_func + +from pysparkling import Row +from pysparkling.sql.ast.ast_to_python import parse_expression +from pysparkling.sql.types import StructType + +ROW = Row() +SCHEMA = StructType() + + +def format_test_name(func, num, p): + base_name = default_name_func(func, num, p) + if len(p.args) > 1 and isinstance(p.args[1], tuple) and isinstance(p.args[1][0], str): + return base_name + "_" + parameterized.to_safe_name(p.args[1][0]) + return base_name + + +class TestOperations(TestCase): + SCENARIOS = { + '60=60': ('EQ', '(60 = 60)', True), + '60=12': ('EQ', '(60 = 12)', False), + '60==12': ('EQ2', '(60 = 12)', False), + '12<>12': ('NEQ', '(NOT (12 = 12))', False), + '60<>12': ('NEQ', '(NOT (60 = 12))', True), + '60!=12': ('NEQ2', '(NOT (60 = 12))', True), + '60<12': ('LT', '(60 < 12)', False), + '60<=12': ('LTE', '(60 <= 12)', False), + '60!>12': ('LTE2', '(60 <= 12)', False), + '60>12': ('GT', '(60 > 12)', True), + '60>=12': ('GTE', '(60 >= 12)', True), + '60!<12': ('GTE2', '(60 >= 12)', True), + '60+12': ('PLUS', '(60 + 12)', 72), + '60-12': ('MINUS', '(60 - 12)', 48), + '60*12': ('TIMES', '(60 * 12)', 720), + # '60/12': ('DIVIDE', '(60 / 12)', None), + '60%12': ('MODULO', '(60 % 12)', 0), + # '60 div 12': ('DIV', '(60 DIV 12)', None), + '6&3': ('BITWISE_AND', '(6 & 3)', 2), + '6|3': ('BITWISE_OR', '(6 | 3)', 7), + # '60||12': ('CONCAT', '(60 || 12)', None), + '6^3': ('BITWISE_XOR', '(6 ^ 3)', 5), + 'true and false': ('LOGICAL_AND', '(true AND false)', False), + 'TRUE AND TRUE': ('LOGICAL_AND', '(true AND true)', True), + 'true AND null': ('LOGICAL_AND', '(true AND NULL)', None), + 'True or False': ('LOGICAL_OR', '(true OR false)', True), + 'false or false': ('LOGICAL_OR', '(false OR false)', False), + 'true or NULL': ('LOGICAL_OR', '(true OR NULL)', None), + "+1": ("UNARY_PLUS", '(+ 1)', 1), + "-(1)": ("UNARY_MINUS", '(- 1)', -1), + "~8": ("BITWISE_NOT", '~8', -9), + 'not true': ("NOT", '(NOT true)', False), + 'Not Null': ("NOT", '(NOT NULL)', None), + } + + @parameterized.expand(SCENARIOS.items(), name_func=format_test_name) + def test_operations(self, string, expected): + operator, expected_parsed, expected_result = expected + logging.debug("Testing %s", operator) + actual_parsed = parse_expression(string, True) + self.assertEqual(expected_parsed, str(actual_parsed)) + actual_result = actual_parsed.eval(Row(), SCHEMA) + self.assertEqual(expected_result, actual_result) diff --git a/pysparkling/sql/ast/tests/test_parse_ddl_string.py b/pysparkling/sql/ast/tests/test_parse_ddl_string.py new file mode 100644 index 000000000..ea709bffb --- /dev/null +++ b/pysparkling/sql/ast/tests/test_parse_ddl_string.py @@ -0,0 +1,111 @@ +from unittest import TestCase + +import pytest +from sqlparser.internalparser import SqlParsingError + +from pysparkling.sql.ast.ast_to_python import parse_ddl_string +from pysparkling.sql.types import ( + ArrayType, ByteType, DateType, DecimalType, DoubleType, IntegerType, LongType, MapType, ShortType, StringType, + StructField, StructType +) + + +class TestFunctions(TestCase): + def test_basic_entries(self): + schema = parse_ddl_string('some_str: string, some_int: integer, some_date: date not null') + assert schema == StructType([ + StructField('some_str', StringType(), True), + StructField('some_int', IntegerType(), True), + StructField('some_date', DateType(), False), + ]) + assert str(schema) == ( + 'StructType(List(' + 'StructField(some_str,StringType,true),' + 'StructField(some_int,IntegerType,true),' + 'StructField(some_date,DateType,false)' + '))' + ) + + def test_just_returning_the_type(self): + schema = parse_ddl_string('int') + assert schema == IntegerType() + + schema = parse_ddl_string('INT') + assert schema == IntegerType() + + def test_byte_decimal(self): + schema = parse_ddl_string("a: byte, b: decimal( 16 , 8 ) ") + assert schema == StructType([ + StructField('a', ByteType(), True), + StructField('b', DecimalType(16, 8), True), + ]) + assert str( + schema) == 'StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true)))' + + def test_double_string(self): + schema = parse_ddl_string("a DOUBLE, b STRING") + assert schema == StructType([ + StructField('a', DoubleType(), True), + StructField('b', StringType(), True), + ]) + assert str( + schema) == 'StructType(List(StructField(a,DoubleType,true),StructField(b,StringType,true)))' + + def test_array_short(self): + schema = parse_ddl_string("a: array< short>") + assert schema == StructType([ + StructField('a', ArrayType(ShortType()), True), + ]) + assert str(schema) == 'StructType(List(StructField(a,ArrayType(ShortType,true),true)))' + + def test_map(self): + schema = parse_ddl_string(" map ") + assert schema == MapType(StringType(), StringType(), True) + assert str(schema) == 'MapType(StringType,StringType,true)' + + def test_array(self): + schema = parse_ddl_string('some_str: string, arr: array') + assert schema == StructType([ + StructField('some_str', StringType(), True), + StructField('arr', ArrayType(StringType()), True), + ]) + assert str(schema) == 'StructType(List(' \ + 'StructField(some_str,StringType,true),' \ + 'StructField(arr,ArrayType(StringType,true),true)' \ + '))' + + def test_nested_array(self): + schema = parse_ddl_string('some_str: string, arr: array>') + assert schema == StructType([ + StructField('some_str', StringType(), True), + StructField('arr', ArrayType(ArrayType(StringType())), True), + ]) + assert str(schema) == 'StructType(List(' \ + 'StructField(some_str,StringType,true),' \ + 'StructField(arr,ArrayType(ArrayType(StringType,true),true),true)' \ + '))' + + def test_alias__long_bigint(self): + schema = parse_ddl_string('i1: long, i2: bigint') + assert schema == StructType([ + StructField('i1', LongType(), True), + StructField('i2', LongType(), True), + ]) + assert str(schema) == 'StructType(List(StructField(i1,LongType,true),StructField(i2,LongType,true)))' + + # Error cases + def test_wrong_type(self): + with pytest.raises(SqlParsingError): + parse_ddl_string("blabla") + + def test_comma_at_end(self): + with pytest.raises(SqlParsingError): + parse_ddl_string("a: int,") + + def test_unclosed_array(self): + with pytest.raises(SqlParsingError): + parse_ddl_string("array>") diff --git a/pysparkling/sql/ast/tests/test_parser.py b/pysparkling/sql/ast/tests/test_parser.py new file mode 100644 index 000000000..c6d142af4 --- /dev/null +++ b/pysparkling/sql/ast/tests/test_parser.py @@ -0,0 +1,42 @@ +from unittest import TestCase + +from pysparkling.sql.ast.ast_to_python import parse_sql, SqlParsingError + + +class TestParser(TestCase): + def test_where(self): + col = parse_sql("doesItWorks = 'In progress!'", rule="booleanExpression") + self.assertEqual("(doesItWorks = In progress!)", str(col)) + + def test_named_expression_no_alias(self): + col = parse_sql("1 + 2", rule="singleExpression") + self.assertEqual("(1 + 2)", str(col)) + + def test_named_expression_implicit_alias(self): + col = parse_sql("(1 + 2) sum", rule="singleExpression") + self.assertEqual("sum", str(col)) + + def test_named_expression_explicit_alias(self): + col = parse_sql("1 + 2 as sum", rule="singleExpression") + self.assertEqual("sum", str(col)) + + def test_named_expression_bad_alias(self): + with self.assertRaises(SqlParsingError) as ctx: + parse_sql("1 + 2 as invalid-alias", rule="singleExpression") + self.assertEqual( + 'Possibly unquoted identifier invalid-alias detected. ' + 'Please consider quoting it with back-quotes as `invalid-alias`', + str(ctx.exception) + ) + + def test_struct(self): + col = parse_sql("Struct('Alice', 2)", rule="primaryExpression") + self.assertEqual("struct(Alice, 2)", str(col)) + + def test_function(self): + col = parse_sql("GREATEST(1,2,3)", rule="singleExpression") + self.assertEqual("greatest(1, 2, 3)", str(col)) + + # def test_where_filter(self): + # col = parse_sql("concat(1,2 ,3) filter (where id<2)", rule="singleExpression") + # self.assertEqual("..", str(col)) diff --git a/pysparkling/sql/ast/tests/test_type_parsing.py b/pysparkling/sql/ast/tests/test_type_parsing.py new file mode 100644 index 000000000..f035e8005 --- /dev/null +++ b/pysparkling/sql/ast/tests/test_type_parsing.py @@ -0,0 +1,164 @@ +import contextlib +import io +from unittest import TestCase + +from parameterized import parameterized + +from pysparkling import Context +from pysparkling.sql.ast.ast_to_python import parse_data_type +from pysparkling.sql.session import SparkSession +from pysparkling.sql.types import ( + ArrayType, BinaryType, BooleanType, ByteType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, + MapType, ShortType, StringType, StructField, StructType, TimestampType +) + + +class TypeParsingTest(TestCase): + DATA_TYPE_SCENARIOS = { + "boolean": BooleanType(), + "booLean": BooleanType(), + "tinyint": ByteType(), + "byte": ByteType(), + "smallint": ShortType(), + "short": ShortType(), + "int": IntegerType(), + "integer": IntegerType(), + "INTeger": IntegerType(), + "bigint": LongType(), + "long": LongType(), + "float": FloatType(), + "real": FloatType(), + "double": DoubleType(), + "date": DateType(), + "timestamp": TimestampType(), + "string": StringType(), + "binary": BinaryType(), + "decimal": DecimalType(10, 0), + "decimal(10,5)": DecimalType(10, 5), + "decimal(5)": DecimalType(5, 0), + "decimal(5, 2)": DecimalType(5, 2), + "dec": DecimalType(10, 0), + "dec(25)": DecimalType(25, 0), + "dec(25, 21)": DecimalType(25, 21), + "numeric": DecimalType(10, 0), + "Array": ArrayType(StringType()), + "Array": ArrayType(IntegerType()), + "array": ArrayType(DoubleType()), + "Array>": ArrayType(MapType(IntegerType(), ByteType(), True), True), + "MAP": MapType(IntegerType(), StringType()), + "Map": MapType(StringType(), IntegerType()), + "Map < integer, String >": MapType(IntegerType(), StringType()), + "Struct": StructType([ + StructField(name="name", dataType=StringType()), + StructField(name="age", dataType=IntegerType()), + ]), + "array>": ArrayType( + StructType([StructField('tinYint', ByteType(), True)]), + True + ), + "MAp>": MapType(IntegerType(), ArrayType(DoubleType(), True), True), + "MAP>": MapType( + IntegerType(), + StructType([StructField("varchar", StringType(), True)]), + True + ), + "struct": StructType([ + StructField("intType", IntegerType(), True), + StructField("ts", TimestampType(), True) + ]), + "Struct": StructType([ + StructField("int", IntegerType(), True), + StructField("timestamp", TimestampType(), True) + ]), + """ + struct< + struct:struct, + MAP:Map, + arrAy:Array, + anotherArray:Array + > + """: ( + StructType([ + StructField("struct", StructType([ + StructField("deciMal", DecimalType(), True), + StructField("anotherDecimal", DecimalType(5, 2), True) + ]), True), + # StructField("MAP", MapType(TimestampType(), VarcharType(10)), True), + StructField("MAP", MapType(TimestampType(), StringType()), True), + StructField("arrAy", ArrayType(DoubleType(), True), True), + # StructField("anotherArray", ArrayType(CharType(9), True), True)]) + StructField("anotherArray", ArrayType(StringType(), True), True) + ]) + ), + "struct<`x+y`:int, `!@#$%^&*()`:string, `1_2.345<>:\"`:varchar(20)>": StructType([ + StructField("x+y", IntegerType(), True), + StructField("!@#$%^&*()", StringType(), True), + # StructField("1_2.345<>:\"", VarcharType(20), True)]) + StructField("1_2.345<>:\"", StringType(), True) + ]), + "strUCt<>": StructType([]), + "Struct": StructType([ + StructField("TABLE", StringType(), True), + StructField("DATE", BooleanType(), True) + ]), + "struct": StructType([ + StructField("end", LongType(), True), + StructField("select", IntegerType(), True), + StructField("from", StringType(), True), + ]), + "Struct": StructType([ + StructField("x", IntegerType(), True), + StructField("y", StringType(), True, {'comment': 'test'}), + ]) + } + + @parameterized.expand(DATA_TYPE_SCENARIOS.items()) + def test_equal(self, string, data_type): + self.assertEqual(parse_data_type(string), data_type) + + def test_comment(self): + string = ( + "Struct<" + "x: INT Not null, " + "y: STRING COMMENT 'nullable', " + "z: string not null COMMENT 'test'" + ">" + ) + data_type = StructType([ + StructField("x", IntegerType(), False), + StructField("y", StringType(), True, {'comment': 'nullable'}), + StructField("z", StringType(), False, {'comment': 'test'}), + ]) + self.assertEqual(parse_data_type(string), data_type) + + SCHEMA_SCENARIOS = { + 'some_str: string, some_int: integer, some_date: date': ( + 'root\n' + ' |-- some_str: string (nullable = true)\n' + ' |-- some_int: integer (nullable = true)\n' + ' |-- some_date: date (nullable = true)\n' + ), + 'some_str: string, arr: array': ( + 'root\n' + ' |-- some_str: string (nullable = true)\n' + ' |-- arr: array (nullable = true)\n' + ' | |-- element: string (containsNull = true)\n' + ), + 'some_str: string, arr: array>': ( + 'root\n' + ' |-- some_str: string (nullable = true)\n' + ' |-- arr: array (nullable = true)\n' + ' | |-- element: array (containsNull = true)\n' + ' | | |-- element: string (containsNull = true)\n' + ), + } + + @parameterized.expand(SCHEMA_SCENARIOS.items()) + def test_dataframe_schema_parsing(self, schema, printed_schema): + spark = SparkSession(Context()) + df = spark.createDataFrame([], schema=schema) + + f = io.StringIO() + with contextlib.redirect_stdout(f): + df.printSchema() + self.assertEqual(printed_schema, f.getvalue()) diff --git a/pysparkling/sql/casts.py b/pysparkling/sql/casts.py index 076a589cc..544b27a92 100644 --- a/pysparkling/sql/casts.py +++ b/pysparkling/sql/casts.py @@ -515,3 +515,15 @@ def get_datetime_parser(java_time_format): for token, _ in JAVA_TIME_FORMAT_TOKENS.findall(java_time_format): python_pattern += FORMAT_MAPPING.get(token, token) return lambda value: datetime.datetime.strptime(value, python_pattern) + + +def tz_diff(date): + """ + Return as a timedelta the offset on this date for the local timezone compared to UTC + """ + time_here = date.astimezone() + time_utc = time_here.replace(tzinfo=pytz.utc) + + offset = time_utc - time_here + + return offset diff --git a/pysparkling/sql/column.py b/pysparkling/sql/column.py index 42f15e337..822212ee4 100644 --- a/pysparkling/sql/column.py +++ b/pysparkling/sql/column.py @@ -1,4 +1,4 @@ -from .expressions.expressions import Expression +from .expressions.expressions import Expression, RegisteredExpressions from .expressions.fields import find_position_in_schema from .expressions.literals import Literal from .expressions.mappers import CaseWhen, StarOperator @@ -8,7 +8,7 @@ Or, Pow, StartsWith, Substring, Time ) from .expressions.orders import Asc, AscNullsFirst, AscNullsLast, Desc, DescNullsFirst, DescNullsLast, SortOrder -from .types import DataType, string_to_type, StructField +from .types import DataType, StructField from .utils import AnalysisException, IllegalArgumentException @@ -469,7 +469,7 @@ def alias(self, *alias, **kwargs): raise ValueError('Pysparkling does not support alias with metadata') if len(alias) == 1: - return Column(Alias(self, Literal(alias[0]))) + return Column(Alias(self, alias[0])) # pylint: disable=W0511 # todo: support it raise ValueError('Pysparkling does not support multiple aliases') @@ -511,7 +511,9 @@ def cast(self, dataType): """ if isinstance(dataType, str): - dataType = string_to_type(dataType) + # pylint: disable=import-outside-toplevel, cyclic-import + from pysparkling.sql.ast.ast_to_python import parse_data_type + dataType = parse_data_type(dataType) elif not isinstance(dataType, DataType): raise NotImplementedError(f"Unknown cast type: {dataType}") @@ -628,7 +630,7 @@ def output_fields(self, schema): return self.expr.output_fields(schema) return [StructField( name=self.col_name, - dataType=self.data_type, + dataType=self.data_type(schema), nullable=self.is_nullable )] @@ -645,7 +647,6 @@ def mergeStats(self, row, schema): def initialize(self, partition_index): if isinstance(self.expr, Expression): self.expr.recursive_initialize(partition_index) - return self def with_pre_evaluation_schema(self, pre_evaluation_schema): if isinstance(self.expr, Expression): @@ -680,11 +681,20 @@ def __nonzero__(self): __bool__ = __nonzero__ - @property - def data_type(self): - # pylint: disable=W0511 - # todo: be more specific - return DataType() + def data_type(self, schema): + if isinstance(self.expr, (Expression, RegisteredExpressions)): + return self.expr.data_type(schema) + if isinstance(self.expr, str): + try: + return schema[self.expr].dataType + except KeyError: + # pylint: disable=raise-missing-from + raise AnalysisException( + f"cannot resolve '`{self.expr}`' given input columns: {schema.fields};" + ) + raise AnalysisException( + f"cannot resolve '`{self.expr}`' type, expecting str or Expression but got {type(self.expr)};" + ) @property def is_nullable(self): diff --git a/pysparkling/sql/dataframe.py b/pysparkling/sql/dataframe.py index 7e5e27fe9..b0e5dae8f 100644 --- a/pysparkling/sql/dataframe.py +++ b/pysparkling/sql/dataframe.py @@ -7,8 +7,7 @@ from .internal_utils.joins import CROSS_JOIN, JOIN_TYPES from .internals import CUBE_TYPE, InternalGroupedDataFrame, ROLLUP_TYPE from .types import ( - _check_series_convert_timestamps_local_tz, ByteType, FloatType, IntegerType, IntegralType, ShortType, - TimestampType + _check_series_convert_timestamps_local_tz, ByteType, FloatType, IntegerType, IntegralType, ShortType, TimestampType ) from .utils import AnalysisException, IllegalArgumentException, require_minimum_pandas_version diff --git a/pysparkling/sql/expressions/aggregate/collectors.py b/pysparkling/sql/expressions/aggregate/collectors.py index 96e14dff9..19edb9a53 100644 --- a/pysparkling/sql/expressions/aggregate/collectors.py +++ b/pysparkling/sql/expressions/aggregate/collectors.py @@ -124,7 +124,7 @@ def args(self): class CountDistinct(Aggregation): pretty_name = "count" - def __init__(self, columns): + def __init__(self, *columns): super().__init__(columns) self.columns = columns self.items = set() diff --git a/pysparkling/sql/expressions/arrays.py b/pysparkling/sql/expressions/arrays.py index 2511d75d6..edf901c27 100644 --- a/pysparkling/sql/expressions/arrays.py +++ b/pysparkling/sql/expressions/arrays.py @@ -1,3 +1,5 @@ +from ..column import Column +from ..types import ArrayType, BooleanType, IntegerType, MapType, NullType, StringType, StructType from ..utils import AnalysisException from .expressions import BinaryOperation, Expression, UnaryExpression @@ -15,6 +17,9 @@ def eval(self, row, schema): return None return False + def data_type(self, schema): + return BooleanType() + class ArrayContains(Expression): pretty_name = "array_contains" @@ -36,11 +41,14 @@ def args(self): self.value ) + def data_type(self, schema): + return BooleanType() + class ArrayColumn(Expression): pretty_name = "array" - def __init__(self, columns): + def __init__(self, *columns): super().__init__(columns) self.columns = columns @@ -50,13 +58,23 @@ def eval(self, row, schema): def args(self): return self.columns + def data_type(self, schema): + if not self.columns: + return ArrayType(elementType=NullType) + return ArrayType(elementType=self.columns[0].data_type(schema)) + class MapColumn(Expression): pretty_name = "map" - def __init__(self, columns): + def __init__(self, *columns): super().__init__(columns) self.columns = columns + if len(columns) % 2 != 0: + raise AnalysisException( + f"Cannot resolve '{self}' due to data type mismatch: " + f"map expects a positive even number of arguments." + ) self.keys = columns[::2] self.values = columns[1::2] @@ -69,6 +87,14 @@ def eval(self, row, schema): def args(self): return self.columns + def data_type(self, schema): + if not self.columns: + return MapType(keyType=NullType, valueType=NullType) + return MapType( + keyType=self.keys[0].data_type(schema), + valueType=self.values[0].data_type(schema), + ) + class MapFromArraysColumn(Expression): pretty_name = "map_from_arrays" @@ -79,9 +105,13 @@ def __init__(self, keys, values): self.values = values def eval(self, row, schema): - return dict( - zip(self.keys.eval(row, schema), self.values.eval(row, schema)) - ) + keys = self.keys.eval(row, schema) + values = self.values.eval(row, schema) + if len(keys) != len(values): + raise AnalysisException( + f"Error in '{self}': The key array and value array of MapData must have the same length." + ) + return dict(zip(keys, values)) def args(self): return ( @@ -89,6 +119,14 @@ def args(self): self.values ) + def data_type(self, schema): + if not isinstance(self.keys, Column) and not self.keys: + return MapType(keyType=NullType, valueType=NullType) + return MapType( + keyType=self.keys[0].data_type(schema), + valueType=self.values[0].data_type(schema), + ) + class Size(UnaryExpression): pretty_name = "size" @@ -101,6 +139,9 @@ def eval(self, row, schema): f"{self.column} value should be an array or a map, got {type(column_value)}" ) + def data_type(self, schema): + return IntegerType() + class ArraySort(UnaryExpression): pretty_name = "array_sort" @@ -108,19 +149,42 @@ class ArraySort(UnaryExpression): def eval(self, row, schema): return sorted(self.column.eval(row, schema)) + def data_type(self, schema): + return self.column.data_type(schema) + class ArrayMin(UnaryExpression): pretty_name = "array_min" def eval(self, row, schema): - return min(self.column.eval(row, schema)) + column_type = self.column.data_type(schema) + column_value = self.column.eval(row, schema) + if not column_type == ArrayType: + raise AnalysisException( + f"Cannot resolve '{self}' due to data type mismatch: argument 1 requires array type, " + f"however, '{column_value}' is of {column_type} type." + ) + return min(column_value) + + def data_type(self, schema): + return self.column.data_type(schema).elementType() class ArrayMax(UnaryExpression): pretty_name = "array_max" def eval(self, row, schema): - return max(self.column.eval(row, schema)) + column_type = self.column.data_type(schema) + column_value = self.column.eval(row, schema) + if not column_type == ArrayType: + raise AnalysisException( + f"Cannot resolve '{self}' due to data type mismatch: argument 1 requires array type, " + f"however, '{column_value}' is of {column_type} type." + ) + return max(column_value) + + def data_type(self, schema): + return self.column.data_type(schema).elementType() class Slice(Expression): @@ -142,6 +206,9 @@ def args(self): self.length ) + def data_type(self, schema): + return self.column.data_type(schema) + class ArrayRepeat(Expression): pretty_name = "array_repeat" @@ -161,9 +228,12 @@ def args(self): self.count ) + def data_type(self, schema): + return ArrayType(self.col.data_type(schema)) + class Sequence(Expression): - pretty_name = "array_join" + pretty_name = "sequence" def __init__(self, start, stop, step): super().__init__(start, stop, step) @@ -201,6 +271,9 @@ def args(self): self.step ) + def data_type(self, schema): + return ArrayType(self.start.data_type(schema)) + class ArrayJoin(Expression): pretty_name = "array_join" @@ -231,6 +304,9 @@ def args(self): self.nullReplacement ) + def data_type(self, schema): + return StringType() + class SortArray(Expression): pretty_name = "sort_array" @@ -249,24 +325,32 @@ def args(self): self.asc ) + def data_type(self, schema): + return self.col.data_type(schema) + class ArraysZip(Expression): pretty_name = "arrays_zip" - def __init__(self, cols): - super().__init__(*cols) - self.cols = cols + def __init__(self, columns): + super().__init__(*columns) + self.columns = columns def eval(self, row, schema): return [ - list(combination) + dict(enumerate(combination)) for combination in zip( - *(c.eval(row, schema) for c in self.cols) + *(c.eval(row, schema) for c in self.columns) ) ] def args(self): - return self.cols + return self.columns + + def data_type(self, schema): + return ArrayType(StructType([ + col.data_type(schema) for col in self.columns + ])) class Flatten(UnaryExpression): @@ -279,6 +363,9 @@ def eval(self, row, schema): for value in array ] + def data_type(self, schema): + return self.column.data_type(schema).elementType + class ArrayPosition(Expression): pretty_name = "array_position" @@ -303,6 +390,9 @@ def args(self): self.value ) + def data_type(self, schema): + return IntegerType() + class ElementAt(Expression): pretty_name = "element_at" @@ -324,6 +414,9 @@ def args(self): self.extraction ) + def data_type(self, schema): + return self.col.data_type(schema).elementType + class ArrayRemove(Expression): pretty_name = "array_remove" @@ -343,6 +436,9 @@ def args(self): self.element ) + def data_type(self, schema): + return self.col.data_type(schema) + class ArrayDistinct(UnaryExpression): pretty_name = "array_distinct" @@ -350,6 +446,9 @@ class ArrayDistinct(UnaryExpression): def eval(self, row, schema): return list(set(self.column.eval(row, schema))) + def data_type(self, schema): + return self.column.data_type(schema) + class ArrayIntersect(BinaryOperation): pretty_name = "array_intersect" @@ -357,6 +456,9 @@ class ArrayIntersect(BinaryOperation): def eval(self, row, schema): return list(set(self.arg1.eval(row, schema)) & set(self.arg2.eval(row, schema))) + def data_type(self, schema): + return self.arg1.data_type(schema) + class ArrayUnion(BinaryOperation): pretty_name = "array_union" @@ -364,6 +466,9 @@ class ArrayUnion(BinaryOperation): def eval(self, row, schema): return list(set(self.arg1.eval(row, schema)) | set(self.arg2.eval(row, schema))) + def data_type(self, schema): + return self.arg1.data_type(schema) + class ArrayExcept(BinaryOperation): pretty_name = "array_except" @@ -371,6 +476,9 @@ class ArrayExcept(BinaryOperation): def eval(self, row, schema): return list(set(self.arg1.eval(row, schema)) - set(self.arg2.eval(row, schema))) + def data_type(self, schema): + return self.arg1.data_type(schema) + __all__ = [ "ArraysZip", "ArrayRepeat", "Flatten", "ArrayMax", "ArrayMin", "SortArray", "Size", diff --git a/pysparkling/sql/expressions/csvs.py b/pysparkling/sql/expressions/csvs.py index 275c26e87..f2ba4644e 100644 --- a/pysparkling/sql/expressions/csvs.py +++ b/pysparkling/sql/expressions/csvs.py @@ -2,6 +2,7 @@ from ..internal_utils.options import Options from ..internal_utils.readers.csvreader import csv_record_to_row, CSVReader from ..internal_utils.readers.utils import guess_schema_from_strings +from ..types import StringType from ..utils import AnalysisException from .expressions import Expression @@ -33,3 +34,6 @@ def eval(self, row, schema): def args(self): return (self.column,) + + def data_type(self, schema): + return StringType() diff --git a/pysparkling/sql/expressions/dates.py b/pysparkling/sql/expressions/dates.py index a081099ad..f68389b30 100644 --- a/pysparkling/sql/expressions/dates.py +++ b/pysparkling/sql/expressions/dates.py @@ -5,8 +5,9 @@ from ...utils import parse_tz from ..casts import get_time_formatter, get_unix_timestamp_parser -from ..types import DateType, FloatType, TimestampType +from ..types import DateType, DoubleType, FloatType, IntegerType, LongType, StringType, TimestampType from .expressions import Expression, UnaryExpression +from .operators import Cast GMT_TIMEZONE = pytz.timezone("GMT") @@ -23,7 +24,7 @@ def __init__(self, start_date, num_months): self.timedelta = datetime.timedelta(days=self.num_months) def eval(self, row, schema): - return self.start_date.cast(DateType()).eval(row, schema) + self.timedelta + return Cast(self.start_date, DateType()).eval(row, schema) + self.timedelta def args(self): return ( @@ -31,6 +32,9 @@ def args(self): self.num_months ) + def data_type(self, schema): + return DateType() + class DateAdd(Expression): pretty_name = "date_add" @@ -42,7 +46,7 @@ def __init__(self, start_date, num_days): self.timedelta = datetime.timedelta(days=self.num_days) def eval(self, row, schema): - return self.start_date.cast(DateType()).eval(row, schema) + self.timedelta + return Cast(self.start_date, DateType()).eval(row, schema) + self.timedelta def args(self): return ( @@ -50,6 +54,9 @@ def args(self): self.num_days ) + def data_type(self, schema): + return DateType() + class DateSub(Expression): pretty_name = "date_sub" @@ -61,7 +68,7 @@ def __init__(self, start_date, num_days): self.timedelta = datetime.timedelta(days=self.num_days) def eval(self, row, schema): - return self.start_date.cast(DateType()).eval(row, schema) - self.timedelta + return Cast(self.start_date, DateType()).eval(row, schema) - self.timedelta def args(self): return ( @@ -69,89 +76,125 @@ def args(self): self.num_days ) + def data_type(self, schema): + return DateType() + class Year(UnaryExpression): pretty_name = "year" def eval(self, row, schema): - return self.column.cast(DateType()).eval(row, schema).year + return Cast(self.column, DateType()).eval(row, schema).year + + def data_type(self, schema): + return IntegerType() class Month(UnaryExpression): pretty_name = "month" def eval(self, row, schema): - return self.column.cast(DateType()).eval(row, schema).month + return Cast(self.column, DateType()).eval(row, schema).month + + def data_type(self, schema): + return IntegerType() class Quarter(UnaryExpression): pretty_name = "quarter" def eval(self, row, schema): - month = self.column.cast(DateType()).eval(row, schema).month + month = Cast(self.column, DateType()).eval(row, schema).month return 1 + int((month - 1) / 3) + def data_type(self, schema): + return IntegerType() + class Hour(UnaryExpression): pretty_name = "hour" def eval(self, row, schema): - return self.column.cast(TimestampType()).eval(row, schema).hour + return Cast(self.column, TimestampType()).eval(row, schema).hour + + def data_type(self, schema): + return IntegerType() class Minute(UnaryExpression): pretty_name = "minute" def eval(self, row, schema): - return self.column.cast(TimestampType()).eval(row, schema).minute + return Cast(self.column, TimestampType()).eval(row, schema).minute + + def data_type(self, schema): + return IntegerType() class Second(UnaryExpression): pretty_name = "second" def eval(self, row, schema): - return self.column.cast(TimestampType()).eval(row, schema).second + return Cast(self.column, TimestampType()).eval(row, schema).second + + def data_type(self, schema): + return IntegerType() class DayOfMonth(UnaryExpression): pretty_name = "dayofmonth" def eval(self, row, schema): - return self.column.cast(DateType()).eval(row, schema).day + return Cast(self.column, DateType()).eval(row, schema).day + + def data_type(self, schema): + return IntegerType() class DayOfYear(UnaryExpression): pretty_name = "dayofyear" def eval(self, row, schema): - value = self.column.cast(DateType()).eval(row, schema) + value = Cast(self.column, DateType()).eval(row, schema) day_from_the_first = value - datetime.date(value.year, 1, 1) return 1 + day_from_the_first.days + def data_type(self, schema): + return IntegerType() + class LastDay(UnaryExpression): pretty_name = "last_day" def eval(self, row, schema): - value = self.column.cast(DateType()).eval(row, schema) + value = Cast(self.column, DateType()).eval(row, schema) first_of_next_month = value + relativedelta(months=1, day=1) return first_of_next_month - datetime.timedelta(days=1) + def data_type(self, schema): + return DateType() + class WeekOfYear(UnaryExpression): pretty_name = "weekofyear" def eval(self, row, schema): - return self.column.cast(DateType()).eval(row, schema).isocalendar()[1] + return Cast(self.column, DateType()).eval(row, schema).isocalendar()[1] + + def data_type(self, schema): + return IntegerType() class DayOfWeek(UnaryExpression): pretty_name = "dayofweek" def eval(self, row, schema): - date = self.column.cast(DateType()).eval(row, schema) + date = Cast(self.column, DateType()).eval(row, schema) return date.isoweekday() + 1 if date.isoweekday() != 7 else 1 + def data_type(self, schema): + return IntegerType() + class NextDay(Expression): pretty_name = "next_day" @@ -162,7 +205,7 @@ def __init__(self, column, day_of_week): self.day_of_week = day_of_week.get_literal_value() def eval(self, row, schema): - value = self.column.cast(DateType()).eval(row, schema) + value = Cast(self.column, DateType()).eval(row, schema) if self.day_of_week.upper() not in DAYS_OF_WEEK: return None @@ -180,6 +223,9 @@ def args(self): self.day_of_week ) + def data_type(self, schema): + return DateType() + class MonthsBetween(Expression): pretty_name = "months_between" @@ -191,8 +237,8 @@ def __init__(self, column1, column2, round_off): self.round_off = round_off.get_literal_value() def eval(self, row, schema): - value_1 = self.column1.cast(TimestampType()).eval(row, schema) - value_2 = self.column2.cast(TimestampType()).eval(row, schema) + value_1 = Cast(self.column1, TimestampType()).eval(row, schema) + value_2 = Cast(self.column2, TimestampType()).eval(row, schema) if (not isinstance(value_1, datetime.datetime) or not isinstance(value_2, datetime.datetime)): @@ -227,6 +273,9 @@ def args(self): str(self.round_off).lower() ) + def data_type(self, schema): + return DoubleType() + class DateDiff(Expression): pretty_name = "datediff" @@ -237,8 +286,8 @@ def __init__(self, column1, column2): self.column2 = column2 def eval(self, row, schema): - value_1 = self.column1.cast(DateType()).eval(row, schema) - value_2 = self.column2.cast(DateType()).eval(row, schema) + value_1 = Cast(self.column1, DateType()).eval(row, schema) + value_2 = Cast(self.column2, DateType()).eval(row, schema) if (not isinstance(value_1, datetime.date) or not isinstance(value_2, datetime.date)): @@ -252,6 +301,9 @@ def args(self): self.column2 ) + def data_type(self, schema): + return IntegerType() + class FromUnixTime(Expression): pretty_name = "from_unixtime" @@ -263,7 +315,7 @@ def __init__(self, column, f): self.formatter = get_time_formatter(self.format) def eval(self, row, schema): - timestamp = self.column.cast(FloatType()).eval(row, schema) + timestamp = Cast(self.column, FloatType()).eval(row, schema) return self.formatter(datetime.datetime.fromtimestamp(timestamp)) def args(self): @@ -272,6 +324,9 @@ def args(self): self.format ) + def data_type(self, schema): + return StringType() + class DateFormat(Expression): pretty_name = "date_format" @@ -283,7 +338,7 @@ def __init__(self, column, f): self.formatter = get_time_formatter(self.format) def eval(self, row, schema): - timestamp = self.column.cast(TimestampType()).eval(row, schema) + timestamp = Cast(self.column, TimestampType()).eval(row, schema) return self.formatter(timestamp) def args(self): @@ -292,6 +347,9 @@ def args(self): self.format ) + def data_type(self, schema): + return StringType() + class CurrentTimestamp(Expression): pretty_name = "current_timestamp" @@ -310,6 +368,9 @@ def initialize(self, partition_index): def args(self): return () + def data_type(self, schema): + return TimestampType() + class CurrentDate(Expression): pretty_name = "current_date" @@ -328,6 +389,9 @@ def initialize(self, partition_index): def args(self): return () + def data_type(self, schema): + return DateType() + class UnixTimestamp(Expression): pretty_name = "unix_timestamp" @@ -348,6 +412,9 @@ def args(self): self.format ) + def data_type(self, schema): + return LongType() + class ParseToTimestamp(Expression): pretty_name = "to_timestamp" @@ -372,6 +439,9 @@ def args(self): f"'{self.format}'" ) + def data_type(self, schema): + return TimestampType() + class ParseToDate(Expression): pretty_name = "to_date" @@ -396,6 +466,9 @@ def args(self): f"'{self.format}'" ) + def data_type(self, schema): + return DateType() + class TruncDate(Expression): pretty_name = "trunc" @@ -406,7 +479,7 @@ def __init__(self, column, level): self.level = level.get_literal_value() def eval(self, row, schema): - value = self.column.cast(DateType()).eval(row, schema) + value = Cast(self.column, DateType()).eval(row, schema) if self.level in ('year', 'yyyy', 'yy'): return datetime.date(value.year, 1, 1) if self.level in ('month', 'mon', 'mm'): @@ -419,6 +492,9 @@ def args(self): self.level ) + def data_type(self, schema): + return DateType() + class TruncTimestamp(Expression): pretty_name = "date_trunc" @@ -429,7 +505,7 @@ def __init__(self, level, column): self.column = column def eval(self, row, schema): - value = self.column.cast(TimestampType()).eval(row, schema) + value = Cast(self.column, TimestampType()).eval(row, schema) day_truncation = self.truncate_to_day(value) if day_truncation: @@ -474,6 +550,9 @@ def args(self): self.column ) + def data_type(self, schema): + return TimestampType() + class FromUTCTimestamp(Expression): pretty_name = "from_utc_timestamp" @@ -485,7 +564,7 @@ def __init__(self, column, tz): self.pytz = parse_tz(self.tz) def eval(self, row, schema): - value = self.column.cast(TimestampType()).eval(row, schema) + value = Cast(self.column, TimestampType()).eval(row, schema) if self.pytz is None: return value gmt_date = GMT_TIMEZONE.localize(value) @@ -498,6 +577,9 @@ def args(self): self.tz ) + def data_type(self, schema): + return TimestampType() + class ToUTCTimestamp(Expression): pretty_name = "to_utc_timestamp" @@ -509,7 +591,7 @@ def __init__(self, column, tz): self.pytz = parse_tz(self.tz) def eval(self, row, schema): - value = self.column.cast(TimestampType()).eval(row, schema) + value = Cast(self.column, TimestampType()).eval(row, schema) if self.pytz is None: return value local_date = self.pytz.localize(value) @@ -522,6 +604,9 @@ def args(self): self.tz ) + def data_type(self, schema): + return TimestampType() + __all__ = [ "ToUTCTimestamp", "FromUTCTimestamp", "TruncTimestamp", "TruncDate", "ParseToDate", diff --git a/pysparkling/sql/expressions/explodes.py b/pysparkling/sql/expressions/explodes.py index e70d040c5..6612dd774 100644 --- a/pysparkling/sql/expressions/explodes.py +++ b/pysparkling/sql/expressions/explodes.py @@ -1,4 +1,4 @@ -from ..types import DataType, IntegerType, StructField +from ..types import IntegerType, StructField, StructType from .expressions import UnaryExpression @@ -20,6 +20,9 @@ def eval(self, row, schema): def __str__(self): return "col" + def data_type(self, schema): + return self.column.data_type(schema).elementType + class ExplodeOuter(Explode): def eval(self, row, schema): @@ -31,6 +34,9 @@ def eval(self, row, schema): def __str__(self): return "col" + def data_type(self, schema): + return self.column.data_type(schema).elementType + class PosExplode(UnaryExpression): def eval(self, row, schema): @@ -53,9 +59,12 @@ def may_output_multiple_cols(self): def output_fields(self, schema): return [ StructField("pos", IntegerType(), False), - StructField("col", DataType(), False) + StructField("col", self.column.data_type(schema).elementType, False) ] + def data_type(self, schema): + return StructType(self.output_fields(schema)) + class PosExplodeOuter(PosExplode): def eval(self, row, schema): diff --git a/pysparkling/sql/expressions/expressions.py b/pysparkling/sql/expressions/expressions.py index 838247c27..eb51d6976 100644 --- a/pysparkling/sql/expressions/expressions.py +++ b/pysparkling/sql/expressions/expressions.py @@ -5,16 +5,16 @@ expression_registry = {} -class RegisterExpressions(type): +class RegisteredExpressions(type): pretty_name = None def __init__(cls, what, bases, dct): super().__init__(what, bases, dct) if cls.pretty_name is not None: - expression_registry[cls.pretty_name] = cls + expression_registry[cls.pretty_name.lower()] = cls -class Expression(metaclass=RegisterExpressions): +class Expression(metaclass=RegisteredExpressions): pretty_name = None def __init__(self, *children): @@ -36,12 +36,11 @@ def __repr__(self): def output_fields(self, schema): return [StructField( name=str(self), - dataType=self.data_type, + dataType=self.data_type(schema), nullable=self.is_nullable )] - @property - def data_type(self): + def data_type(self, schema): # pylint: disable=W0511 # todo: be more specific return DataType() diff --git a/pysparkling/sql/expressions/fields.py b/pysparkling/sql/expressions/fields.py index 23c794590..1cb8c9815 100644 --- a/pysparkling/sql/expressions/fields.py +++ b/pysparkling/sql/expressions/fields.py @@ -20,6 +20,9 @@ def output_fields(self, schema): def args(self): return (self.field,) + def data_type(self, schema): + return schema[find_position_in_schema(schema, self.field)].dataType + def find_position_in_schema(schema, expr): if isinstance(expr, str): diff --git a/pysparkling/sql/expressions/jsons.py b/pysparkling/sql/expressions/jsons.py index bc508ee4a..96de60d7f 100644 --- a/pysparkling/sql/expressions/jsons.py +++ b/pysparkling/sql/expressions/jsons.py @@ -3,6 +3,7 @@ from ...utils import get_json_encoder from ..internal_utils.options import Options from ..internal_utils.readers.jsonreader import JSONReader +from ..types import StringType from .expressions import Expression @@ -37,5 +38,8 @@ def args(self): self.input_options ) + def data_type(self, schema): + return StringType() + __all__ = ["StructsToJson"] diff --git a/pysparkling/sql/expressions/literals.py b/pysparkling/sql/expressions/literals.py index 775087ff0..e22040250 100644 --- a/pysparkling/sql/expressions/literals.py +++ b/pysparkling/sql/expressions/literals.py @@ -1,3 +1,4 @@ +from ..types import _infer_type from ..utils import AnalysisException from .expressions import Expression @@ -6,6 +7,7 @@ class Literal(Expression): def __init__(self, value): super().__init__() self.value = value + self._data_type = _infer_type(self.value) def eval(self, row, schema): return self.value @@ -28,5 +30,8 @@ def get_literal_value(self): def args(self): return (self.value, ) + def data_type(self, schema): + return self._data_type + __all__ = ["Literal"] diff --git a/pysparkling/sql/expressions/mappers.py b/pysparkling/sql/expressions/mappers.py index f862632e0..657e9d5f9 100644 --- a/pysparkling/sql/expressions/mappers.py +++ b/pysparkling/sql/expressions/mappers.py @@ -6,9 +6,13 @@ from ...utils import half_even_round, half_up_round, MonotonicallyIncreasingIDGenerator, XORShiftRandom from ..internal_utils.column import resolve_column -from ..types import create_row, StringType +from ..types import ( + ArrayType, BinaryType, BooleanType, create_row, DoubleType, IntegerType, LongType, MapType, NullType, StringType, + StructField, StructType +) from ..utils import AnalysisException from .expressions import Expression, NullSafeColumnOperation, UnaryExpression +from .operators import Cast JVM_MAX_INTEGER_SIZE = 2 ** 63 @@ -71,6 +75,9 @@ def set_otherwise(self, default): default ) + def data_type(self, schema): + return self.values[0].data_type(schema) + class Otherwise(Expression): def __init__(self, conditions, values, default): @@ -104,6 +111,9 @@ def args(self): self.default ) + def data_type(self, schema): + return self.values[0].data_type(schema) + class RegExpExtract(Expression): pretty_name = "regexp_extract" @@ -136,6 +146,9 @@ def args(self): self.groupIdx ) + def data_type(self, schema): + return StringType() + class RegExpReplace(Expression): pretty_name = "regexp_replace" @@ -160,6 +173,9 @@ def args(self): self.replacement ) + def data_type(self, schema): + return StringType() + class Round(NullSafeColumnOperation): pretty_name = "round" @@ -177,6 +193,9 @@ def args(self): self.scale ) + def data_type(self, schema): + return self.column.data_type(schema) + class Bround(NullSafeColumnOperation): pretty_name = "bround" @@ -194,6 +213,9 @@ def args(self): self.scale ) + def data_type(self, schema): + return self.column.data_type(schema) + class FormatNumber(Expression): pretty_name = "format_number" @@ -218,6 +240,9 @@ def args(self): self.digits ) + def data_type(self, schema): + return StringType() + class SubstringIndex(Expression): pretty_name = "substring_index" @@ -239,11 +264,14 @@ def args(self): self.count ) + def data_type(self, schema): + return StringType() + class Coalesce(Expression): pretty_name = "coalesce" - def __init__(self, columns): + def __init__(self, *columns): super().__init__(columns) self.columns = columns @@ -257,6 +285,13 @@ def eval(self, row, schema): def args(self): return self.columns + def data_type(self, schema): + for col in self.columns: + col_type = col.data_type(schema) + if col_type != NullType(): + return col_type + return NullType() + class IsNaN(UnaryExpression): pretty_name = "isnan" @@ -264,6 +299,9 @@ class IsNaN(UnaryExpression): def eval(self, row, schema): return math.isnan(self.eval(row, schema)) + def data_type(self, schema): + return BooleanType() + class NaNvl(Expression): pretty_name = "nanvl" @@ -286,6 +324,9 @@ def args(self): self.col2 ) + def data_type(self, schema): + return DoubleType() + class Hypot(Expression): pretty_name = "hypot" @@ -304,6 +345,9 @@ def args(self): self.b ) + def data_type(self, schema): + return DoubleType() + class Sqrt(UnaryExpression): pretty_name = "SQRT" @@ -311,6 +355,9 @@ class Sqrt(UnaryExpression): def eval(self, row, schema): return math.sqrt(self.column.eval(row, schema)) + def data_type(self, schema): + return DoubleType() + class Cbrt(UnaryExpression): pretty_name = "CBRT" @@ -318,6 +365,9 @@ class Cbrt(UnaryExpression): def eval(self, row, schema): return self.column.eval(row, schema) ** 1. / 3. + def data_type(self, schema): + return DoubleType() + class Abs(UnaryExpression): pretty_name = "ABS" @@ -325,6 +375,9 @@ class Abs(UnaryExpression): def eval(self, row, schema): return abs(self.column.eval(row, schema)) + def data_type(self, schema): + return self.column.data_type(schema) + class Acos(UnaryExpression): pretty_name = "ACOS" @@ -332,6 +385,9 @@ class Acos(UnaryExpression): def eval(self, row, schema): return math.acos(self.column.eval(row, schema)) + def data_type(self, schema): + return DoubleType() + class Asin(UnaryExpression): pretty_name = "ASIN" @@ -339,6 +395,9 @@ class Asin(UnaryExpression): def eval(self, row, schema): return math.asin(self.column.eval(row, schema)) + def data_type(self, schema): + return DoubleType() + class Atan(UnaryExpression): pretty_name = "ATAN" @@ -346,6 +405,9 @@ class Atan(UnaryExpression): def eval(self, row, schema): return math.atan(self.column.eval(row, schema)) + def data_type(self, schema): + return DoubleType() + class Atan2(Expression): pretty_name = "ATAN" @@ -364,6 +426,9 @@ def args(self): self.x ) + def data_type(self, schema): + return DoubleType() + class Tan(UnaryExpression): pretty_name = "TAN" @@ -371,6 +436,9 @@ class Tan(UnaryExpression): def eval(self, row, schema): return math.tan(self.column.eval(row, schema)) + def data_type(self, schema): + return DoubleType() + class Tanh(UnaryExpression): pretty_name = "TANH" @@ -378,6 +446,9 @@ class Tanh(UnaryExpression): def eval(self, row, schema): return math.tanh(self.column.eval(row, schema)) + def data_type(self, schema): + return DoubleType() + class Cos(UnaryExpression): pretty_name = "COS" @@ -385,6 +456,9 @@ class Cos(UnaryExpression): def eval(self, row, schema): return math.cos(self.column.eval(row, schema)) + def data_type(self, schema): + return DoubleType() + class Cosh(UnaryExpression): pretty_name = "COSH" @@ -392,6 +466,9 @@ class Cosh(UnaryExpression): def eval(self, row, schema): return math.cosh(self.column.eval(row, schema)) + def data_type(self, schema): + return DoubleType() + class Sin(UnaryExpression): pretty_name = "SIN" @@ -399,6 +476,9 @@ class Sin(UnaryExpression): def eval(self, row, schema): return math.sin(self.column.eval(row, schema)) + def data_type(self, schema): + return DoubleType() + class Sinh(UnaryExpression): pretty_name = "SINH" @@ -406,6 +486,9 @@ class Sinh(UnaryExpression): def eval(self, row, schema): return math.sinh(self.column.eval(row, schema)) + def data_type(self, schema): + return DoubleType() + class Exp(UnaryExpression): pretty_name = "EXP" @@ -413,6 +496,9 @@ class Exp(UnaryExpression): def eval(self, row, schema): return math.exp(self.column.eval(row, schema)) + def data_type(self, schema): + return DoubleType() + class ExpM1(UnaryExpression): pretty_name = "EXPM1" @@ -420,6 +506,9 @@ class ExpM1(UnaryExpression): def eval(self, row, schema): return math.expm1(self.column.eval(row, schema)) + def data_type(self, schema): + return DoubleType() + class Factorial(UnaryExpression): pretty_name = "factorial" @@ -427,6 +516,9 @@ class Factorial(UnaryExpression): def eval(self, row, schema): return math.factorial(self.column.eval(row, schema)) + def data_type(self, schema): + return LongType() + class Floor(UnaryExpression): pretty_name = "FLOOR" @@ -434,6 +526,9 @@ class Floor(UnaryExpression): def eval(self, row, schema): return math.floor(self.column.eval(row, schema)) + def data_type(self, schema): + return LongType() + class Ceil(UnaryExpression): pretty_name = "CEIL" @@ -441,6 +536,9 @@ class Ceil(UnaryExpression): def eval(self, row, schema): return math.ceil(self.column.eval(row, schema)) + def data_type(self, schema): + return LongType() + class Log(Expression): pretty_name = "LOG" @@ -458,12 +556,15 @@ def eval(self, row, schema): def args(self): if self.base == math.e: - return (self.value, ) + return (self.value,) return ( self.base, self.value ) + def data_type(self, schema): + return DoubleType() + class Log10(UnaryExpression): pretty_name = "LOG10" @@ -471,6 +572,9 @@ class Log10(UnaryExpression): def eval(self, row, schema): return math.log10(self.column.eval(row, schema)) + def data_type(self, schema): + return DoubleType() + class Log2(UnaryExpression): pretty_name = "LOG2" @@ -478,6 +582,9 @@ class Log2(UnaryExpression): def eval(self, row, schema): return math.log(self.column.eval(row, schema), 2) + def data_type(self, schema): + return DoubleType() + class Log1p(UnaryExpression): pretty_name = "LOG1P" @@ -485,12 +592,18 @@ class Log1p(UnaryExpression): def eval(self, row, schema): return math.log1p(self.column.eval(row, schema)) + def data_type(self, schema): + return DoubleType() + class Rint(UnaryExpression): pretty_name = "ROUND" def eval(self, row, schema): - return round(self.column.eval(row, schema)) + return float(round(self.column.eval(row, schema))) + + def data_type(self, schema): + return DoubleType() class Signum(UnaryExpression): @@ -504,6 +617,9 @@ def eval(self, row, schema): return 1.0 return -1.0 + def data_type(self, schema): + return DoubleType() + class ToDegrees(UnaryExpression): pretty_name = "DEGREES" @@ -511,6 +627,9 @@ class ToDegrees(UnaryExpression): def eval(self, row, schema): return math.degrees(self.column.eval(row, schema)) + def data_type(self, schema): + return DoubleType() + class ToRadians(UnaryExpression): pretty_name = "RADIANS" @@ -518,6 +637,9 @@ class ToRadians(UnaryExpression): def eval(self, row, schema): return math.radians(self.column.eval(row, schema)) + def data_type(self, schema): + return DoubleType() + class Rand(Expression): pretty_name = "rand" @@ -536,6 +658,9 @@ def initialize(self, partition_index): def args(self): return self.seed + def data_type(self, schema): + return DoubleType() + class Randn(Expression): pretty_name = "randn" @@ -554,6 +679,9 @@ def initialize(self, partition_index): def args(self): return self.seed + def data_type(self, schema): + return DoubleType() + class SparkPartitionID(Expression): pretty_name = "SPARK_PARTITION_ID" @@ -571,11 +699,14 @@ def initialize(self, partition_index): def args(self): return () + def data_type(self, schema): + return IntegerType() + class CreateStruct(Expression): pretty_name = "struct" - def __init__(self, columns): + def __init__(self, *columns): super().__init__(columns) self.columns = columns @@ -590,6 +721,12 @@ def eval(self, row, schema): def args(self): return self.columns + def data_type(self, schema): + return StructType([ + StructField(name=str(col), dataType=col.data_type(schema), nullable=True) + for col in self.columns + ]) + class Bin(UnaryExpression): pretty_name = "bin" @@ -597,6 +734,9 @@ class Bin(UnaryExpression): def eval(self, row, schema): return format(self.column.eval(row, schema), 'b') + def data_type(self, schema): + return StringType() + class ShiftLeft(Expression): pretty_name = "shiftleft" @@ -607,7 +747,7 @@ def __init__(self, arg, num_bits): self.num_bits = num_bits.get_literal_value() def eval(self, row, schema): - return self.arg.eval(row, schema) << self.num_bits + return int(self.arg.eval(row, schema)) << self.num_bits def args(self): return ( @@ -615,6 +755,9 @@ def args(self): self.num_bits ) + def data_type(self, schema): + return IntegerType() + class ShiftRight(Expression): pretty_name = "shiftright" @@ -625,7 +768,7 @@ def __init__(self, arg, num_bits): self.num_bits = num_bits.get_literal_value() def eval(self, row, schema): - return self.arg.eval(row, schema) >> self.num_bits + return int(self.arg.eval(row, schema)) >> self.num_bits def args(self): return ( @@ -633,6 +776,9 @@ def args(self): self.num_bits ) + def data_type(self, schema): + return IntegerType() + class ShiftRightUnsigned(Expression): pretty_name = "shiftrightunsigned" @@ -643,7 +789,7 @@ def __init__(self, arg, num_bits): self.num_bits = num_bits.get_literal_value() def eval(self, row, schema): - rightShifted = self.arg.eval(row, schema) >> self.num_bits + rightShifted = int(self.arg.eval(row, schema)) >> self.num_bits return rightShifted % JVM_MAX_INTEGER_SIZE def args(self): @@ -652,11 +798,14 @@ def args(self): self.num_bits ) + def data_type(self, schema): + return IntegerType() + class Greatest(Expression): pretty_name = "greatest" - def __init__(self, columns): + def __init__(self, *columns): super().__init__(columns) self.columns = columns @@ -667,11 +816,16 @@ def eval(self, row, schema): def args(self): return self.columns + def data_type(self, schema): + if not self.columns: + return NullType() + return self.columns[0].data_type(schema) + class Least(Expression): pretty_name = "least" - def __init__(self, columns): + def __init__(self, *columns): super().__init__(columns) self.columns = columns @@ -682,6 +836,11 @@ def eval(self, row, schema): def args(self): return self.columns + def data_type(self, schema): + if not self.columns: + return NullType() + return self.columns[0].data_type(schema) + class Length(UnaryExpression): pretty_name = "length" @@ -689,6 +848,9 @@ class Length(UnaryExpression): def eval(self, row, schema): return len(str(self.column.eval(row, schema))) + def data_type(self, schema): + return IntegerType() + class Lower(UnaryExpression): pretty_name = "lower" @@ -696,6 +858,9 @@ class Lower(UnaryExpression): def eval(self, row, schema): return str(self.column.eval(row, schema)).lower() + def data_type(self, schema): + return StringType() + class Upper(UnaryExpression): pretty_name = "upper" @@ -703,11 +868,14 @@ class Upper(UnaryExpression): def eval(self, row, schema): return str(self.column.eval(row, schema)).upper() + def data_type(self, schema): + return StringType() + class Concat(Expression): pretty_name = "concat" - def __init__(self, columns): + def __init__(self, *columns): super().__init__(columns) self.columns = columns @@ -717,11 +885,14 @@ def eval(self, row, schema): def args(self): return self.columns + def data_type(self, schema): + return StringType() + class ConcatWs(Expression): pretty_name = "concat_ws" - def __init__(self, sep, columns): + def __init__(self, sep, *columns): super().__init__(columns) self.sep = sep.get_literal_value() self.columns = columns @@ -731,9 +902,12 @@ def eval(self, row, schema): def args(self): if self.columns: - return [self.sep] + self.columns + return [self.sep] + list(self.columns) return [self.sep] + def data_type(self, schema): + return StringType() + class Reverse(UnaryExpression): pretty_name = "reverse" @@ -741,6 +915,9 @@ class Reverse(UnaryExpression): def eval(self, row, schema): return str(self.column.eval(row, schema))[::-1] + def data_type(self, schema): + return StringType() + class MapKeys(UnaryExpression): pretty_name = "map_keys" @@ -748,6 +925,9 @@ class MapKeys(UnaryExpression): def eval(self, row, schema): return list(self.column.eval(row, schema).keys()) + def data_type(self, schema): + return ArrayType(elementType=self.column.data_type(schema).keyType) + class MapValues(UnaryExpression): pretty_name = "map_values" @@ -755,6 +935,9 @@ class MapValues(UnaryExpression): def eval(self, row, schema): return list(self.column.eval(row, schema).values()) + def data_type(self, schema): + return ArrayType(elementType=self.column.data_type(schema).valueType) + class MapEntries(UnaryExpression): pretty_name = "map_entries" @@ -762,6 +945,9 @@ class MapEntries(UnaryExpression): def eval(self, row, schema): return list(self.column.eval(row, schema).items()) + def data_type(self, schema): + return ArrayType(elementType=self.column.data_type(schema).valueType) + class MapFromEntries(UnaryExpression): pretty_name = "map_from_entries" @@ -769,11 +955,20 @@ class MapFromEntries(UnaryExpression): def eval(self, row, schema): return dict(self.column.eval(row, schema)) + def data_type(self, schema): + map_type = self.column.data_type(schema) + key_type = map_type.keyType + value_type = map_type.valueType + return ArrayType(elementType=StructType([ + StructField("key", key_type, True), + StructField("value", value_type, True), + ])) + class MapConcat(Expression): pretty_name = "map_concat" - def __init__(self, columns): + def __init__(self, *columns): super().__init__(*columns) self.columns = columns @@ -788,6 +983,11 @@ def eval(self, row, schema): def args(self): return self.columns + def data_type(self, schema): + if not self.columns: + return MapType(keyType=StringType(), valueType=StringType) + return self.columns[0].data_type(schema) + class StringSplit(Expression): pretty_name = "split" @@ -815,6 +1015,9 @@ def args(self): self.limit ) + def data_type(self, schema): + return ArrayType(elementType=StringType()) + class Conv(Expression): pretty_name = "conv" @@ -826,7 +1029,7 @@ def __init__(self, column, from_base, to_base): self.to_base = to_base.get_literal_value() def eval(self, row, schema): - value = self.column.cast(StringType()).eval(row, schema) + value = Cast(self.column, StringType()).eval(row, schema) return self.convert( value, self.from_base, @@ -915,6 +1118,9 @@ def convert(from_string, from_base, to_base, positive_only=False): return returned_string + def data_type(self, schema): + return StringType() + class Hex(UnaryExpression): pretty_name = "hex" @@ -927,6 +1133,9 @@ def eval(self, row, schema): positive_only=True ) + def data_type(self, schema): + return StringType() + class Unhex(UnaryExpression): pretty_name = "unhex" @@ -939,6 +1148,9 @@ def eval(self, row, schema): positive_only=True ) + def data_type(self, schema): + return StringType() + class Ascii(UnaryExpression): pretty_name = "ascii" @@ -952,6 +1164,9 @@ def eval(self, row, schema): return None return ord(value_as_string[0]) + def data_type(self, schema): + return StringType() + class MonotonicallyIncreasingID(Expression): pretty_name = "monotonically_increasing_id" @@ -969,6 +1184,9 @@ def initialize(self, partition_index): def args(self): return () + def data_type(self, schema): + return LongType() + class Base64(UnaryExpression): pretty_name = "base64" @@ -978,6 +1196,9 @@ def eval(self, row, schema): encoded = base64.b64encode(bytes(value, encoding="utf-8")) return str(encoded)[2:-1] + def data_type(self, schema): + return StringType() + class UnBase64(UnaryExpression): pretty_name = "unbase64" @@ -986,11 +1207,14 @@ def eval(self, row, schema): value = self.column.eval(row, schema) return bytearray(base64.b64decode(value)) + def data_type(self, schema): + return BinaryType() + class GroupingID(Expression): pretty_name = "grouping_id" - def __init__(self, columns): + def __init__(self, *columns): super().__init__(*columns) self.columns = columns @@ -1006,6 +1230,9 @@ def eval(self, row, schema): def args(self): return self.columns + def data_type(self, schema): + return IntegerType() + class Grouping(UnaryExpression): pretty_name = "grouping" @@ -1013,10 +1240,13 @@ class Grouping(UnaryExpression): def eval(self, row, schema): metadata = row.get_metadata() if metadata is None or "grouping" not in metadata: - raise AnalysisException("grouping_id() can only be used with GroupingSets/Cube/Rollup") + raise AnalysisException("grouping() can only be used with GroupingSets/Cube/Rollup") pos = self.column.find_position_in_schema(schema) return int(metadata["grouping"][pos]) + def data_type(self, schema): + return self.column.data_type(schema) + class InputFileName(Expression): pretty_name = "input_file_name" @@ -1030,6 +1260,9 @@ def eval(self, row, schema): def args(self): return () + def data_type(self, schema): + return StringType() + __all__ = [ "Grouping", "GroupingID", "Coalesce", "IsNaN", "MonotonicallyIncreasingID", "NaNvl", "Rand", diff --git a/pysparkling/sql/expressions/operators.py b/pysparkling/sql/expressions/operators.py index ce7c9cb85..70310b6ae 100644 --- a/pysparkling/sql/expressions/operators.py +++ b/pysparkling/sql/expressions/operators.py @@ -1,5 +1,5 @@ from ..casts import get_caster -from ..types import Row, StructType +from ..types import BooleanType, DoubleType, largest_numeric_type, Row, StringType, StructType from .expressions import BinaryOperation, Expression, NullSafeBinaryOperation, TypeSafeBinaryOperation, UnaryExpression @@ -10,6 +10,9 @@ def eval(self, row, schema): def __str__(self): return f"(- {self.column})" + def data_type(self, schema): + return self.column.data_type(schema) + class Add(NullSafeBinaryOperation): def unsafe_operation(self, value1, value2): @@ -18,6 +21,11 @@ def unsafe_operation(self, value1, value2): def __str__(self): return f"({self.arg1} + {self.arg2})" + def data_type(self, schema): + type1 = self.arg1.data_type(schema) + type2 = self.arg2.data_type(schema) + return largest_numeric_type(type1, type2, operation="add") + class Minus(NullSafeBinaryOperation): def unsafe_operation(self, value1, value2): @@ -26,6 +34,11 @@ def unsafe_operation(self, value1, value2): def __str__(self): return f"({self.arg1} - {self.arg2})" + def data_type(self, schema): + type1 = self.arg1.data_type(schema) + type2 = self.arg2.data_type(schema) + return largest_numeric_type(type1, type2, operation="minus") + class Time(NullSafeBinaryOperation): def unsafe_operation(self, value1, value2): @@ -34,6 +47,11 @@ def unsafe_operation(self, value1, value2): def __str__(self): return f"({self.arg1} * {self.arg2})" + def data_type(self, schema): + type1 = self.arg1.data_type(schema) + type2 = self.arg2.data_type(schema) + return largest_numeric_type(type1, type2, operation="multiply") + class Divide(NullSafeBinaryOperation): def unsafe_operation(self, value1, value2): @@ -42,6 +60,11 @@ def unsafe_operation(self, value1, value2): def __str__(self): return f"({self.arg1} / {self.arg2})" + def data_type(self, schema): + type1 = self.arg1.data_type(schema) + type2 = self.arg2.data_type(schema) + return largest_numeric_type(type1, type2, operation="divide") + class Mod(NullSafeBinaryOperation): def unsafe_operation(self, value1, value2): @@ -50,6 +73,11 @@ def unsafe_operation(self, value1, value2): def __str__(self): return f"({self.arg1} % {self.arg2})" + def data_type(self, schema): + type1 = self.arg1.data_type(schema) + type2 = self.arg2.data_type(schema) + return largest_numeric_type(type1, type2, operation="mod") + class Pow(NullSafeBinaryOperation): def unsafe_operation(self, value1, value2): @@ -58,6 +86,9 @@ def unsafe_operation(self, value1, value2): def __str__(self): return f"POWER({self.arg1}, {self.arg2})" + def data_type(self, schema): + return DoubleType() + class Equal(TypeSafeBinaryOperation): def unsafe_operation(self, value_1, value_2): @@ -66,6 +97,9 @@ def unsafe_operation(self, value_1, value_2): def __str__(self): return f"({self.arg1} = {self.arg2})" + def data_type(self, schema): + return BooleanType() + class LessThan(TypeSafeBinaryOperation): def unsafe_operation(self, value_1, value_2): @@ -74,6 +108,9 @@ def unsafe_operation(self, value_1, value_2): def __str__(self): return f"({self.arg1} < {self.arg2})" + def data_type(self, schema): + return BooleanType() + class LessThanOrEqual(TypeSafeBinaryOperation): def unsafe_operation(self, value_1, value_2): @@ -82,6 +119,9 @@ def unsafe_operation(self, value_1, value_2): def __str__(self): return f"({self.arg1} <= {self.arg2})" + def data_type(self, schema): + return BooleanType() + class GreaterThan(TypeSafeBinaryOperation): def unsafe_operation(self, value_1, value_2): @@ -90,6 +130,9 @@ def unsafe_operation(self, value_1, value_2): def __str__(self): return f"({self.arg1} > {self.arg2})" + def data_type(self, schema): + return BooleanType() + class GreaterThanOrEqual(TypeSafeBinaryOperation): def unsafe_operation(self, value_1, value_2): @@ -98,6 +141,9 @@ def unsafe_operation(self, value_1, value_2): def __str__(self): return f"({self.arg1} >= {self.arg2})" + def data_type(self, schema): + return BooleanType() + class And(TypeSafeBinaryOperation): def unsafe_operation(self, value_1, value_2): @@ -106,6 +152,9 @@ def unsafe_operation(self, value_1, value_2): def __str__(self): return f"({self.arg1} AND {self.arg2})" + def data_type(self, schema): + return BooleanType() + class Or(TypeSafeBinaryOperation): def unsafe_operation(self, value_1, value_2): @@ -114,6 +163,9 @@ def unsafe_operation(self, value_1, value_2): def __str__(self): return f"({self.arg1} OR {self.arg2})" + def data_type(self, schema): + return BooleanType() + class Invert(UnaryExpression): def eval(self, row, schema): @@ -125,6 +177,9 @@ def eval(self, row, schema): def __str__(self): return f"(NOT {self.column})" + def data_type(self, schema): + return BooleanType() + class BitwiseOr(BinaryOperation): def eval(self, row, schema): @@ -133,6 +188,11 @@ def eval(self, row, schema): def __str__(self): return f"({self.arg1} | {self.arg2})" + def data_type(self, schema): + type1 = self.arg1.data_type(schema) + type2 = self.arg2.data_type(schema) + return largest_numeric_type(type1, type2, operation="bitwise_or") + class BitwiseAnd(BinaryOperation): def eval(self, row, schema): @@ -141,6 +201,11 @@ def eval(self, row, schema): def __str__(self): return f"({self.arg1} & {self.arg2})" + def data_type(self, schema): + type1 = self.arg1.data_type(schema) + type2 = self.arg2.data_type(schema) + return largest_numeric_type(type1, type2, operation="bitwise_and") + class BitwiseXor(BinaryOperation): def eval(self, row, schema): @@ -149,6 +214,11 @@ def eval(self, row, schema): def __str__(self): return f"({self.arg1} ^ {self.arg2})" + def data_type(self, schema): + type1 = self.arg1.data_type(schema) + type2 = self.arg2.data_type(schema) + return largest_numeric_type(type1, type2, operation="bitwise_xor") + class BitwiseNot(UnaryExpression): def eval(self, row, schema): @@ -157,6 +227,9 @@ def eval(self, row, schema): def __str__(self): return f"~{self.column}" + def data_type(self, schema): + return self.column.data_type(schema) + class EqNullSafe(BinaryOperation): def eval(self, row, schema): @@ -165,6 +238,9 @@ def eval(self, row, schema): def __str__(self): return f"({self.arg1} <=> {self.arg2})" + def data_type(self, schema): + return BooleanType() + class GetField(Expression): def __init__(self, item, field): @@ -217,6 +293,9 @@ def args(self): self.value ) + def data_type(self, schema): + return BooleanType() + class StartsWith(Expression): pretty_name = "startswith" @@ -235,6 +314,9 @@ def args(self): self.substr ) + def data_type(self, schema): + return BooleanType() + class EndsWith(Expression): pretty_name = "endswith" @@ -253,6 +335,9 @@ def args(self): self.substr ) + def data_type(self, schema): + return BooleanType() + class IsIn(Expression): def __init__(self, arg1, cols): @@ -270,6 +355,9 @@ def __str__(self): def args(self): return [self.arg1] + self.cols + def data_type(self, schema): + return BooleanType() + class IsNotNull(UnaryExpression): def eval(self, row, schema): @@ -278,6 +366,9 @@ def eval(self, row, schema): def __str__(self): return f"({self.column} IS NOT NULL)" + def data_type(self, schema): + return BooleanType() + class IsNull(UnaryExpression): def eval(self, row, schema): @@ -286,18 +377,25 @@ def eval(self, row, schema): def __str__(self): return f"({self.column} IS NULL)" + def data_type(self, schema): + return BooleanType() + class Cast(Expression): def __init__(self, column, destination_type): super().__init__(column) self.column = column self.destination_type = destination_type - self.caster = get_caster( - from_type=self.column.data_type, to_type=destination_type, options={} - ) def eval(self, row, schema): - return self.caster(self.column.eval(row, schema)) + caster = get_caster( + from_type=self.column.data_type(schema), + to_type=self.destination_type, + options={} + ) + return caster( + self.column.eval(row, schema) + ) def __str__(self): return str(self.column) @@ -311,6 +409,9 @@ def args(self): self.destination_type ) + def data_type(self, schema): + return self.destination_type + class Substring(Expression): pretty_name = "substring" @@ -331,12 +432,15 @@ def args(self): self.length ) + def data_type(self, schema): + return StringType() + class Alias(Expression): - def __init__(self, expr, alias): + def __init__(self, expr, alias: str): super().__init__(expr, alias) self.expr = expr - self.alias = alias.get_literal_value() + self.alias = alias @property def may_output_multiple_cols(self): @@ -362,6 +466,9 @@ def args(self): self.alias ) + def data_type(self, schema): + return self.expr.data_type(schema) + class UnaryPositive(UnaryExpression): def eval(self, row, schema): @@ -370,6 +477,9 @@ def eval(self, row, schema): def __str__(self): return f"(+ {self.column})" + def data_type(self, schema): + return self.column.data_type(schema) + __all__ = [ "Negate", diff --git a/pysparkling/sql/expressions/orders.py b/pysparkling/sql/expressions/orders.py index 6429abe43..8432d36f4 100644 --- a/pysparkling/sql/expressions/orders.py +++ b/pysparkling/sql/expressions/orders.py @@ -17,6 +17,9 @@ def __str__(self): def args(self): return (self.column,) + def data_type(self, schema): + return self.column.data_type(schema) + class AscNullsFirst(SortOrder): sort_order = "ASC NULLS FIRST" diff --git a/pysparkling/sql/expressions/strings.py b/pysparkling/sql/expressions/strings.py index 2665d69a1..3f5311ddb 100644 --- a/pysparkling/sql/expressions/strings.py +++ b/pysparkling/sql/expressions/strings.py @@ -1,8 +1,9 @@ import string from ...utils import levenshtein_distance -from ..types import StringType +from ..types import IntegerType, StringType from .expressions import Expression, UnaryExpression +from .operators import Cast class StringTrim(UnaryExpression): @@ -11,6 +12,9 @@ class StringTrim(UnaryExpression): def eval(self, row, schema): return self.column.eval(row, schema).strip() + def data_type(self, schema): + return StringType() + class StringLTrim(UnaryExpression): pretty_name = "ltrim" @@ -18,6 +22,9 @@ class StringLTrim(UnaryExpression): def eval(self, row, schema): return self.column.eval(row, schema).lstrip() + def data_type(self, schema): + return StringType() + class StringRTrim(UnaryExpression): pretty_name = "rtrim" @@ -25,6 +32,9 @@ class StringRTrim(UnaryExpression): def eval(self, row, schema): return self.column.eval(row, schema).rstrip() + def data_type(self, schema): + return StringType() + class StringInStr(Expression): pretty_name = "instr" @@ -32,13 +42,14 @@ class StringInStr(Expression): def __init__(self, column, substr): super().__init__(column) self.column = column - self.substr = substr.get_literal_value() + self.substr = substr def eval(self, row, schema): - value = self.column.cast(StringType()).eval(row, schema) + value = Cast(self.column, StringType()).eval(row, schema) + substr_value = Cast(self.substr, StringType()).eval(row, schema) try: - return value.index(self.substr) - except IndexError: + return value.index(substr_value) + 1 + except ValueError: return 0 def args(self): @@ -47,6 +58,9 @@ def args(self): self.substr ) + def data_type(self, schema): + return IntegerType() + class StringLocate(Expression): pretty_name = "locate" @@ -58,7 +72,7 @@ def __init__(self, substr, column, pos): self.start = pos.get_literal_value() - 1 def eval(self, row, schema): - value = self.column.cast(StringType()).eval(row, schema) + value = Cast(self.column, StringType()).eval(row, schema) if self.substr not in value[self.start:]: return 0 return value.index(self.substr, self.start) + 1 @@ -75,6 +89,9 @@ def args(self): self.start ) + def data_type(self, schema): + return IntegerType() + class StringLPad(Expression): pretty_name = "lpad" @@ -86,7 +103,7 @@ def __init__(self, column, length, pad): self.pad = pad.get_literal_value() def eval(self, row, schema): - value = self.column.cast(StringType()).eval(row, schema) + value = Cast(self.column, StringType()).eval(row, schema) delta = self.length - len(value) padding = (self.pad * delta)[:delta] # Handle pad with multiple characters return f"{padding}{value}" @@ -98,6 +115,9 @@ def args(self): self.pad ) + def data_type(self, schema): + return StringType() + class StringRPad(Expression): pretty_name = "rpad" @@ -109,7 +129,7 @@ def __init__(self, column, length, pad): self.pad = pad.get_literal_value() def eval(self, row, schema): - value = self.column.cast(StringType()).eval(row, schema) + value = Cast(self.column, StringType()).eval(row, schema) delta = self.length - len(value) padding = (self.pad * delta)[:delta] # Handle pad with multiple characters return f"{value}{padding}" @@ -121,6 +141,9 @@ def args(self): self.pad ) + def data_type(self, schema): + return StringType() + class StringRepeat(Expression): pretty_name = "repeat" @@ -131,7 +154,7 @@ def __init__(self, column, n): self.n = n.get_literal_value() def eval(self, row, schema): - value = self.column.cast(StringType()).eval(row, schema) + value = Cast(self.column, StringType()).eval(row, schema) return value * self.n def args(self): @@ -140,6 +163,9 @@ def args(self): self.n ) + def data_type(self, schema): + return StringType() + class StringTranslate(Expression): pretty_name = "translate" @@ -157,7 +183,7 @@ def __init__(self, column, matching_string, replace_string): ) def eval(self, row, schema): - return self.column.cast(StringType()).eval(row, schema).translate(self.translation_table) + return Cast(self.column, StringType()).eval(row, schema).translate(self.translation_table) def args(self): return ( @@ -166,14 +192,20 @@ def args(self): self.replace_string ) + def data_type(self, schema): + return StringType() + class InitCap(UnaryExpression): pretty_name = "initcap" def eval(self, row, schema): - value = self.column.cast(StringType()).eval(row, schema) + value = Cast(self.column, StringType()).eval(row, schema) return " ".join(word.capitalize() for word in value.split()) + def data_type(self, schema): + return StringType() + class Levenshtein(Expression): pretty_name = "levenshtein" @@ -184,8 +216,8 @@ def __init__(self, column1, column2): self.column2 = column2 def eval(self, row, schema): - value_1 = self.column1.cast(StringType()).eval(row, schema) - value_2 = self.column2.cast(StringType()).eval(row, schema) + value_1 = Cast(self.column1, StringType()).eval(row, schema) + value_2 = Cast(self.column2, StringType()).eval(row, schema) if value_1 is None or value_2 is None: return None return levenshtein_distance(value_1, value_2) @@ -196,6 +228,9 @@ def args(self): self.column2 ) + def data_type(self, schema): + return IntegerType() + class SoundEx(UnaryExpression): pretty_name = "soundex" @@ -209,7 +244,7 @@ class SoundEx(UnaryExpression): } def eval(self, row, schema): - raw_value = self.column.cast(StringType()).eval(row, schema) + raw_value = Cast(self.column, StringType()).eval(row, schema) if raw_value is None: return None @@ -246,6 +281,9 @@ def _encode(self, letter): """ return self._soundex_mapping.get(letter) + def data_type(self, schema): + return StringType() + __all__ = [ "StringTrim", "StringTranslate", "StringRTrim", "StringRepeat", "StringRPad", diff --git a/pysparkling/sql/expressions/userdefined.py b/pysparkling/sql/expressions/userdefined.py index f9559e18f..ae5cc386a 100644 --- a/pysparkling/sql/expressions/userdefined.py +++ b/pysparkling/sql/expressions/userdefined.py @@ -18,5 +18,8 @@ def __str__(self): def args(self): return self.exprs + def data_type(self, schema): + return self.return_type + __all__ = ["UserDefinedFunction"] diff --git a/pysparkling/sql/functions.py b/pysparkling/sql/functions.py index 0ab2d0855..409d4d6db 100644 --- a/pysparkling/sql/functions.py +++ b/pysparkling/sql/functions.py @@ -355,8 +355,8 @@ def struct(*exprs): """ if len(exprs) == 1 and isinstance(exprs[0], list): exprs = exprs[0] - cols = [parse(e) for e in exprs] - return col(CreateStruct(cols)) + columns = [parse(e) for e in exprs] + return col(CreateStruct(*columns)) def array(*exprs): @@ -364,7 +364,7 @@ def array(*exprs): :rtype: Column """ columns = [parse(e) for e in exprs] - return col(ArrayColumn(columns)) + return col(ArrayColumn(*columns)) def map_from_arrays(col1, col2): @@ -412,7 +412,7 @@ def countDistinct(*exprs): :rtype: Column """ columns = [parse(e) for e in exprs] - return col(CountDistinct(columns=columns)) + return col(CountDistinct(*columns)) def collect_set(e): @@ -532,8 +532,8 @@ def grouping_id(*exprs): :rtype: Column """ - cols = [parse(e) for e in exprs] - return col(GroupingID(cols)) + columns = [parse(e) for e in exprs] + return col(GroupingID(*columns)) def kurtosis(e): @@ -725,9 +725,9 @@ def create_map(*exprs): """ if len(exprs) == 1 and isinstance(exprs[0], (list, set)): exprs = exprs[0] - cols = [parse(e) for e in exprs] + columns = [parse(e) for e in exprs] - return col(MapColumn(cols)) + return col(MapColumn(*columns)) def broadcast(df): @@ -743,7 +743,7 @@ def coalesce(*exprs): :rtype: Column """ columns = [parse(e) for e in exprs] - return col(Coalesce(columns)) + return col(Coalesce(*columns)) def input_file_name(): @@ -945,8 +945,8 @@ def greatest(*exprs): """ :rtype: Column """ - cols = [parse(e) for e in exprs] - return col(Greatest(cols)) + columns = [parse(e) for e in exprs] + return col(Greatest(*columns)) # noinspection PyShadowingBuiltins @@ -978,8 +978,8 @@ def least(*exprs): """ :rtype: Column """ - cols = [parse(e) for e in exprs] - return col(Least(cols)) + columns = [parse(e) for e in exprs] + return col(Least(*columns)) def log(arg1, arg2=None): @@ -1288,8 +1288,8 @@ def concat_ws(sep, *exprs): +--------------+ """ - cols = [parse(e) for e in exprs] - return col(ConcatWs(lit(sep), cols)) + columns = [parse(e) for e in exprs] + return col(ConcatWs(lit(sep), *columns)) def decode(value, charset): @@ -2177,8 +2177,8 @@ def concat(*exprs): """ :rtype: Column """ - cols = [parse(e) for e in exprs] - return col(Concat(cols)) + columns = [parse(e) for e in exprs] + return col(Concat(*columns)) def array_position(column, value): @@ -2432,7 +2432,8 @@ def arrays_zip(*exprs): """ :rtype: Column """ - return col(ArraysZip([parse(e) for e in exprs])) + columns = [parse(e) for e in exprs] + return col(ArraysZip(*columns)) def map_concat(*exprs): @@ -2453,8 +2454,8 @@ def map_concat(*exprs): [Row(map_concat(m1, m2)={1: 'a', 2: 'c', 3: 'd'})] """ - cols = [parse(e) for e in exprs] - return col(MapConcat(cols)) + columns = [parse(e) for e in exprs] + return col(MapConcat(*columns)) def from_csv(e, schema, options=None): @@ -2538,7 +2539,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): raise NotImplementedError("Pysparkling does not support yet this function") -def callUDF(udfName, *cols): +def callUDF(udfName, *columns): raise NotImplementedError("Pysparkling does not support yet this function") diff --git a/pysparkling/sql/internals.py b/pysparkling/sql/internals.py index 9acefd147..fccd187bf 100644 --- a/pysparkling/sql/internals.py +++ b/pysparkling/sql/internals.py @@ -358,20 +358,21 @@ def select(self, *exprs): def select_mapper(partition_index, partition): # Initialize non deterministic functions so that they are reproducible - initialized_cols = [col.initialize(partition_index) for col in cols] - generators = [col for col in initialized_cols if col.may_output_multiple_rows] - non_generators = [col for col in initialized_cols if not col.may_output_multiple_rows] + for col in cols: + col.initialize(partition_index) + generators = [col for col in cols if col.may_output_multiple_rows] + non_generators = [col for col in cols if not col.may_output_multiple_rows] number_of_generators = len(generators) if number_of_generators > 1: raise Exception( "Only one generator allowed per select clause" - f" but found {number_of_generators}: {', '.join(generators)}" + f" but found {number_of_generators}: {', '.join(str(g) for g in generators)}" ) return self.get_select_output_field_lists( partition, non_generators, - initialized_cols, + cols, generators[0] if generators else None ) @@ -423,8 +424,8 @@ def filter(self, condition): condition = parse(condition) def mapper(partition_index, partition): - initialized_condition = condition.initialize(partition_index) - return (row for row in partition if initialized_condition.eval(row, self.bound_schema)) + condition.initialize(partition_index) + return (row for row in partition if condition.eval(row, self.bound_schema)) return self._with_rdd( self._rdd.mapPartitionsWithIndex(mapper), @@ -637,7 +638,7 @@ def horizontal_show(rows, cols, truncate, min_col_width): output += sep output += _generate_show_layout('|', padded_header) output += sep - output += '\n'.join(_generate_show_layout('|', row) for row in padded_rows) + output += ''.join(_generate_show_layout('|', row) for row in padded_rows) output += sep return output diff --git a/pysparkling/sql/session.py b/pysparkling/sql/session.py index babd55d99..0e19fadf5 100644 --- a/pysparkling/sql/session.py +++ b/pysparkling/sql/session.py @@ -3,6 +3,7 @@ from ..__version__ import __version__ from ..context import Context from ..rdd import RDD +from .ast.ast_to_python import parse_ddl_string from .conf import RuntimeConfig from .dataframe import DataFrame from .internals import DataFrameInternal @@ -223,7 +224,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr raise TypeError("data is already a DataFrame") if isinstance(schema, str): - schema = StructType.fromDDL(schema) + schema = parse_ddl_string(schema) elif isinstance(schema, (list, tuple)): # Must re-encode any unicode strings to be consistent with StructField names schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema] diff --git a/pysparkling/sql/tests/test_casts.py b/pysparkling/sql/tests/test_casts.py index 70fb82dac..eefd0ef83 100644 --- a/pysparkling/sql/tests/test_casts.py +++ b/pysparkling/sql/tests/test_casts.py @@ -2,12 +2,10 @@ import datetime from unittest import TestCase -import pytz - from pysparkling.sql.casts import ( cast_from_none, cast_to_array, cast_to_binary, cast_to_boolean, cast_to_byte, cast_to_date, cast_to_decimal, cast_to_float, cast_to_int, cast_to_long, cast_to_map, cast_to_short, cast_to_string, cast_to_struct, - cast_to_timestamp, FloatType, identity + cast_to_timestamp, FloatType, identity, tz_diff ) from pysparkling.sql.types import ( ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, IntegerType, LongType, MapType, @@ -20,19 +18,6 @@ class CastTests(TestCase): maxDiff = None - def setUp(self): - # TODO: what is the behaviour of Spark when you load up something in another timezone? - # - # os.environ['TZ'] = 'Europe/Paris' - # - # if hasattr(time, 'tzset'): - # time.tzset() # pylint: disable=no-member - - now_here = datetime.datetime.now().astimezone() - now_utc = now_here.replace(tzinfo=pytz.utc) - - self.tz_diff = now_utc - now_here - def test_identity(self): x = object() self.assertEqual(identity(x, options=BASE_OPTIONS), x) @@ -246,13 +231,14 @@ def test_cast_timestamp_to_float_without_jump_issue(self): # there is a discrepancy in behaviours # This test is using a value for which Spark can handle the exact value # Hence the behaviour is the same in pysparkling and PySpark + input_date = datetime.datetime(2019, 8, 28, 0, 2, 40) self.assertEqual( cast_to_float( - datetime.datetime(2019, 8, 28, 0, 2, 40), + input_date, TimestampType(), options=BASE_OPTIONS ), - 1566939760.0 + self.tz_diff.seconds + 1566950560.0 - tz_diff(input_date).seconds ) def test_cast_string_to_binary(self): @@ -327,43 +313,51 @@ def test_cast_basic_string_to_timestamp(self): ) def test_cast_gmt_string_to_timestamp(self): + gmt_expectation = datetime.datetime(2019, 10, 1, 5, 40, 36) + expected = gmt_expectation + tz_diff(gmt_expectation) self.assertEqual( cast_to_timestamp( "2019-10-01T05:40:36Z", StringType(), options=BASE_OPTIONS ), - datetime.datetime(2019, 10, 1, 6, 40, 36) + self.tz_diff + expected ) def test_cast_weird_tz_string_to_timestamp(self): + gmt_expectation = datetime.datetime(2019, 10, 1, 2, 35, 36) + expected = gmt_expectation + tz_diff(gmt_expectation) self.assertEqual( cast_to_timestamp( "2019-10-01T05:40:36+3:5", StringType(), options=BASE_OPTIONS ), - datetime.datetime(2019, 10, 1, 3, 35, 36) + self.tz_diff + expected ) def test_cast_short_tz_string_to_timestamp(self): + gmt_expectation = datetime.datetime(2019, 10, 1, 2, 40, 36) + expected = gmt_expectation + tz_diff(gmt_expectation) self.assertEqual( cast_to_timestamp( "2019-10-01T05:40:36+03", StringType(), options=BASE_OPTIONS ), - datetime.datetime(2019, 10, 1, 3, 40, 36) + self.tz_diff + expected ) def test_cast_longer_tz_string_to_timestamp(self): + gmt_expectation = datetime.datetime(2019, 10, 1, 2, 40, 36) + expected = gmt_expectation + tz_diff(gmt_expectation) self.assertEqual( cast_to_timestamp( "2019-10-01T05:40:36+03:", StringType(), options=BASE_OPTIONS ), - datetime.datetime(2019, 10, 1, 3, 40, 36) + self.tz_diff + expected ) def test_cast_date_string_to_timestamp(self): @@ -422,33 +416,39 @@ def test_cast_hour_string_to_timestamp(self): ) def test_cast_bool_to_timestamp(self): + gmt_expectation = datetime.datetime(1970, 1, 1, 0, 0, 1) + expected = gmt_expectation + tz_diff(gmt_expectation) self.assertEqual( cast_to_timestamp( True, BooleanType(), options=BASE_OPTIONS ), - datetime.datetime(1970, 1, 1, 0, 0, 1) + self.tz_diff + expected ) def test_cast_int_to_timestamp(self): + gmt_expectation = datetime.datetime(1971, 1, 1, 0, 0, 0) + expected = gmt_expectation + tz_diff(gmt_expectation) self.assertEqual( cast_to_timestamp( 86400 * 365, IntegerType(), options=BASE_OPTIONS ), - datetime.datetime(1971, 1, 1, 0, 0, 0) + self.tz_diff + expected ) def test_cast_decimal_to_timestamp(self): + gmt_expectation = datetime.datetime(1970, 1, 1, 0, 2, 27, 580000) + expected = gmt_expectation + tz_diff(gmt_expectation) self.assertEqual( cast_to_timestamp( 147.58, DecimalType(), options=BASE_OPTIONS ), - datetime.datetime(1970, 1, 1, 0, 2, 27, 580000) + self.tz_diff + expected ) def test_cast_date_to_decimal(self): @@ -463,14 +463,16 @@ def test_cast_date_to_decimal(self): ) def test_cast_timestamp_to_decimal_without_scale(self): + input_date = datetime.datetime(2019, 8, 28) + gmt_expectation = 1566950400.0 self.assertEqual( cast_to_decimal( - datetime.datetime(2019, 8, 28), + input_date, TimestampType(), DecimalType(), options=BASE_OPTIONS ), - 1566939600.0 + self.tz_diff.seconds + gmt_expectation - tz_diff(input_date).seconds ) def test_cast_timestamp_to_decimal_with_too_small_precision(self): @@ -485,14 +487,16 @@ def test_cast_timestamp_to_decimal_with_too_small_precision(self): ) def test_cast_timestamp_to_decimal_with_scale(self): + input_date = datetime.datetime(2019, 8, 28) + gmt_expectation = 1566950400.0 self.assertEqual( cast_to_decimal( - datetime.datetime(2019, 8, 28), + input_date, TimestampType(), DecimalType(precision=11, scale=1), options=BASE_OPTIONS ), - 1566939600.0 + self.tz_diff.seconds + gmt_expectation - tz_diff(input_date).seconds ) def test_cast_float_to_decimal_with_scale(self): diff --git a/pysparkling/sql/tests/test_write.py b/pysparkling/sql/tests/test_write.py index 88ce6c4de..896479c25 100644 --- a/pysparkling/sql/tests/test_write.py +++ b/pysparkling/sql/tests/test_write.py @@ -6,6 +6,7 @@ from dateutil.tz import tzlocal from pysparkling import Context, Row +from pysparkling.sql.casts import tz_diff from pysparkling.sql.session import SparkSession from pysparkling.sql.utils import AnalysisException @@ -23,6 +24,21 @@ def get_folder_content(folder_path): return folder_content +def format_as_offset(delta: datetime.timedelta): + """ + Format a timedelta as a timezone offset string, e.g. "+01:00" + """ + if delta.days < 0: + sign = "-" + delta = -delta + else: + sign = "+" + hours = delta.seconds // 3600 + minutes = (delta.seconds % 3600) // 60 + formatted = f'{sign}{hours:02d}:{minutes:02d}' + return formatted + + class DataFrameWriterTests(TestCase): maxDiff = None @@ -34,16 +50,17 @@ def clean(): def setUp(self): self.clean() - tz = datetime.datetime.now().astimezone().strftime('%z') # +0100 - self.tz = f'{tz[:3]}:{tz[3:]}' # --> +01:00 - def tearDown(self): self.clean() def test_write_to_csv(self): + alice_time = datetime.datetime(2017, 1, 1, tzinfo=tzlocal()) + bob_time = datetime.datetime(2014, 3, 2, tzinfo=tzlocal()) + alice_tz = format_as_offset(tz_diff(alice_time)) + bob_tz = format_as_offset(tz_diff(bob_time)) df = spark.createDataFrame( - [Row(age=2, name='Alice', time=datetime.datetime(2017, 1, 1, tzinfo=tzlocal()), ), - Row(age=5, name='Bob', time=datetime.datetime(2014, 3, 2, tzinfo=tzlocal()))] + [Row(age=2, name='Alice', time=alice_time), + Row(age=5, name='Bob', time=bob_time)] ) df.write.csv(".tmp/wonderland/") self.assertDictEqual( @@ -51,8 +68,8 @@ def test_write_to_csv(self): { '_SUCCESS': [], 'part-00000-8447389540241120843.csv': [ - f'2,Alice,2017-01-01T00:00:00.000{self.tz}\n', - f'5,Bob,2014-03-02T00:00:00.000{self.tz}\n' + f'2,Alice,2017-01-01T00:00:00.000{alice_tz}\n', + f'5,Bob,2014-03-02T00:00:00.000{bob_tz}\n' ] } ) @@ -98,9 +115,13 @@ def test_write_to_csv_fail_when_overwrite(self): ) def test_write_to_json(self): + alice_time = datetime.datetime(2017, 1, 1, tzinfo=tzlocal()) + bob_time = datetime.datetime(2014, 3, 2, tzinfo=tzlocal()) + alice_tz = format_as_offset(tz_diff(alice_time)) + bob_tz = format_as_offset(tz_diff(bob_time)) df = spark.createDataFrame( - [Row(age=2, name='Alice', time=datetime.datetime(2017, 1, 1, tzinfo=tzlocal()), ), - Row(age=5, name='Bob', time=datetime.datetime(2014, 3, 2, tzinfo=tzlocal()))] + [Row(age=2, name='Alice', time=alice_time), + Row(age=5, name='Bob', time=bob_time)] ) df.write.json(".tmp/wonderland/") self.assertDictEqual( @@ -108,8 +129,8 @@ def test_write_to_json(self): { '_SUCCESS': [], 'part-00000-8447389540241120843.json': [ - f'{{"age":2,"name":"Alice","time":"2017-01-01T00:00:00.000{self.tz}"}}\n', - f'{{"age":5,"name":"Bob","time":"2014-03-02T00:00:00.000{self.tz}"}}\n', + f'{{"age":2,"name":"Alice","time":"2017-01-01T00:00:00.000{alice_tz}"}}\n', + f'{{"age":5,"name":"Bob","time":"2014-03-02T00:00:00.000{bob_tz}"}}\n', ], } ) diff --git a/pysparkling/sql/types.py b/pysparkling/sql/types.py index 3b6574e05..5447b52b9 100644 --- a/pysparkling/sql/types.py +++ b/pysparkling/sql/types.py @@ -24,7 +24,9 @@ import re import sys -from .utils import ParseException, require_minimum_pandas_version +from sqlparser.internalparser import SqlParsingError + +from .utils import AnalysisException, require_minimum_pandas_version __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", @@ -197,6 +199,9 @@ class DecimalType(FractionalType): :param scale: the number of digits on right side of dot. (default: 0) """ + _MAX_PRECISION = 38 + _MINIMUM_ADJUSTED_SCALE = 6 + def __init__(self, precision=10, scale=0): self.precision = precision self.scale = scale @@ -211,6 +216,19 @@ def jsonValue(self): def __repr__(self): return "DecimalType(%d,%d)" % (self.precision, self.scale) + @classmethod + def adjust_precision_scale(cls, precision, scale): + if precision <= cls._MAX_PRECISION: + return cls(precision, scale) + if scale < 0: + return cls(cls._MAX_PRECISION, scale) + + current_digits = precision - scale + min_scale_value = min(scale, cls._MINIMUM_ADJUSTED_SCALE) + adjusted_scale = max(cls._MAX_PRECISION - current_digits, min_scale_value) + + return cls(cls._MAX_PRECISION, adjusted_scale) + class DoubleType(FractionalType): """Double data type, representing double precision floats. @@ -553,19 +571,20 @@ def simpleString(self): def treeString(self): """ - >>> schema = StructType.fromDDL('some_str: string, some_int: integer, some_date: date') + >>> from pysparkling.sql.ast.ast_to_python import parse_ddl_string + >>> schema = parse_ddl_string('some_str: string, some_int: integer, some_date: date') >>> print(schema.treeString()) |-- some_str: string (nullable = true) |-- some_int: integer (nullable = true) |-- some_date: date (nullable = true) - >>> schema = StructType.fromDDL('some_str: string, arr: array') + >>> schema = parse_ddl_string('some_str: string, arr: array') >>> print(schema.treeString()) |-- some_str: string (nullable = true) |-- arr: array (nullable = true) | |-- element: string (containsNull = true) - >>> schema = StructType.fromDDL('some_str: string, arr: array>') + >>> schema = parse_ddl_string('some_str: string, arr: array>') >>> print(schema.treeString()) |-- some_str: string (nullable = true) |-- arr: array (nullable = true) @@ -691,29 +710,6 @@ def fromInternal(self, obj): values = obj return create_row(self.names, values) - @classmethod - def fromDDL(cls, string): - def get_class(type_: str) -> DataType: - type_to_load = f'{type_.strip().title()}Type' - - if type_to_load not in globals(): - match = re.match(r'^\s*array\s*<(.*)>\s*$', type_, flags=re.IGNORECASE) - if match: - return ArrayType(get_class(match.group(1))) - - raise ValueError(f"Couldn't find '{type_to_load}'?") - - return globals()[type_to_load]() - - fields = StructType() - - for description in string.split(','): - name, type_ = [x.strip() for x in description.split(':')] - - fields.add(StructField(name.strip(), get_class(type_), True)) - - return fields - class UserDefinedType(DataType): """User-defined type (UDT). @@ -846,34 +842,9 @@ def _parse_datatype_string(s): for :class:`IntegerType`. Since Spark 2.3, this also supports a schema in a DDL-formatted string and case-insensitive strings. """ - raise NotImplementedError("_parse_datatype_string is not yet supported by pysparkling") - # pylint: disable=W0511 - # todo: implement in pure Python the code below - # NB: it probably requires to use antl4r - - # sc = SparkContext._active_spark_context - # - # def from_ddl_schema(type_str): - # return _parse_datatype_json_string( - # sc._jvm.org.apache.spark.sql.types.StructType.fromDDL(type_str).json()) - # - # def from_ddl_datatype(type_str): - # return _parse_datatype_json_string( - # sc._jvm.org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str).json()) - # - # try: - # # DDL format, "fieldname datatype, fieldname datatype". - # return from_ddl_schema(s) - # except Exception as e: - # try: - # # For backwards compatibility, "integer", "struct" and etc. - # return from_ddl_datatype(s) - # except: - # try: - # # For backwards compatibility, "fieldname: datatype, fieldname: datatype" case. - # return from_ddl_datatype("struct<%s>" % s.strip()) - # except: - # raise e + # pylint: disable=import-outside-toplevel, cyclic-import + from pysparkling.sql.ast.ast_to_python import parse_ddl_string + return parse_ddl_string(s) def _parse_datatype_json_string(json_string): @@ -1138,6 +1109,56 @@ def _merge_type(a, b, name=None): return a +INTEGRAL_TYPES_ORDER = ( + # DecimalType, FloatType, DoubleType, + ByteType, ShortType, IntegerType, LongType +) + + +def largest_numeric_type(a, b, operation): + if a == b: + return a + if not isinstance(a, NumericType) or not isinstance(b, NumericType): + raise AnalysisException(f"Expected two numeric types, got {a} and {b}") + a_is_fractional = isinstance(a, FractionalType) + b_is_fractional = isinstance(b, FractionalType) + if not a_is_fractional and not b_is_fractional: + a_index = INTEGRAL_TYPES_ORDER.index(a) + b_index = INTEGRAL_TYPES_ORDER.index(b) + return INTEGRAL_TYPES_ORDER[max(a_index, b_index)] + if a_is_fractional and not b_is_fractional: + return a + if not a_is_fractional and b_is_fractional: + return b + if isinstance(a, (FloatType, DoubleType)) and isinstance(b, (FloatType, DoubleType)): + return DoubleType() + if isinstance(a, DecimalType) and isinstance(b, DecimalType): + return merge_decimal_types(a.precision, a.scale, b.precision, b.scale, operation) + raise AnalysisException(f"Unable to merge types {a} and {b}") + + +def merge_decimal_types(p1, s1, p2, s2, operation): + """ + Computes from 2 decimals precision and scale the operation result type + """ + if operation in ("add", "minus"): + result_scale = max(s1, s2) + return DecimalType.adjust_precision_scale(max(p1 - s1, p2 - s2) + result_scale + 1, result_scale) + if operation == "multiply": + return DecimalType.adjust_precision_scale(p1 + p2 + 1, s1 + s2) + if operation == "divide": + result_scale = max(6, s1 + p2 + 1) + return DecimalType.adjust_precision_scale(p1 - s1 + s2 + result_scale, result_scale) + if operation == "mod": + result_scale = max(s1, s2) + return DecimalType.adjust_precision_scale(min(p1 - s1, p2 - s2) + result_scale, result_scale) + if operation in ("bitwise_or", "bitwise_and", "bitwise_xor"): + if (p1, s1) != (p2, s2): + raise AnalysisException("data type mismatch: differing types") + return DecimalType.adjust_precision_scale(p1, s1) + raise ValueError(f"Unknown operation {operation}") + + def _need_converter(dataType): if isinstance(dataType, StructType): return True @@ -1840,31 +1861,47 @@ def _check_series_convert_timestamps_tz_local(s, timezone): byte=ByteType(), smallint=ShortType(), short=ShortType(), - int=LongType(), - integer=LongType(), + int=IntegerType(), + integer=IntegerType(), bigint=LongType(), long=LongType(), float=FloatType(), + real=FloatType(), double=DoubleType(), date=DateType(), timestamp=TimestampType(), string=StringType(), binary=BinaryType(), - decimal=DecimalType() + decimal=DecimalType(), + dec=DecimalType(), + numeric=DecimalType(), + struct=StructType(), ) -def string_to_type(string): - if string in STRING_TO_TYPE: - return STRING_TO_TYPE[string] - if string.startswith("decimal("): - arguments = string[8:-1] - if arguments.count(",") == 1: - precision, scale = arguments.split(",") +def parsed_string_to_type(data_type, arguments): + data_type = data_type.lower() + if not arguments and data_type in STRING_TO_TYPE: + return STRING_TO_TYPE[data_type] + if data_type in ("dec", "decimal"): + if len(arguments) == 2: + precision, scale = arguments + elif len(arguments) == 1: + precision, scale = arguments[0], 0 else: - precision, scale = arguments, 0 + raise SqlParsingError("Unrecognized decimal parameters: {0}".format(arguments)) return DecimalType(precision=int(precision), scale=int(scale)) - raise ParseException(f"Unable to parse data type {string}") + if data_type == "array" and len(arguments) == 1: + return ArrayType(arguments[0]) + if data_type == "map" and len(arguments) == 2: + return MapType(arguments[0], arguments[1]) + if data_type == "struct" and len(arguments) == 1 and all(len(arg) >= 2 for arg in arguments[0]): + return StructType([StructField(*arg) for arg in arguments[0]]) + if data_type in ("char", "varchar") and len(arguments) == 1: + return StringType() + raise SqlParsingError( + "Unable to parse data type {0}{1}".format(data_type, arguments if arguments else "") + ) # Internal type hierarchy: diff --git a/pysparkling/sql/utils.py b/pysparkling/sql/utils.py index d5dcfcd15..ac08f7335 100644 --- a/pysparkling/sql/utils.py +++ b/pysparkling/sql/utils.py @@ -6,10 +6,6 @@ class AnalysisException(CapturedException): pass -class ParseException(CapturedException): - pass - - class IllegalArgumentException(CapturedException): pass diff --git a/pysparkling/tests/test_textFile.py b/pysparkling/tests/test_textFile.py index ec43f4c64..be09f95ad 100644 --- a/pysparkling/tests/test_textFile.py +++ b/pysparkling/tests/test_textFile.py @@ -260,21 +260,21 @@ def test_pyspark_compatibility_txt(): kv = Context().textFile( f'{LOCAL_TEST_PATH}/pyspark/key_value.txt').collect() print(kv) - assert u"('a', 1)" in kv and u"('b', 2)" in kv and len(kv) == 2 + assert "('a', 1)" in kv and "('b', 2)" in kv and len(kv) == 2 def test_pyspark_compatibility_bz2(): kv = Context().textFile( f'{LOCAL_TEST_PATH}/pyspark/key_value.txt.bz2').collect() print(kv) - assert u"a\t1" in kv and u"b\t2" in kv and len(kv) == 2 + assert "a\t1" in kv and "b\t2" in kv and len(kv) == 2 def test_pyspark_compatibility_gz(): kv = Context().textFile( f'{LOCAL_TEST_PATH}/pyspark/key_value.txt.gz').collect() print(kv) - assert u"a\t1" in kv and u"b\t2" in kv and len(kv) == 2 + assert "a\t1" in kv and "b\t2" in kv and len(kv) == 2 def test_local_regex_read(): diff --git a/setup.cfg b/setup.cfg index fbb7bd5a2..ba9e7fdbb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,7 +6,6 @@ max-line-length = 119 [tool:pytest] addopts = --doctest-modules --cov=pysparkling --cov-report=html --cov-branch testpaths = pysparkling -doctest_optionflags = ALLOW_UNICODE NORMALIZE_WHITESPACE [pycodestyle] max-line-length=119 @@ -47,4 +46,4 @@ order_by_type = False case_sensitive = False multi_line_output = 5 force_sort_within_sections = True -skip = versioneer.py \ No newline at end of file +skip = versioneer.py diff --git a/setup.py b/setup.py index 262d7a714..d1c869518 100644 --- a/setup.py +++ b/setup.py @@ -13,37 +13,42 @@ url='https://github.com/svenkreiss/pysparkling', install_requires=[ - 'boto>=2.36.0', - 'future>=0.15', - 'requests>=2.6.0', - 'pytz>=2019.3', - 'python-dateutil>=2.8.0' + 'boto==2.49.0', + 'future==0.18.2', + 'requests==2.26.0', + 'pytz==2021.1', + 'python-dateutil==2.8.2', + 'pythonsqlparser==0.1.2', ], extras_require={ 'hdfs': ['hdfs>=2.0.0'], - 'performance': ['matplotlib>=1.5.3'], - 'streaming': ['tornado>=4.3'], + 'performance': ['matplotlib==3.3.4'], + 'streaming': ['tornado==6.1'], + 'dev': [ + 'antlr4-python3-runtime==4.7.1', + ], 'sql': [ - 'numpy', - 'pandas>=0.23.2', + 'numpy==1.19.5', + 'pandas==1.1.5', ], 'tests': [ 'backports.tempfile==1.0rc1', - 'cloudpickle>=0.1.0', - 'futures>=3.0.1', - 'pylint', - 'pylzma', - 'memory-profiler>=0.47', - 'pycodestyle', - 'pytest', - 'pytest-cov', - 'isort', - 'tornado>=4.3', + 'cloudpickle==1.6.0', + 'futures==3.1.1', + 'pylint==2.10.2', + 'pylzma==0.5.0', + 'memory-profiler==0.58.0', + 'pycodestyle==2.7.0', + 'pytest==6.2.4', + 'pytest-cov==2.12.1', + 'isort==5.9.3', + 'tornado==6.1', + 'parameterized==0.7.4', ], 'scripts': [ - 'ipyparallel', - 'pyspark', - 'matplotlib', + 'ipyparallel==6.3.0', + 'pyspark==3.1.2', + 'matplotlib==3.3.4', ] },