Skip to content

Commit

Permalink
Merge pull request #129 from gandaldf/master
Browse files Browse the repository at this point in the history
Upgrade GORM version and documentation update
  • Loading branch information
alexedwards authored Nov 29, 2021
2 parents b70d0e0 + 3fa2cbf commit 2e73121
Show file tree
Hide file tree
Showing 8 changed files with 379 additions and 133 deletions.
14 changes: 8 additions & 6 deletions gormstore/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,20 @@ import (
"log"
"net/http"

"github.com/alexedwards/scs/v2"
"github.com/alexedwards/scs/gormstore"
"github.com/jinzhu/gorm"

_ "github.com/jinzhu/gorm/dialects/postgres"
"github.com/alexedwards/scs/v2"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)

var sessionManager *scs.SessionManager

func main() {
// Establish connection to your store (PostgreSQL here).
db, err := gorm.Open("postgres", "postgres://user:pass@localhost/db")
// Establish connection to your store.
db, err := gorm.Open(postgres.Open("postgres://username:password@host/dbname", &gorm.Config{})) // PostgreSQL
//db, err := gorm.Open(sqlserver.Open("sqlserver://username:password@host?database=dbname", &gorm.Config{})) // MSSQL
//db, err := gorm.Open(mysql.Open(username:password@tcp(host)/dbname?parseTime=true", &gorm.Config{})) // MySQL
//db, err := gorm.Open(sqlite.Open("sqlite3_database.db"), &gorm.Config{})) // SQLite3
if err != nil {
log.Fatal(err)
}
Expand Down
8 changes: 7 additions & 1 deletion gormstore/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,10 @@ module github.com/alexedwards/scs/gormstore

go 1.12

require github.com/jinzhu/gorm v1.9.12
require (
gorm.io/driver/mysql v1.2.0
gorm.io/driver/postgres v1.2.2
gorm.io/driver/sqlite v1.2.6
gorm.io/driver/sqlserver v1.2.1
gorm.io/gorm v1.22.3
)
197 changes: 182 additions & 15 deletions gormstore/go.sum

Large diffs are not rendered by default.

114 changes: 67 additions & 47 deletions gormstore/gormstore.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package gormstore

import (
"errors"
"log"
"time"

"github.com/jinzhu/gorm"
"gorm.io/gorm"
)

// GORMStore represents the session store.
Expand All @@ -14,7 +15,7 @@ type GORMStore struct {
}

type session struct {
Token string `gorm:"primary_key;type:varchar(100)"`
Token string `gorm:"primary_key;type:varchar(43)"`
Data []byte
Expiry time.Time `gorm:"index"`
}
Expand All @@ -30,81 +31,100 @@ func New(db *gorm.DB) (*GORMStore, error) {
// background cleanup goroutine. Setting it to 0 prevents the cleanup goroutine
// from running (i.e. expired sessions will not be removed).
func NewWithCleanupInterval(db *gorm.DB, cleanupInterval time.Duration) (*GORMStore, error) {
p := &GORMStore{db: db}
if err := p.migrate(); err != nil {
g := &GORMStore{db: db}
if err := g.migrate(); err != nil {
return nil, err
}
if cleanupInterval > 0 {
go p.startCleanup(cleanupInterval)
go g.startCleanup(cleanupInterval)
}
return p, nil
return g, nil
}

// Find returns the data for a given session token from the PostgresStore instance.
// Find returns the data for a given session token from the GORMStore instance.
// If the session token is not found or is expired, the returned exists flag will
// be set to false.
func (p *GORMStore) Find(token string) ([]byte, bool, error) {
row := &session{}
sess := p.db.First(row, "token = ? AND expiry >= ?", token, time.Now())
if sess.RecordNotFound() {
func (g *GORMStore) Find(token string) (b []byte, exists bool, err error) {
s := &session{}
row := g.db.First(s, "token = ? AND expiry >= ?", token, time.Now())
if errors.Is(row.Error, gorm.ErrRecordNotFound) {
return nil, false, nil
} else if errs := sess.GetErrors(); len(errs) > 0 {
return nil, false, errs[0]
} else if row.Error != nil {
return nil, false, row.Error
}
if row == nil {

}
return row.Data, true, nil
return s.Data, true, nil
}

// Commit adds a session token and data to the PostgresStore instance with the
// Commit adds a session token and data to the GORMStore instance with the
// given expiry time. If the session token already exists, then the data and expiry
// time are updated.
func (p *GORMStore) Commit(token string, b []byte, expiry time.Time) error {
row := &session{}
sess := p.db.Where(session{Token: token}).Assign(session{Data: b, Expiry: expiry}).FirstOrCreate(&row)
if errs := sess.GetErrors(); len(errs) > 0 {
return errs[0]
func (g *GORMStore) Commit(token string, b []byte, expiry time.Time) error {
s := &session{}
row := g.db.Where(session{Token: token}).Assign(session{Data: b, Expiry: expiry}).FirstOrCreate(s)
if row.Error != nil {
return row.Error
}
return nil
}

// Delete removes a session token and corresponding data from the PostgresStore
// Delete removes a session token and corresponding data from the GORMStore
// instance.
func (p *GORMStore) Delete(token string) error {
sess := p.db.Delete(&session{}, "token = ?", token)
if errs := sess.GetErrors(); len(errs) > 0 {
return errs[0]
func (g *GORMStore) Delete(token string) error {
row := g.db.Delete(&session{}, "token = ?", token)
if row.Error != nil {
return row.Error
}
return nil
}

func (p *GORMStore) migrate() error {
// All returns a map containing the token and data for all active (i.e.
// not expired) sessions in the GORMStore instance.
func (g *GORMStore) All() (map[string][]byte, error) {
rows, err := g.db.Find(&[]session{}, "expiry >= ?", time.Now()).Rows()
if err != nil {
return nil, err
}
defer rows.Close()
ss := make(map[string][]byte)
for rows.Next() {
s := &session{}
err := g.db.ScanRows(rows, s)
if err != nil {
return nil, err
}
ss[s.Token] = s.Data
}
err = rows.Err()
if err != nil {
return nil, err
}
return ss, nil
}

func (g *GORMStore) migrate() error {
var tableOptions string
// Set table options for MySQL database dialect
if p.db.Dialect().GetName() == "mysql" {
// Set table options for MySQL database dialect.
if g.db.Dialector.Name() == "mysql" {
tableOptions = "ENGINE=InnoDB CHARSET=utf8mb4"
}

sess := p.db.Set("gorm:table_options", tableOptions).
AutoMigrate(&session{})
if errs := sess.GetErrors(); len(errs) > 0 {
return errs[0]
err := g.db.Set("gorm:table_options", tableOptions).AutoMigrate(&session{})
if err != nil {
return err
}
return nil
}

func (p *GORMStore) startCleanup(interval time.Duration) {
p.stopCleanup = make(chan bool)
func (g *GORMStore) startCleanup(interval time.Duration) {
g.stopCleanup = make(chan bool)
ticker := time.NewTicker(interval)
for {
select {
case <-ticker.C:
err := p.deleteExpired()
err := g.deleteExpired()
if err != nil {
log.Println(err)
}
case <-p.stopCleanup:
case <-g.stopCleanup:
ticker.Stop()
return
}
Expand All @@ -121,16 +141,16 @@ func (p *GORMStore) startCleanup(interval time.Duration) {
// scenario, the cleanup goroutine (which will run forever) will prevent the
// GORMStore object from being garbage collected even after the test function
// has finished. You can prevent this by manually calling StopCleanup.
func (p *GORMStore) StopCleanup() {
if p.stopCleanup != nil {
p.stopCleanup <- true
func (g *GORMStore) StopCleanup() {
if g.stopCleanup != nil {
g.stopCleanup <- true
}
}

func (p *GORMStore) deleteExpired() error {
sess := p.db.Delete(&session{}, "expiry < ?", time.Now())
if errs := sess.GetErrors(); len(errs) > 0 {
return errs[0]
func (g *GORMStore) deleteExpired() error {
row := g.db.Delete(&session{}, "expiry < ?", time.Now())
if row.Error != nil {
return row.Error
}
return nil
}
Loading

0 comments on commit 2e73121

Please sign in to comment.