diff --git a/reexec/reexec.go b/reexec/reexec.go index 414f35b..b97a0aa 100644 --- a/reexec/reexec.go +++ b/reexec/reexec.go @@ -7,6 +7,7 @@ package reexec import ( + "context" "fmt" "os" "os/exec" @@ -52,6 +53,20 @@ func Command(args ...string) *exec.Cmd { return command(args...) } +// CommandContext is like [Command] but includes a context. It uses +// [exec.CommandContext] under the hood. +// +// The provided context is used to interrupt the process +// (by calling cmd.Cancel or [os.Process.Kill]) +// if the context becomes done before the command completes on its own. +// +// CommandContext sets the command's Cancel function to invoke the Kill method +// on its Process, and leaves its WaitDelay unset. The caller may change the +// cancellation behavior by modifying those fields before starting the command. +func CommandContext(ctx context.Context, args ...string) *exec.Cmd { + return commandContext(ctx, args...) +} + // Self returns the path to the current process's binary. // // On Linux, it returns "/proc/self/exe", which provides the in-memory version diff --git a/reexec/reexec_linux.go b/reexec/reexec_linux.go index 12b6c32..cbe13cf 100644 --- a/reexec/reexec_linux.go +++ b/reexec/reexec_linux.go @@ -1,6 +1,7 @@ package reexec import ( + "context" "os/exec" "syscall" ) @@ -16,3 +17,15 @@ func command(args ...string) *exec.Cmd { } return cmd } + +func commandContext(ctx context.Context, args ...string) *exec.Cmd { + // We try to stay close to exec.Command's behavior, but after + // constructing the cmd, we remove "Self()" from cmd.Args, which + // is prepended by exec.Command. + cmd := exec.CommandContext(ctx, Self(), args...) + cmd.Args = cmd.Args[1:] + cmd.SysProcAttr = &syscall.SysProcAttr{ + Pdeathsig: syscall.SIGTERM, + } + return cmd +} diff --git a/reexec/reexec_other.go b/reexec/reexec_other.go index e07dd2d..27bd128 100644 --- a/reexec/reexec_other.go +++ b/reexec/reexec_other.go @@ -3,6 +3,7 @@ package reexec import ( + "context" "os/exec" ) @@ -14,3 +15,12 @@ func command(args ...string) *exec.Cmd { cmd.Args = cmd.Args[1:] return cmd } + +func commandContext(ctx context.Context, args ...string) *exec.Cmd { + // We try to stay close to exec.Command's behavior, but after + // constructing the cmd, we remove "Self()" from cmd.Args, which + // is prepended by exec.Command. + cmd := exec.CommandContext(ctx, Self(), args...) + cmd.Args = cmd.Args[1:] + return cmd +} diff --git a/reexec/reexec_test.go b/reexec/reexec_test.go index 057c4fa..0620c37 100644 --- a/reexec/reexec_test.go +++ b/reexec/reexec_test.go @@ -1,6 +1,7 @@ package reexec import ( + "context" "errors" "fmt" "os" @@ -9,11 +10,13 @@ import ( "reflect" "strings" "testing" + "time" ) const ( testReExec = "test-reexec" testReExec2 = "test-reexec2" + testReExec3 = "test-reexec3" ) func init() { @@ -28,6 +31,11 @@ func init() { fmt.Println("Hello", testReExec2, args) os.Exit(0) }) + Register(testReExec3, func() { + fmt.Println("Hello " + testReExec3) + time.Sleep(1 * time.Second) + os.Exit(0) + }) Init() } @@ -113,6 +121,83 @@ func TestCommand(t *testing.T) { } } +func TestCommandContext(t *testing.T) { + tests := []struct { + doc string + cmdAndArgs []string + cancel bool + expOut string + expError bool + }{ + { + doc: "basename", + cmdAndArgs: []string{testReExec2}, + expOut: "Hello test-reexec2", + }, + { + doc: "full path", + cmdAndArgs: []string{filepath.Join("something", testReExec2)}, + expOut: "Hello test-reexec2", + }, + { + doc: "command with args", + cmdAndArgs: []string{testReExec2, "--some-flag", "some-value", "arg1", "arg2"}, + expOut: `Hello test-reexec2 (args: []string{"--some-flag", "some-value", "arg1", "arg2"})`, + }, + { + doc: "context canceled", + cancel: true, + cmdAndArgs: []string{testReExec2}, + expError: true, + }, + { + doc: "context timeout", + cmdAndArgs: []string{testReExec3}, + expOut: "Hello test-reexec3", + expError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.doc, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + cmd := CommandContext(ctx, tc.cmdAndArgs...) + if !reflect.DeepEqual(cmd.Args, tc.cmdAndArgs) { + t.Fatalf("got %+v, want %+v", cmd.Args, tc.cmdAndArgs) + } + + w, err := cmd.StdinPipe() + if err != nil { + t.Fatalf("Error on pipe creation: %v", err) + } + defer func() { _ = w.Close() }() + if tc.cancel { + cancel() + } + out, err := cmd.CombinedOutput() + if tc.cancel { + if !errors.Is(err, context.Canceled) { + t.Errorf("got %[1]v (%[1]T), want %v", err, context.Canceled) + } + } + if tc.expError { + if err == nil { + t.Errorf("expected error, got nil") + } + } else if err != nil { + t.Errorf("error on re-exec cmd: %v, out: %v", err, string(out)) + } + + actual := strings.TrimSpace(string(out)) + if actual != tc.expOut { + t.Errorf("got %v, want %v", actual, tc.expOut) + } + }) + } +} + func TestNaiveSelf(t *testing.T) { if os.Getenv("TEST_CHECK") == "1" { os.Exit(2)