Skip to content

Commit

Permalink
Merge pull request #55 from chrisjohnson/command-line-flags
Browse files Browse the repository at this point in the history
Command line flags
  • Loading branch information
chrisjohnson authored Feb 20, 2020
2 parents bdc1f91 + 5e5c3b9 commit e963dff
Show file tree
Hide file tree
Showing 13 changed files with 116 additions and 63 deletions.
2 changes: 1 addition & 1 deletion authconfig/azureconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func Environment() *azure.Environment {
if err != nil {
// TODO: move to initialization of var
panic(fmt.Sprintf(
"invalid cloud name '%s' specified, cannot continue\n", cloudName))
"invalid cloud name '%s' specified, cannot continue", cloudName))
}
environment = &env
return environment
Expand Down
12 changes: 6 additions & 6 deletions authconfig/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,38 +26,38 @@ func ParseEnvironment() error {
var err error
useDeviceFlow, err = strconv.ParseBool(envy.Get("AZURE_USE_DEVICEFLOW", "0"))
if err != nil {
log.Printf("invalid value specified for AZURE_USE_DEVICEFLOW, disabling\n")
log.Printf("invalid value specified for AZURE_USE_DEVICEFLOW, disabling")
useDeviceFlow = false
}
keepResources, err = strconv.ParseBool(envy.Get("AZURE_SAMPLES_KEEP_RESOURCES", "0"))
if err != nil {
log.Printf("invalid value specified for AZURE_SAMPLES_KEEP_RESOURCES, discarding\n")
log.Printf("invalid value specified for AZURE_SAMPLES_KEEP_RESOURCES, discarding")
keepResources = false
}

// these must be provided by environment
// clientID
clientID, err = envy.MustGet("AZURE_CLIENT_ID")
if err != nil {
return fmt.Errorf("expected env vars not provided: %s\n", err)
return fmt.Errorf("expected env vars not provided: %s", err)
}

// clientSecret
clientSecret, err = envy.MustGet("AZURE_CLIENT_SECRET")
if err != nil && useDeviceFlow != true { // don't need a secret for device flow
return fmt.Errorf("expected env vars not provided: %s\n", err)
return fmt.Errorf("expected env vars not provided: %s", err)
}

// tenantID (AAD)
tenantID, err = envy.MustGet("AZURE_TENANT_ID")
if err != nil {
return fmt.Errorf("expected env vars not provided: %s\n", err)
return fmt.Errorf("expected env vars not provided: %s", err)
}

// subscriptionID (ARM)
subscriptionID, err = envy.MustGet("AZURE_SUBSCRIPTION_ID")
if err != nil {
return fmt.Errorf("expected env vars not provided: %s\n", err)
return fmt.Errorf("expected env vars not provided: %s", err)
}

return nil
Expand Down
14 changes: 7 additions & 7 deletions authconfig/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ import (

// AddFlags adds flags applicable to all services.
// Remember to call `flag.Parse()` in your main or TestMain.
func AddFlags() error {
flag.StringVar(&subscriptionID, "subscription", subscriptionID, "Subscription for tests.")
flag.StringVar(&locationDefault, "location", locationDefault, "Default location for tests.")
flag.StringVar(&cloudName, "cloud", cloudName, "Name of Azure cloud.")
flag.StringVar(&baseGroupName, "baseGroupName", BaseGroupName(), "Specify prefix name of resource group for sample resources.")
func AddFlags(fs flag.FlagSet) error {
fs.StringVar(&subscriptionID, "subscription", subscriptionID, "Subscription for tests.")
fs.StringVar(&locationDefault, "location", locationDefault, "Default location for tests.")
fs.StringVar(&cloudName, "cloud", cloudName, "Name of Azure cloud.")
fs.StringVar(&baseGroupName, "baseGroupName", BaseGroupName(), "Specify prefix name of resource group for sample resources.")

flag.BoolVar(&useDeviceFlow, "useDeviceFlow", useDeviceFlow, "Use device-flow grant type rather than client credentials.")
flag.BoolVar(&keepResources, "keepResources", keepResources, "Keep resources created by samples.")
fs.BoolVar(&useDeviceFlow, "useDeviceFlow", useDeviceFlow, "Use device-flow grant type rather than client credentials.")
fs.BoolVar(&keepResources, "keepResources", keepResources, "Keep resources created by samples.")

return nil
}
12 changes: 6 additions & 6 deletions certs/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func newClient() keyvault.BaseClient {
client := keyvault.New()
a, err := iam.GetKeyvaultAuthorizer()
if err != nil {
panic(fmt.Sprintf("Error authorizing: %v\n", err.Error()))
panic(fmt.Sprintf("Error authorizing: %v", err.Error()))
}
client.Authorizer = a
client.AddToUserAgent(authconfig.UserAgent())
Expand All @@ -34,7 +34,7 @@ func newClient() keyvault.BaseClient {
func GetCert(vaultBaseURL string, certName string, certVersion string) (Cert, error) {
cert, err := newClient().GetCertificate(context.Background(), vaultBaseURL, certName, certVersion)
if err != nil {
log.Printf("Error getting cert: %v\n", err.Error())
log.Printf("Error getting cert: %v", err.Error())
return Cert{}, err
}

Expand All @@ -44,7 +44,7 @@ func GetCert(vaultBaseURL string, certName string, certVersion string) (Cert, er
func GetCertByURL(certURL string) (Cert, error) {
u, err := url.Parse(certURL)
if err != nil {
log.Printf("Failed to parse URL for cert: %v\n", err.Error())
log.Printf("Failed to parse URL for cert: %v", err.Error())
return Cert{}, err
}
vaultBaseURL := fmt.Sprintf("%v://%v", u.Scheme, u.Host)
Expand All @@ -55,7 +55,7 @@ func GetCertByURL(certURL string) (Cert, error) {

result, err := GetCert(vaultBaseURL, certName, "")
if err != nil {
log.Printf("Failed to get cert from parsed values %v and %v: %v\n", vaultBaseURL, certName, err.Error())
log.Printf("Failed to get cert from parsed values %v and %v: %v", vaultBaseURL, certName, err.Error())
return Cert{}, err
}

Expand All @@ -66,7 +66,7 @@ func GetCerts(vaultBaseURL string) (results []Cert, err error) {
max := int32(25)
pages, err := newClient().GetCertificates(context.Background(), vaultBaseURL, &max)
if err != nil {
log.Printf("Error getting cert: %v\n", err.Error())
log.Printf("Error getting cert: %v", err.Error())
return nil, err
}

Expand All @@ -75,7 +75,7 @@ func GetCerts(vaultBaseURL string) (results []Cert, err error) {
certURL := *value.ID
cert, err := GetCertByURL(certURL)
if err != nil {
log.Printf("Error loading cert contents: %v\n", err.Error())
log.Printf("Error loading cert contents: %v", err.Error())
return nil, err
}

Expand Down
4 changes: 2 additions & 2 deletions configparser/configparser.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package configparser

import (
"fmt"
"io/ioutil"
log "github.com/sirupsen/logrus"
"io/ioutil"
"strconv"
"time"

Expand Down Expand Up @@ -42,7 +42,7 @@ func validateWorkerConfigs(workerConfigs []config.WorkerConfig) {
for i, workerConfig := range workerConfigs {
err := validate.Struct(workerConfig)
if err != nil {
panic(fmt.Sprintf("Error parsing worker config: %v\n", err))
panic(fmt.Sprintf("Error parsing worker config: %v", err))
}

// Convert human readable time and save into TimeFrequency
Expand Down
8 changes: 4 additions & 4 deletions configwatcher/configwatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
func Watcher(path string) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
panic(fmt.Sprintf("Error establishing file watcher: %v\n", err))
panic(fmt.Sprintf("Error establishing file watcher: %v", err))
}

// If something goes wrong along the way, close the watcher
Expand All @@ -30,7 +30,7 @@ func Watcher(path string) {

err = watcher.Add(filepath.Dir(path))
if err != nil {
panic(fmt.Sprintf("Error watching path %v: %v\n", path, err))
panic(fmt.Sprintf("Error watching path %v: %v", path, err))
}
<-done // Block until done
}
Expand Down Expand Up @@ -62,7 +62,7 @@ func doWatch(watcher *fsnotify.Watcher, cancel context.CancelFunc, path string)
}

if (event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create) && event.Name == path {
log.Printf("Config watcher noticed a change to %v\n", event.Name)
log.Printf("Config watcher noticed a change to %v", event.Name)
// Kill workers
cancel()
// Start new workers
Expand All @@ -72,7 +72,7 @@ func doWatch(watcher *fsnotify.Watcher, cancel context.CancelFunc, path string)
if !ok {
continue
}
log.Printf("Config watcher encountered an error for %v: %v\n", path, err)
log.Printf("Config watcher encountered an error for %v: %v", path, err)
return
}
}
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ require (
github.com/huandu/xstrings v1.3.0 // indirect
github.com/imdario/mergo v0.3.8 // indirect
github.com/jpillora/backoff v1.0.0
github.com/luci/luci-go v0.0.0-20200220034857-6a27eb3e318d
github.com/marstr/randname v0.0.0-20181206212954-d5b0f288ab8c
github.com/mitchellh/copystructure v1.0.0 // indirect
github.com/sirupsen/logrus v1.4.2
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/luci/luci-go v0.0.0-20200220034857-6a27eb3e318d h1:5ODLMo1u+AXIjn88uA577TQ+ODf1/uyAZVI+WLMtgLM=
github.com/luci/luci-go v0.0.0-20200220034857-6a27eb3e318d/go.mod h1:k3ApF9lUbij/xWd3mIeo0MlMVJ4YBamc9AkaqV2lcnI=
github.com/marstr/collection v1.0.1 h1:j61osRfyny7zxBlLRtoCvOZ2VX7HEyybkZcsLNLJ0z0=
github.com/marstr/collection v1.0.1/go.mod h1:HHDXVxjLO3UYCBXJWY+J/ZrxCUOYqrO66ob1AzIsmYA=
github.com/marstr/randname v0.0.0-20181206212954-d5b0f288ab8c h1:JE+MDz5rhFN5EC9Dj/N8dLYKboTWm6FXeWhnyKVj0vA=
Expand Down
12 changes: 6 additions & 6 deletions keys/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func newClient() keyvault.BaseClient {
client := keyvault.New()
a, err := iam.GetKeyvaultAuthorizer()
if err != nil {
panic(fmt.Sprintf("Error authorizing: %v\n", err.Error()))
panic(fmt.Sprintf("Error authorizing: %v", err.Error()))
}
client.Authorizer = a
client.AddToUserAgent(authconfig.UserAgent())
Expand All @@ -50,7 +50,7 @@ func newClient() keyvault.BaseClient {
func GetKey(vaultBaseURL string, keyName string, keyVersion string) (Key, error) {
key, err := newClient().GetKey(context.Background(), vaultBaseURL, keyName, keyVersion)
if err != nil {
log.Printf("Error getting key: %v\n", err.Error())
log.Printf("Error getting key: %v", err.Error())
return Key{}, err
}

Expand All @@ -62,7 +62,7 @@ func GetKey(vaultBaseURL string, keyName string, keyVersion string) (Key, error)
func GetKeyByURL(keyURL string) (Key, error) {
u, err := url.Parse(keyURL)
if err != nil {
log.Printf("Failed to parse URL for key: %v\n", err.Error())
log.Printf("Failed to parse URL for key: %v", err.Error())
return Key{}, err
}
vaultBaseURL := fmt.Sprintf("%v://%v", u.Scheme, u.Host)
Expand All @@ -73,7 +73,7 @@ func GetKeyByURL(keyURL string) (Key, error) {

result, err := GetKey(vaultBaseURL, keyName, "")
if err != nil {
log.Printf("Failed to get key from parsed values %v and %v: %v\n", vaultBaseURL, keyName, err.Error())
log.Printf("Failed to get key from parsed values %v and %v: %v", vaultBaseURL, keyName, err.Error())
return Key{}, err
}

Expand All @@ -84,7 +84,7 @@ func GetKeys(vaultBaseURL string) (results []Key, err error) {
max := int32(25)
pages, err := newClient().GetKeys(context.Background(), vaultBaseURL, &max)
if err != nil {
log.Printf("Error getting key: %v\n", err.Error())
log.Printf("Error getting key: %v", err.Error())
return nil, err
}

Expand All @@ -93,7 +93,7 @@ func GetKeys(vaultBaseURL string) (results []Key, err error) {
keyURL := *value.Kid
key, err := GetKeyByURL(keyURL)
if err != nil {
log.Printf("Error loading key contents: %v\n", err.Error())
log.Printf("Error loading key contents: %v", err.Error())
return nil, err
}

Expand Down
75 changes: 63 additions & 12 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,82 @@ package main

import (
"flag"
"fmt"
"github.com/chrisjohnson/azure-key-vault-agent/authconfig"
"github.com/chrisjohnson/azure-key-vault-agent/configwatcher"
"github.com/luci/luci-go/common/flag/flagenum"
log "github.com/sirupsen/logrus"
"os"
)

type outputType uint

var outputTypeEnum = flagenum.Enum{
"json": outputType(10),
"text": outputType(20),
}

func (val *outputType) Set(v string) error {
return outputTypeEnum.FlagSet(val, v)
}

func (val *outputType) String() string {
return outputTypeEnum.FlagString(*val)
}

func (val outputType) MarshalJSON() ([]byte, error) {
return outputTypeEnum.JSONMarshal(val)
}

var configFile string
var output outputType
var help bool

func init() {
// JSON Format customized to use _timestamp so it marshals first alphabetically
log.SetFormatter(&log.JSONFormatter{
FieldMap: log.FieldMap{
log.FieldKeyTime: "_timestamp",
},
})
fs := flag.NewFlagSet(os.Args[0], flag.ExitOnError)
fs.SetOutput(os.Stdout)

fs.BoolVar(&help, "help", false, "Show this help text")
fs.StringVar(&configFile, "config", "", "Read config from this `file`")
fs.StringVar(&configFile, "c", "", "Read config from this `file` (shorthand)")
fs.Var(&output, "output", fmt.Sprintf("Output type (default json). Options are: %v (default json)", outputTypeEnum.Choices()))

fs.Parse(os.Args[1:])

if help {
fs.PrintDefaults()
os.Exit(0)
}

if configFile == "" {
log.Fatalf("Missing --config/-c")
}

if output != outputTypeEnum["text"] {
// JSON Format customized to use _timestamp so it marshals first alphabetically
log.SetFormatter(&log.JSONFormatter{
FieldMap: log.FieldMap{
log.FieldKeyTime: "_timestamp",
},
})
} else {
log.SetFormatter(&log.TextFormatter{
FullTimestamp: true,
})
}

var err error
err = authconfig.ParseEnvironment()
if err != nil {
log.Fatalf("failed to parse env: %v\n", err.Error())
log.Fatalf("AuthConfig: Failed to parse env: %v", err.Error())
}

err = authconfig.AddFlags()
err = authconfig.AddFlags(*fs)
if err != nil {
log.Fatalf("failed to parse flags: %v\n", err.Error())
log.Fatalf("AuthConfig: Failed to add flags: %v", err.Error())
}
flag.Parse()

fs.Parse(os.Args[1:])
}

func main() {
Expand All @@ -35,6 +87,5 @@ func main() {
}
}()

configwatcher.Watcher("akva.yaml")
configwatcher.Watcher(configFile)
}

Loading

0 comments on commit e963dff

Please sign in to comment.