From bedcc95f95417c87111326de1c13c99e6c0221ce Mon Sep 17 00:00:00 2001 From: Matthew Nibecker Date: Thu, 16 Nov 2023 15:03:17 -0700 Subject: [PATCH] semantic check: identify boolean udfs (#4886) Update the semantic analyzer isBool function to identify user-defined functions that return a boolean value. --- compiler/semantic/op.go | 12 ++++++++---- compiler/ztests/udf-implied-where.yaml | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) create mode 100644 compiler/ztests/udf-implied-where.yaml diff --git a/compiler/semantic/op.go b/compiler/semantic/op.go index 721b5120ac..7a041e6646 100644 --- a/compiler/semantic/op.go +++ b/compiler/semantic/op.go @@ -1014,7 +1014,7 @@ func (a *analyzer) semOpExpr(e ast.Expr, seq dag.Seq) (dag.Seq, error) { if err != nil { return nil, err } - if isBool(out) { + if a.isBool(out) { return append(seq, dag.NewFilter(out)), nil } return append(seq, &dag.Yield{ @@ -1023,12 +1023,12 @@ func (a *analyzer) semOpExpr(e ast.Expr, seq dag.Seq) (dag.Seq, error) { }), nil } -func isBool(e dag.Expr) bool { +func (a *analyzer) isBool(e dag.Expr) bool { switch e := e.(type) { case *dag.Literal: return e.Value == "true" || e.Value == "false" case *dag.UnaryExpr: - return isBool(e.Operand) + return a.isBool(e.Operand) case *dag.BinaryExpr: switch e.Op { case "and", "or", "in", "==", "!=", "<", "<=", ">", ">=": @@ -1037,8 +1037,12 @@ func isBool(e dag.Expr) bool { return false } case *dag.Conditional: - return isBool(e.Then) && isBool(e.Else) + return a.isBool(e.Then) && a.isBool(e.Else) case *dag.Call: + // If udf recurse to inner expression. + if f, _ := a.scope.LookupExpr(e.Name); f != nil { + return a.isBool(f.(*dag.Func).Expr) + } if e.Name == "cast" { if len(e.Args) != 2 { return false diff --git a/compiler/ztests/udf-implied-where.yaml b/compiler/ztests/udf-implied-where.yaml new file mode 100644 index 0000000000..63bfdd15b5 --- /dev/null +++ b/compiler/ztests/udf-implied-where.yaml @@ -0,0 +1,14 @@ +script: | + zc -C -s 'func h(e): ( has(e) ) h(foo)' + +outputs: + - name: stdout + data: | + reader + | ( + func h(e): ( + has(e) + ) + + where h(foo) + )