Skip to content

Commit

Permalink
fix race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
tamirms committed Aug 14, 2024
1 parent 35174b0 commit c89e633
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 60 deletions.
118 changes: 89 additions & 29 deletions services/horizon/internal/db2/history/account_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,6 @@ type LoaderStats struct {
Inserted int
}

type accountGetOrCreate struct {
Account
Inserted bool `db:"inserted"`
}

// Exec will look up all the history account ids for the addresses registered in the loader.
// If there are no history account ids for a given set of addresses, Exec will insert rows
// into the history_accounts table to establish a mapping between address and history account id.
Expand All @@ -112,8 +107,8 @@ func (a *AccountLoader) Exec(ctx context.Context, session db.SessionInterface) e
// https://github.com/stellar/go/issues/2370
sort.Strings(addresses)

var accounts []accountGetOrCreate
err := bulkGetOrCreate(
var accounts []Account
err := bulkInsert(
ctx,
q,
"history_accounts",
Expand All @@ -131,11 +126,41 @@ func (a *AccountLoader) Exec(ctx context.Context, session db.SessionInterface) e
}
for _, account := range accounts {
a.ids[account.Address] = account.ID
if account.Inserted {
a.stats.Inserted++
}
a.stats.Inserted++
}
a.stats.Total += len(accounts)

remaining := make([]string, 0, len(addresses))
for _, address := range addresses {
if _, ok := a.ids[address]; ok {
continue
}
remaining = append(remaining, address)
}
if len(remaining) > 0 {
var remainingAccounts []Account
err = bulkGet(
ctx,
q,
"history_accounts",
[]columnValues{
{
name: "address",
dbType: "character varying(64)",
objects: remaining,
},
},
&remainingAccounts,
)
if err != nil {
return err
}
for _, account := range remainingAccounts {
a.ids[account.Address] = account.ID
}
a.stats.Total += len(remainingAccounts)
}

return nil
}

Expand All @@ -155,12 +180,12 @@ type columnValues struct {
objects []string
}

func bulkGetOrCreate(ctx context.Context, q *Q, table string, fields []columnValues, response interface{}) error {
func bulkInsert(ctx context.Context, q *Q, table string, fields []columnValues, response interface{}) error {
unnestPart := make([]string, 0, len(fields))
insertFieldsPart := make([]string, 0, len(fields))
pqArrays := make([]interface{}, 0, len(fields))

// In the code below we are building the bulk insert part of the query which looks like:
// In the code below we are building the bulk insert query which looks like:
//
// WITH rows AS
// (SELECT
Expand All @@ -175,7 +200,7 @@ func bulkGetOrCreate(ctx context.Context, q *Q, table string, fields []columnVal
// field2,
// ...
// )
// SELECT * FROM rows ON CONFLICT (field1, field2, ...) DO NOTHING
// SELECT * FROM rows ON CONFLICT (field1, field2, ...) DO NOTHING RETURNING *
//
// Using unnest allows to get around the maximum limit of 65,535 query parameters,
// see https://www.postgresql.org/docs/12/limits.html and
Expand All @@ -199,25 +224,60 @@ func bulkGetOrCreate(ctx context.Context, q *Q, table string, fields []columnVal
}
columns := strings.Join(insertFieldsPart, ",")

// We can combine the inserted rows with a query to find pre-existing rows
// using a UNION ALL clause. Note that the query to fetch pre-existing rows
// will not see the effects of the inserted_rows CTE because of the snapshot
// isolation semantics of postgres CTEs (see
// https://www.postgresql.org/docs/12/queries-with.html ).
sql := `
WITH rows AS
(SELECT ` + strings.Join(unnestPart, ",") + `),
inserted_rows AS (
INSERT INTO ` + table + `
(` + columns + `)
SELECT * FROM rows
ON CONFLICT (` + columns + `) DO NOTHING
RETURNING *
(SELECT ` + strings.Join(unnestPart, ",") + `)
INSERT INTO ` + table + `
(` + columns + `)
SELECT * FROM rows
ON CONFLICT (` + columns + `) DO NOTHING
RETURNING *`

return q.SelectRaw(
ctx,
response,
sql,
pqArrays...,
)
SELECT *, true as inserted FROM inserted_rows
UNION ALL
SELECT *, false as inserted FROM ` + table + ` WHERE (` + columns + `) IN
(SELECT * FROM rows)`
}

func bulkGet(ctx context.Context, q *Q, table string, fields []columnValues, response interface{}) error {
unnestPart := make([]string, 0, len(fields))
columns := make([]string, 0, len(fields))
pqArrays := make([]interface{}, 0, len(fields))

// In the code below we are building the bulk get query which looks like:
//
// SELECT * FROM table WHERE (field1, field2, ...) IN
// (SELECT
// /* unnestPart */
// unnest(?::type1[]), /* field1 */
// unnest(?::type2[]), /* field2 */
// ...
// )
//
// Using unnest allows to get around the maximum limit of 65,535 query parameters,
// see https://www.postgresql.org/docs/12/limits.html and
// https://klotzandrew.com/blog/postgres-passing-65535-parameter-limit/
//
// Without using unnest we would have to use multiple select statements to obtain
// all the rows for large datasets.
for _, field := range fields {
unnestPart = append(
unnestPart,
fmt.Sprintf("unnest(?::%s[]) /* %s */", field.dbType, field.name),
)
columns = append(
columns,
field.name,
)
pqArrays = append(
pqArrays,
pq.Array(field.objects),
)
}
sql := `SELECT * FROM ` + table + ` WHERE (` + strings.Join(columns, ",") + `) IN
(SELECT ` + strings.Join(unnestPart, ",") + `)`

return q.SelectRaw(
ctx,
Expand Down
66 changes: 56 additions & 10 deletions services/horizon/internal/db2/history/asset_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,6 @@ func (a *AssetLoader) GetNow(asset AssetKey) (int64, error) {
}
}

type assetGetOrCreate struct {
Asset
Inserted bool `db:"inserted"`
}

// Exec will look up all the history asset ids for the assets registered in the loader.
// If there are no history asset ids for a given set of assets, Exec will insert rows
// into the history_assets table.
Expand Down Expand Up @@ -132,8 +127,8 @@ func (a *AssetLoader) Exec(ctx context.Context, session db.SessionInterface) err
assetIssuers = append(assetIssuers, key.Issuer)
}

var rows []assetGetOrCreate
err := bulkGetOrCreate(
var rows []Asset
err := bulkInsert(
ctx,
q,
"history_assets",
Expand Down Expand Up @@ -165,12 +160,63 @@ func (a *AssetLoader) Exec(ctx context.Context, session db.SessionInterface) err
Code: row.Code,
Issuer: row.Issuer,
}] = row.ID
if row.Inserted {
a.stats.Inserted++
a.stats.Inserted++
}
a.stats.Total += len(rows)

remaining := make([]AssetKey, 0, len(keys))
for _, key := range keys {
if _, ok := a.ids[key]; ok {
continue
}
remaining = append(remaining, key)
}
a.stats.Total += len(keys)
if len(remaining) > 0 {
assetTypes = make([]string, 0, len(remaining))
assetCodes = make([]string, 0, len(remaining))
assetIssuers = make([]string, 0, len(remaining))
for _, key := range remaining {
assetTypes = append(assetTypes, key.Type)
assetCodes = append(assetCodes, key.Code)
assetIssuers = append(assetIssuers, key.Issuer)
}

var remainingRows []Asset
err = bulkGet(
ctx,
q,
"history_assets",
[]columnValues{
{
name: "asset_code",
dbType: "character varying(12)",
objects: assetCodes,
},
{
name: "asset_type",
dbType: "character varying(64)",
objects: assetTypes,
},
{
name: "asset_issuer",
dbType: "character varying(56)",
objects: assetIssuers,
},
},
&remainingRows,
)
if err != nil {
return err
}
for _, row := range remainingRows {
a.ids[AssetKey{
Type: row.Type,
Code: row.Code,
Issuer: row.Issuer,
}] = row.ID
}
a.stats.Total += len(remainingRows)
}
return nil
}

Expand Down
44 changes: 34 additions & 10 deletions services/horizon/internal/db2/history/claimable_balance_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,6 @@ func (a *ClaimableBalanceLoader) getNow(id string) (int64, error) {
}
}

type historyClaimableBalanceGetOrCreate struct {
HistoryClaimableBalance
Inserted bool `db:"inserted"`
}

// Exec will look up all the internal history ids for the claimable balances registered in the loader.
// If there are no internal ids for a given set of claimable balances, Exec will insert rows
// into the history_claimable_balances table.
Expand All @@ -98,8 +93,8 @@ func (a *ClaimableBalanceLoader) Exec(ctx context.Context, session db.SessionInt
// sort entries before inserting rows to prevent deadlocks on acquiring a ShareLock
// https://github.com/stellar/go/issues/2370
sort.Strings(ids)
var rows []historyClaimableBalanceGetOrCreate
err := bulkGetOrCreate(
var rows []HistoryClaimableBalance
err := bulkInsert(
ctx,
q,
"history_claimable_balances",
Expand All @@ -117,11 +112,40 @@ func (a *ClaimableBalanceLoader) Exec(ctx context.Context, session db.SessionInt
}
for _, row := range rows {
a.ids[row.BalanceID] = row.InternalID
if row.Inserted {
a.stats.Inserted++
a.stats.Inserted++
}
a.stats.Total += len(rows)

remaining := make([]string, 0, len(ids))
for _, id := range ids {
if _, ok := a.ids[id]; ok {
continue
}
remaining = append(remaining, id)
}
if len(remaining) > 0 {
var remainingRows []HistoryClaimableBalance
err = bulkGet(
ctx,
q,
"history_claimable_balances",
[]columnValues{
{
name: "claimable_balance_id",
dbType: "text",
objects: remaining,
},
},
&remainingRows,
)
if err != nil {
return err
}
for _, row := range remainingRows {
a.ids[row.BalanceID] = row.InternalID
}
a.stats.Total += len(remainingRows)
}
a.stats.Total += len(ids)
return nil
}

Expand Down
45 changes: 34 additions & 11 deletions services/horizon/internal/db2/history/liquidity_pool_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,6 @@ func (a *LiquidityPoolLoader) GetNow(id string) (int64, error) {
}
}

type historyLiquidityPoolGetOrCreate struct {
HistoryLiquidityPool
Inserted bool `db:"inserted"`
}

// Exec will look up all the internal history ids for the liquidity pools registered in the loader.
// If there are no internal history ids for a given set of liquidity pools, Exec will insert rows
// into the history_liquidity_pools table.
Expand All @@ -98,8 +93,8 @@ func (a *LiquidityPoolLoader) Exec(ctx context.Context, session db.SessionInterf
// sort entries before inserting rows to prevent deadlocks on acquiring a ShareLock
// https://github.com/stellar/go/issues/2370
sort.Strings(ids)
var rows []historyLiquidityPoolGetOrCreate
err := bulkGetOrCreate(
var rows []HistoryLiquidityPool
err := bulkInsert(
ctx,
q,
"history_liquidity_pools",
Expand All @@ -117,12 +112,40 @@ func (a *LiquidityPoolLoader) Exec(ctx context.Context, session db.SessionInterf
}
for _, row := range rows {
a.ids[row.PoolID] = row.InternalID
if row.Inserted {
a.stats.Inserted++
}
a.stats.Inserted++
}
a.stats.Total += len(ids)
a.stats.Total += len(rows)

remaining := make([]string, 0, len(ids))
for _, id := range ids {
if _, ok := a.ids[id]; ok {
continue
}
remaining = append(remaining, id)
}
if len(remaining) > 0 {
var remainingRows []HistoryLiquidityPool
err = bulkGet(
ctx,
q,
"history_liquidity_pools",
[]columnValues{
{
name: "liquidity_pool_id",
dbType: "text",
objects: remaining,
},
},
&remainingRows,
)
if err != nil {
return err
}
for _, row := range remainingRows {
a.ids[row.PoolID] = row.InternalID
}
a.stats.Total += len(remainingRows)
}
return nil
}

Expand Down

0 comments on commit c89e633

Please sign in to comment.