Skip to content

Commit

Permalink
Merge pull request #6 from JasonYangShadow/insert_ignore
Browse files Browse the repository at this point in the history
  • Loading branch information
mstg authored Nov 28, 2023
2 parents d39dbc9 + 9c37a40 commit 131e7ce
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 8 deletions.
4 changes: 2 additions & 2 deletions pika.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ type QuerySet[T any] interface {
ClearAll() QuerySet[T]

// Create creates a new value
Create(value *T) error
Create(value *T, options ...CreateOption) error

// Update updates a value
// All filters will be applied
Expand Down Expand Up @@ -130,7 +130,7 @@ type QuerySet[T any] interface {
// Query related methods

// CreateQuery returns the query and args for Create
CreateQuery(value *T) (string, []interface{})
CreateQuery(value *T, options ...CreateOption) (string, []interface{})

// UpdateQuery returns the query and args for Update
UpdateQuery(value *T) (string, []interface{})
Expand Down
35 changes: 29 additions & 6 deletions pika_psql.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ type basePsql[T any] struct {
psql *PostgreSQL
}

type CreateOption byte

const InsertOnConflictionDoNothing CreateOption = 1 << iota

// NewPostgreSQL returns a new PostgreSQL instance.
// connectionString should be sqlx compatible.
func NewPostgreSQL(connectionString string) (*PostgreSQL, error) {
Expand Down Expand Up @@ -254,19 +258,24 @@ func (b *basePsql[T]) ClearAll() QuerySet[T] {
}

// Create creates a new record in the database.
func (b *basePsql[T]) Create(x *T) error {
func (b *basePsql[T]) Create(x *T, options ...CreateOption) error {
if b.err != nil {
return b.err
}

origIgnoreOrderBy := b.ignoreOrderBy
b.ignoreOrderBy = true
q, args := b.CreateQuery(x)

q, args := b.CreateQuery(x, options...)
b.ignoreOrderBy = origIgnoreOrderBy

// Execute query
err := b.psql.Queryable().Get(x, q, args...)
if err != nil {
// ignore no rows in resultset error when ignoreConflict is set to true, this is a normal case
if errors.Is(err, sql.ErrNoRows) && (InsertOnConflictionDoNothing&getOption(options...) != 0) {
return nil
}
return err
}

Expand Down Expand Up @@ -468,8 +477,8 @@ func (b *basePsql[T]) ResetOrderBy() QuerySet[T] {
}

// CreateQuery returns the query and arguments for Create
func (b *basePsql[T]) CreateQuery(x *T) (string, []interface{}) {
q, args := b.psqlCreateQuery(x)
func (b *basePsql[T]) CreateQuery(x *T, options ...CreateOption) (string, []interface{}) {
q, args := b.psqlCreateQuery(x, options...)
logger.Debugf("Pika query: %s", q)

return q, args
Expand Down Expand Up @@ -1126,7 +1135,7 @@ func (b *basePsql[T]) psqlCountQuery() string {
return q
}

func (b *basePsql[T]) psqlCreateQuery(value *T) (string, []any) {
func (b *basePsql[T]) psqlCreateQuery(value *T, options ...CreateOption) (string, []any) {
// Get info from metadata
tableName := b.metadata[PikaMetadataTableName]
modelName := b.metadata[pikaMetadataModelName]
Expand Down Expand Up @@ -1178,7 +1187,12 @@ func (b *basePsql[T]) psqlCreateQuery(value *T) (string, []any) {
// Remove the model name prefix from the select list
// since we are inserting into the table
selectList = strings.Replace(selectList, fmt.Sprintf("\"%s\".", modelName), "", -1)
q := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s) RETURNING %s", tableName, columnStr, valueStr, selectList)
conflict := ""

if InsertOnConflictionDoNothing&getOption(options...) != 0 {
conflict = " ON CONFLICT DO NOTHING"
}
q := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s)%s RETURNING %s", tableName, columnStr, valueStr, conflict, selectList)

// Convert value to arguments
args := make([]interface{}, 0, ref.Elem().NumField())
Expand Down Expand Up @@ -1425,3 +1439,12 @@ func generateRangeSlice(start, length int) *[]int {
}
return &ret
}

func getOption(options ...CreateOption) CreateOption {
var option CreateOption
for _, o := range options {
option |= o
}

return option
}
47 changes: 47 additions & 0 deletions pika_psql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,53 @@ func TestCreate(t *testing.T) {
require.Equal(t, entry.Description, x[0].Description)
}

func TestCreateIgnore(t *testing.T) {
psql := newPsql(t)
createTestModelCreate(t, psql)
qs := Q[simpleModelCreate](psql)

// Create a new entry
entry := simpleModelCreate{
ID: 1,
Title: "test",
Description: "test-description",
}

expectedQuery := `INSERT INTO "simple_model_create" ("id", "title", "description") VALUES ($1, $2, $3) ON CONFLICT DO NOTHING RETURNING "id", "title", "description"`
expectedArgs := []interface{}{1, "test", "test-description"}
actualQuery, actualArgs := qs.CreateQuery(&entry, InsertOnConflictionDoNothing)
require.Equal(t, expectedQuery, actualQuery)
require.Equal(t, expectedArgs, actualArgs)

err := qs.Create(&entry)
require.Nil(t, err)
require.Equal(t, 1, entry.ID)

// create again will cause duplication error
err = qs.Create(&entry)
require.NotNil(t, err)
require.Contains(t, err.Error(), "duplicate key value violates unique constraint")

// createignore will not throw out the error
err = qs.Create(&entry, InsertOnConflictionDoNothing)
require.Nil(t, err)

// Select the entry and check if it is the same
qs = Q[simpleModelCreate](psql)
args := NewArgs()
args.Set("id", entry.ID)
args.Set("title", entry.Title)
args.Set("description", entry.Description)
qs = qs.Filter("id=:id", "title=:title", "description=:description").Args(args)

x, err := qs.All()
require.Nil(t, err)
require.Equal(t, 1, len(x))
require.Equal(t, entry.ID, x[0].ID)
require.Equal(t, entry.Title, x[0].Title)
require.Equal(t, entry.Description, x[0].Description)
}

func TestUpdate(t *testing.T) {
psql := newPsql(t)
createTestModelCreate(t, psql)
Expand Down

0 comments on commit 131e7ce

Please sign in to comment.