diff --git a/engine/clause.go b/engine/clause.go index 34bd07d..afa1d46 100644 --- a/engine/clause.go +++ b/engine/clause.go @@ -64,19 +64,19 @@ type clause struct { } func compileClause(head Term, body Term, env *Env) (clause, error) { - var c clause - var goals []Term - - head, goals = desugar(head, goals) - body, goals = desugar(body, goals) + head, preds := desugarHead(head, env) + body = desugarBody(body, env) - if body != nil { - goals = append(goals, body) - } - if len(goals) > 0 { - body = seq(atomComma, goals...) + if len(preds) > 0 { + predSeq := seq(atomComma, preds...) + if body == nil { + body = predSeq + } else { + body = atomComma.Apply(body, predSeq) + } } + var c clause c.compileHead(head, env) if body != nil { @@ -86,22 +86,48 @@ func compileClause(head Term, body Term, env *Env) (clause, error) { } c.emit(instruction{opcode: OpExit}) + return c, nil } -func desugar(term Term, acc []Term) (Term, []Term) { - switch t := term.(type) { +func desugarHead(head Term, env *Env) (Term, []Term) { + if head, ok := env.Resolve(head).(Compound); ok { + return desugarPred(head, nil, env) + } + return head, nil +} + +func desugarBody(body Term, env *Env) Term { + if body == nil { + return body + } + + var items []Term + iter := seqIterator{Seq: body, Env: env} + for iter.Next() { + t, preds := desugarPred(iter.Current(), nil, env) + if len(preds) > 0 { + items = append(items, preds...) + } + items = append(items, t) + } + + return seq(atomComma, items...) +} + +func desugarPred(term Term, acc []Term, env *Env) (Term, []Term) { + switch t := env.Resolve(term).(type) { case charList, codeList: return t, acc case list: l := make(list, len(t)) for i, e := range t { - l[i], acc = desugar(e, acc) + l[i], acc = desugarPred(e, acc, env) } return l, acc case *partial: - c, acc := desugar(t.Compound, acc) - tail, acc := desugar(*t.tail, acc) + c, acc := desugarPred(t.Compound, acc, env) + tail, acc := desugarPred(*t.tail, acc, env) return &partial{ Compound: c.(Compound), tail: &tail, @@ -109,8 +135,8 @@ func desugar(term Term, acc []Term) (Term, []Term) { case Compound: if t.Functor() == atomSpecialDot && t.Arity() == 2 { tempV := NewVariable() - lhs, acc := desugar(t.Arg(0), acc) - rhs, acc := desugar(t.Arg(1), acc) + lhs, acc := desugarPred(t.Arg(0), acc, env) + rhs, acc := desugarPred(t.Arg(1), acc, env) return tempV, append(acc, atomDot.Apply(lhs, rhs, tempV)) } @@ -120,7 +146,7 @@ func desugar(term Term, acc []Term) (Term, []Term) { args: make([]Term, t.Arity()), } for i := 0; i < t.Arity(); i++ { - c.args[i], acc = desugar(t.Arg(i), acc) + c.args[i], acc = desugarPred(t.Arg(i), acc, env) } if _, ok := t.(Dict); ok { diff --git a/engine/text_test.go b/engine/text_test.go index 1a0cedb..3ddd593 100644 --- a/engine/text_test.go +++ b/engine/text_test.go @@ -282,6 +282,58 @@ point(point{x: 5}.x). }, }, )}, + {title: "dict head (3)", text: ` +point(point{x: 5}.x) :- true. +`, result: buildOrderedMap( + procedurePair{ + Key: procedureIndicator{name: NewAtom("foo"), arity: 1}, + Value: &userDefined{ + multifile: true, + clauses: clauses{ + { + pi: procedureIndicator{name: NewAtom("foo"), arity: 1}, + raw: &compound{functor: NewAtom("foo"), args: []Term{NewAtom("c")}}, + bytecode: bytecode{ + {opcode: OpGetConst, operand: NewAtom("c")}, + {opcode: OpExit}, + }, + }, + }, + }, + }, + procedurePair{ + Key: procedureIndicator{name: NewAtom("point"), arity: 1}, + Value: &userDefined{ + clauses: clauses{ + { + pi: procedureIndicator{name: NewAtom("point"), arity: 1}, + raw: atomIf.Apply( + &compound{functor: "point", args: []Term{ + &compound{functor: "$dot", args: []Term{ + &dict{compound: compound{functor: "dict", args: []Term{NewAtom("point"), NewAtom("x"), Integer(5)}}}, + NewAtom("x"), + }}}}, + NewAtom("true")), + vars: []Variable{lastVariable() + 1}, + bytecode: bytecode{ + {opcode: OpGetVar, operand: Integer(0)}, + {opcode: OpEnter}, + {opcode: OpCall, operand: procedureIndicator{name: atomTrue, arity: 0}}, + {opcode: OpPutDict, operand: Integer(3)}, + {opcode: OpPutConst, operand: NewAtom("point")}, + {opcode: OpPutConst, operand: NewAtom("x")}, + {opcode: OpPutConst, operand: Integer(5)}, + {opcode: OpPop}, + {opcode: OpPutConst, operand: NewAtom("x")}, + {opcode: OpPutVar, operand: Integer(0)}, + {opcode: OpCall, operand: procedureIndicator{name: atomDot, arity: Integer(3)}}, + {opcode: OpExit}, + }, + }, + }, + }, + }, + )}, {title: "dict body", text: ` p :- foo(point{x: 5}). `, result: buildOrderedMap( @@ -328,6 +380,58 @@ p :- foo(point{x: 5}). }, }, )}, + {title: "dict body (2)", text: ` +x(X) :- p(P), =(X, P.x). +`, result: buildOrderedMap( + procedurePair{ + Key: procedureIndicator{name: NewAtom("foo"), arity: 1}, + Value: &userDefined{ + multifile: true, + clauses: clauses{ + { + pi: procedureIndicator{name: NewAtom("foo"), arity: 1}, + raw: &compound{functor: NewAtom("foo"), args: []Term{NewAtom("c")}}, + bytecode: bytecode{ + {opcode: OpGetConst, operand: NewAtom("c")}, + {opcode: OpExit}, + }, + }, + }, + }, + }, + procedurePair{ + Key: procedureIndicator{name: NewAtom("x"), arity: 1}, + Value: &userDefined{ + clauses: clauses{ + { + pi: procedureIndicator{name: NewAtom("x"), arity: 1}, + raw: atomIf.Apply( + NewAtom("x").Apply(lastVariable()+1), + seq( + atomComma, + NewAtom("p").Apply(lastVariable()+2), + atomEqual.Apply(lastVariable()+1, NewAtom("$dot").Apply(lastVariable()+2, NewAtom("x"))), + )), + vars: []Variable{lastVariable() + 1, lastVariable() + 2, lastVariable() + 3}, + bytecode: bytecode{ + {opcode: OpGetVar, operand: Integer(0)}, + {opcode: OpEnter}, + {opcode: OpPutVar, operand: Integer(1)}, + {opcode: OpCall, operand: procedureIndicator{name: NewAtom("p"), arity: 1}}, + {opcode: OpPutVar, operand: Integer(1)}, + {opcode: OpPutConst, operand: NewAtom("x")}, + {opcode: OpPutVar, operand: Integer(2)}, + {opcode: OpCall, operand: procedureIndicator{name: NewAtom("."), arity: 3}}, + {opcode: OpPutVar, operand: Integer(0)}, + {opcode: OpPutVar, operand: Integer(2)}, + {opcode: OpCall, operand: procedureIndicator{name: NewAtom("="), arity: 2}}, + {opcode: OpExit}, + }, + }, + }, + }, + }, + )}, {title: "dynamic", text: ` :- dynamic(foo/1). foo(a). diff --git a/interpreter_test.go b/interpreter_test.go index db01532..2cfd796 100644 --- a/interpreter_test.go +++ b/interpreter_test.go @@ -1022,6 +1022,27 @@ func TestDict(t *testing.T) { "X": "1", }}}, }, + { + program: "ok. p(point{x:1}.x) :- ok.", + query: "p(X).", + wantResult: []result{{solutions: map[string]TermString{ + "X": "1", + }}}, + }, + { + program: "point(point{x: X}.x) :- X = 5.", + query: "point(X).", + wantResult: []result{{solutions: map[string]TermString{ + "X": "5", + }}}, + }, + { + program: "point(point{x: 5}.X) :- X = x.", + query: "point(X).", + wantResult: []result{{solutions: map[string]TermString{ + "X": "5", + }}}, + }, // access { query: "A = point{x:1,y:2}.x.", @@ -1055,6 +1076,13 @@ func TestDict(t *testing.T) { "X": "10", }}}, }, + { + program: "p(P) :- P = point{x:10, y:20}. x(X) :- p(P), X = P.x.", + query: "x(X).", + wantResult: []result{{solutions: map[string]TermString{ + "X": "10", + }}}, + }, { query: "A = point{x:1,y:2}.z.", wantResult: []result{