diff --git a/config_test.go b/config_test.go index 4ac3876..eac6c83 100644 --- a/config_test.go +++ b/config_test.go @@ -8,7 +8,12 @@ import ( "github.com/stretchr/testify/require" ) -func createDefaultTestConfig(t *testing.T) *config { +type testConfigParam interface { + *testing.T | *testing.B + TempDir() string +} + +func createDefaultTestConfig[V testConfigParam](t V) *config { c := createDefaultConfig() dir := t.TempDir() c.Db.File = filepath.Join(dir, "blog.db") diff --git a/markdown_test.go b/markdown_test.go index 91ae261..3a8da00 100644 --- a/markdown_test.go +++ b/markdown_test.go @@ -100,11 +100,7 @@ func Benchmark_markdown(b *testing.B) { mdExp := string(markdownExample) app := &goBlog{ - cfg: &config{ - Server: &configServer{ - PublicAddress: "https://example.com", - }, - }, + cfg: createDefaultTestConfig(b), } app.initMarkdown() diff --git a/reactions.go b/reactions.go index b0c1ba8..f6948df 100644 --- a/reactions.go +++ b/reactions.go @@ -4,11 +4,11 @@ import ( "errors" "io" "net/http" + "strings" "time" "github.com/dgraph-io/ristretto" "github.com/samber/lo" - "go.goblog.app/app/pkgs/builderpool" "go.goblog.app/app/pkgs/contenttype" ) @@ -16,12 +16,9 @@ const reactionsCacheTTL = 6 * time.Hour // Hardcoded for now var allowedReactions = []string{ - "❤️", - "👍", - "🎉", - "😂", - "😱", + "❤️", "👍", "🎉", "😂", "😱", } +var allowedReactionsStr = strings.Join(allowedReactions, "") func (a *goBlog) reactionsEnabled() bool { return a.cfg.Reactions != nil && a.cfg.Reactions.Enabled @@ -98,6 +95,10 @@ func (a *goBlog) getReactions(w http.ResponseWriter, r *http.Request) { io.WriteString(w, reactions) } +const reactionsQuery = "select json_group_object(reaction, count) as json_result from (" + + "select reaction, count from reactions where path = ? and instr(?, reaction) > 0 " + + "and path not in (select path from post_parameters where parameter=? and value=?) and count > 0)" + func (a *goBlog) getReactionsFromDatabase(path string) (string, error) { // Init a.initReactions() @@ -108,24 +109,7 @@ func (a *goBlog) getReactionsFromDatabase(path string) (string, error) { } // Get reactions res, err, _ := a.reactionsSfg.Do(path, func() (any, error) { - // Build query - sqlBuf := builderpool.Get() - defer builderpool.Put(sqlBuf) - sqlArgs := []any{} - sqlBuf.WriteString("select json_group_object(reaction, count) as json_result from (") - sqlBuf.WriteString("select reaction, count from reactions where path=? and reaction in (") - sqlArgs = append(sqlArgs, path) - for i, reaction := range allowedReactions { - if i > 0 { - sqlBuf.WriteString(",") - } - sqlBuf.WriteString("?") - sqlArgs = append(sqlArgs, reaction) - } - sqlBuf.WriteString(") and path not in (select path from post_parameters where parameter=? and value=?) and count > 0)") - sqlArgs = append(sqlArgs, reactionsPostParam, "false") - // Execute query - row, err := a.db.QueryRow(sqlBuf.String(), sqlArgs...) + row, err := a.db.QueryRow(reactionsQuery, path, allowedReactionsStr, reactionsPostParam, "false") if err != nil { return nil, err } diff --git a/reactions_test.go b/reactions_test.go index 1ce0254..d015f4e 100644 --- a/reactions_test.go +++ b/reactions_test.go @@ -15,6 +15,7 @@ func Test_reactionsLowLevel(t *testing.T) { app := &goBlog{ cfg: createDefaultTestConfig(t), } + app.cfg.Reactions = &configReactions{Enabled: true} _ = app.initConfig(false) _ = app.initCache() @@ -94,6 +95,7 @@ func Test_reactionsHighLevel(t *testing.T) { app := &goBlog{ cfg: createDefaultTestConfig(t), } + app.cfg.Reactions = &configReactions{Enabled: true} _ = app.initConfig(false) app.initMarkdown()