-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from DeBankDeFi/update_pkg
feat: add more utilities package for common usage
- Loading branch information
Showing
11 changed files
with
747 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
package cmdhelper | ||
|
||
import ( | ||
"fmt" | ||
"os" | ||
"reflect" | ||
"strings" | ||
|
||
"github.com/spf13/cobra" | ||
) | ||
|
||
type EnvFlags map[string]EnvFlag | ||
|
||
type EnvFlag struct { | ||
Env string | ||
Split string | ||
} | ||
|
||
func updateEnvFlags(obj interface{}, envFlags EnvFlags) { | ||
if envFlags == nil { | ||
return | ||
} | ||
|
||
t := reflect.TypeOf(obj) | ||
v := reflect.ValueOf(obj) | ||
for i := 0; i < t.NumField(); i++ { | ||
field := t.Field(i) | ||
if field.Type.Kind() == reflect.Struct { | ||
updateEnvFlags(v.Field(i).Interface(), envFlags) | ||
} | ||
|
||
tag := field.Tag | ||
name := tag.Get("name") | ||
env := tag.Get("env") | ||
split := tag.Get("split") | ||
|
||
if env != "" && name != "" { | ||
envFlags[name] = EnvFlag{ | ||
Env: env, | ||
Split: split, | ||
} | ||
} | ||
} | ||
} | ||
|
||
func ResolveEnvVariable(cmd *cobra.Command, f interface{}) { | ||
envFlags := make(map[string]EnvFlag) | ||
updateEnvFlags(f, envFlags) | ||
|
||
for flagName, env := range envFlags { | ||
flag := cmd.Flag(flagName) | ||
if flag == nil { | ||
continue | ||
} | ||
|
||
flag.Usage = fmt.Sprintf("%v [env %v]", flag.Usage, env.Env) | ||
if value := os.Getenv(env.Env); value != "" { | ||
if env.Split != "" { | ||
strings.Split(value, env.Split) | ||
|
||
for _, v := range strings.Split(value, env.Split) { | ||
flag.Value.Set(v) | ||
} | ||
} else { | ||
flag.Value.Set(value) | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
package cmdhelper | ||
|
||
import ( | ||
"fmt" | ||
"os" | ||
"reflect" | ||
"regexp" | ||
"strconv" | ||
"strings" | ||
|
||
"github.com/spf13/cobra" | ||
) | ||
|
||
type Flags []Flag | ||
|
||
type Flag struct { | ||
// e.g. foo | ||
Name string | ||
// With struct hierarchy prefix | ||
// e.g. prefix-foo | ||
FullName string | ||
|
||
// Enable env | ||
EnableEnv bool | ||
// With struct hierarchy prefix | ||
// e.g. PREFIX_FOO | ||
FullEnv string | ||
|
||
Split string | ||
Shorthand string | ||
Usage string | ||
|
||
Type string | ||
Value interface{} | ||
Pointer interface{} | ||
} | ||
|
||
func resolveFieldName(field reflect.StructField) string { | ||
if name := field.Tag.Get("name"); name != "" { | ||
return name | ||
} | ||
|
||
// Use field name | ||
return toSnake(field.Name) | ||
} | ||
|
||
var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") | ||
var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])") | ||
|
||
func toSnake(name string) string { | ||
snake := matchFirstCap.ReplaceAllString(name, "${1}-${2}") | ||
snake = matchAllCap.ReplaceAllString(snake, "${1}-${2}") | ||
return strings.ToLower(snake) | ||
} | ||
|
||
func resolveFlags(obj interface{}, flags Flags, namePrefix string) Flags { | ||
t := reflect.TypeOf(obj) | ||
if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct { | ||
t = t.Elem() | ||
} | ||
|
||
v := reflect.ValueOf(obj) | ||
if v.Kind() == reflect.Ptr && v.Elem().Kind() == reflect.Struct { | ||
v = v.Elem() | ||
} | ||
|
||
for i := 0; i < t.NumField(); i++ { | ||
field := t.Field(i) | ||
if field.Type.Kind() == reflect.Struct { | ||
name := resolveFieldName(field) | ||
flags = resolveFlags(v.Field(i).Addr().Interface(), flags, genFullName(namePrefix, name)) | ||
} | ||
|
||
tag := field.Tag | ||
name := resolveFieldName(field) | ||
env := tag.Get("enable-env") | ||
flagType := tag.Get("type") | ||
|
||
if name != "" && flagType != "" { | ||
flag := Flag{ | ||
Name: name, | ||
FullName: genFullName(namePrefix, name), | ||
Split: tag.Get("split"), | ||
Shorthand: tag.Get("shorthand"), | ||
Usage: tag.Get("usage"), | ||
Type: flagType, | ||
Value: v.Field(i).Interface(), | ||
Pointer: v.Field(i).Addr().Interface(), | ||
} | ||
|
||
// Enable env | ||
enableEnv, _ := strconv.ParseBool(env) | ||
if enableEnv { | ||
flag.EnableEnv = true | ||
flag.FullEnv = genEnv(flag.FullName) | ||
} | ||
|
||
flags = append(flags, flag) | ||
} | ||
} | ||
return flags | ||
} | ||
|
||
func genFullName(namePrefix, name string) string { | ||
var fullName string | ||
if namePrefix != "" { | ||
fullName = namePrefix + "-" + name | ||
} else { | ||
fullName = name | ||
} | ||
return fullName | ||
} | ||
|
||
// genEnv replace "-" to "_", and upper all character | ||
func genEnv(name string) string { | ||
return strings.ToUpper(strings.ReplaceAll(name, "-", "_")) | ||
} | ||
|
||
// ResolveFlagVariable register flags and env | ||
func ResolveFlagVariable(cmd *cobra.Command, f interface{}) { | ||
t := reflect.TypeOf(f) | ||
if t.Kind() != reflect.Ptr { | ||
panic("flag variable require pointer type") | ||
} | ||
|
||
var flags Flags | ||
flags = resolveFlags(f, flags, "") | ||
|
||
// Check full name conflict | ||
set := make(map[string]struct{}) | ||
for _, v := range flags { | ||
if _, ok := set[v.FullName]; ok { | ||
panic(fmt.Sprintf("flag full name conflict, %s", v.FullName)) | ||
} else { | ||
set[v.FullName] = struct{}{} | ||
} | ||
} | ||
|
||
// Register flags to cobra | ||
for _, v := range flags { | ||
switch v.Type { | ||
case "bool": | ||
cmd.PersistentFlags().BoolVarP(v.Pointer.(*bool), v.FullName, v.Shorthand, v.Value.(bool), v.Usage) | ||
case "string": | ||
cmd.PersistentFlags().StringVarP(v.Pointer.(*string), v.FullName, v.Shorthand, v.Value.(string), v.Usage) | ||
case "int": | ||
cmd.PersistentFlags().IntVarP(v.Pointer.(*int), v.FullName, v.Shorthand, v.Value.(int), v.Usage) | ||
case "string-slice": | ||
cmd.PersistentFlags().StringSliceVarP(v.Pointer.(*[]string), v.FullName, v.Shorthand, v.Value.([]string), v.Usage) | ||
case "int-slice": | ||
cmd.PersistentFlags().IntSliceVarP(v.Pointer.(*[]int), v.FullName, v.Shorthand, v.Value.([]int), v.Usage) | ||
case "string-to-string": | ||
cmd.PersistentFlags().StringToStringVarP(v.Pointer.(*map[string]string), v.FullName, v.Shorthand, v.Value.(map[string]string), v.Usage) | ||
default: | ||
panic(fmt.Sprintf("not supported flag type: %s", v.Type)) | ||
} | ||
} | ||
|
||
// Register env | ||
for _, v := range flags { | ||
if !v.EnableEnv { | ||
continue | ||
} | ||
|
||
flag := cmd.Flag(v.FullName) | ||
if flag == nil { | ||
continue | ||
} | ||
|
||
if flag.Usage == "" { | ||
flag.Usage = fmt.Sprintf("[env %v]", v.FullEnv) | ||
} else { | ||
flag.Usage = fmt.Sprintf("%v [env %v]", flag.Usage, v.FullEnv) | ||
} | ||
if value := os.Getenv(v.FullEnv); value != "" { | ||
// stringArray 和 stringSlice 可以同时从 env 和 args 里添加 | ||
if v.Split != "" { | ||
strings.Split(value, v.Split) | ||
|
||
for _, v := range strings.Split(value, v.Split) { | ||
flag.Value.Set(v) | ||
} | ||
} else { | ||
flag.Value.Set(value) | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
package cmdhelper | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
type testFlag struct { | ||
Name string `name:"name" type:"string" shorthand:"n" usage:"usage"` | ||
Array []string `name:"array" type:"string-slice" shorthand:"a" split:"," usage:"usage array"` | ||
IntArray []int `name:"int-array" type:"int-slice" shorthand:"a" usage:"usage array"` | ||
Flag2 testFlag2 `name:"prefix"` | ||
Flag3 testFlag3 | ||
} | ||
|
||
type testFlag2 struct { | ||
Foo string `name:"foo" type:"string" enable-env:"true"` | ||
Bar string `name:"bar" type:"string" enable-env:"false"` | ||
Bar2 string `name:"bar2" type:"string"` | ||
} | ||
|
||
type testFlag3 struct { | ||
Bar3 string `type:"string"` | ||
} | ||
|
||
func Test_resolveFlags(t *testing.T) { | ||
f := testFlag{ | ||
Name: "name", | ||
Array: []string{"a"}, | ||
IntArray: []int{80}, | ||
} | ||
|
||
var result Flags | ||
result = resolveFlags(&f, result, "") | ||
assert.Equal(t, Flags{ | ||
{Name: "name", FullName: "name", Shorthand: "n", Usage: "usage", Type: "string", Value: "name", Pointer: &f.Name}, | ||
{Name: "array", FullName: "array", Split: ",", Shorthand: "a", Usage: "usage array", Type: "string-slice", Value: []string{"a"}, Pointer: &f.Array}, | ||
{Name: "int-array", FullName: "int-array", Shorthand: "a", Usage: "usage array", Type: "int-slice", Value: []int{80}, Pointer: &f.IntArray}, | ||
{Name: "foo", FullName: "prefix-foo", Type: "string", EnableEnv: true, FullEnv: "PREFIX_FOO", Value: "", Pointer: &f.Flag2.Foo}, | ||
{Name: "bar", FullName: "prefix-bar", Type: "string", EnableEnv: false, Value: "", Pointer: &f.Flag2.Bar}, | ||
{Name: "bar2", FullName: "prefix-bar2", Type: "string", Value: "", Pointer: &f.Flag2.Bar2}, | ||
{Name: "bar3", FullName: "flag3-bar3", Type: "string", Value: "", Pointer: &f.Flag3.Bar3}, | ||
}, result) | ||
} | ||
|
||
func Test_toSnake(t *testing.T) { | ||
testCases := []struct { | ||
name string | ||
expected string | ||
}{ | ||
{ | ||
name: "InfluxDB", | ||
expected: "influx-db", | ||
}, | ||
{ | ||
name: "InfluxDBV2", | ||
expected: "influx-dbv2", | ||
}, | ||
{ | ||
name: "fooBarVer", | ||
expected: "foo-bar-ver", | ||
}, | ||
{ | ||
name: "EtcdEndpoints", | ||
expected: "etcd-endpoints", | ||
}, | ||
} | ||
|
||
for _, tc := range testCases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
result := toSnake(tc.name) | ||
assert.Equal(t, tc.expected, result) | ||
}) | ||
} | ||
} |
Oops, something went wrong.