Skip to content

Commit

Permalink
Improve reactions query using temporary table
Browse files Browse the repository at this point in the history
  • Loading branch information
jlelse committed Jul 26, 2024
1 parent f6a6ec2 commit cbe2e00
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 37 deletions.
1 change: 0 additions & 1 deletion app.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ type goBlog struct {
profileImageHashString string
profileImageHashGroup singleflight.Group
// Reactions
reactionsInit sync.Once
reactionsCache *ristretto.Cache
reactionsSfg singleflight.Group
// Regex Redirects
Expand Down
4 changes: 4 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ func (app *goBlog) initComponents() {
app.logErrAndQuit("Failed to init HTTP logging", "err", err)
return
}
if err = app.initReactions(); err != nil {
app.logErrAndQuit("Failed to init reactions", "err", err)
return
}
if err = app.initActivityPub(); err != nil {
app.logErrAndQuit("Failed to init ActivityPub", "err", err)
return
Expand Down
69 changes: 35 additions & 34 deletions reactions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package main

import (
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"

"github.com/dgraph-io/ristretto"
"github.com/samber/lo"
"go.goblog.app/app/pkgs/builderpool"
"go.goblog.app/app/pkgs/contenttype"
)

Expand All @@ -33,22 +34,38 @@ func (a *goBlog) reactionsEnabledForPost(post *post) bool {
return a.reactionsEnabled() && post != nil && post.firstParameter(reactionsPostParam) != "false"
}

func (a *goBlog) initReactions() {
a.reactionsInit.Do(func() {
if !a.reactionsEnabled() {
return
}
a.reactionsCache, _ = ristretto.NewCache(&ristretto.Config{
NumCounters: 1000,
MaxCost: 100, // Cache reactions for 100 posts
BufferItems: 64,
IgnoreInternalCost: true,
})
func (a *goBlog) initReactions() (err error) {
if !a.reactionsEnabled() {
return nil
}
a.reactionsCache, err = ristretto.NewCache(&ristretto.Config{
NumCounters: 1000,
MaxCost: 100, // Cache reactions for 100 posts
BufferItems: 64,
IgnoreInternalCost: true,
})
if err != nil {
return err
}
return a.createAllowedReactionsTempTable()
}

func (a *goBlog) createAllowedReactionsTempTable() error {
_, err := a.db.Exec("create temp table if not exists temp_allowed_reactions (r text)")
if err != nil {
return err
}
values := make([]string, len(allowedReactions))
args := make([]any, len(allowedReactions))
for i, reaction := range allowedReactions {
values[i] = "(?)"
args[i] = reaction
}
_, err = a.db.Exec(fmt.Sprintf(`insert into temp_allowed_reactions (r) values %s`, strings.Join(values, ", ")), args...)
return err
}

func (a *goBlog) deleteReactionsCache(path string) {
a.initReactions()
if a.reactionsCache != nil {
a.reactionsCache.Del(path)
}
Expand Down Expand Up @@ -76,8 +93,6 @@ func (a *goBlog) saveReaction(reaction, path string) error {
if !lo.Contains(allowedReactions, reaction) {
return errors.New("reaction not allowed")
}
// Init
a.initReactions()
// Delete from cache
defer a.reactionsSfg.Forget(path)
defer a.reactionsCache.Del(path)
Expand All @@ -99,33 +114,19 @@ func (a *goBlog) getReactions(w http.ResponseWriter, r *http.Request) {
}

func (a *goBlog) getReactionsFromDatabase(path string) (string, error) {
// Init
a.initReactions()
// Check cache
if val, cached := a.reactionsCache.Get(path); cached {
// Return from cache
return val.(string), nil
}
// Get reactions
res, err, _ := a.reactionsSfg.Do(path, func() (any, error) {
// Build query
sqlBuf := builderpool.Get()
defer builderpool.Put(sqlBuf)
sqlArgs := []any{}
sqlBuf.WriteString("select json_group_object(reaction, count) as json_result from (")
sqlBuf.WriteString("select reaction, count from reactions where path=? and reaction in (")
sqlArgs = append(sqlArgs, path)
for i, reaction := range allowedReactions {
if i > 0 {
sqlBuf.WriteString(",")
}
sqlBuf.WriteString("?")
sqlArgs = append(sqlArgs, reaction)
}
sqlBuf.WriteString(") and path not in (select path from post_parameters where parameter=? and value=?) and count > 0)")
sqlArgs = append(sqlArgs, reactionsPostParam, "false")
// Execute query
row, err := a.db.QueryRow(sqlBuf.String(), sqlArgs...)
row, err := a.db.QueryRow(`
select json_group_object(reaction, count) as json_result from (
select reaction, count from reactions where path=? and reaction in ( select r from temp_allowed_reactions )
and path not in (select path from post_parameters where parameter=? and value=?))
`, path, reactionsPostParam, "false")
if err != nil {
return nil, err
}
Expand Down
17 changes: 15 additions & 2 deletions reactions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@ func Test_reactionsLowLevel(t *testing.T) {
app := &goBlog{
cfg: createDefaultTestConfig(t),
}
app.cfg.Reactions = &configReactions{
Enabled: true,
}

_ = app.initConfig(false)
_ = app.initCache()

err := app.saveReaction("🖕", "/testpost")
err := app.initReactions()
require.NoError(t, err)

err = app.saveReaction("🖕", "/testpost")
assert.ErrorContains(t, err, "not allowed")

err = app.saveReaction("❤️", "/testpost")
Expand Down Expand Up @@ -94,12 +100,18 @@ func Test_reactionsHighLevel(t *testing.T) {
app := &goBlog{
cfg: createDefaultTestConfig(t),
}
app.cfg.Reactions = &configReactions{
Enabled: true,
}

_ = app.initConfig(false)
app.initMarkdown()
app.initSessions()
_ = app.initCache()

err := app.initReactions()
require.NoError(t, err)

// Send unsuccessful reaction
form := url.Values{
"reaction": {"❤️"},
Expand All @@ -112,7 +124,7 @@ func Test_reactionsHighLevel(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, rec.Code)

// Create a post
err := app.createPost(&post{
err = app.createPost(&post{
Path: "/testpost",
Content: "test",
})
Expand All @@ -128,6 +140,7 @@ func Test_reactionsHighLevel(t *testing.T) {
rec = httptest.NewRecorder()
app.postReaction(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "{\"❤️\":1}", rec.Body.String())

// Check if reaction count is 1
req = httptest.NewRequest(http.MethodGet, "/?path=/testpost", nil)
Expand Down

0 comments on commit cbe2e00

Please sign in to comment.