Skip to content

Primary branch flag for scan create command (AST-102468) #1207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 21, 2025
16 changes: 16 additions & 0 deletions internal/commands/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ const (
ScsRepoWarningMsg = "SCS scan warning: Unable to start Scorecard scan due to missing required flags, please include in the ast-cli arguments: " +
"--scs-repo-url your_repo_url --scs-repo-token your_repo_token"
ScsScorecardUnsupportedHostWarningMsg = "SCS scan warning: Unable to run Scorecard scanner due to unsupported repo host. Currently, Scorecard can only run on GitHub Cloud repos."
BranchPrimaryPrefix = "--branch-primary="
)

var (
Expand Down Expand Up @@ -722,6 +723,11 @@ func scanCreateSubCommand(
"Enable SAST scan using light query configuration",
)

createScanCmd.PersistentFlags().Bool(
commonParams.BranchPrimaryFlag,
false,
"This flag sets the branch specified in --branch as the PRIMARY branch for the project")

createScanCmd.PersistentFlags().Bool(
commonParams.SastRecommendedExclusionsFlags,
false,
Expand Down Expand Up @@ -845,6 +851,7 @@ func setupScanTypeProjectAndConfig(
userAllowedEngines, _ := jwtWrapper.GetAllowedEngines(featureFlagsWrapper)
var info map[string]interface{}
newProjectName, _ := cmd.Flags().GetString(commonParams.ProjectName)

_ = json.Unmarshal(*input, &info)
info[resultsMapType] = getUploadType(cmd)
// Handle the project settings
Expand Down Expand Up @@ -3006,6 +3013,15 @@ func validateCreateScanFlags(cmd *cobra.Command) error {
return fmt.Errorf("Invalid value for --%s flag. Must be a valid UUID.", commonParams.IacsPresetIDFlag)
}
}
// check if flag was passed as arg
isBranchChanged := cmd.Flags().Changed(commonParams.BranchPrimaryFlag)
if isBranchChanged {
for _, a := range os.Args[1:] {
if strings.HasPrefix(a, BranchPrimaryPrefix) {
return fmt.Errorf("invalid value for --branch-primary flag. This flag is sent without any values")
}
}
}

return nil
}
Expand Down
35 changes: 31 additions & 4 deletions internal/commands/scan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ const (
SCSScoreCardError = "SCS scan failed to start: Scorecard scan is missing required flags, please include in the ast-cli arguments: " +
"--scs-repo-url your_repo_url --scs-repo-token your_repo_token"
outputFileName = "test_output.log"
noUpdatesForExistingProject = "No tags to update. Skipping project update."
noUpdatesForExistingProject = "No tags or branch to update. Skipping project update."
ScaResolverZipNotSupportedErr = "Scanning Zip files is not supported by ScaResolver.Please use non-zip source"
)

Expand Down Expand Up @@ -400,6 +400,7 @@ func TestCreateScanBranches(t *testing.T) {
// Bind cx_branch environment variable
_ = viper.BindEnv("cx_branch", "CX_BRANCH")
viper.SetDefault("cx_branch", "branch_from_environment_variable")
assert.Equal(t, viper.GetString("cx_branch"), "branch_from_environment_variable")

// Test branch from environment variable. Since the cx_branch is bind the scan must run successfully without a branch flag defined
execCmdNilAssertion(t, "scan", "create", "--project-name", "MOCK", "-s", dummyRepo)
Expand Down Expand Up @@ -637,6 +638,35 @@ func TestCreateScanResubmitWithScanTypes(t *testing.T) {
execCmdNilAssertion(t, "scan", "create", "--project-name", "MOCK", "-s", dummyRepo, "-b", "dummy_branch", "--scan-types", "sast,iac-security,sca", "--debug", "--resubmit")
}

func TestCreateScanWithPrimaryBranchFlag_Passed(t *testing.T) {
execCmdNilAssertion(t, "scan", "create", "--project-name", "MOCK", "-s", dummyRepo, "-b", "dummy_branch", "--debug", "--branch-primary")
}

func TestCreateScanWithPrimaryBranchFlagBooleanValueTrue_Failed(t *testing.T) {
original := os.Args
defer func() { os.Args = original }()
os.Args = []string{
"scan", "create", "--project-name", "MOCK", "-s", dummyRepo, "-b", "dummy_branch", "--debug", "--branch-primary=true",
}
err := execCmdNotNilAssertion(t, "scan", "create", "--project-name", "MOCK", "-s", dummyRepo, "-b", "dummy_branch", "--debug", "--branch-primary=true")
assert.ErrorContains(t, err, "invalid value for --branch-primary flag", err.Error())
}

func TestCreateScanWithPrimaryBranchFlagBooleanValueFalse_Failed(t *testing.T) {
original := os.Args
defer func() { os.Args = original }()
os.Args = []string{
"scan", "create", "--project-name", "MOCK", "-s", dummyRepo, "-b", "dummy_branch", "--debug", "--branch-primary=false",
}
err := execCmdNotNilAssertion(t, "scan", "create", "--project-name", "MOCK", "-s", dummyRepo, "-b", "dummy_branch", "--debug", "--branch-primary=false")
assert.ErrorContains(t, err, "invalid value for --branch-primary flag", err.Error())
}

func TestCreateScanWithPrimaryBranchFlagStringValue_Should_Fail(t *testing.T) {
err := execCmdNotNilAssertion(t, "scan", "create", "--project-name", "MOCK", "-s", dummyRepo, "-b", "dummy_branch", "--debug", "--branch-primary=string")
assert.ErrorContains(t, err, "invalid argument \"string\"", err.Error())
}

func Test_parseThresholdSuccess(t *testing.T) {
want := make(map[string]int)
want["iac-security-low"] = 1
Expand All @@ -645,7 +675,6 @@ func Test_parseThresholdSuccess(t *testing.T) {
t.Errorf("parseThreshold() = %v, want %v", got, want)
}
}

func Test_parseThresholdsSuccess(t *testing.T) {
want := make(map[string]int)
want["sast-high"] = 1
Expand All @@ -656,15 +685,13 @@ func Test_parseThresholdsSuccess(t *testing.T) {
t.Errorf("parseThreshold() = %v, want %v", got, want)
}
}

func Test_parseThresholdParseError(t *testing.T) {
want := make(map[string]int)
threshold := " KICS - LoW=error"
if got := parseThreshold(threshold); !reflect.DeepEqual(got, want) {
t.Errorf("parseThreshold() = %v, want %v", got, want)
}
}

func TestCreateScanProjectTags(t *testing.T) {
execCmdNilAssertion(t, scanCommand, "create", "--project-name", "MOCK", "-s", dummyRepo, "-b", "dummy_branch",
"--project-tags", "test", "--debug")
Expand Down
1 change: 1 addition & 0 deletions internal/params/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ const (
NtlmProxyDomainFlag = "proxy-ntlm-domain"
SastFastScanFlag = "sast-fast-scan"
SastLightQueriesFlag = "sast-light-queries"
BranchPrimaryFlag = "branch-primary"
SastRecommendedExclusionsFlags = "sast-recommended-exclusions"
NtlmProxyDomainFlagUsage = "Window domain when using NTLM proxy"
BaseURIFlagUsage = "The base system URI"
Expand Down
31 changes: 24 additions & 7 deletions internal/services/projects.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package services

import (
"fmt"
"slices"
"strconv"
"strings"
"time"

featureFlagsConstants "github.com/checkmarx/ast-cli/internal/constants/feature-flags"
Expand All @@ -11,6 +13,7 @@ import (
"github.com/checkmarx/ast-cli/internal/wrappers"
"github.com/pkg/errors"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)

const (
Expand All @@ -31,17 +34,19 @@ func FindProject(
applicationWrapper wrappers.ApplicationsWrapper,
featureFlagsWrapper wrappers.FeatureFlagsWrapper,
) (string, error) {
var isBranchPrimary bool
resp, err := GetProjectsCollectionByProjectName(projectName, projectsWrapper)
if err != nil {
return "", err
}

branchName := strings.TrimSpace(viper.GetString(commonParams.BranchKey))
isBranchPrimary, _ = cmd.Flags().GetBool(commonParams.BranchPrimaryFlag)
for i := 0; i < len(resp.Projects); i++ {
project := resp.Projects[i]
if project.Name == projectName {
projectTags, _ := cmd.Flags().GetString(commonParams.ProjectTagList)
projectPrivatePackage, _ := cmd.Flags().GetString(commonParams.ProjecPrivatePackageFlag)
return updateProject(&project, projectsWrapper, projectTags, projectPrivatePackage)
return updateProject(&project, projectsWrapper, projectTags, projectPrivatePackage, isBranchPrimary, branchName)
}
}

Expand All @@ -55,7 +60,7 @@ func FindProject(
}

projectID, err := createProject(projectName, cmd, projectsWrapper, groupsWrapper, accessManagementWrapper, applicationWrapper,
applicationID, projectGroups, projectPrivatePackage, featureFlagsWrapper)
applicationID, projectGroups, projectPrivatePackage, featureFlagsWrapper, isBranchPrimary, branchName)
if err != nil {
logger.PrintIfVerbose("error in creating project!")
return "", err
Expand Down Expand Up @@ -97,12 +102,18 @@ func createProject(
projectGroups string,
projectPrivatePackage string,
featureFlagsWrapper wrappers.FeatureFlagsWrapper,
isBranchPrimary bool,
branchName string,
) (string, error) {
projectTags, _ := cmd.Flags().GetString(commonParams.ProjectTagList)
applicationName, _ := cmd.Flags().GetString(commonParams.ApplicationName)
var projModel = wrappers.Project{}
projModel.Name = projectName
projModel.ApplicationIds = applicationID
if isBranchPrimary {
logger.PrintIfVerbose(fmt.Sprintf("Setting the branch in project : %s", branchName))
projModel.MainBranch = branchName
}
var groupsMap []*wrappers.Group
if projectGroups != "" {
var groups []string
Expand Down Expand Up @@ -179,14 +190,20 @@ func verifyApplicationAssociationDone(applicationName, projectID string, applica
//nolint:gocyclo
func updateProject(project *wrappers.ProjectResponseModel,
projectsWrapper wrappers.ProjectsWrapper,
projectTags string, projectPrivatePackage string) (string, error) {
projectTags string, projectPrivatePackage string, isBranchPrimary bool, branchName string) (string, error) {
var projectID string
var projModel = wrappers.Project{}
projectID = project.ID
projModel.MainBranch = project.MainBranch
if isBranchPrimary {
projModel.MainBranch = branchName
logger.PrintfIfVerbose("Updating the branch as primary: %s", branchName)
} else {
projModel.MainBranch = project.MainBranch
}
projModel.RepoURL = project.RepoURL
if projectTags == "" && projectPrivatePackage == "" {
logger.PrintIfVerbose("No tags to update. Skipping project update.")

if projectTags == "" && projectPrivatePackage == "" && isBranchPrimary == false {
logger.PrintIfVerbose("No tags or branch to update. Skipping project update.")
return projectID, nil
}
if projectPrivatePackage != "" {
Expand Down
4 changes: 2 additions & 2 deletions internal/services/projects_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func Test_createProject(t *testing.T) {
ttt.args.applicationID,
ttt.args.projectGroups,
ttt.args.projectPrivatePackage,
ttt.args.featureFlagsWrapper)
ttt.args.featureFlagsWrapper, false, "")
if (err != nil) != ttt.wantErr {
t.Errorf("createProject() error = %v, wantErr %v", err, ttt.wantErr)
return
Expand Down Expand Up @@ -240,7 +240,7 @@ func Test_updateProject(t *testing.T) {
ttt := tt
t.Run(tt.name, func(t *testing.T) {
got, err := updateProject(ttt.args.project, ttt.args.projectsWrapper,
ttt.args.projectTags, ttt.args.projectPrivatePackage)
ttt.args.projectTags, ttt.args.projectPrivatePackage, false, "")
if (err != nil) != ttt.wantErr {
t.Errorf("updateProject() error = %v, wantErr %v", err, ttt.wantErr)
return
Expand Down
10 changes: 10 additions & 0 deletions test/integration/scan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,12 @@ func TestIncrementalScan(t *testing.T) {
executeScanAssertions(t, projectIDInc, scanIDInc, map[string]string{})
}

func TestBranchPrimaryFlag(t *testing.T) {
projectName := getProjectNameForScanTests()
scanID, projectID := createScanWithPrimaryBranchFlag(t, Dir, projectName, map[string]string{})
executeScanAssertions(t, projectID, scanID, map[string]string{})
}

// Start a scan guaranteed to take considerable time, cancel it and assert the status
func TestCancelScan(t *testing.T) {
scanID, _ := createScanSastNoWait(t, SlowRepo, map[string]string{})
Expand Down Expand Up @@ -969,6 +975,10 @@ func createScanScaWithResolver(
)
}

func createScanWithPrimaryBranchFlag(t *testing.T, source string, name string, tags map[string]string) (string, string) {
return executeCreateScan(t, append(getCreateArgsWithName(source, tags, name, "sast,sca,iac-security"), "--branch-primary"))
}

func createScanIncremental(t *testing.T, source string, name string, tags map[string]string) (string, string) {
return executeCreateScan(t, append(getCreateArgsWithName(source, tags, name, "sast,sca,iac-security"), "--sast-incremental"))
}
Expand Down
Loading