Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 34 additions & 27 deletions flag_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,31 +83,46 @@ func (c *Command) ValidateFlagGroups() error {
return nil
}

flags := c.Flags()

// groupStatus format is the list of flags as a unique ID,
// then a map of each flag name and whether it is set or not.
groupStatus := map[string]map[string]bool{}
oneRequiredGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
flags.VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
})
statuses := c.getFlagGroupStatuses()

if err := validateRequiredFlagGroups(groupStatus); err != nil {
if err := validateRequiredFlagGroups(statuses.Required); err != nil {
return err
}
if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil {
if err := validateOneRequiredFlagGroups(statuses.OneRequired); err != nil {
return err
}
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
if err := validateExclusiveFlagGroups(statuses.MutuallyExclusive); err != nil {
return err
}
return nil
}

type flagGroupStatuses struct {
Required map[string]map[string]bool
OneRequired map[string]map[string]bool
MutuallyExclusive map[string]map[string]bool
}

// getFlagGroupStatuses collects the status of all flags belonging to any flag group.
func (c *Command) getFlagGroupStatuses() flagGroupStatuses {
flags := c.Flags()
required := map[string]map[string]bool{}
oneRequired := map[string]map[string]bool{}
mutuallyExclusive := map[string]map[string]bool{}

flags.VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, required)
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequired)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusive)
})

return flagGroupStatuses{
Required: required,
OneRequired: oneRequired,
MutuallyExclusive: mutuallyExclusive,
}
}

func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {
for _, fname := range flagnames {
f := fs.Lookup(fname)
Expand Down Expand Up @@ -227,19 +242,11 @@ func (c *Command) enforceFlagGroupsForCompletion() {
return
}

flags := c.Flags()
groupStatus := map[string]map[string]bool{}
oneRequiredGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
c.Flags().VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
})
statuses := c.getFlagGroupStatuses()

// If a flag that is part of a group is present, we make all the other flags
// of that group required so that the shell completion suggests them automatically
for flagList, flagnameAndStatus := range groupStatus {
for flagList, flagnameAndStatus := range statuses.Required {
for _, isSet := range flagnameAndStatus {
if isSet {
// One of the flags of the group is set, mark the other ones as required
Expand All @@ -252,7 +259,7 @@ func (c *Command) enforceFlagGroupsForCompletion() {

// If none of the flags of a one-required group are present, we make all the flags
// of that group required so that the shell completion suggests them automatically
for flagList, flagnameAndStatus := range oneRequiredGroupStatus {
for flagList, flagnameAndStatus := range statuses.OneRequired {
isSet := false

for _, isSet = range flagnameAndStatus {
Expand All @@ -272,7 +279,7 @@ func (c *Command) enforceFlagGroupsForCompletion() {

// If a flag that is mutually exclusive to others is present, we hide the other
// flags of that group so the shell completion does not suggest them
for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
for flagList, flagnameAndStatus := range statuses.MutuallyExclusive {
for flagName, isSet := range flagnameAndStatus {
if isSet {
// One of the flags of the mutually exclusive group is set, mark the other ones as hidden
Expand Down