diff --git a/go/common/storage/database/init/sqlite/host_init.sql b/go/common/storage/database/init/sqlite/host_init.sql index 3e29e4bc56..3d18668e3a 100644 --- a/go/common/storage/database/init/sqlite/host_init.sql +++ b/go/common/storage/database/init/sqlite/host_init.sql @@ -34,12 +34,22 @@ create table if not exists batch tx_count int NOT NULL, header blob NOT NULL, body_id int NOT NULL REFERENCES batch_body - ); +); create index IDX_BATCH_HASH on batch (hash); create index IDX_BATCH_HEIGHT on batch (height); +create table if not exists transactions +( + hash binary(16) primary key, + full_hash binary(32) NOT NULL, + body_id int REFERENCES batch_body +); + create table if not exists transaction_count ( id int NOT NULL primary key, total int NOT NULL ); + +insert into transaction_count (id, total) +values (1, 0) on CONFLICT (id) DO NOTHING; \ No newline at end of file diff --git a/go/host/rpc/clientapi/client_api_eth.go b/go/host/rpc/clientapi/client_api_eth.go index 9c15154e07..5e86f0313d 100644 --- a/go/host/rpc/clientapi/client_api_eth.go +++ b/go/host/rpc/clientapi/client_api_eth.go @@ -12,6 +12,7 @@ import ( "github.com/ten-protocol/go-ten/go/common" "github.com/ten-protocol/go-ten/go/common/host" "github.com/ten-protocol/go-ten/go/common/log" + "github.com/ten-protocol/go-ten/go/host/storage/hostdb" "github.com/ten-protocol/go-ten/go/responses" gethcommon "github.com/ethereum/go-ethereum/common" @@ -39,7 +40,7 @@ func (api *EthereumAPI) ChainId() (*hexutil.Big, error) { //nolint:stylecheck,re // BlockNumber returns the height of the current head batch. func (api *EthereumAPI) BlockNumber() hexutil.Uint64 { - header, err := api.host.DB().GetHeadBatchHeader() + header, err := hostdb.GetHeadBatchHeader(api.host.DB()) if err != nil { // This error may be nefarious, but unfortunately the Eth API doesn't allow us to return an error. api.logger.Error("could not retrieve head batch header", log.ErrKey, err) @@ -59,7 +60,7 @@ func (api *EthereumAPI) GetBlockByNumber(ctx context.Context, number rpc.BlockNu // GetBlockByHash returns the header of the batch with the given hash. func (api *EthereumAPI) GetBlockByHash(_ context.Context, hash gethcommon.Hash, _ bool) (*common.BatchHeader, error) { - batchHeader, err := api.host.DB().GetBatchHeader(hash) + batchHeader, err := hostdb.GetBatchHeader(api.host.DB(), hash) if err != nil { return nil, err } @@ -68,7 +69,7 @@ func (api *EthereumAPI) GetBlockByHash(_ context.Context, hash gethcommon.Hash, // GasPrice is a placeholder for an RPC method required by MetaMask/Remix. func (api *EthereumAPI) GasPrice(context.Context) (*hexutil.Big, error) { - header, err := api.host.DB().GetHeadBatchHeader() + header, err := hostdb.GetHeadBatchHeader(api.host.DB()) if err != nil { return nil, err } @@ -187,7 +188,7 @@ func (api *EthereumAPI) GetStorageAt(_ context.Context, encryptedParams common.E // rpc.DecimalOrHex -> []byte func (api *EthereumAPI) FeeHistory(context.Context, string, rpc.BlockNumber, []float64) (*FeeHistoryResult, error) { // todo (#1621) - return a non-dummy fee history - header, err := api.host.DB().GetHeadBatchHeader() + header, err := hostdb.GetHeadBatchHeader(api.host.DB()) if err != nil { api.logger.Error("Unable to retrieve header for fee history.", log.ErrKey, err) return nil, fmt.Errorf("unable to retrieve fee history") @@ -226,7 +227,7 @@ func (api *EthereumAPI) batchNumberToBatchHash(batchNumber rpc.BlockNumber) (*ge // note: our API currently treats all these block statuses the same for obscuro batches if batchNumber == rpc.LatestBlockNumber || batchNumber == rpc.PendingBlockNumber || batchNumber == rpc.FinalizedBlockNumber || batchNumber == rpc.SafeBlockNumber { - batchHeader, err := api.host.DB().GetHeadBatchHeader() + batchHeader, err := hostdb.GetHeadBatchHeader(api.host.DB()) if err != nil { return nil, err } @@ -235,7 +236,7 @@ func (api *EthereumAPI) batchNumberToBatchHash(batchNumber rpc.BlockNumber) (*ge } batchNumberBig := big.NewInt(batchNumber.Int64()) - batchHash, err := api.host.DB().GetBatchHash(batchNumberBig) + batchHash, err := hostdb.GetBatchHashByNumber(api.host.DB(), batchNumberBig) if err != nil { return nil, err } diff --git a/go/host/storage/hostdb/batch.go b/go/host/storage/hostdb/batch.go index 1a9be95e68..b9539290fe 100644 --- a/go/host/storage/hostdb/batch.go +++ b/go/host/storage/hostdb/batch.go @@ -13,7 +13,7 @@ import ( ) const ( - selectTxCount = "SELECT count FROM transaction_count WHERE id = 1" + selectTxCount = "SELECT total FROM transaction_count WHERE id = 1" selectBatch = "SELECT b.sequence_order, b.full_hash, b.hash, b.height, b.tx_count, b.header, b.body_id, bb.body FROM batch b JOIN batch_body bb ON b.body_id = bb.id" selectBatchBody = "SELECT content FROM batch_body WHERE id = ?" selectDescendingBatches = ` @@ -23,16 +23,17 @@ const ( ORDER BY b.sequence_order DESC LIMIT 1 ` - selectHeader = "select b.header from batch b" + selectHeader = "SELECt b.header FROM batch b" + selectTransactions = "SELECT t.hash FROM transactions t JOIN batch b ON t.body_id = b.body_id WHERE b.full_hash = ?" - insertBatchBody = "INSERT INTO batch_body (id, content) VALUES (?, ?)" - insertBatch = "INSERT INTO batch (sequence_order, full_hash, hash, height, tx_count, header, body_id) VALUES (?, ?, ?, ?, ?, ?,?)" - insertTxCount = "INSERT INTO transaction_count (id, count) VALUES (?, ?) ON DUPLICATE KEY UPDATE count = ?" + insertBatchBody = "INSERT INTO batch_body (id, content) VALUES (?, ?)" + insertBatch = "INSERT INTO batch (sequence_order, full_hash, hash, height, tx_count, header, body_id) VALUES (?, ?, ?, ?, ?, ?, ?)" + insertTransactions = "INSERT INTO transactions (hash, full_hash, body_id) VALUES (?, ?, ?)" + insertTxCount = "INSERT INTO transaction_count (id, total) VALUES (?, ?) ON CONFLICT(id) DO UPDATE SET total = excluded.total;" ) // AddBatch adds a batch and its header to the DB func AddBatch(db *sql.DB, batch *common.ExtBatch) error { - // Encode batch data batchBodyID := batch.Header.SequencerOrderNo.Uint64() body, err := rlp.EncodeToBytes(batch.EncryptedTxBlob) @@ -44,25 +45,34 @@ func AddBatch(db *sql.DB, batch *common.ExtBatch) error { return fmt.Errorf("could not encode batch header: %w", err) } - // Execute body insert + // Insert the batch body data _, err = db.Exec(insertBatchBody, batchBodyID, body) - //_, err = batchBodyStmt.Exec(batchBodyID, body) if err != nil { return fmt.Errorf("failed to insert body: %w", err) } if len(batch.TxHashes) > 0 { + //Insert transactions + for _, transaction := range batch.TxHashes { + // GET LAST 16 s + shortHash := truncTo16(transaction) + fullHash := transaction.Bytes() + _, err := db.Exec(insertTransactions, shortHash, fullHash, batchBodyID) + if err != nil { + return fmt.Errorf("failed to insert transaction with hash: %d", err) + } + } + //Increment total count var currentTotal int err := db.QueryRow(selectTxCount).Scan(¤tTotal) - if err != nil { - return fmt.Errorf("failed to retrieve current tx total value: %w", err) - } newTotal := currentTotal + len(batch.TxHashes) + // Increase the TX count _, err = db.Exec(insertTxCount, 1, newTotal, newTotal) if err != nil { return fmt.Errorf("failed to update transaction count: %w", err) } } + // Insert the batch data _, err = db.Exec(insertBatch, batch.Header.SequencerOrderNo.Uint64(), // sequence batch.Hash(), // full hash @@ -121,27 +131,56 @@ func GetBatchHeader(db *sql.DB, hash gethcommon.Hash) (*common.BatchHeader, erro return fetchBatchHeader(db, " where hash=?", truncTo16(hash)) } -// GetBatchHash returns the hash of a batch given its number. -func GetBatchHash(db *sql.DB, number *big.Int) (*gethcommon.Hash, error) { - panic("implement me") -} - -// GetBatchTxs returns the transaction hashes of the batch with the given hash. -func GetBatchTxs(db *sql.DB, batchHash gethcommon.Hash) ([]gethcommon.Hash, error) { - panic("implement me") +// GetBatchHashByNumber returns the hash of a batch given its number. +func GetBatchHashByNumber(db *sql.DB, number *big.Int) (*gethcommon.Hash, error) { + batch, err := fetchBatchHeader(db, " where sequence_order=?", number.Uint64()) + if err != nil { + return nil, err + } + l2BatchHash := batch.Hash() + return &l2BatchHash, nil } func GetHeadBatchHeader(db *sql.DB) (*common.BatchHeader, error) { batch, err := fetchHeadBatch(db) if err != nil { - return nil, fmt.Errorf("failed to fetch head batch: %w", err) + return nil, err } return batch.Header, nil } // GetBatchNumber returns the number of the batch containing the given transaction hash. func GetBatchNumber(db *sql.DB, txHash gethcommon.Hash) (*big.Int, error) { - panic("implement me") + batch, err := fetchBatchHeader(db, " where b.hash=?", truncTo16(txHash)) + if err != nil { + return nil, err + } + return batch.Number, nil +} + +// GetBatchTxs returns the transaction hashes of the batch with the given hash. +func GetBatchTxs(db *sql.DB, batchHash gethcommon.Hash) ([]gethcommon.Hash, error) { + rows, err := db.Query(selectTransactions, batchHash) + if err != nil { + return nil, fmt.Errorf("query execution failed: %w", err) + } + defer rows.Close() + + var transactions []gethcommon.Hash + for rows.Next() { + var txHashBytes []byte + if err := rows.Scan(&txHashBytes); err != nil { + return nil, fmt.Errorf("failed to scan transaction hash: %w", err) + } + txHash := gethcommon.BytesToHash(txHashBytes) + transactions = append(transactions, txHash) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error looping through transacion rows: %w", err) + } + + return transactions, nil } // GetTotalTransactions returns the total number of batched transactions. @@ -309,22 +348,22 @@ func fetchHeadBatch(db *sql.DB) (*common.PublicBatch, error) { var heightInt64 int var txCountInt64 int var headerBlob []byte - var body_id int + var bodyId int - err := db.QueryRow(selectDescendingBatches).Scan(&sequenceInt64, &fullHash, &hash, &heightInt64, &txCountInt64, &headerBlob, &body_id) + err := db.QueryRow(selectDescendingBatches).Scan(&sequenceInt64, &fullHash, &hash, &heightInt64, &txCountInt64, &headerBlob, &bodyId) if err != nil { if errors.Is(err, sql.ErrNoRows) { - return nil, fmt.Errorf("no batches found") + return nil, errutil.ErrNotFound } return nil, fmt.Errorf("failed to fetch current head batch: %w", err) } - //Select from batch_body table + var content []byte - err = db.QueryRow(selectBatchBody, &body_id).Scan(&content) + err = db.QueryRow(selectBatchBody, &bodyId).Scan(&content) if err != nil { return nil, fmt.Errorf("failed to fetch batch content given the id: %w", err) } - // Decode the batch header + var header common.BatchHeader err = rlp.DecodeBytes(headerBlob, &header) if err != nil { diff --git a/go/host/storage/hostdb/batch_test.go b/go/host/storage/hostdb/batch_test.go index 60b0cb0c32..16f1c6113c 100644 --- a/go/host/storage/hostdb/batch_test.go +++ b/go/host/storage/hostdb/batch_test.go @@ -114,7 +114,6 @@ func TestLowerNumberBatchDoesNotBecomeBatchHeader(t *testing.T) { //nolint:dupl } func TestHeadBatchHeaderIsNotSetInitially(t *testing.T) { - //FIXME db, err := createSQLiteDB(t) _, err = GetHeadBatchHeader(db) @@ -124,7 +123,6 @@ func TestHeadBatchHeaderIsNotSetInitially(t *testing.T) { } func TestCanRetrieveBatchHashByNumber(t *testing.T) { - //FIXME Implement me db, err := createSQLiteDB(t) batch, err := getBatch(batchNumber, []common.L2TxHash{}) if err != nil { @@ -136,7 +134,7 @@ func TestCanRetrieveBatchHashByNumber(t *testing.T) { t.Errorf("could not store batch. Cause: %s", err) } - batchHash, err := GetBatchHash(db, batch.Header.Number) + batchHash, err := GetBatchHashByNumber(db, batch.Header.Number) if err != nil { t.Errorf("stored batch but could not retrieve headers hash by number. Cause: %s", err) } @@ -149,7 +147,7 @@ func TestUnknownBatchNumberReturnsNotFound(t *testing.T) { db, err := createSQLiteDB(t) header := types.Header{} - _, err = GetBatchHash(db, header.Number) + _, err = GetBatchHashByNumber(db, header.Number) if !errors.Is(err, errutil.ErrNotFound) { t.Errorf("did not store batch hash but was able to retrieve it") } diff --git a/go/host/storage/hostdb/utils.go b/go/host/storage/hostdb/utils.go index 4d9e0103d2..cba373a34e 100644 --- a/go/host/storage/hostdb/utils.go +++ b/go/host/storage/hostdb/utils.go @@ -5,7 +5,7 @@ import gethcommon "github.com/ethereum/go-ethereum/common" const truncHash = 16 func truncTo16(hash gethcommon.Hash) []byte { - return truncBTo16(hash.Bytes()) + return truncLastTo16(hash.Bytes()) } func truncBTo16(bytes []byte) []byte { @@ -17,3 +17,17 @@ func truncBTo16(bytes []byte) []byte { copy(c, b) return c } + +func truncLastTo16(bytes []byte) []byte { + if len(bytes) == 0 { + return bytes + } + start := len(bytes) - truncHash + if start < 0 { + start = 0 + } + b := bytes[start:] + c := make([]byte, truncHash) + copy(c, b) + return c +}