Skip to content

Commit

Permalink
refactor: simplify merging secondary configs
Browse files Browse the repository at this point in the history
  • Loading branch information
mrexox committed Nov 4, 2024
1 parent 723d1c6 commit 144e436
Showing 1 changed file with 43 additions and 51 deletions.
94 changes: 43 additions & 51 deletions internal/config/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,30 @@ func (err NotFoundError) Error() string {

// Loads configs from the given directory with extensions.
func Load(fs afero.Fs, repo *git.Repository) (*Config, error) {
global, err := readOne(fs, repo.RootPath, []string{"lefthook", ".lefthook"})
main, err := readOne(fs, repo.RootPath, []string{"lefthook", ".lefthook"})
if err != nil {
return nil, err
}

extends, err := mergeAll(fs, repo)
extends := main.GetStringSlice("extends")
var remote *Remote
var remotes []*Remote
err = main.UnmarshalKey("remotes", &remotes)
if err != nil {
return nil, err
}
// Deprecated
err = main.UnmarshalKey("remote", &remote)
if err != nil {
return nil, err
}

// Backward compatibility
if remote != nil {
remotes = append(remotes, remote)
}

secondary, err := readSecondary(fs, repo, extends, remotes)
if err != nil {
return nil, err
}
Expand All @@ -53,7 +71,7 @@ func Load(fs afero.Fs, repo *git.Repository) (*Config, error) {
config.SourceDir = DefaultSourceDir
config.SourceDirLocal = DefaultSourceDirLocal

err = unmarshalConfigs(global, extends, &config)
err = unmarshalConfigs(main, secondary, &config)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -104,34 +122,26 @@ func readOne(fs afero.Fs, path string, names []string) (*viper.Viper, error) {
return nil, NotFoundError{fmt.Sprintf("No config files with names %q have been found in \"%s\"", names, path)}
}

// mergeAll merges configs using the following order.
// - lefthook/.lefthook
// readSecondary reads extends, remotes and local config.
// - files from `extends`
// - files from `remotes`
// - lefthook-local/.lefthook-local.
func mergeAll(fs afero.Fs, repo *git.Repository) (*viper.Viper, error) {
extends, err := readOne(fs, repo.RootPath, []string{"lefthook", ".lefthook"})
if err != nil {
func readSecondary(fs afero.Fs, repo *git.Repository, extends []string, remotes []*Remote) (*viper.Viper, error) {
secondary := newViper(fs, repo.RootPath)
if err := extend(fs, repo.RootPath, secondary, extends); err != nil {
return nil, err
}

if err := extend(fs, extends, repo.RootPath); err != nil {
return nil, err
}

// Save global extends to compare them after merging local config
globalExtends := extends.GetStringSlice("extends")

if err := mergeRemotes(fs, repo, extends); err != nil {
if err := mergeRemotes(fs, repo, secondary, remotes); err != nil {
return nil, err
}

//nolint:nestif
if err := mergeLocal(extends); err == nil {
if err := mergeLocal(secondary); err == nil {
// Local extends need to be re-applied only if they have different settings
localExtends := extends.GetStringSlice("extends")
if !slices.Equal(globalExtends, localExtends) {
if err = extend(fs, extends, repo.RootPath); err != nil {
localExtends := secondary.GetStringSlice("extends")
if !slices.Equal(extends, localExtends) {
if err = extend(fs, repo.RootPath, secondary, localExtends); err != nil {
return nil, err
}
}
Expand All @@ -142,30 +152,11 @@ func mergeAll(fs afero.Fs, repo *git.Repository) (*viper.Viper, error) {
}
}

return extends, nil
return secondary, nil
}

// mergeRemotes merges remote configs to the current one.
func mergeRemotes(fs afero.Fs, repo *git.Repository, v *viper.Viper) error {
var remote *Remote // Deprecated
var remotes []*Remote

err := v.UnmarshalKey("remotes", &remotes)
if err != nil {
return err
}

// Deprecated
err = v.UnmarshalKey("remote", &remote)
if err != nil {
return err
}

// Backward compatibility
if remote != nil {
remotes = append(remotes, remote)
}

func mergeRemotes(fs afero.Fs, repo *git.Repository, v *viper.Viper, remotes []*Remote) error {
for _, remote := range remotes {
if !remote.Configured() {
continue
Expand All @@ -187,7 +178,7 @@ func mergeRemotes(fs afero.Fs, repo *git.Repository, v *viper.Viper) error {

log.Debugf("Merging remote config: %s: %s", remote.GitURL, configPath)

_, err = fs.Stat(configPath)
_, err := fs.Stat(configPath)
if err != nil {
continue
}
Expand All @@ -196,13 +187,14 @@ func mergeRemotes(fs afero.Fs, repo *git.Repository, v *viper.Viper) error {
return err
}

if err = extend(fs, v, filepath.Dir(configPath)); err != nil {
extends := v.GetStringSlice("extends")
if err = extend(fs, filepath.Dir(configPath), v, extends); err != nil {
return err
}
}

// Reset extends to omit issues when extending with remote extends.
err = v.MergeConfigMap(map[string]interface{}{"extends": nil})
err := v.MergeConfigMap(map[string]interface{}{"extends": nil})
if err != nil {
return err
}
Expand All @@ -212,14 +204,14 @@ func mergeRemotes(fs afero.Fs, repo *git.Repository, v *viper.Viper) error {
}

// extend merges all files listed in 'extends' option into the config.
func extend(fs afero.Fs, v *viper.Viper, root string) error {
return extendRecursive(fs, v, root, make(map[string]struct{}))
func extend(fs afero.Fs, root string, v *viper.Viper, extends []string) error {
return extendRecursive(fs, root, v, extends, make(map[string]struct{}))
}

// extendRecursive merges extends.
// If extends contain other extends they get merged too.
func extendRecursive(fs afero.Fs, v *viper.Viper, root string, extends map[string]struct{}) error {
for _, pathOrGlob := range v.GetStringSlice("extends") {
func extendRecursive(fs afero.Fs, root string, v *viper.Viper, extends []string, visited map[string]struct{}) error {
for _, pathOrGlob := range extends {
if !filepath.IsAbs(pathOrGlob) {
pathOrGlob = filepath.Join(root, pathOrGlob)
}
Expand All @@ -230,18 +222,18 @@ func extendRecursive(fs afero.Fs, v *viper.Viper, root string, extends map[strin
}

for _, path := range paths {
if _, contains := extends[path]; contains {
if _, contains := visited[path]; contains {
return fmt.Errorf("possible recursion in extends: path %s is specified multiple times", path)
}
extends[path] = struct{}{}
visited[path] = struct{}{}

extendV := newViper(fs, root)
extendV.SetConfigFile(path)
if err := extendV.ReadInConfig(); err != nil {
return err
}

if err := extendRecursive(fs, extendV, root, extends); err != nil {
if err := extendRecursive(fs, root, extendV, extendV.GetStringSlice("extends"), visited); err != nil {
return err
}

Expand Down

0 comments on commit 144e436

Please sign in to comment.