Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ go 1.24

require (
github.com/Azure/azure-extension-foundation v0.0.0-20250620154556-caff9e3c3c5c
github.com/Azure/azure-extension-platform v0.0.0-20250107200156-aa20f765d49f
github.com/Azure/azure-extension-platform v0.0.0-20260107210613-2a62cc200c34
github.com/Azure/azure-sdk-for-go v63.2.0+incompatible
github.com/ahmetalpbalkan/go-httpbin v0.0.0-20160706084156-8817b883dae1
github.com/go-kit/kit v0.12.0
Expand Down Expand Up @@ -35,6 +35,7 @@ require (
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/sys v0.15.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
Expand Down
3 changes: 3 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ github.com/Azure/azure-extension-foundation v0.0.0-20250620154556-caff9e3c3c5c h
github.com/Azure/azure-extension-foundation v0.0.0-20250620154556-caff9e3c3c5c/go.mod h1:sNC6lMTUkXwjrQ+nttr6GXhDfvSGT7t3UDq30BEYzu8=
github.com/Azure/azure-extension-platform v0.0.0-20250107200156-aa20f765d49f h1:ddsUz/suc9txCMz/xWOslqNMvzhbWFMTflUrbcMNoSw=
github.com/Azure/azure-extension-platform v0.0.0-20250107200156-aa20f765d49f/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY=
github.com/Azure/azure-extension-platform v0.0.0-20260107210613-2a62cc200c34 h1:7bEC4DJC4w0gx7SBy7M7Q2qi6ckmHcnnlFJzo+X/gi4=
github.com/Azure/azure-extension-platform v0.0.0-20260107210613-2a62cc200c34/go.mod h1:0458BvQsi5ch6kn+KZtI5m88Z3L9UFXdoY1+6nKdivY=
github.com/Azure/azure-sdk-for-go v63.2.0+incompatible h1:OIqkK/zTGqVUuzpEvY0B1YSYDRAFC/j+y0w2GovCggI=
github.com/Azure/azure-sdk-for-go v63.2.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc=
github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs=
Expand Down Expand Up @@ -91,6 +93,7 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
Expand Down
55 changes: 30 additions & 25 deletions main/cmds.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ import (
"strconv"
"time"

utils "github.com/Azure/azure-extension-platform/pkg/utils"
"github.com/Azure/azure-extension-platform/pkg/utils"
vmextension "github.com/Azure/azure-extension-platform/vmextension"
"github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/Azure/custom-script-extension-linux/pkg/seqnum"
"github.com/go-kit/kit/log"
"github.com/pkg/errors"
Expand All @@ -22,7 +24,7 @@ const (
maxScriptSize = 256 * 1024
)

type cmdFunc func(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) (msg string, err error)
type cmdFunc func(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) (msg string, ewc *vmextension.ErrorWithClarification)
type preFunc func(ctx *log.Context, hEnv HandlerEnvironment, seqNum int) error

type cmd struct {
Expand Down Expand Up @@ -55,14 +57,14 @@ var (
}
)

func noop(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) {
func noop(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, *vmextension.ErrorWithClarification) {
ctx.Log("event", "noop")
return "", nil
}

func install(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) {
func install(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, *vmextension.ErrorWithClarification) {
if err := os.MkdirAll(dataDir, 0755); err != nil {
return "", errors.Wrap(err, "failed to create data dir")
return "", vmextension.NewErrorWithClarificationPtr(errorutil.SystemError, errors.Wrap(err, "failed to create data dir"))
}

// If the file mrseq does not exists it is for two possible reasons.
Expand All @@ -77,12 +79,12 @@ func install(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error)
return "", nil
}

func uninstall(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) {
func uninstall(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, *vmextension.ErrorWithClarification) {
{ // a new context scope with path
ctx = ctx.With("path", dataDir)
ctx.Log("event", "removing data dir", "path", dataDir)
if err := os.RemoveAll(dataDir); err != nil {
return "", errors.Wrap(err, "failed to delete data directory")
return "", vmextension.NewErrorWithClarificationPtr(errorutil.Os_FailedToDeleteDataDir, errors.Wrap(err, "failed to delete data directory"))
}
ctx.Log("event", "removed data dir")
}
Expand Down Expand Up @@ -110,16 +112,18 @@ func min(a, b int) int {
return b
}

func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, error) {
func enable(ctx *log.Context, h HandlerEnvironment, seqNum int) (string, *vmextension.ErrorWithClarification) {
// parse the extension handler settings (not available prior to 'enable')
cfg, err := parseAndValidateSettings(ctx, h.HandlerEnvironment.ConfigFolder, seqNum)
if err != nil {
return "", errors.Wrap(err, "failed to get configuration")
cfg, ewc := parseAndValidateSettings(ctx, h.HandlerEnvironment.ConfigFolder, seqNum)
if ewc != nil {
ewc.Err = errors.Wrap(ewc.Err, "failed to get configuration")
return "", ewc
}

dir := filepath.Join(dataDir, downloadDir, fmt.Sprintf("%d", seqNum))
if err := downloadFiles(ctx, dir, cfg); err != nil {
return "", errors.Wrap(err, "processing file downloads failed")
if ewc := downloadFiles(ctx, dir, cfg); ewc != nil {
ewc.Err = errors.Wrap(ewc.Err, "processing file downloads failed")
return "", ewc
}

// execute the command, save its error
Expand Down Expand Up @@ -175,12 +179,12 @@ func checkAndSaveSeqNum(ctx log.Logger, seq int, mrseqPath string) (shouldExit b

// downloadFiles downloads the files specified in cfg into dir (creates if does
// not exist) and takes storage credentials specified in cfg into account.
func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) error {
func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) *vmextension.ErrorWithClarification {
// - prepare the output directory for files and the command output
// - create the directory if missing
ctx.Log("event", "creating output directory", "path", dir)
if err := os.MkdirAll(dir, 0700); err != nil {
return errors.Wrap(err, "failed to prepare output directory")
return vmextension.NewErrorWithClarificationPtr(errorutil.FileDownload_unableToCreateDownloadDirectory, errors.Wrap(err, "failed to prepare output directory"))
}
ctx.Log("event", "created output directory")

Expand All @@ -200,21 +204,22 @@ func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) error {
for i, f := range cfg.fileUrls() {
ctx := ctx.With("file", i)
ctx.Log("event", "download start")
if err := downloadAndProcessURL(ctx, f, dir, &cfg); err != nil {
ctx.Log("event", "download failed", "error", err)
return errors.Wrapf(err, "failed to download file[%d]", i)
if ewc := downloadAndProcessURL(ctx, f, dir, &cfg); ewc != nil {
ctx.Log("event", "download failed", "error", ewc.Err)
return vmextension.NewErrorWithClarificationPtr(ewc.ErrorCode, errors.Wrapf(ewc.Err, "failed to download file[%d]", i))
}
ctx.Log("event", "download complete", "output", dir)
}
return nil
}

// runCmd runs the command (extracted from cfg) in the given dir (assumed to exist).
func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (err error) {
func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (ewc *vmextension.ErrorWithClarification) {
ctx.Log("event", "executing command", "output", dir)
var cmd string
var scenario string
var scenarioInfo string
var err error

// So many ways to execute a command!
if cfg.publicSettings.CommandToExecute != "" {
Expand All @@ -228,27 +233,27 @@ func runCmd(ctx log.Logger, dir string, cfg handlerSettings) (err error) {
} else if cfg.publicSettings.Script != "" {
ctx.Log("event", "executing public script", "output", dir)
if cmd, scenarioInfo, err = writeTempScript(cfg.publicSettings.Script, dir, cfg.publicSettings.SkipDos2Unix); err != nil {
return
return nil
}
scenario = fmt.Sprintf("public-script;%s", scenarioInfo)
} else if cfg.protectedSettings.Script != "" {
ctx.Log("event", "executing protected script", "output", dir)
if cmd, scenarioInfo, err = writeTempScript(cfg.protectedSettings.Script, dir, cfg.publicSettings.SkipDos2Unix); err != nil {
return
return nil
}
scenario = fmt.Sprintf("protected-script;%s", scenarioInfo)
}

begin := time.Now()
err = ExecCmdInDir(cmd, dir)
ewc = ExecCmdInDir(cmd, dir)
elapsed := time.Now().Sub(begin)
isSuccess := err == nil
isSuccess := ewc == nil

telemetry("scenario", scenario, isSuccess, elapsed)

if err != nil {
if ewc != nil {
ctx.Log("event", "failed to execute command", "error", err, "output", dir)
return errors.Wrap(err, "failed to execute command")
return vmextension.NewErrorWithClarificationPtr(ewc.ErrorCode, errors.Wrap(ewc.Err, "failed to execute command"))
}
ctx.Log("event", "executed command", "output", dir)
return nil
Expand Down
14 changes: 8 additions & 6 deletions main/cmds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"path/filepath"
"testing"

"github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/ahmetalpbalkan/go-httpbin"
"github.com/go-kit/kit/log"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -84,7 +85,7 @@ func Test_runCmd_success(t *testing.T) {

require.Nil(t, runCmd(log.NewNopLogger(), dir, handlerSettings{
publicSettings: publicSettings{CommandToExecute: "date"},
}), "command should run successfully")
}).Err, "command should run successfully")

// check stdout stderr files
_, err = os.Stat(filepath.Join(dir, "stdout"))
Expand All @@ -98,11 +99,12 @@ func Test_runCmd_fail(t *testing.T) {
require.Nil(t, err)
defer os.RemoveAll(dir)

err = runCmd(log.NewNopLogger(), dir, handlerSettings{
ewc := runCmd(log.NewNopLogger(), dir, handlerSettings{
publicSettings: publicSettings{CommandToExecute: "non-existing-cmd"},
})
require.NotNil(t, err, "command terminated with exit status")
require.Contains(t, err.Error(), "failed to execute command")
require.Equal(t, errorutil.CommandExecution_failureExitCode, ewc.ErrorCode)
require.NotNil(t, ewc.Err, "command terminated with exit status")
require.Contains(t, ewc.Err.Error(), "failed to execute command")
}

func Test_downloadFiles(t *testing.T) {
Expand All @@ -113,7 +115,7 @@ func Test_downloadFiles(t *testing.T) {
srv := httptest.NewServer(httpbin.GetMux())
defer srv.Close()

err = downloadFiles(log.NewContext(log.NewNopLogger()),
ewc := downloadFiles(log.NewContext(log.NewNopLogger()),
dir,
handlerSettings{
publicSettings: publicSettings{
Expand All @@ -123,7 +125,7 @@ func Test_downloadFiles(t *testing.T) {
srv.URL + "/bytes/1000",
}},
})
require.Nil(t, err)
require.Nil(t, ewc)

// check the files
f := []string{"10", "100", "1000"}
Expand Down
22 changes: 14 additions & 8 deletions main/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"path/filepath"
"syscall"

vmextension "github.com/Azure/azure-extension-platform/vmextension"
errorutil "github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/pkg/errors"
)

Expand All @@ -16,7 +18,7 @@ import (
//
// On error, an exit code may be returned if it is an exit code error.
// Given stdout and stderr will be closed upon returning.
func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, error) {
func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, *vmextension.ErrorWithClarification) {
defer stdout.Close()
defer stderr.Close()

Expand All @@ -30,10 +32,14 @@ func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, error) {
if ok {
if status, ok := exitErr.Sys().(syscall.WaitStatus); ok {
code := status.ExitStatus()
return code, fmt.Errorf("command terminated with exit status=%d", code)
return code, vmextension.NewErrorWithClarificationPtr(errorutil.CommandExecution_failureExitCode, fmt.Errorf("command terminated with exit status=%d", code))
}
}
return 0, errors.Wrapf(err, "failed to execute command")
if err == nil {
return 0, nil
}

return 0, vmextension.NewErrorWithClarificationPtr(errorutil.CommandExecution_failedUnknownError, errors.Wrapf(err, "failed to execute command"))
}

// ExecCmdInDir executes the given command in given directory and saves output
Expand All @@ -42,20 +48,20 @@ func Exec(cmd, workdir string, stdout, stderr io.WriteCloser) (int, error) {
//
// Ideally, we execute commands only once per sequence number in custom-script-extension,
// and save their output under /var/lib/waagent/<dir>/download/<seqnum>/*.
func ExecCmdInDir(cmd, workdir string) error {
func ExecCmdInDir(cmd, workdir string) *vmextension.ErrorWithClarification {
outFn, errFn := logPaths(workdir)

outF, err := os.OpenFile(outFn, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
if err != nil {
return errors.Wrapf(err, "failed to open stdout file")
return vmextension.NewErrorWithClarificationPtr(errorutil.Os_FailedToOpenStdOut, errors.Wrapf(err, "failed to open stdout file"))
}
errF, err := os.OpenFile(errFn, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
if err != nil {
return errors.Wrapf(err, "failed to open stderr file")
return vmextension.NewErrorWithClarificationPtr(errorutil.Os_FailedToOpenStdErr, errors.Wrapf(err, "failed to open stderr file"))
}

_, err = Exec(cmd, workdir, outF, errF)
return err
_, ewc := Exec(cmd, workdir, outF, errF)
return ewc
}

// logPaths returns stdout and stderr file paths for the specified output
Expand Down
31 changes: 19 additions & 12 deletions main/exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ import (
"path/filepath"
"testing"

"github.com/Azure/custom-script-extension-linux/pkg/errorutil"
"github.com/stretchr/testify/require"
)

func TestExec_success(t *testing.T) {
v := new(mockFile)
ec, err := Exec("date", "/", v, v)
require.Nil(t, err, "err: %v -- out: %s", err, v.b.Bytes())
require.Nil(t, err.Err, "err: %v -- out: %s", err.Err, v.b.Bytes())
require.EqualValues(t, 0, ec)
}

Expand All @@ -24,7 +25,7 @@ func TestExec_success_redirectsStdStreams_closesFds(t *testing.T) {
require.False(t, e.closed, "stderr open")

_, err := Exec("/bin/echo 'I am stdout!'>&1; /bin/echo 'I am stderr!'>&2", "/", o, e)
require.Nil(t, err, "err: %v -- stderr: %s", err, e.b.Bytes())
require.Nil(t, err, "err: %v -- stderr: %s", err.Err, e.b.Bytes())
require.Equal(t, "I am stdout!\n", string(o.b.Bytes()))
require.Equal(t, "I am stderr!\n", string(e.b.Bytes()))
require.True(t, o.closed, "stdout closed")
Expand All @@ -33,24 +34,27 @@ func TestExec_success_redirectsStdStreams_closesFds(t *testing.T) {

func TestExec_failure_exitError(t *testing.T) {
ec, err := Exec("exit 12", "/", new(mockFile), new(mockFile))
require.NotNil(t, err)
require.EqualError(t, err, "command terminated with exit status=12") // error is customized
require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failureExitCode)
require.NotNil(t, err.Err)
require.EqualError(t, err.Err, "command terminated with exit status=12") // error is customized
require.EqualValues(t, 12, ec)
}

func TestExec_failure_genericError(t *testing.T) {
_, err := Exec("date", "/non-existing-path", new(mockFile), new(mockFile))
require.NotNil(t, err)
require.Contains(t, err.Error(), "failed to execute command:") // error is wrapped
require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failedUnknownError)
require.NotNil(t, err.Err)
require.Contains(t, err.Err.Error(), "failed to execute command:") // error is wrapped
}

func TestExec_failure_fdClosed(t *testing.T) {
out := new(mockFile)
require.Nil(t, out.Close())

_, err := Exec("date", "/", out, out)
require.NotNil(t, err)
require.Contains(t, err.Error(), "file closed") // error is wrapped
require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failedUnknownError)
require.NotNil(t, err.Err)
require.Contains(t, err.Err.Error(), "file closed") // error is wrapped
}

func TestExec_failure_redirectsStdStreams_closesFds(t *testing.T) {
Expand All @@ -59,7 +63,8 @@ func TestExec_failure_redirectsStdStreams_closesFds(t *testing.T) {
require.False(t, e.closed, "stderr open")

_, err := Exec(`/bin/echo 'I am stdout!'>&1; /bin/echo 'I am stderr!'>&2; exit 12`, "/", o, e)
require.NotNil(t, err)
require.Equal(t, err.ErrorCode, errorutil.CommandExecution_failureExitCode)
require.NotNil(t, err.Err)
require.Equal(t, "I am stdout!\n", string(o.b.Bytes()))
require.Equal(t, "I am stderr!\n", string(e.b.Bytes()))
require.True(t, o.closed, "stdout closed")
Expand All @@ -71,8 +76,8 @@ func TestExecCmdInDir(t *testing.T) {
require.Nil(t, err)
defer os.RemoveAll(dir)

err = ExecCmdInDir("/bin/echo 'Hello world'", dir)
require.Nil(t, err)
ewc := ExecCmdInDir("/bin/echo 'Hello world'", dir)
require.Nil(t, ewc)
require.True(t, fileExists(t, filepath.Join(dir, "stdout")), "stdout file should be created")
require.True(t, fileExists(t, filepath.Join(dir, "stderr")), "stderr file should be created")

Expand All @@ -87,7 +92,9 @@ func TestExecCmdInDir(t *testing.T) {

func TestExecCmdInDir_cantOpenError(t *testing.T) {
err := ExecCmdInDir("/bin/echo 'Hello world'", "/non-existing-dir")
require.Contains(t, err.Error(), "failed to open stdout file")
require.Equal(t, err.ErrorCode, errorutil.Os_FailedToOpenStdErr)
require.NotNil(t, err.Err)
require.Contains(t, err.Err.Error(), "failed to open stdout file")
}

func TestExecCmdInDir_truncates(t *testing.T) {
Expand Down
Loading