diff --git a/pika.go b/pika.go index 5cc04cc..de13491 100644 --- a/pika.go +++ b/pika.go @@ -85,7 +85,7 @@ type QuerySet[T any] interface { ClearAll() QuerySet[T] // Create creates a new value - Create(value *T) error + Create(value *T, options ...CreateOption) error // Update updates a value // All filters will be applied @@ -130,7 +130,7 @@ type QuerySet[T any] interface { // Query related methods // CreateQuery returns the query and args for Create - CreateQuery(value *T) (string, []interface{}) + CreateQuery(value *T, options ...CreateOption) (string, []interface{}) // UpdateQuery returns the query and args for Update UpdateQuery(value *T) (string, []interface{}) diff --git a/pika_psql.go b/pika_psql.go index f762e9e..c03dbe5 100644 --- a/pika_psql.go +++ b/pika_psql.go @@ -67,6 +67,10 @@ type basePsql[T any] struct { psql *PostgreSQL } +type CreateOption byte + +const InsertOnConflictionDoNothing CreateOption = 1 << iota + // NewPostgreSQL returns a new PostgreSQL instance. // connectionString should be sqlx compatible. func NewPostgreSQL(connectionString string) (*PostgreSQL, error) { @@ -254,19 +258,24 @@ func (b *basePsql[T]) ClearAll() QuerySet[T] { } // Create creates a new record in the database. -func (b *basePsql[T]) Create(x *T) error { +func (b *basePsql[T]) Create(x *T, options ...CreateOption) error { if b.err != nil { return b.err } origIgnoreOrderBy := b.ignoreOrderBy b.ignoreOrderBy = true - q, args := b.CreateQuery(x) + + q, args := b.CreateQuery(x, options...) b.ignoreOrderBy = origIgnoreOrderBy // Execute query err := b.psql.Queryable().Get(x, q, args...) if err != nil { + // ignore no rows in resultset error when ignoreConflict is set to true, this is a normal case + if errors.Is(err, sql.ErrNoRows) && (InsertOnConflictionDoNothing&getOption(options...) != 0) { + return nil + } return err } @@ -468,8 +477,8 @@ func (b *basePsql[T]) ResetOrderBy() QuerySet[T] { } // CreateQuery returns the query and arguments for Create -func (b *basePsql[T]) CreateQuery(x *T) (string, []interface{}) { - q, args := b.psqlCreateQuery(x) +func (b *basePsql[T]) CreateQuery(x *T, options ...CreateOption) (string, []interface{}) { + q, args := b.psqlCreateQuery(x, options...) logger.Debugf("Pika query: %s", q) return q, args @@ -1126,7 +1135,7 @@ func (b *basePsql[T]) psqlCountQuery() string { return q } -func (b *basePsql[T]) psqlCreateQuery(value *T) (string, []any) { +func (b *basePsql[T]) psqlCreateQuery(value *T, options ...CreateOption) (string, []any) { // Get info from metadata tableName := b.metadata[PikaMetadataTableName] modelName := b.metadata[pikaMetadataModelName] @@ -1178,7 +1187,12 @@ func (b *basePsql[T]) psqlCreateQuery(value *T) (string, []any) { // Remove the model name prefix from the select list // since we are inserting into the table selectList = strings.Replace(selectList, fmt.Sprintf("\"%s\".", modelName), "", -1) - q := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s) RETURNING %s", tableName, columnStr, valueStr, selectList) + conflict := "" + + if InsertOnConflictionDoNothing&getOption(options...) != 0 { + conflict = " ON CONFLICT DO NOTHING" + } + q := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s)%s RETURNING %s", tableName, columnStr, valueStr, conflict, selectList) // Convert value to arguments args := make([]interface{}, 0, ref.Elem().NumField()) @@ -1425,3 +1439,12 @@ func generateRangeSlice(start, length int) *[]int { } return &ret } + +func getOption(options ...CreateOption) CreateOption { + var option CreateOption + for _, o := range options { + option |= o + } + + return option +} diff --git a/pika_psql_test.go b/pika_psql_test.go index 0880b28..a457727 100644 --- a/pika_psql_test.go +++ b/pika_psql_test.go @@ -1144,6 +1144,53 @@ func TestCreate(t *testing.T) { require.Equal(t, entry.Description, x[0].Description) } +func TestCreateIgnore(t *testing.T) { + psql := newPsql(t) + createTestModelCreate(t, psql) + qs := Q[simpleModelCreate](psql) + + // Create a new entry + entry := simpleModelCreate{ + ID: 1, + Title: "test", + Description: "test-description", + } + + expectedQuery := `INSERT INTO "simple_model_create" ("id", "title", "description") VALUES ($1, $2, $3) ON CONFLICT DO NOTHING RETURNING "id", "title", "description"` + expectedArgs := []interface{}{1, "test", "test-description"} + actualQuery, actualArgs := qs.CreateQuery(&entry, InsertOnConflictionDoNothing) + require.Equal(t, expectedQuery, actualQuery) + require.Equal(t, expectedArgs, actualArgs) + + err := qs.Create(&entry) + require.Nil(t, err) + require.Equal(t, 1, entry.ID) + + // create again will cause duplication error + err = qs.Create(&entry) + require.NotNil(t, err) + require.Contains(t, err.Error(), "duplicate key value violates unique constraint") + + // createignore will not throw out the error + err = qs.Create(&entry, InsertOnConflictionDoNothing) + require.Nil(t, err) + + // Select the entry and check if it is the same + qs = Q[simpleModelCreate](psql) + args := NewArgs() + args.Set("id", entry.ID) + args.Set("title", entry.Title) + args.Set("description", entry.Description) + qs = qs.Filter("id=:id", "title=:title", "description=:description").Args(args) + + x, err := qs.All() + require.Nil(t, err) + require.Equal(t, 1, len(x)) + require.Equal(t, entry.ID, x[0].ID) + require.Equal(t, entry.Title, x[0].Title) + require.Equal(t, entry.Description, x[0].Description) +} + func TestUpdate(t *testing.T) { psql := newPsql(t) createTestModelCreate(t, psql)