diff --git a/sql/db.go b/sql/db.go new file mode 100644 index 0000000..90197f0 --- /dev/null +++ b/sql/db.go @@ -0,0 +1,241 @@ +package sql + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "strings" +) + +const ( + driverPostgres = "postgres" + driverMysql = "mysql" + driverMssql = "mssql" + driverOracle = "oracle" + driverSqlite3 = "sqlite3" + driverNotSupport = "no support" +) + +func query(ctx context.Context, db *sql.DB, results interface{}, sql string, values ...interface{}) ([]string, error) { + rows, er1 := db.QueryContext(ctx, sql, values...) + if er1 != nil { + return nil, er1 + } + defer rows.Close() + columns, er2 := rows.Columns() + if er2 != nil { + return columns, er2 + } + modelType := reflect.TypeOf(results).Elem().Elem() + + fieldsIndex, er3 := getColumnIndexes(modelType) + if er3 != nil { + return columns, er3 + } + + tb, er4 := scans(rows, modelType, fieldsIndex) + if er4 != nil { + return columns, er4 + } + for _, element := range tb { + appendToArray(results, element) + } + er5 := rows.Close() + if er5 != nil { + return columns, er5 + } + // Rows.Err will report the last error encountered by Rows.Scan. + if er6 := rows.Err(); er6 != nil { + return columns, er6 + } + return columns, nil +} +func getColumnIndexes(modelType reflect.Type) (map[string]int, error) { + ma := make(map[string]int, 0) + if modelType.Kind() != reflect.Struct { + return ma, errors.New("bad type") + } + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + ormTag := field.Tag.Get("gorm") + column, ok := findTag(ormTag, "column") + column = strings.ToLower(column) + if ok { + ma[column] = i + } + } + return ma, nil +} +func findTag(tag string, key string) (string, bool) { + if has := strings.Contains(tag, key); has { + str1 := strings.Split(tag, ";") + num := len(str1) + for i := 0; i < num; i++ { + str2 := strings.Split(str1[i], ":") + for j := 0; j < len(str2); j++ { + if str2[j] == key { + return str2[j+1], true + } + } + } + } + return "", false +} +func scans(rows *sql.Rows, modelType reflect.Type, fieldsIndex map[string]int) (t []interface{}, err error) { + columns, er0 := getColumns(rows.Columns()) + if er0 != nil { + return nil, er0 + } + for rows.Next() { + initModel := reflect.New(modelType).Interface() + r, swapValues := structScan(initModel, columns, fieldsIndex, -1) + if err = rows.Scan(r...); err == nil { + swapValuesToBool(initModel, &swapValues) + t = append(t, initModel) + } + } + return +} +func getColumns(cols []string, err error) ([]string, error) { + if cols == nil || err != nil { + return cols, err + } + c2 := make([]string, 0) + for _, c := range cols { + s := strings.ToLower(c) + c2 = append(c2, s) + } + return c2, nil +} +func structScan(s interface{}, columns []string, fieldsIndex map[string]int, indexIgnore int) (r []interface{}, swapValues map[int]interface{}) { + if s != nil { + modelType := reflect.TypeOf(s).Elem() + swapValues = make(map[int]interface{}, 0) + maps := reflect.Indirect(reflect.ValueOf(s)) + + if columns == nil { + for i := 0; i < maps.NumField(); i++ { + tagBool := modelType.Field(i).Tag.Get("true") + if tagBool == "" { + r = append(r, maps.Field(i).Addr().Interface()) + } else { + var str string + swapValues[i] = reflect.New(reflect.TypeOf(str)).Elem().Addr().Interface() + r = append(r, swapValues[i]) + } + } + return + } + for i, columnsName := range columns { + if i == indexIgnore { + continue + } + var index int + var ok bool + var modelField reflect.StructField + var valueField reflect.Value + if fieldsIndex == nil { + if modelField, ok = modelType.FieldByName(columnsName); !ok { + var t interface{} + r = append(r, &t) + continue + } + valueField = maps.FieldByName(columnsName) + } else { + if index, ok = fieldsIndex[columnsName]; !ok { + var t interface{} + r = append(r, &t) + continue + } + modelField = modelType.Field(index) + valueField = maps.Field(index) + } + tagBool := modelField.Tag.Get("true") + if tagBool == "" { + r = append(r, valueField.Addr().Interface()) + } else { + var str string + swapValues[index] = reflect.New(reflect.TypeOf(str)).Elem().Addr().Interface() + r = append(r, swapValues[index]) + } + + } + } + return +} +func swapValuesToBool(s interface{}, swap *map[int]interface{}) { + if s != nil { + modelType := reflect.TypeOf(s).Elem() + maps := reflect.Indirect(reflect.ValueOf(s)) + for index, element := range (*swap) { + var isBool bool + boolStr := modelType.Field(index).Tag.Get("true") + var dbValue = element.(*string) + isBool = *dbValue == boolStr + if maps.Field(index).Kind() == reflect.Ptr { + maps.Field(index).Set(reflect.ValueOf(&isBool)) + } else { + maps.Field(index).SetBool(isBool) + } + } + } +} +func appendToArray(arr interface{}, item interface{}) interface{} { + arrValue := reflect.ValueOf(arr) + elemValue := reflect.Indirect(arrValue) + + itemValue := reflect.ValueOf(item) + if itemValue.Kind() == reflect.Ptr { + itemValue = reflect.Indirect(itemValue) + } + elemValue.Set(reflect.Append(elemValue, itemValue)) + return arr +} + +func getDriver(db *sql.DB) string { + if db == nil { + return driverNotSupport + } + driver := reflect.TypeOf(db.Driver()).String() + switch driver { + case "*pq.Driver": + return driverPostgres + case "*godror.drv": + return driverOracle + case "*mysql.MySQLDriver": + return driverMysql + case "*mssql.Driver": + return driverMssql + case "*sqlite3.SQLiteDriver": + return driverSqlite3 + default: + return driverNotSupport + } +} +func replaceQueryArgs(driver string, query string) string { + if driver == driverOracle || driver == driverPostgres || driver == driverMssql { + var x string + if driver == driverOracle { + x = ":val" + } else if driver == driverPostgres { + x = "$" + } else if driver == driverMssql { + x = "@p" + } + i := 1 + k := strings.Index(query, "?") + if k >= 0 { + for { + query = strings.Replace(query, "?", x+fmt.Sprintf("%v", i), 1) + i = i + 1 + k := strings.Index(query, "?") + if k < 0 { + return query + } + } + } + } + return query +} diff --git a/sql/privileges_loader.go b/sql/privileges_loader.go index 69b1a5f..d2e2a6b 100644 --- a/sql/privileges_loader.go +++ b/sql/privileges_loader.go @@ -3,9 +3,7 @@ package sql import ( "context" "database/sql" - "errors" "github.com/core-go/auth" - "reflect" "strings" ) @@ -61,33 +59,11 @@ func (l PrivilegesLoader) Load(ctx context.Context, id string) ([]auth.Privilege params = append(params, id) } } - driver := l.Driver - rows, er1 := l.DB.Query(l.Query, params...) + columns, er1 := query(ctx, l.DB, &models, l.Query, params...) if er1 != nil { return p0, er1 } - defer rows.Close() - columns, er2 := rows.Columns() hasPermission := hasPermissions(columns) - if er2 != nil { - return p0, er2 - } - // get list indexes column - modelTypes := reflect.TypeOf(models).Elem() - modelType := reflect.TypeOf(auth.Module{}) - indexes, er3 := getColumnIndexes(modelType, columns, driver) - if er3 != nil { - return p0, er3 - } - tb, er4 := scanType(rows, modelTypes, indexes) - if er4 != nil { - return p0, er4 - } - for _, v := range tb { - if c, ok := v.(*auth.Module); ok { - models = append(models, *c) - } - } if hasPermission && l.Or { models = auth.OrPermissions(models) } @@ -108,93 +84,3 @@ func hasPermissions(cols []string) bool { } return false } -func scanType(rows *sql.Rows, modelTypes reflect.Type, indexes []int) (t []interface{}, err error) { - for rows.Next() { - initArray := reflect.New(modelTypes).Interface() - if err = rows.Scan(structScan(initArray, indexes)...); err == nil { - t = append(t, initArray) - } - } - return -} -func structScan(s interface{}, indexColumns []int) (r []interface{}) { - if s != nil { - maps := reflect.Indirect(reflect.ValueOf(s)) - for _, index := range indexColumns { - r = append(r, maps.Field(index).Addr().Interface()) - } - } - return -} - -func getColumnIndex(modelType reflect.Type, columnsName string, driver string) (index int, err error) { - if modelType.Kind() != reflect.Struct { - return -1, errors.New("bad type") - } - for i := 0; i < modelType.NumField(); i++ { - field := modelType.Field(i) - ormTag := field.Tag.Get("gorm") - column, ok := findTag(ormTag, "column") - if driver == driverOracle { - column = strings.ToUpper(column) - } else { - column = strings.ToLower(column) - } - if ok { - if columnsName == column { - return i, nil - } - } - } - return -1, errors.New("col " + columnsName + "not found") -} - -func getColumnIndexes(modelType reflect.Type, columnsNames []string, driver string) (indexes []int, err error) { - if modelType.Kind() != reflect.Struct { - return nil, errors.New("bad type") - } - for i := 0; i < len(columnsNames); i++ { - index, err := getColumnIndex(modelType, columnsNames[i], driver) - if err != nil{ - return nil, err - } - indexes = append(indexes, index) - } - /*for i := 0; i < modelType.NumField(); i++ { - field := modelType.Field(i) - ormTag := field.Tag.Get("gorm") - column, ok := FindTag(ormTag, "column") - if driver == DriverOracle { - column = strings.ToUpper(column) - } - if ok { - if contains(columnsNames, column) { - indexes = append(indexes, i) - } - } - }*/ - return -} -func findTag(tag string, key string) (string, bool) { - if has := strings.Contains(tag, key); has { - str1 := strings.Split(tag, ";") - num := len(str1) - for i := 0; i < num; i++ { - str2 := strings.Split(str1[i], ":") - for j := 0; j < len(str2); j++ { - if str2[j] == key { - return str2[j+1], true - } - } - } - } - return "", false -} -func contains(array []string, v string) bool { - for _, s := range array { - if s == v { - return true - } - } - return false -} diff --git a/sql/privileges_reader.go b/sql/privileges_reader.go index 36ea553..0ad19c3 100644 --- a/sql/privileges_reader.go +++ b/sql/privileges_reader.go @@ -3,19 +3,7 @@ package sql import ( "context" "database/sql" - "fmt" "github.com/core-go/auth" - "reflect" - "strings" -) - -const ( - driverPostgres = "postgres" - driverMysql = "mysql" - driverMssql = "mssql" - driverOracle = "oracle" - driverSqlite3 = "sqlite3" - driverNotSupport = "no support" ) type PrivilegesReader struct { @@ -46,31 +34,10 @@ func NewPrivilegesReader(db *sql.DB, query string, options ...bool) *PrivilegesR func (l PrivilegesReader) Privileges(ctx context.Context) ([]auth.Privilege, error) { models := make([]auth.Module, 0) p0 := make([]auth.Privilege, 0) - rows, er1 := l.DB.QueryContext(ctx, l.Query) + _, er1 := query(ctx, l.DB, &models, l.Query) if er1 != nil { return p0, er1 } - defer rows.Close() - columns, er2 := rows.Columns() - if er2 != nil { - return p0, er2 - } - // get list indexes column - modelTypes := reflect.TypeOf(models).Elem() - modelType := reflect.TypeOf(auth.Module{}) - indexes, er3 := getColumnIndexes(modelType, columns, l.Driver) - if er3 != nil { - return p0, er3 - } - tb, er4 := scanType(rows, modelTypes, indexes) - if er4 != nil { - return p0, er4 - } - for _, v := range tb { - if c, ok := v.(*auth.Module); ok { - models = append(models, *c) - } - } var p []auth.Privilege if l.NoSequence == true { p = auth.ToPrivilegesWithNoSequence(models) @@ -79,49 +46,3 @@ func (l PrivilegesReader) Privileges(ctx context.Context) ([]auth.Privilege, err } return p, nil } -func getDriver(db *sql.DB) string { - if db == nil { - return driverNotSupport - } - driver := reflect.TypeOf(db.Driver()).String() - switch driver { - case "*pq.Driver": - return driverPostgres - case "*godror.drv": - return driverOracle - case "*mysql.MySQLDriver": - return driverMysql - case "*mssql.Driver": - return driverMssql - case "*sqlite3.SQLiteDriver": - return driverSqlite3 - default: - return driverNotSupport - } -} - -func replaceQueryArgs(driver string, query string) string { - if driver == driverOracle || driver == driverPostgres || driver == driverMssql { - var x string - if driver == driverOracle { - x = ":val" - } else if driver == driverPostgres { - x = "$" - } else if driver == driverMssql { - x = "@p" - } - i := 1 - k := strings.Index(query, "?") - if k >= 0 { - for { - query = strings.Replace(query, "?", x+fmt.Sprintf("%v", i), 1) - i = i + 1 - k := strings.Index(query, "?") - if k < 0 { - return query - } - } - } - } - return query -} diff --git a/sql/user_info_service.go b/sql/user_info_service.go index 709931a..eecae7f 100644 --- a/sql/user_info_service.go +++ b/sql/user_info_service.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" a "github.com/core-go/auth" - "reflect" "time" ) @@ -47,43 +46,21 @@ func NewSqlUserInfoByConfig(db *sql.DB, c SqlConfig, options...bool) *UserInfoSe } func (l UserInfoService) GetUserInfo(ctx context.Context, auth a.AuthInfo) (*a.UserInfo, error) { models := make([]a.UserInfo, 0) - rows, er1 := l.DB.QueryContext(ctx, l.Query, auth.Username) - if er1 != nil { - switch er1 { - case sql.ErrNoRows: - return nil, nil - default: - return nil, er1 - } - } - defer rows.Close() - modelTypes := reflect.TypeOf(models).Elem() - modelType := reflect.TypeOf(a.UserInfo{}) - columns, er2 := rows.Columns() - if er2 != nil { - return nil, er2 - } - // get list indexes column - indexes, er3 := getColumnIndexes(modelType, columns, l.Driver) - if er3 != nil { - return nil, er3 + _, err := query(ctx, l.DB, &models, l.Query, auth.Username) + if err != nil { + return nil, err } - tb, er4 := scanType(rows, modelTypes, indexes) - if er4 != nil { - return nil, er4 - } - if len(tb) > 0 { - if c, ok := tb[0].(*a.UserInfo); ok { - if len(c.Status) > 0 { - if c.Status == l.SuspendedStatus { - c.Suspended = true - } - if c.Status == l.DisableStatus { - c.Disable = true - } + if len(models) > 0 { + c := models[0] + if len(c.Status) > 0 { + if c.Status == l.SuspendedStatus { + c.Suspended = true + } + if c.Status == l.DisableStatus { + c.Disable = true } - return c, nil } + return &c, nil } return nil, nil }