From baf0affeb9ec1e6283d16ddbc259ae6caa50677c Mon Sep 17 00:00:00 2001 From: Marvin Blum Date: Wed, 11 Dec 2024 17:29:49 +0100 Subject: [PATCH] Fixed concurrent data access. --- CHANGELOG.md | 1 + Makefile | 6 +++++- pkg/cms/cms.go | 49 ++++++++++++++++++++++++++------------------ pkg/cms/cms_test.go | 37 +++++++++++++++++++++++++++++++++ pkg/cms/tpl_cache.go | 11 +++++++--- 5 files changed, 80 insertions(+), 24 deletions(-) create mode 100644 pkg/cms/cms_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 2214b65..09d1f08 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## 0.10.0 +* fixed concurrent data access * updated Go version * updated dependencies diff --git a/Makefile b/Makefile index aa41c2b..584c60b 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,11 @@ deps: go mod vendor test: - go test -cover ./pkg/... + go test -cover ./pkg/cfg + go test -cover -race $$(go list ./pkg/... | grep -v /cfg) + +benchmark: + go test -bench=. ./pkg/... build_mac: test GOOS=darwin go build -a -installsuffix cgo -ldflags "-s -w" cmd/shifu/main.go diff --git a/pkg/cms/cms.go b/pkg/cms/cms.go index 4156bf5..be394ba 100644 --- a/pkg/cms/cms.go +++ b/pkg/cms/cms.go @@ -85,16 +85,14 @@ func (cms *CMS) Serve(w http.ResponseWriter, r *http.Request) { } start := time.Now() - cms.m.RLock() path := r.URL.Path - page, ok := cms.pages[path] + page, ok := cms.getPage(path) if !ok { slog.Debug("Page not found", "path", path) - page, ok = cms.pages[notFoundPath] + page, ok = cms.getPage(notFoundPath) if !ok { - cms.m.RUnlock() return } @@ -102,32 +100,27 @@ func (cms *CMS) Serve(w http.ResponseWriter, r *http.Request) { } if page.Handler != "" { - handler, ok := cms.handler[page.Handler] + handler, ok := cms.getHandler(page.Handler) if !ok { slog.Error("Page handler not found", "path", path, "handler", page.Handler) w.WriteHeader(http.StatusInternalServerError) - cms.m.RUnlock() return } - cms.m.RUnlock() handler(cms, page, w, r) return } - cms.m.RUnlock() cms.RenderPage(w, r, path, &page) slog.Debug("Served page", "time_ms", time.Now().Sub(start).Milliseconds()) } // RenderPage renders given page and returns it to the client. func (cms *CMS) RenderPage(w http.ResponseWriter, r *http.Request, path string, page *Content) { - cms.m.RLock() cms.selectExperiments(w, r, page) if cms.redirectExperiment(w, r, page) { - cms.m.RUnlock() return } @@ -138,14 +131,15 @@ func (cms *CMS) RenderPage(w http.ResponseWriter, r *http.Request, path string, } if !page.DisableCache { + cms.m.RLock() data, ok := cms.pageCache[path] + cms.m.RUnlock() if ok { if _, err := w.Write(data); err != nil { slog.Debug("Error sending response", "path", path, "error", err) } - cms.m.RUnlock() return } } @@ -156,7 +150,6 @@ func (cms *CMS) RenderPage(w http.ResponseWriter, r *http.Request, path string, out, err := cms.renderContent(page, content) if err != nil { - cms.m.RUnlock() slog.Error("Error rendering template", "path", path, "error", err) return } @@ -164,7 +157,6 @@ func (cms *CMS) RenderPage(w http.ResponseWriter, r *http.Request, path string, buffer.Write(out) } - cms.m.RUnlock() data := buffer.Bytes() if _, err := w.Write(data); err != nil { @@ -180,10 +172,8 @@ func (cms *CMS) RenderPage(w http.ResponseWriter, r *http.Request, path string, // Render404 renders the 404 page if it exists. func (cms *CMS) Render404(w http.ResponseWriter, r *http.Request, path string) { - cms.m.RLock() - defer cms.m.RUnlock() slog.Debug("Page not found", "path", path) - page, ok := cms.pages[notFoundPath] + page, ok := cms.getPage(notFoundPath) w.WriteHeader(http.StatusNotFound) if ok { @@ -193,8 +183,6 @@ func (cms *CMS) Render404(w http.ResponseWriter, r *http.Request, path string) { // Render renders and returns the content for given page. func (cms *CMS) Render(page *Content, content []Content) template.HTML { - cms.m.RLock() - defer cms.m.RUnlock() out, err := cms.renderContent(page, content) if err != nil { @@ -222,7 +210,9 @@ func (cms *CMS) renderContent(page *Content, content []Content) ([]byte, error) for _, c := range content { if c.Ref != "" { + cms.m.RLock() ref, ok := cms.refs[c.Ref] + cms.m.RUnlock() if !ok { return nil, errors.New("reference not found") @@ -349,7 +339,9 @@ func (cms *CMS) selectExperiments(w http.ResponseWriter, r *http.Request, page * } if page.Analytics.Experiment.Name != "" { + cms.m.RLock() variants, ok := cms.pageExperiments[page.Analytics.Experiment.Name] + cms.m.RUnlock() if ok && len(variants) > 1 { selectedVariant, ok := selected[page.Analytics.Experiment.Name] @@ -388,6 +380,9 @@ func (cms *CMS) selectExperiments(w http.ResponseWriter, r *http.Request, page * func (cms *CMS) redirectExperiment(w http.ResponseWriter, r *http.Request, page *Content) bool { if page.SelectedPageExperiment != "" && page.Analytics.Experiment.Variant != page.SelectedPageExperiment { + cms.m.RLock() + defer cms.m.RUnlock() + for _, v := range cms.pages { if v.Analytics.Experiment.Name == page.Analytics.Experiment.Name && v.Analytics.Experiment.Variant == page.SelectedPageExperiment { redirect, ok := v.Path[page.Language] @@ -416,8 +411,6 @@ func (cms *CMS) pageView(r *http.Request, page *Content) { } func (cms *CMS) updateContent() { - cms.m.Lock() - defer cms.m.Unlock() pages := make(map[string]Content) refs := make(map[string]Content) pageExperiments := make(map[string][]string) @@ -481,6 +474,8 @@ func (cms *CMS) updateContent() { slog.Error("Error reading website content directory", "error", err) } + cms.m.Lock() + defer cms.m.Unlock() cms.pages = pages cms.refs = refs cms.pageExperiments = pageExperiments @@ -550,3 +545,17 @@ func (cms *CMS) extractExperiments(refs map[string]Content, content *Content, ex } } } + +func (cms *CMS) getPage(path string) (Content, bool) { + cms.m.RLock() + defer cms.m.RUnlock() + page, found := cms.pages[path] + return page, found +} + +func (cms *CMS) getHandler(name string) (Handler, bool) { + cms.m.RLock() + defer cms.m.RUnlock() + handler, found := cms.handler[name] + return handler, found +} diff --git a/pkg/cms/cms_test.go b/pkg/cms/cms_test.go new file mode 100644 index 0000000..534675e --- /dev/null +++ b/pkg/cms/cms_test.go @@ -0,0 +1,37 @@ +package cms + +import ( + "context" + "github.com/emvi/shifu/pkg/cfg" + "github.com/emvi/shifu/pkg/sitemap" + "github.com/emvi/shifu/pkg/source" + "net/http" + "net/http/httptest" + "sync" + "testing" +) + +func BenchmarkCMS(b *testing.B) { + cfg.Get().BaseDir = "../../demo" + c := NewCMS(Options{ + Ctx: context.Background(), + BaseDir: "../../demo", + HotReload: true, + FuncMap: defaultFuncMap, + Source: source.NewFS("../../demo", 1), + Sitemap: sitemap.New(), + }) + var wg sync.WaitGroup + wg.Add(b.N) + + for i := 0; i < b.N; i++ { + go func() { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + c.Serve(w, r) + wg.Done() + }() + } + + wg.Wait() +} diff --git a/pkg/cms/tpl_cache.go b/pkg/cms/tpl_cache.go index 208cb85..ff9d1ae 100644 --- a/pkg/cms/tpl_cache.go +++ b/pkg/cms/tpl_cache.go @@ -57,9 +57,13 @@ func (cache *Cache) Render(name string, data any) ([]byte, error) { return buffer.Bytes(), nil } -// Get returns the HTML template or loads it in case the cache is disabled or it hasn't been loaded yet. +// Get returns the HTML template or loads it in case the cache is disabled, or it hasn't been loaded yet. func (cache *Cache) Get() *template.Template { - if cache.disabled || !cache.loaded { + cache.m.RLock() + load := cache.disabled || !cache.loaded + cache.m.RUnlock() + + if load { if err := cache.loadTemplate(); err != nil { slog.Error("Error refreshing template files from directory", "error", err, "directory", cache.dir) panic(err) @@ -68,7 +72,8 @@ func (cache *Cache) Get() *template.Template { cache.m.RLock() defer cache.m.RUnlock() - return &cache.temp + t := cache.temp + return &t } func (cache *Cache) loadTemplate() error {