From 88d01bc7ec9454825788d9735a8e978f5b55d2f6 Mon Sep 17 00:00:00 2001 From: Soham Zemse <22412996+zemse@users.noreply.github.com> Date: Sat, 16 Mar 2024 01:29:00 +0000 Subject: [PATCH] work on merkle tree logic --- src/field.ts | 28 +++++++++ src/index.ts | 1 + src/merkle-tree.ts | 120 ++++++++++++++++++++++++++++++++++++ src/note-merkle-tree.ts | 128 +++++++++++++++++++++++++++++++++++++++ src/transaction.ts | 21 +++++++ test/merkle-tree.test.ts | 59 ++++++++++++++++++ 6 files changed, 357 insertions(+) create mode 100644 src/merkle-tree.ts create mode 100644 src/note-merkle-tree.ts create mode 100644 src/transaction.ts create mode 100644 test/merkle-tree.test.ts diff --git a/src/field.ts b/src/field.ts index a22cdfe..e4fb01a 100644 --- a/src/field.ts +++ b/src/field.ts @@ -50,4 +50,32 @@ export class Field { mul(other: Field): Field { return new Field((this.value * other.value) % PRIME); } + + neg(): Field { + return new Field((PRIME - this.value) % PRIME); + } + + isNeg(): boolean { + return this.value > PRIME / 2n; + } + + gt(other: Field): boolean { + return this.value > other.value; + } + + lt(other: Field): boolean { + return this.value < other.value; + } + + eq(other: Field): boolean { + return this.value === other.value; + } + + gte(other: Field): boolean { + return this.value >= other.value; + } + + lte(other: Field): boolean { + return this.value <= other.value; + } } diff --git a/src/index.ts b/src/index.ts index 5e99a19..2470fbd 100644 --- a/src/index.ts +++ b/src/index.ts @@ -3,3 +3,4 @@ export * from './keypair'; export * from './hash'; export * from './input'; export * from './utils'; +export * from './merkle-tree'; diff --git a/src/merkle-tree.ts b/src/merkle-tree.ts new file mode 100644 index 0000000..5212392 --- /dev/null +++ b/src/merkle-tree.ts @@ -0,0 +1,120 @@ +import { hash } from './hash'; +import { Field } from './field'; + +export class MerkleTree { + public elements: Field[]; + private null_root_cache: { + [null_leaf: string]: { [depth: number]: Field }; + } = {}; + + constructor(public depth: number) { + this.elements = []; + } + + toJson() { + return { + elements: this.elements.map((x) => x.hex()), + depth: this.depth, + }; + } + + static fromJson(json: any): MerkleTree { + let tree = new MerkleTree(json.depth); + tree.elements = json.elements.map((x: string) => Field.from(x)); + return tree; + } + + insert(value: Field) { + if (this.elements.length >= 2 ** this.depth) { + throw new Error('Merkle tree is full'); + } + this.elements.push(value); + } + + /** + * Calculates the root of the merkle tree + * @returns the root of the merkle tree + */ + async calculateRoot(): Promise { + if (this.elements.length === 0) { + return await this._null_root(this.depth, Field.zero()); + } + let _nodes = [...this.elements]; + for (let i = 0; i < this.depth; i++) { + let upperNodes: Field[] = []; + for (let j = 0; j < _nodes.length; j += 2) { + const left = _nodes[j]; + let right: Field; + if (j + 1 === _nodes.length) { + right = await this._null_root(i, Field.zero()); + } else { + right = _nodes[j + 1]; + } + upperNodes.push(await hash([left, right])); + } + _nodes = upperNodes; + } + return _nodes[0]; + } + + async calculateRootHex(): Promise { + return (await this.calculateRoot()).hex(); + } + + /** + * Generates a merkle proof for a given element + * @param index Index of the element to generate the merkle proof for + * @returns An array of the intermediate nodes of the merkle tree + */ + async merkleProof(index: number): Promise { + if (index > this.elements.length) { + throw new Error('Index out of range'); + } + let proof: Field[] = []; + let _nodes = [...this.elements]; + for (let i = 0; i < this.depth; i++) { + let upperNodes: Field[] = []; + for (let j = 0; j < _nodes.length; j += 2) { + // get the left side node + const left = _nodes[j]; + // calculate the right side node + let right: Field; + if (j + 1 === _nodes.length) { + right = await this._null_root(i, Field.zero()); + } else { + right = _nodes[j + 1]; + } + // include in the proof if this node is in the merkle path + if (index === j || index === j + 1) { + proof.push(index % 2 !== 0 ? left : right); + } + // generate the upper depth nodes + upperNodes.push(await hash([left, right])); + } + index = Math.floor(index / 2); + _nodes = upperNodes; + } + + return proof; + } + + async _null_root( + depth: number, + null_leaf: Field = Field.zero() + ): Promise { + if (this.null_root_cache[null_leaf.hex()] === undefined) { + this.null_root_cache[null_leaf.hex()] = {}; + } + if (this.null_root_cache[null_leaf.hex()][depth] !== undefined) { + return this.null_root_cache[null_leaf.hex()][depth]; + } + if (depth == 0) { + return null_leaf; + } + + let node = await this._null_root(depth - 1, null_leaf); + node = await hash([node, node]); + this.null_root_cache[null_leaf.hex()][depth] = node; + return node; + } +} diff --git a/src/note-merkle-tree.ts b/src/note-merkle-tree.ts new file mode 100644 index 0000000..92e3e68 --- /dev/null +++ b/src/note-merkle-tree.ts @@ -0,0 +1,128 @@ +import { BigNumberish } from 'ethers'; + +import { Field } from './field'; +import { MerkleTree } from './merkle-tree'; +import { Note } from './note'; +import { Input } from './input'; +import { KeyPair } from './keypair'; +import { Transaction } from './transaction'; + +interface CreateTransactionArgs { + inputNotes?: Note[]; + depositAmount: Field | BigNumberish; + keypair?: KeyPair; + updateTree: boolean; + withdrawAddress?: Field; +} + +export class NoteMerkleTree extends MerkleTree { + constructor( + public depth: number, + public numInputs: number = 2, + public numOutputs: number = 1 + ) { + super(depth); + this.insert( + // Note.zero().commitment() + Field.from( + 6693032976676388986107828574443457670072006098614160789085314534828627402874n + ) + ); + if (this.numOutputs !== 1) { + throw new Error('exactly 1 outputs are currently supported'); + } + } + + toJson() { + return { + elements: this.elements.map((x) => x.hex()), + depth: this.depth, + numInputs: this.numInputs, + numOutputs: this.numOutputs, + }; + } + + static fromJson(json: any): NoteMerkleTree { + let tree = new NoteMerkleTree(json.depth, json.numInputs, json.numOutputs); + tree.elements = json.elements.map((x: string) => Field.from(x)); + return tree; + } + + /** + * Spend input notes and create new notes if necessary + * Two source notes and one output note is currently supported + */ + async createTransaction({ + inputNotes, + depositAmount, + keypair, + updateTree, + withdrawAddress, + }: CreateTransactionArgs) { + depositAmount = Field.from(depositAmount); + inputNotes = inputNotes ?? []; + keypair = keypair ?? (await KeyPair.random()); + withdrawAddress = withdrawAddress ?? Field.zero(); + + let inputs: Input[] = []; + let sum = Field.zero(); + for (let note of inputNotes) { + const input = await this._createInput(note); + inputs.push(input); + sum = sum.add(input.note.amount); + } + for (let i = inputs.length; i < this.numInputs; i++) { + inputs.push(await this._createInput(Note.zero())); + } + + if (depositAmount.isNeg() && depositAmount.neg().gt(sum)) { + throw new Error('Transaction amount exceeds the sum of input notes'); + } + + const note = new Note(sum.add(depositAmount), keypair, Field.from(0)); + + const root = await this.calculateRoot(); + + if (inputs.length !== 2) { + throw new Error('exactly 2 inputs are supported'); + } + + if (updateTree) { + this.insert(await note.commitment()); + } + + const transaction = new Transaction( + this.depth, + root, + inputs, + [note], + depositAmount, + withdrawAddress + ); + return transaction; + } + + async findNoteIndex(note: Note): Promise { + const commitment = await note.commitment(); + for (let i = 0; i < this.elements.length; i++) { + if (this.elements[i].eq(commitment)) { + return i; + } + } + throw new Error( + `Note not found from list of ${ + this.elements.length + } notes: ${JSON.stringify(note.toString())}` + ); + } + + async _createInput(note: Note) { + const index = await this.findNoteIndex(note); + return new Input( + note, + new Field(index), + await this.merkleProof(index), + this.depth + ); + } +} diff --git a/src/transaction.ts b/src/transaction.ts new file mode 100644 index 0000000..d770b61 --- /dev/null +++ b/src/transaction.ts @@ -0,0 +1,21 @@ +import { Field } from './field'; +import { Input } from './input'; +import { Note } from './note'; + +export class Transaction { + constructor( + public depth: number, + public root: Field, + public inputs: Input[], + public outputs: Note[], + public depositAmount: Field, + public withdrawAddress: Field + ) { + if (inputs.length !== 2) { + throw new Error('exactly two inputs are supported: ' + inputs.length); + } + if (outputs.length !== 1) { + throw new Error('exactly single output is supported: ' + outputs.length); + } + } +} diff --git a/test/merkle-tree.test.ts b/test/merkle-tree.test.ts new file mode 100644 index 0000000..8ee7568 --- /dev/null +++ b/test/merkle-tree.test.ts @@ -0,0 +1,59 @@ +import { MerkleTree, Field } from '../src'; +import 'jest'; + +describe('MerkleTree', () => { + it.skip('merkleProof 3', async () => { + let depth = 3; + const tree = new MerkleTree(depth); + tree.insert(new Field(1n)); + tree.insert(new Field(0n)); + tree.insert(new Field(0n)); + + let merkleProof = await tree.merkleProof(0); + expect(merkleProof.length).toEqual(depth); + + expect(merkleProof[0].hex()).toEqual(hex(0n)); + if (depth === 3) { + expect(merkleProof[1].hex()).toEqual( + hex( + 14744269619966411208579211824598458697587494354926760081771325075741142829156n + ) + ); + expect(merkleProof[2].hex()).toEqual( + hex( + 7423237065226347324353380772367382631490014989348495481811164164159255474657n + ) + ); + } + }); + + it('merkleProof 32', async () => { + let depth = 32; + const tree = new MerkleTree(depth); + tree.insert(new Field(1n)); + + let merkleProof = await tree.merkleProof(0); + expect(merkleProof.length).toEqual(depth); + + expect(merkleProof[0].hex()).toEqual(hex(0n)); + }); + + it('json', async () => { + const tree = new MerkleTree(32); + tree.insert(Field.random()); + tree.insert(Field.random()); + tree.insert(Field.random()); + + let tree2 = MerkleTree.fromJson(JSON.parse(JSON.stringify(tree.toJson()))); + expect(tree2.depth).toEqual(tree.depth); + expect(tree2.elements.length).toEqual(tree.elements.length); + + for (let i = 0; i < tree.elements.length; i++) { + expect(tree2.elements[i].hex()).toEqual(tree.elements[i].hex()); + } + }); +}); + +function hex(a: bigint) { + return Field.from(a).hex(); +}