Skip to content

Commit

Permalink
Refactor sql and casandra
Browse files Browse the repository at this point in the history
  • Loading branch information
minhduc140583 committed Oct 10, 2021
1 parent 630ab6a commit ce1d024
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 106 deletions.
36 changes: 15 additions & 21 deletions cassandra/authentication_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

type AuthenticationRepository struct {
UserCassandra *gocql.ClusterConfig
Session *gocql.Session
userTableName string
passwordTableName string
CheckTwoFactors func(ctx context.Context, id string) (bool, error)
Expand Down Expand Up @@ -41,17 +41,17 @@ type AuthenticationRepository struct {
TwoFactorsName string
}

func NewAuthenticationRepositoryByConfig(db *gocql.ClusterConfig, userTableName, passwordTableName string, activatedStatus string, status auth.UserStatusConfig, c auth.SchemaConfig, options ...func(context.Context, string) (bool, error)) *AuthenticationRepository {
return NewAuthenticationRepository(db, userTableName, passwordTableName, activatedStatus, status, c.Id, c.Username, c.UserId, c.SuccessTime, c.FailTime, c.FailCount, c.LockedUntilTime, c.Status, c.PasswordChangedTime, c.Password, c.Contact, c.Email, c.Phone, c.DisplayName, c.MaxPasswordAge, c.UserType, c.AccessDateFrom, c.AccessDateTo, c.AccessTimeFrom, c.AccessTimeTo, c.TwoFactors, options...)
func NewAuthenticationRepositoryByConfig(session *gocql.Session, userTableName, passwordTableName string, activatedStatus string, status auth.UserStatusConfig, c auth.SchemaConfig, options ...func(context.Context, string) (bool, error)) *AuthenticationRepository {
return NewAuthenticationRepository(session, userTableName, passwordTableName, activatedStatus, status, c.Id, c.Username, c.UserId, c.SuccessTime, c.FailTime, c.FailCount, c.LockedUntilTime, c.Status, c.PasswordChangedTime, c.Password, c.Contact, c.Email, c.Phone, c.DisplayName, c.MaxPasswordAge, c.UserType, c.AccessDateFrom, c.AccessDateTo, c.AccessTimeFrom, c.AccessTimeTo, c.TwoFactors, options...)
}

func NewAuthenticationRepository(db *gocql.ClusterConfig, userTableName, passwordTableName string, activatedStatus string, status auth.UserStatusConfig, idName, userName, userID, successTimeName, failTimeName, failCountName, lockedUntilTimeName, statusName, passwordChangedTimeName, passwordName, contactName, emailName, phoneName, displayNameName, maxPasswordAgeName, userTypeName, accessDateFromName, accessDateToName, accessTimeFromName, accessTimeToName, twoFactorsName string, options ...func(context.Context, string) (bool, error)) *AuthenticationRepository {
func NewAuthenticationRepository(session *gocql.Session, userTableName, passwordTableName string, activatedStatus string, status auth.UserStatusConfig, idName, userName, userID, successTimeName, failTimeName, failCountName, lockedUntilTimeName, statusName, passwordChangedTimeName, passwordName, contactName, emailName, phoneName, displayNameName, maxPasswordAgeName, userTypeName, accessDateFromName, accessDateToName, accessTimeFromName, accessTimeToName, twoFactorsName string, options ...func(context.Context, string) (bool, error)) *AuthenticationRepository {
var checkTwoFactors func(context.Context, string) (bool, error)
if len(options) >= 1 {
checkTwoFactors = options[0]
}
return &AuthenticationRepository{
UserCassandra: db,
Session: session,
userTableName: strings.ToLower(userTableName),
passwordTableName: strings.ToLower(passwordTableName),
CheckTwoFactors: checkTwoFactors,
Expand Down Expand Up @@ -82,11 +82,8 @@ func NewAuthenticationRepository(db *gocql.ClusterConfig, userTableName, passwor
}

func (r *AuthenticationRepository) GetUserInfo(ctx context.Context, user string) (*auth.UserInfo, error) {
session := r.Session
userInfo := auth.UserInfo{}
session, er0 := r.UserCassandra.CreateSession()
if er0 != nil {
return nil, er0
}
query := "SELECT * FROM " + r.userTableName + " WHERE " + r.UserName + " = ? ALLOW FILTERING"
raws := session.Query(query, user).Iter()
for {
Expand Down Expand Up @@ -210,6 +207,7 @@ func (r *AuthenticationRepository) PassAndActivate(ctx context.Context, userId s
}

func (r *AuthenticationRepository) passAuthenticationAndActivate(ctx context.Context, userId string, updateStatus bool) (int64, error) {
session := r.Session
if len(r.SuccessTimeName) == 0 && len(r.FailCountName) == 0 && len(r.LockedUntilTimeName) == 0 {
if !updateStatus {
return 0, nil
Expand All @@ -231,23 +229,23 @@ func (r *AuthenticationRepository) passAuthenticationAndActivate(ctx context.Con
r.IdName: userId,
}
if !updateStatus {
return patch(ctx, r.UserCassandra, r.passwordTableName, pass, query)
return patch(ctx, session, r.passwordTableName, pass, query)
}

if r.userTableName == r.passwordTableName {
pass[r.StatusName] = r.activatedStatus
return patch(ctx, r.UserCassandra, r.passwordTableName, pass, query)
return patch(ctx, session, r.passwordTableName, pass, query)
}

k1, err := patch(ctx, r.UserCassandra, r.passwordTableName, pass, query)
k1, err := patch(ctx, session, r.passwordTableName, pass, query)
if err != nil {
return k1, err
}

user := make(map[string]interface{})
user[r.IdName] = userId
user[r.StatusName] = r.activatedStatus
k2, err1 := patch(ctx, r.UserCassandra, r.userTableName, user, query)
k2, err1 := patch(ctx, session, r.userTableName, user, query)
return k1 + k2, err1
}

Expand All @@ -269,15 +267,11 @@ func (r *AuthenticationRepository) Fail(ctx context.Context, userId string, fail
query := map[string]interface{}{
r.IdName: userId,
}
_, err := patch(ctx, r.UserCassandra, r.passwordTableName, pass, query)
_, err := patch(ctx, r.Session, r.passwordTableName, pass, query)
return err
}

func patch(ctx context.Context, db *gocql.ClusterConfig, table string, model map[string]interface{}, query map[string]interface{}) (int64, error) {
session, er0 := db.CreateSession()
if er0 != nil {
return 0, er0
}
func patch(ctx context.Context, session *gocql.Session, table string, model map[string]interface{}, query map[string]interface{}) (int64, error) {
keyUpdate := ""
keyValue := ""
for k, v := range query {
Expand All @@ -296,13 +290,13 @@ func patch(ctx context.Context, db *gocql.ClusterConfig, table string, model map
if !flag {
if k == "failtime" || k == "lockeduntiltime" || k == "successtime" {
queryAddCol := "ALTER TABLE " + table + " ADD " + k + " timestamp"
er0 = session.Query(queryAddCol).Exec()
er0 := session.Query(queryAddCol).Exec()
if er0 != nil {
return 0, er0
}
} else {
queryAddCol := "ALTER TABLE " + table + " ADD " + k + " int"
er0 = session.Query(queryAddCol).Exec()
er0 := session.Query(queryAddCol).Exec()
if er0 != nil {
return 0, er0
}
Expand Down
189 changes: 122 additions & 67 deletions sql/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sql
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
Expand All @@ -18,40 +19,6 @@ const (
driverNotSupport = "no support"
)

func query(ctx context.Context, db *sql.DB, results interface{}, sql string, values ...interface{}) ([]string, error) {
rows, er1 := db.QueryContext(ctx, sql, values...)
if er1 != nil {
return nil, er1
}
defer rows.Close()
columns, er2 := rows.Columns()
if er2 != nil {
return columns, er2
}
modelType := reflect.TypeOf(results).Elem().Elem()

fieldsIndex, er3 := getColumnIndexes(modelType)
if er3 != nil {
return columns, er3
}

tb, er4 := scans(rows, modelType, fieldsIndex)
if er4 != nil {
return columns, er4
}
for _, element := range tb {
appendToArray(results, element)
}
er5 := rows.Close()
if er5 != nil {
return columns, er5
}
// Rows.Err will report the last error encountered by Rows.Scan.
if er6 := rows.Err(); er6 != nil {
return columns, er6
}
return columns, nil
}
func getColumnIndexes(modelType reflect.Type) (map[string]int, error) {
ma := make(map[string]int, 0)
if modelType.Kind() != reflect.Struct {
Expand Down Expand Up @@ -83,33 +50,110 @@ func findTag(tag string, key string) (string, bool) {
}
return "", false
}
func scans(rows *sql.Rows, modelType reflect.Type, fieldsIndex map[string]int) (t []interface{}, err error) {
func appendToArray(arr interface{}, item interface{}) interface{} {
arrValue := reflect.ValueOf(arr)
elemValue := reflect.Indirect(arrValue)

itemValue := reflect.ValueOf(item)
if itemValue.Kind() == reflect.Ptr {
itemValue = reflect.Indirect(itemValue)
}
elemValue.Set(reflect.Append(elemValue, itemValue))
return arr
}
func queryWithMap(ctx context.Context, db *sql.DB, fieldsIndex map[string]int, results interface{}, sql string, values ...interface{}) ([]string, error) {
return queryWithMapAndArray(ctx, db, fieldsIndex, results, nil, sql, values...)
}
func queryWithMapAndArray(ctx context.Context, db *sql.DB, fieldsIndex map[string]int, results interface{}, toArray func(interface{}) interface {
driver.Valuer
sql.Scanner
}, sql string, values ...interface{}) ([]string, error) {
rows, er1 := db.QueryContext(ctx, sql, values...)
if er1 != nil {
return nil, er1
}
defer rows.Close()
columns, er2 := rows.Columns()
if er2 != nil {
return columns, er2
}
modelType := reflect.TypeOf(results).Elem().Elem()
tb, er3 := scan(rows, modelType, fieldsIndex, toArray)
if er3 != nil {
return columns, er3
}
for _, element := range tb {
appendToArray(results, element)
}
er4 := rows.Close()
if er4 != nil {
return columns, er4
}
// Rows.Err will report the last error encountered by Rows.Scan.
if er5 := rows.Err(); er5 != nil {
return columns, er5
}
return columns, nil
}
func getColumns(cols []string, err error) ([]string, error) {
if cols == nil || err != nil {
return cols, err
}
c2 := make([]string, 0)
for _, c := range cols {
s := strings.ToLower(c)
c2 = append(c2, s)
}
return c2, nil
}
func scan(rows *sql.Rows, modelType reflect.Type, fieldsIndex map[string]int, options ...func(interface{}) interface {
driver.Valuer
sql.Scanner
}) (t []interface{}, err error) {
if fieldsIndex == nil {
fieldsIndex, err = getColumnIndexes(modelType)
if err != nil {
return
}
}
var toArray func(interface{}) interface {
driver.Valuer
sql.Scanner
}
if len(options) > 0 {
toArray = options[0]
}
columns, er0 := getColumns(rows.Columns())
if er0 != nil {
return nil, er0
}
for rows.Next() {
initModel := reflect.New(modelType).Interface()
r, swapValues := structScan(initModel, columns, fieldsIndex, -1)
r, swapValues := structScan(initModel, columns, fieldsIndex, toArray)
if err = rows.Scan(r...); err == nil {
swapValuesToBool(initModel, &swapValues)
t = append(t, initModel)
}
}
return
}
func getColumns(cols []string, err error) ([]string, error) {
if cols == nil || err != nil {
return cols, err
func structScan(s interface{}, columns []string, fieldsIndex map[string]int, options ...func(interface{}) interface {
driver.Valuer
sql.Scanner
}) (r []interface{}, swapValues map[int]interface{}) {
var toArray func(interface{}) interface {
driver.Valuer
sql.Scanner
}
c2 := make([]string, 0)
for _, c := range cols {
s := strings.ToLower(c)
c2 = append(c2, s)
if len(options) > 0 {
toArray = options[0]
}
return c2, nil
return structScanAndIgnore(s, columns, fieldsIndex, toArray, -1)
}
func structScan(s interface{}, columns []string, fieldsIndex map[string]int, indexIgnore int) (r []interface{}, swapValues map[int]interface{}) {
func structScanAndIgnore(s interface{}, columns []string, fieldsIndex map[string]int, toArray func(interface{}) interface {
driver.Valuer
sql.Scanner
}, indexIgnore int) (r []interface{}, swapValues map[int]interface{}) {
if s != nil {
modelType := reflect.TypeOf(s).Elem()
swapValues = make(map[int]interface{}, 0)
Expand All @@ -128,6 +172,7 @@ func structScan(s interface{}, columns []string, fieldsIndex map[string]int, ind
}
return
}

for i, columnsName := range columns {
if i == indexIgnore {
continue
Expand All @@ -152,15 +197,19 @@ func structScan(s interface{}, columns []string, fieldsIndex map[string]int, ind
modelField = modelType.Field(index)
valueField = maps.Field(index)
}
x := valueField.Addr().Interface()
tagBool := modelField.Tag.Get("true")
if tagBool == "" {
r = append(r, valueField.Addr().Interface())
if toArray != nil && valueField.Kind() == reflect.Slice {
x = toArray(x)
}
r = append(r, x)
} else {
var str string
swapValues[index] = reflect.New(reflect.TypeOf(str)).Elem().Addr().Interface()
y := reflect.New(reflect.TypeOf(str))
swapValues[index] = y.Elem().Addr().Interface()
r = append(r, swapValues[index])
}

}
}
return
Expand All @@ -169,30 +218,36 @@ func swapValuesToBool(s interface{}, swap *map[int]interface{}) {
if s != nil {
modelType := reflect.TypeOf(s).Elem()
maps := reflect.Indirect(reflect.ValueOf(s))
for index, element := range (*swap) {
var isBool bool
boolStr := modelType.Field(index).Tag.Get("true")
var dbValue = element.(*string)
isBool = *dbValue == boolStr
if maps.Field(index).Kind() == reflect.Ptr {
maps.Field(index).Set(reflect.ValueOf(&isBool))
for index, element := range *swap {
dbValue2, ok2 := element.(*bool)
if ok2 {
if maps.Field(index).Kind() == reflect.Ptr {
maps.Field(index).Set(reflect.ValueOf(dbValue2))
} else {
maps.Field(index).SetBool(*dbValue2)
}
} else {
maps.Field(index).SetBool(isBool)
dbValue, ok := element.(*string)
if ok {
var isBool bool
if *dbValue == "true" {
isBool = true
} else if *dbValue == "false" {
isBool = false
} else {
boolStr := modelType.Field(index).Tag.Get("true")
isBool = *dbValue == boolStr
}
if maps.Field(index).Kind() == reflect.Ptr {
maps.Field(index).Set(reflect.ValueOf(&isBool))
} else {
maps.Field(index).SetBool(isBool)
}
}
}
}
}
}
func appendToArray(arr interface{}, item interface{}) interface{} {
arrValue := reflect.ValueOf(arr)
elemValue := reflect.Indirect(arrValue)

itemValue := reflect.ValueOf(item)
if itemValue.Kind() == reflect.Ptr {
itemValue = reflect.Indirect(itemValue)
}
elemValue.Set(reflect.Append(elemValue, itemValue))
return arr
}

func getDriver(db *sql.DB) string {
if db == nil {
Expand Down
Loading

0 comments on commit ce1d024

Please sign in to comment.