Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor HIR #674

Merged
merged 3 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/binfile/binfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func (p *Header) IsCompatible() bool {
// matter what version, we should always have the ZKBINARY identifier first,
// followed by a GOB encoding of the header. What follows after that, however,
// is determined by the major version.
const BINFILE_MAJOR_VERSION uint16 = 2
const BINFILE_MAJOR_VERSION uint16 = 3

// BINFILE_MINOR_VERSION gives the minor version of the binary file format. The
// expected interpretation is that older versions are compatible with newer
Expand Down
2 changes: 1 addition & 1 deletion pkg/binfile/legacy/constraint_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func allocateRegisters(cs *constraintSet, schema *hir.Schema) map[uint]uint {
// Check whether a type constraint required or not.
if c.MustProve && col_type.AsUint() != nil {
bound := col_type.AsUint().Bound()
schema.AddRangeConstraint(c.Handle, ctx, &hir.ColumnAccess{Column: cid, Shift: 0}, bound)
schema.AddRangeConstraint(c.Handle, ctx, hir.NewColumnAccess(cid, 0), bound)
}
}
}
Expand Down
30 changes: 15 additions & 15 deletions pkg/binfile/legacy/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (e *jsonTypedExpr) ToHir(colmap map[uint]uint, schema *hir.Schema) hir.Expr
// ToHir converts a big integer represented as a sequence of unsigned 32bit
// words into HIR constant expression.
func (e *jsonExprConst) ToHir(schema *hir.Schema) hir.Expr {
return &hir.Constant{Val: e.ToField()}
return hir.NewConst(e.ToField())
}

func (e *jsonExprConst) ToField() fr.Element {
Expand Down Expand Up @@ -122,7 +122,7 @@ func (e *jsonExprColumn) ToHir(colmap map[uint]uint, schema *hir.Schema) hir.Exp
// Determine binfile column index
cid := asColumn(e.Handle)
// Map to schema column index
return &hir.ColumnAccess{Column: colmap[cid], Shift: e.Shift}
return hir.NewColumnAccess(colmap[cid], e.Shift)
}

func (e *jsonExprFuncall) ToHir(colmap map[uint]uint, schema *hir.Schema) hir.Expr {
Expand All @@ -135,47 +135,47 @@ func (e *jsonExprFuncall) ToHir(colmap map[uint]uint, schema *hir.Schema) hir.Ex
switch e.Func {
case "Normalize":
if len(args) == 1 {
return &hir.Normalise{Arg: args[0]}
return hir.Normalise(args[0])
} else {
panic("incorrect arguments for Normalize")
}
case "VectorAdd", "Add":
return &hir.Add{Args: args}
return hir.Sum(args...)
case "VectorMul", "Mul":
return &hir.Mul{Args: args}
return hir.Product(args...)
case "VectorSub", "Sub":
return &hir.Sub{Args: args}
return hir.Subtract(args...)
case "Exp":
if len(args) != 2 {
panic(fmt.Sprintf("incorrect number of arguments for Exp (%d)", len(args)))
}

c, ok := args[1].(*hir.Constant)
c, ok := args[1].Term.(*hir.Constant)

if !ok {
panic(fmt.Sprintf("constant power expected for Exp, got %s", args[1].Lisp(schema)))
} else if !c.Val.IsUint64() {
} else if !c.Value.IsUint64() {
panic("constant power too large for Exp")
}

var k big.Int
// Convert power to uint64
c.Val.BigInt(&k)
c.Value.BigInt(&k)
// Done
return &hir.Exp{Arg: args[0], Pow: k.Uint64()}
return hir.Exponent(args[0], k.Uint64())
case "IfZero":
if len(args) == 2 {
return &hir.IfZero{Condition: args[0], TrueBranch: args[1], FalseBranch: nil}
return hir.If(args[0], args[1], hir.VOID)
} else if len(args) == 3 {
return &hir.IfZero{Condition: args[0], TrueBranch: args[1], FalseBranch: args[2]}
return hir.If(args[0], args[1], args[2])
} else {
panic(fmt.Sprintf("incorrect number of arguments for IfZero (%d)", len(args)))
}
case "IfNotZero":
if len(args) == 2 {
return &hir.IfZero{Condition: args[0], TrueBranch: nil, FalseBranch: args[1]}
return hir.If(args[0], hir.VOID, args[1])
} else if len(args) == 3 {
return &hir.IfZero{Condition: args[0], TrueBranch: args[2], FalseBranch: args[1]}
return hir.If(args[0], args[2], args[1])
} else {
panic(fmt.Sprintf("incorrect number of arguments for IfNotZero (%d)", len(args)))
}
Expand All @@ -190,7 +190,7 @@ func jsonListToHir(Args []jsonTypedExpr, colmap map[uint]uint, schema *hir.Schem
args[i] = Args[i].ToHir(colmap, schema)
}

return &hir.List{Args: args}
return hir.ListOf(args...)
}

func jsonExprsToHirUnit(Args []jsonTypedExpr, colmap map[uint]uint, schema *hir.Schema) []hir.UnitExpr {
Expand Down
2 changes: 1 addition & 1 deletion pkg/corset/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func compileSelector(env compiler.Environment, selector ast.Expr) *hir.UnitExpr
// Lookup column binding
register_id := env.RegisterOf(binding.AbsolutePath())
// Done
expr := &hir.ColumnAccess{Column: register_id, Shift: 0}
expr := hir.NewColumnAccess(register_id, 0)
//
return &hir.UnitExpr{Expr: expr}
}
Expand Down
60 changes: 30 additions & 30 deletions pkg/corset/compiler/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func (t *translator) translateTypeConstraints(regIndex uint) {
}
// Add appropriate type constraint
bound := regInfo.DataType.AsUint().Bound()
t.schema.AddRangeConstraint(regInfo.Name(), regInfo.Context, &hir.ColumnAccess{Column: regIndex, Shift: 0}, bound)
t.schema.AddRangeConstraint(regInfo.Name(), regInfo.Context, hir.NewColumnAccess(regIndex, 0), bound)
}
}

Expand Down Expand Up @@ -283,22 +283,22 @@ func (t *translator) translateDefConstraint(decl *ast.DefConstraint, module util
errors = append(errors, guard_errors...)
errors = append(errors, selector_errors...)
// Apply guard
if constraint == nil {
if constraint == hir.VOID {
// NOTE: in this case, the constraint itself has been translated as nil.
// This means there is no constraint (e.g. its a debug constraint, but
// debug mode is not enabled).
return errors
}
// Apply guard (if applicable)
if guard != nil {
constraint = &hir.IfZero{Condition: guard, TrueBranch: nil, FalseBranch: constraint}
if guard != hir.VOID {
constraint = hir.If(guard, hir.VOID, constraint)
}
// Apply perspective selector (if applicable)
if selector != nil {
if selector != hir.VOID {
// NOTE: using an ifnot (as above) would be preferable here. However,
// this is currently done just to ensure constraints identical to the
// original are generated.
constraint = &hir.Mul{Args: []hir.Expr{selector, constraint}}
constraint = hir.Product(selector, constraint)
}
//
if len(errors) == 0 {
Expand Down Expand Up @@ -328,7 +328,7 @@ func (t *translator) translateSelectorInModule(perspective *ast.PerspectiveName,
return t.translateExpressionInModule(perspective.InnerBinding().Selector, module, 0)
}
//
return nil, nil
return hir.VOID, nil
}

// Translate a "deflookup" declaration.
Expand Down Expand Up @@ -454,7 +454,7 @@ func (t *translator) translateOptionalExpressionInModule(expr ast.Expr, module u
return t.translateExpressionInModule(expr, module, shift)
}

return nil, nil
return hir.VOID, nil
}

// Translate an optional expression in a given context. That is an expression
Expand Down Expand Up @@ -490,10 +490,10 @@ func (t *translator) translateExpressionsInModule(exprs []ast.Expr, module util.
var errs []SyntaxError
hirExprs[i], errs = t.translateExpressionInModule(e, module, shift)
errors = append(errors, errs...)
// Check for non-voidability
if hirExprs[i] == nil {
errors = append(errors, *t.srcmap.SyntaxError(e, "void expression not permitted here"))
}
} else {
// Strictly speaking, this assignment is unnecessary. However, the
// purpose is just to make it clear what's going on.
hirExprs[i] = hir.VOID
}
}
//
Expand All @@ -509,47 +509,47 @@ func (t *translator) translateExpressionInModule(expr ast.Expr, module util.Path
// Lookup underlying column info
registerId, errors := t.registerOfColumnAccess(e)
// Done
return &hir.ColumnAccess{Column: registerId, Shift: shift}, errors
return hir.NewColumnAccess(registerId, shift), errors
case *ast.Add:
args, errs := t.translateExpressionsInModule(e.Args, module, shift)
return &hir.Add{Args: args}, errs
return hir.Sum(args...), errs
case *ast.Constant:
var val fr.Element
// Initialise field from bigint
val.SetBigInt(&e.Val)
//
return &hir.Constant{Val: val}, nil
return hir.NewConst(val), nil
case *ast.Exp:
return t.translateExpInModule(e, module, shift)
case *ast.If:
args, errs := t.translateExpressionsInModule([]ast.Expr{e.Condition, e.TrueBranch, e.FalseBranch}, module, shift)
// Construct appropriate if form
if e.IsIfZero() {
return &hir.IfZero{Condition: args[0], TrueBranch: args[1], FalseBranch: args[2]}, errs
return hir.If(args[0], args[1], args[2]), errs
} else if e.IsIfNotZero() {
// In this case, switch the ordering.
return &hir.IfZero{Condition: args[0], TrueBranch: args[2], FalseBranch: args[1]}, errs
return hir.If(args[0], args[2], args[1]), errs
}
// Should be unreachable
return nil, t.srcmap.SyntaxErrors(expr, "unresolved conditional encountered during translation")
return hir.VOID, t.srcmap.SyntaxErrors(expr, "unresolved conditional encountered during translation")
case *ast.List:
args, errs := t.translateExpressionsInModule(e.Args, module, shift)
return &hir.List{Args: args}, errs
return hir.ListOf(args...), errs
case *ast.Mul:
args, errs := t.translateExpressionsInModule(e.Args, module, shift)
return &hir.Mul{Args: args}, errs
return hir.Product(args...), errs
case *ast.Normalise:
arg, errs := t.translateExpressionInModule(e.Arg, module, shift)
return &hir.Normalise{Arg: arg}, errs
return hir.Normalise(arg), errs
case *ast.Sub:
args, errs := t.translateExpressionsInModule(e.Args, module, shift)
return &hir.Sub{Args: args}, errs
return hir.Subtract(args...), errs
case *ast.Shift:
return t.translateShiftInModule(e, module, shift)
case *ast.VariableAccess:
return t.translateVariableAccessInModule(e, shift)
default:
return nil, t.srcmap.SyntaxErrors(expr, "unknown expression encountered during translation")
return hir.VOID, t.srcmap.SyntaxErrors(expr, "unknown expression encountered during translation")
}
}

Expand All @@ -564,19 +564,19 @@ func (t *translator) translateExpInModule(expr *ast.Exp, module util.Path, shift
}
// Sanity check errors
if len(errs) == 0 {
return &hir.Exp{Arg: arg, Pow: pow.Uint64()}, errs
return hir.Exponent(arg, pow.Uint64()), errs
}
//
return nil, errs
return hir.VOID, errs
}

func (t *translator) translateShiftInModule(expr *ast.Shift, module util.Path, shift int) (hir.Expr, []SyntaxError) {
constant := expr.Shift.AsConstant()
// Determine the shift constant
if constant == nil {
return nil, t.srcmap.SyntaxErrors(expr.Shift, "expected constant shift")
return hir.VOID, t.srcmap.SyntaxErrors(expr.Shift, "expected constant shift")
} else if !constant.IsInt64() {
return nil, t.srcmap.SyntaxErrors(expr.Shift, "constant shift too large")
return hir.VOID, t.srcmap.SyntaxErrors(expr.Shift, "constant shift too large")
}
// Now translate target expression with updated shift.
return t.translateExpressionInModule(expr.Arg, module, shift+int(constant.Int64()))
Expand All @@ -587,17 +587,17 @@ func (t *translator) translateVariableAccessInModule(expr *ast.VariableAccess, s
// Lookup column binding
register_id := t.env.RegisterOf(binding.AbsolutePath())
// Done
return &hir.ColumnAccess{Column: register_id, Shift: shift}, nil
return hir.NewColumnAccess(register_id, shift), nil
} else if binding, ok := expr.Binding().(*ast.ConstantBinding); ok {
// Just fill in the constant.
var constant fr.Element
// Initialise field from bigint
constant.SetBigInt(binding.Value.AsConstant())
//
return &hir.Constant{Val: constant}, nil
return hir.NewConst(constant), nil
}
// error
return nil, t.srcmap.SyntaxErrors(expr, "unbound variable")
return hir.VOID, t.srcmap.SyntaxErrors(expr, "unbound variable")
}

// Determine the underlying register for a symbol which represents a column access.
Expand Down
Loading