Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tudor-malene committed May 1, 2024
1 parent 29cbacc commit 17de35e
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 138 deletions.
1 change: 1 addition & 0 deletions go/common/enclave.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ type EnclaveScan interface {
GetTotalContractCount(context.Context) (*big.Int, SystemError)

// GetCustomQuery returns the data of a custom query
// todo - better name and description
GetCustomQuery(ctx context.Context, encryptedParams EncryptedParamsGetStorageAt) (*responses.PrivateQueryResponse, SystemError)

// GetPublicTransactionData returns a list of public transaction data
Expand Down
70 changes: 25 additions & 45 deletions go/enclave/storage/enclavedb/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"math/big"
"strings"

"github.com/ethereum/go-ethereum/params"

Expand All @@ -20,30 +19,9 @@ import (
)

const (
bodyInsert = "replace into batch_body values (?,?)"
txInsert = "replace into tx (hash, full_hash, content, sender_address, nonce, idx, body) values "
txInsertValue = "(?,?,?,?,?,?,?)"

batchInsert = "insert into batch values (?,?,?,?,?,?,?,?,?,?)"
updateBatchExecuted = "update batch set is_executed=true where sequence=?"

selectBatch = "select b.header, bb.content from batch b join batch_body bb on b.body=bb.id"

txExecInsert = "insert into exec_tx (created_contract_address, receipt, tx, batch) values "
txExecInsertValue = "(?,?,?,?)"
queryReceipts = "select exec_tx.receipt, tx.content, batch.full_hash, batch.height from exec_tx join tx on tx.id=exec_tx.tx join batch on batch.sequence=exec_tx.batch "
queryReceiptsCount = "select count(1) from exec_tx join tx on tx.id=exec_tx.tx join batch on batch.sequence=exec_tx.batch "

selectTxQuery = "select tx.content, batch.full_hash, batch.height, tx.idx from exec_tx join tx on tx.id=exec_tx.tx join batch on batch.sequence=exec_tx.batch where batch.is_canonical=true and tx.hash=? and tx.full_hash=?"

selectContractCreationTx = "select tx.full_hash from exec_tx join tx on tx.id=exec_tx.tx where created_contract_address=?"
selectTotalCreatedContracts = "select count( distinct created_contract_address) from exec_tx "
queryBatchWasExecuted = "select is_executed from batch where is_canonical=true and hash=? and full_hash=?"

isCanonQuery = "select is_canonical from block where hash=? and full_hash=?"

queryTxList = "select tx.full_hash, batch.height, batch.header from exec_tx join batch on batch.sequence=exec_tx.batch join tx on tx.id=exec_tx.tx where batch.is_canonical=true"
queryTxCountList = "select count(1) from exec_tx join batch on batch.sequence=exec_tx.batch where batch.is_canonical=true"
queryReceipts = "select exec_tx.receipt, tx.content, batch.full_hash, batch.height from exec_tx join tx on tx.id=exec_tx.tx join batch on batch.sequence=exec_tx.batch "
)

// WriteBatchAndTransactions - persists the batch and the transactions
Expand All @@ -60,17 +38,17 @@ func WriteBatchAndTransactions(ctx context.Context, dbtx DBTransaction, batch *c
return fmt.Errorf("could not encode batch header. Cause: %w", err)
}

dbtx.ExecuteSQL(bodyInsert, batchBodyID, body)
dbtx.ExecuteSQL("replace into batch_body values (?,?)", batchBodyID, body)

var isCanon bool
err = dbtx.GetDB().QueryRowContext(ctx, isCanonQuery, truncTo4(batch.Header.L1Proof), batch.Header.L1Proof.Bytes()).Scan(&isCanon)
err = dbtx.GetDB().QueryRowContext(ctx, "select is_canonical from block where hash=? and full_hash=?", truncTo4(batch.Header.L1Proof), batch.Header.L1Proof.Bytes()).Scan(&isCanon)
if err != nil {
// if the block is not found, we assume it is non-canonical
// fmt.Printf("IsCanon %s err: %s\n", batch.Header.L1Proof, err)
isCanon = false
}

dbtx.ExecuteSQL(batchInsert,
dbtx.ExecuteSQL("insert into batch values (?,?,?,?,?,?,?,?,?,?)",
batch.Header.SequencerOrderNo.Uint64(), // sequence
batch.Hash(), // full hash
convertedHash, // converted_hash
Expand All @@ -85,8 +63,7 @@ func WriteBatchAndTransactions(ctx context.Context, dbtx DBTransaction, batch *c

// creates a big insert statement for all transactions
if len(batch.Transactions) > 0 {
insert := txInsert + strings.Repeat(txInsertValue+",", len(batch.Transactions))
insert = insert[0 : len(insert)-1] // remove trailing comma
insert := "replace into tx (hash, full_hash, content, sender_address, nonce, idx, body) values " + repeat("(?,?,?,?,?,?,?)", ",", len(batch.Transactions))

args := make([]any, 0)
for i, transaction := range batch.Transactions {
Expand Down Expand Up @@ -116,7 +93,7 @@ func WriteBatchAndTransactions(ctx context.Context, dbtx DBTransaction, batch *c

// WriteBatchExecution - insert all receipts to the db
func WriteBatchExecution(ctx context.Context, dbtx DBTransaction, seqNo *big.Int, receipts []*types.Receipt) error {
dbtx.ExecuteSQL(updateBatchExecuted, seqNo.Uint64())
dbtx.ExecuteSQL("update batch set is_executed=true where sequence=?", seqNo.Uint64())

args := make([]any, 0)
for _, receipt := range receipts {
Expand All @@ -127,18 +104,16 @@ func WriteBatchExecution(ctx context.Context, dbtx DBTransaction, seqNo *big.Int
return fmt.Errorf("failed to encode block receipts. Cause: %w", err)
}

// ignore the error because synthetic transactions will not be inserted
txId, _ := ReadTxId(ctx, dbtx, storageReceipt.TxHash)
//if err != nil {
// return err
//}
args = append(args, receipt.ContractAddress.Bytes()) // created_contract_address
args = append(args, receiptBytes) // the serialised receipt
args = append(args, txId) // tx id
args = append(args, seqNo.Uint64()) // batch_seq
args = append(args, truncBTo4(receipt.ContractAddress.Bytes())) // created_contract_address
args = append(args, receipt.ContractAddress.Bytes()) // created_contract_address
args = append(args, receiptBytes) // the serialised receipt
args = append(args, txId) // tx id
args = append(args, seqNo.Uint64()) // batch_seq
}
if len(args) > 0 {
insert := txExecInsert + strings.Repeat(txExecInsertValue+",", len(receipts))
insert = insert[0 : len(insert)-1] // remove trailing comma
insert := "insert into exec_tx (created_contract_address,created_contract_address_full, receipt, tx, batch) values " + repeat("(?,?,?,?,?)", ",", len(receipts))
dbtx.ExecuteSQL(insert, args...)
}
return nil
Expand Down Expand Up @@ -369,7 +344,9 @@ func ReadReceipt(ctx context.Context, db *sql.DB, txHash common.L2TxHash, config
}

func ReadTransaction(ctx context.Context, db *sql.DB, txHash gethcommon.Hash) (*types.Transaction, common.L2BatchHash, uint64, uint64, error) {
row := db.QueryRowContext(ctx, selectTxQuery, truncTo4(txHash), txHash.Bytes())
row := db.QueryRowContext(ctx,
"select tx.content, batch.full_hash, batch.height, tx.idx from exec_tx join tx on tx.id=exec_tx.tx join batch on batch.sequence=exec_tx.batch where batch.is_canonical=true and tx.hash=? and tx.full_hash=?",
truncTo4(txHash), txHash.Bytes())

// tx, batch, height, idx
var txData []byte
Expand All @@ -394,7 +371,7 @@ func ReadTransaction(ctx context.Context, db *sql.DB, txHash gethcommon.Hash) (*
}

func GetContractCreationTx(ctx context.Context, db *sql.DB, address gethcommon.Address) (*gethcommon.Hash, error) {
row := db.QueryRowContext(ctx, selectContractCreationTx, address.Bytes())
row := db.QueryRowContext(ctx, "select tx.full_hash from exec_tx join tx on tx.id=exec_tx.tx where created_contract_address=? and created_contract_address_full=?", truncBTo4(address.Bytes()), address.Bytes())

var txHashBytes []byte
err := row.Scan(&txHashBytes)
Expand All @@ -411,7 +388,7 @@ func GetContractCreationTx(ctx context.Context, db *sql.DB, address gethcommon.A
}

func ReadContractCreationCount(ctx context.Context, db *sql.DB) (*big.Int, error) {
row := db.QueryRowContext(ctx, selectTotalCreatedContracts)
row := db.QueryRowContext(ctx, "select count( distinct created_contract_address) from exec_tx ")

var count int64
err := row.Scan(&count)
Expand All @@ -427,7 +404,7 @@ func ReadUnexecutedBatches(ctx context.Context, db *sql.DB, from *big.Int) ([]*c
}

func BatchWasExecuted(ctx context.Context, db *sql.DB, hash common.L2BatchHash) (bool, error) {
row := db.QueryRowContext(ctx, queryBatchWasExecuted, truncTo4(hash), hash.Bytes())
row := db.QueryRowContext(ctx, "select is_executed from batch where is_canonical=true and hash=? and full_hash=?", truncTo4(hash), hash.Bytes())

var result bool
err := row.Scan(&result)
Expand All @@ -443,11 +420,13 @@ func BatchWasExecuted(ctx context.Context, db *sql.DB, hash common.L2BatchHash)
}

func GetReceiptsPerAddress(ctx context.Context, db *sql.DB, config *params.ChainConfig, address *gethcommon.Address, pagination *common.QueryPagination) (types.Receipts, error) {
// todo - not indexed
return selectReceipts(ctx, db, config, "where tx.sender_address = ? ORDER BY height DESC LIMIT ? OFFSET ? ", address.Bytes(), pagination.Size, pagination.Offset)
}

func GetReceiptsPerAddressCount(ctx context.Context, db *sql.DB, address *gethcommon.Address) (uint64, error) {
row := db.QueryRowContext(ctx, queryReceiptsCount+" where tx.sender_address = ?", address.Bytes())
// todo - this is not indexed and will do a full table scan!
row := db.QueryRowContext(ctx, "select count(1) from exec_tx join tx on tx.id=exec_tx.tx join batch on batch.sequence=exec_tx.batch "+" where tx.sender_address = ?", address.Bytes())

var count uint64
err := row.Scan(&count)
Expand All @@ -465,7 +444,8 @@ func GetPublicTransactionData(ctx context.Context, db *sql.DB, pagination *commo
func selectPublicTxsBySender(ctx context.Context, db *sql.DB, query string, args ...any) ([]common.PublicTransaction, error) {
var publicTxs []common.PublicTransaction

rows, err := db.QueryContext(ctx, queryTxList+" "+query, args...)
q := "select tx.full_hash, batch.height, batch.header from exec_tx join batch on batch.sequence=exec_tx.batch join tx on tx.id=exec_tx.tx where batch.is_canonical=true " + query
rows, err := db.QueryContext(ctx, q, args...)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
// make sure the error is converted to obscuro-wide not found error
Expand Down Expand Up @@ -503,7 +483,7 @@ func selectPublicTxsBySender(ctx context.Context, db *sql.DB, query string, args
}

func GetPublicTransactionCount(ctx context.Context, db *sql.DB) (uint64, error) {
row := db.QueryRowContext(ctx, queryTxCountList)
row := db.QueryRowContext(ctx, "select count(1) from exec_tx join batch on batch.sequence=exec_tx.batch where batch.is_canonical=true")

var count uint64
err := row.Scan(&count)
Expand Down
50 changes: 15 additions & 35 deletions go/enclave/storage/enclavedb/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,20 @@ import (
"errors"
"fmt"
"math/big"
"strings"

"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ten-protocol/go-ten/go/common"
"github.com/ten-protocol/go-ten/go/common/errutil"
)

const (
blockInsert = "insert into block (hash,full_hash,is_canonical,header,height) values (?,?,?,?,?)"
selectBlockHeader = "select header from block "

l1msgInsert = "insert into l1_msg (message, block, is_transfer) values "
l1msgValue = "(?,?,?)"
selectL1Msg = "select message from l1_msg m join block b on m.block=b.id "

rollupInsert = "replace into rollup (hash, full_hash, start_seq, end_seq, time_stamp, header, compression_block) values (?,?,?,?,?,?,?)"
rollupSelect = "select full_hash from rollup r join block b on r.compression_block=b.id where "
rollupSelectMetadata = "select start_seq, time_stamp from rollup where hash = ? and full_hash=?"

updateCanonicalBlock = "update block set is_canonical=? where "
// todo - do we need the is_canonical field?
updateCanonicalBatches = "update batch set is_canonical=? where l1_proof in "
)

func WriteBlock(_ context.Context, dbtx DBTransaction, b *types.Header) error {
header, err := rlp.EncodeToBytes(b)
if err != nil {
return fmt.Errorf("could not encode block header. Cause: %w", err)
}

dbtx.ExecuteSQL(blockInsert,
dbtx.ExecuteSQL("insert into block (hash,full_hash,is_canonical,header,height) values (?,?,?,?,?)",
truncTo4(b.Hash()), // hash
b.Hash().Bytes(), // full_hash
true, // is_canonical
Expand All @@ -58,20 +40,18 @@ func UpdateCanonicalBlocks(ctx context.Context, dbtx DBTransaction, canonical []
}

func updateCanonicalValue(_ context.Context, dbtx DBTransaction, isCanonical bool, blocks []common.L1BlockHash) {
token := "(hash=? and full_hash=?) OR "
updateBlocksWhere := strings.Repeat(token, len(blocks))
updateBlocksWhere = updateBlocksWhere + "1=0"

updateBlocks := updateCanonicalBlock + updateBlocksWhere
canonicalBlocks := repeat("(hash=? and full_hash=?)", "OR", len(blocks))

args := make([]any, 0)
args = append(args, isCanonical)
for _, blockHash := range blocks {
args = append(args, truncTo4(blockHash), blockHash.Bytes())
}

updateBlocks := "update block set is_canonical=? where " + canonicalBlocks
dbtx.ExecuteSQL(updateBlocks, args...)

updateBatches := updateCanonicalBatches + "(" + "select id from block where " + updateBlocksWhere + ")"
updateBatches := "update batch set is_canonical=? where l1_proof in (select id from block where " + canonicalBlocks + ")"
dbtx.ExecuteSQL(updateBatches, args...)
}

Expand All @@ -81,6 +61,7 @@ func FetchBlock(ctx context.Context, db *sql.DB, hash common.L1BlockHash) (*type
}

func FetchHeadBlock(ctx context.Context, db *sql.DB) (*types.Block, error) {
// todo - just read the one with the max id
return fetchBlock(ctx, db, "where is_canonical=true and height=(select max(b.height) from block b where is_canonical=true)")
}

Expand All @@ -95,8 +76,7 @@ func GetBlockId(ctx context.Context, db *sql.DB, hash common.L1BlockHash) (uint6
}

func WriteL1Messages[T any](ctx context.Context, db *sql.DB, blockId uint64, messages []T, isValueTransfer bool) error {
insert := l1msgInsert + strings.Repeat(l1msgValue+",", len(messages))
insert = insert[0 : len(insert)-1] // remove trailing comma
insert := "insert into l1_msg (message, block, is_transfer) values " + repeat("(?,?,?)", ",", len(messages))

args := make([]any, 0)

Expand All @@ -118,7 +98,7 @@ func WriteL1Messages[T any](ctx context.Context, db *sql.DB, blockId uint64, mes

func FetchL1Messages[T any](ctx context.Context, db *sql.DB, blockHash common.L1BlockHash, isTransfer bool) ([]T, error) {
var result []T
query := selectL1Msg + " where b.hash = ? and b.full_hash = ? and is_transfer = ?"
query := "select message from l1_msg m join block b on m.block=b.id where b.hash = ? and b.full_hash = ? and is_transfer = ?"
rows, err := db.QueryContext(ctx, query, truncTo4(blockHash), blockHash.Bytes(), isTransfer)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
Expand Down Expand Up @@ -153,7 +133,7 @@ func WriteRollup(_ context.Context, dbtx DBTransaction, rollup *common.RollupHea
if err != nil {
return fmt.Errorf("could not encode batch header. Cause: %w", err)
}
dbtx.ExecuteSQL(rollupInsert,
dbtx.ExecuteSQL("replace into rollup (hash, full_hash, start_seq, end_seq, time_stamp, header, compression_block) values (?,?,?,?,?,?,?)",
truncTo4(rollup.Hash()),
rollup.Hash().Bytes(),
internalHeader.FirstBatchSequence.Uint64(),
Expand All @@ -166,11 +146,9 @@ func WriteRollup(_ context.Context, dbtx DBTransaction, rollup *common.RollupHea
}

func FetchReorgedRollup(ctx context.Context, db *sql.DB, reorgedBlocks []common.L1BlockHash) (*common.L2BatchHash, error) {
token := "(b.hash=? and b.full_hash=?) OR "
whereClause := strings.Repeat(token, len(reorgedBlocks))
whereClause = whereClause + "1=0"
whereClause := repeat("(b.hash=? and b.full_hash=?)", "OR", len(reorgedBlocks))

query := rollupSelect + whereClause
query := "select full_hash from rollup r join block b on r.compression_block=b.id where " + whereClause

args := make([]any, 0)
for _, blockHash := range reorgedBlocks {
Expand All @@ -193,7 +171,9 @@ func FetchRollupMetadata(ctx context.Context, db *sql.DB, hash common.L2RollupHa
var startTime uint64

rollup := new(common.PublicRollupMetadata)
err := db.QueryRowContext(ctx, rollupSelectMetadata, truncTo4(hash), hash.Bytes()).Scan(&startSeq, &startTime)
err := db.QueryRowContext(ctx,
"select start_seq, time_stamp from rollup where hash = ? and full_hash=?", truncTo4(hash), hash.Bytes(),
).Scan(&startSeq, &startTime)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, errutil.ErrNotFound
Expand All @@ -207,7 +187,7 @@ func FetchRollupMetadata(ctx context.Context, db *sql.DB, hash common.L2RollupHa

func fetchBlockHeader(ctx context.Context, db *sql.DB, whereQuery string, args ...any) (*types.Header, error) {
var header string
query := selectBlockHeader + " " + whereQuery
query := "select header from block " + whereQuery
var err error
if len(args) > 0 {
err = db.QueryRowContext(ctx, query, args...).Scan(&header)
Expand Down
Loading

0 comments on commit 17de35e

Please sign in to comment.