From 8f64b5414dde5832d51a3d0910f168a0060fec24 Mon Sep 17 00:00:00 2001 From: jason yang Date: Fri, 24 Nov 2023 18:00:38 +0900 Subject: [PATCH] add ON CONFLICT DO NOTHING for insert Signed-off-by: jason yang --- pika.go | 5 ++++- pika_psql.go | 26 ++++++++++++++++++++----- pika_psql_test.go | 49 ++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 73 insertions(+), 7 deletions(-) diff --git a/pika.go b/pika.go index 5cc04cc..a1a6a93 100644 --- a/pika.go +++ b/pika.go @@ -130,7 +130,7 @@ type QuerySet[T any] interface { // Query related methods // CreateQuery returns the query and args for Create - CreateQuery(value *T) (string, []interface{}) + CreateQuery(value *T, ignoreConflict bool) (string, []interface{}) // UpdateQuery returns the query and args for Update UpdateQuery(value *T) (string, []interface{}) @@ -165,6 +165,9 @@ type QuerySet[T any] interface { Exclude(excludes ...string) QuerySet[T] // Include fields Include(includes ...string) QuerySet[T] + + // CreateIgnore creates a new value but will not return error if conflict occurs + CreateIgnore(value *T) error } func NewArgs() *orderedmap.OrderedMap[string, any] { diff --git a/pika_psql.go b/pika_psql.go index f762e9e..9e061f1 100644 --- a/pika_psql.go +++ b/pika_psql.go @@ -255,18 +255,30 @@ func (b *basePsql[T]) ClearAll() QuerySet[T] { // Create creates a new record in the database. func (b *basePsql[T]) Create(x *T) error { + return b.create(x, false) +} + +func (b *basePsql[T]) CreateIgnore(x *T) error { + return b.create(x, true) +} + +func (b *basePsql[T]) create(x *T, ignoreConflict bool) error { if b.err != nil { return b.err } origIgnoreOrderBy := b.ignoreOrderBy b.ignoreOrderBy = true - q, args := b.CreateQuery(x) + q, args := b.CreateQuery(x, ignoreConflict) b.ignoreOrderBy = origIgnoreOrderBy // Execute query err := b.psql.Queryable().Get(x, q, args...) if err != nil { + // ignore no rows in resultset error when ignoreConflict is set to true, this is a normal case + if errors.Is(err, sql.ErrNoRows) && ignoreConflict { + return nil + } return err } @@ -468,8 +480,8 @@ func (b *basePsql[T]) ResetOrderBy() QuerySet[T] { } // CreateQuery returns the query and arguments for Create -func (b *basePsql[T]) CreateQuery(x *T) (string, []interface{}) { - q, args := b.psqlCreateQuery(x) +func (b *basePsql[T]) CreateQuery(x *T, ignoreConflict bool) (string, []interface{}) { + q, args := b.psqlCreateQuery(x, ignoreConflict) logger.Debugf("Pika query: %s", q) return q, args @@ -1126,7 +1138,7 @@ func (b *basePsql[T]) psqlCountQuery() string { return q } -func (b *basePsql[T]) psqlCreateQuery(value *T) (string, []any) { +func (b *basePsql[T]) psqlCreateQuery(value *T, ignoreConflict bool) (string, []any) { // Get info from metadata tableName := b.metadata[PikaMetadataTableName] modelName := b.metadata[pikaMetadataModelName] @@ -1178,7 +1190,11 @@ func (b *basePsql[T]) psqlCreateQuery(value *T) (string, []any) { // Remove the model name prefix from the select list // since we are inserting into the table selectList = strings.Replace(selectList, fmt.Sprintf("\"%s\".", modelName), "", -1) - q := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s) RETURNING %s", tableName, columnStr, valueStr, selectList) + conflict := "" + if ignoreConflict { + conflict = " ON CONFLICT DO NOTHING" + } + q := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s)%s RETURNING %s", tableName, columnStr, valueStr, conflict, selectList) // Convert value to arguments args := make([]interface{}, 0, ref.Elem().NumField()) diff --git a/pika_psql_test.go b/pika_psql_test.go index 0880b28..8452c43 100644 --- a/pika_psql_test.go +++ b/pika_psql_test.go @@ -1120,7 +1120,7 @@ func TestCreate(t *testing.T) { expectedQuery := `INSERT INTO "simple_model_create" ("title", "description") VALUES ($1, $2) RETURNING "id", "title", "description"` expectedArgs := []interface{}{"test", "test-description"} - actualQuery, actualArgs := qs.CreateQuery(&entry) + actualQuery, actualArgs := qs.CreateQuery(&entry, false) require.Equal(t, expectedQuery, actualQuery) require.Equal(t, expectedArgs, actualArgs) @@ -1144,6 +1144,53 @@ func TestCreate(t *testing.T) { require.Equal(t, entry.Description, x[0].Description) } +func TestCreateIgnore(t *testing.T) { + psql := newPsql(t) + createTestModelCreate(t, psql) + qs := Q[simpleModelCreate](psql) + + // Create a new entry + entry := simpleModelCreate{ + ID: 1, + Title: "test", + Description: "test-description", + } + + expectedQuery := `INSERT INTO "simple_model_create" ("id", "title", "description") VALUES ($1, $2, $3) ON CONFLICT DO NOTHING RETURNING "id", "title", "description"` + expectedArgs := []interface{}{1, "test", "test-description"} + actualQuery, actualArgs := qs.CreateQuery(&entry, true) + require.Equal(t, expectedQuery, actualQuery) + require.Equal(t, expectedArgs, actualArgs) + + err := qs.Create(&entry) + require.Nil(t, err) + require.Equal(t, 1, entry.ID) + + // create again will cause duplication error + err = qs.Create(&entry) + require.NotNil(t, err) + require.Contains(t, err.Error(), "duplicate key value violates unique constraint") + + // createignore will not throw out the error + err = qs.CreateIgnore(&entry) + require.Nil(t, err) + + // Select the entry and check if it is the same + qs = Q[simpleModelCreate](psql) + args := NewArgs() + args.Set("id", entry.ID) + args.Set("title", entry.Title) + args.Set("description", entry.Description) + qs = qs.Filter("id=:id", "title=:title", "description=:description").Args(args) + + x, err := qs.All() + require.Nil(t, err) + require.Equal(t, 1, len(x)) + require.Equal(t, entry.ID, x[0].ID) + require.Equal(t, entry.Title, x[0].Title) + require.Equal(t, entry.Description, x[0].Description) +} + func TestUpdate(t *testing.T) { psql := newPsql(t) createTestModelCreate(t, psql)