package options import ( "errors" "fmt" "os" "reflect" "regexp" "strings" "github.com/a8m/envsubst" "github.com/go-viper/mapstructure/v2" "github.com/spf13/pflag" "github.com/spf13/viper" "gopkg.in/yaml.v3" ) // Load reads in the config file at the path given, then merges in environment // variables (prefixed with `OAUTH2_PROXY`) and finally merges in flags from the flagSet. // If a config value is unset and the flag has a non-zero value default, this default will be used. // Eg. A field defined: // // FooBar `cfg:"foo_bar" flag:"foo-bar"` // // Can be set in the config file as `foo_bar="baz"`, in the environment as `OAUTH2_PROXY_FOO_BAR=baz`, // or via the command line flag `--foo-bar=baz`. func Load(configFileName string, flagSet *pflag.FlagSet, into interface{}) error { v := viper.New() v.SetConfigFile(configFileName) v.SetConfigType("toml") // Config is in toml format v.SetEnvPrefix("OAUTH2_PROXY") v.AutomaticEnv() v.SetTypeByDefaultValue(true) if configFileName != "" { err := v.ReadInConfig() if err != nil { return fmt.Errorf("unable to load config file: %w", err) } } err := registerFlags(v, "", flagSet, into) if err != nil { // This should only happen if there is a programming error return fmt.Errorf("unable to register flags: %w", err) } // UnmarshalExact will return an error if the config includes options that are // not mapped to fields of the into struct err = v.UnmarshalExact(into, decodeFromCfgTag) if err != nil { return fmt.Errorf("error unmarshalling config: %w", err) } return nil } // LoadYAML will load a YAML based configuration file into the options interface provided. func LoadYAML(configFileName string, opts interface{}) error { buffer, err := loadAndSubstituteEnvs(configFileName) if err != nil { return err } // Generic interface for loading arbitrary yaml structure var intermediate map[string]interface{} if err := yaml.Unmarshal(buffer, &intermediate); err != nil { return fmt.Errorf("error unmarshalling config: %w", err) } // Using mapstructure to decode arbitrary yaml structure into options and // merge with existing values instead of overwriting everything. This is especially // important as we have a lot of default values for boolean which are supposed to be // true by default. Normally by just parsing through yaml all booleans that aren't // referenced in the config file would be parsed as false and we cannot identify after // the fact if they have been explicitly set to false or have not been referenced. return Decode(intermediate, opts) } // Decode processes an input map and decodes it into a given struct while preserving default values. // It ensures proper conversion of duration values from strings, floats, and int64 into time.Duration. // // Parameters: // - input: A map[string]interface{} representing the input data. // - result: A pointer to a struct where the decoded values will be stored. // // Returns: // - An error if decoding fails or if there are unmapped keys. func Decode(input interface{}, result interface{}) error { decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ DecodeHook: mapstructure.ComposeDecodeHookFunc( toDurationHookFunc(), stringToBytesHookFunc(), ), Metadata: nil, // Don't track any metadata Result: result, // Decode the result into the prefilled options TagName: "yaml", // Parse all fields that use the json tag ZeroFields: false, // Don't clean the default values from the result map (options) ErrorUnused: true, // Throw an error if keys have been used that aren't mapped to any struct fields IgnoreUntaggedFields: true, // Ignore fields in structures that aren't tagged with json }) if err != nil { return fmt.Errorf("error creating decoder for config: %w", err) } if err := decoder.Decode(input); err != nil { return fmt.Errorf("error decoding config: %w", err) } return nil } // loadAndSubstituteEnvs reads the yaml config into a generic byte buffer and // substitute env references func loadAndSubstituteEnvs(configFileName string) ([]byte, error) { if configFileName == "" { return nil, errors.New("no configuration file provided") } unparsedBuffer, err := os.ReadFile(configFileName) if err != nil { return nil, fmt.Errorf("unable to load config file: %w", err) } modifiedBuffer, err := normalizeSubstitution(unparsedBuffer) if err != nil { return nil, fmt.Errorf("error normalizing substitution string : %w", err) } buffer, err := envsubst.Bytes(modifiedBuffer) if err != nil { return nil, fmt.Errorf("error in substituting env variables : %w", err) } return buffer, nil } // registerFlags uses `cfg` and `flag` tags to associate flags in the flagSet // to the fields in the options interface provided. // Each exported field in the options must have a `cfg` tag otherwise an error will occur. // - For fields, set `cfg` and `flag` so that `flag` is the name of the flag associated to this config option // - For exported fields that are not user facing, set the `cfg` to `,internal` // - For structs containing user facing fields, set the `cfg` to `,squash` func registerFlags(v *viper.Viper, prefix string, flagSet *pflag.FlagSet, options interface{}) error { val := reflect.ValueOf(options) var typ reflect.Type if val.Kind() == reflect.Ptr { typ = val.Elem().Type() } else { typ = val.Type() } for i := 0; i < typ.NumField(); i++ { // pull out the struct tags: // flag - the name of the command line flag // cfg - the name of the config file option field := typ.Field(i) fieldV := reflect.Indirect(val).Field(i) fieldName := strings.Join([]string{prefix, field.Name}, ".") cfgName := field.Tag.Get("cfg") if cfgName == ",internal" { // Public but internal types that should not be exposed to users, skip them continue } if isUnexported(field.Name) { // Unexported fields cannot be set by a user, so won't have tags or flags, skip them continue } if field.Type.Kind() == reflect.Struct { if cfgName != ",squash" { return fmt.Errorf("field %q does not have required cfg tag: `,squash`", fieldName) } err := registerFlags(v, fieldName, flagSet, fieldV.Interface()) if err != nil { return err } continue } flagName := field.Tag.Get("flag") if flagName == "" || cfgName == "" { return fmt.Errorf("field %q does not have required tags (cfg, flag)", fieldName) } if flagSet == nil { return fmt.Errorf("flagset cannot be nil") } f := flagSet.Lookup(flagName) if f == nil { return fmt.Errorf("field %q does not have a registered flag", flagName) } err := v.BindPFlag(cfgName, f) if err != nil { return fmt.Errorf("error binding flag for field %q: %w", fieldName, err) } } return nil } // decodeFromCfgTag sets the Viper decoder to read the names from the `cfg` tag // on each struct entry. func decodeFromCfgTag(c *mapstructure.DecoderConfig) { c.TagName = "cfg" } // isUnexported checks if a field name starts with a lowercase letter and therefore // if it is unexported. func isUnexported(name string) bool { if len(name) == 0 { // This should never happen panic("field name has len 0") } first := string(name[0]) return first == strings.ToLower(first) } // normalizeSubstitution normalizes dollar signs ($) with numerals like // $1 or $2 properly by correctly escaping them func normalizeSubstitution(unparsedBuffer []byte) ([]byte, error) { unparsedString := string(unparsedBuffer) regexPattern := regexp.MustCompile(`\$(\d+)`) substitutedString := regexPattern.ReplaceAllString(unparsedString, `$$$$1`) return []byte(substitutedString), nil }