diff --git a/server.go b/server.go index b05af92..a726c32 100644 --- a/server.go +++ b/server.go @@ -8,14 +8,12 @@ import ( "crypto/tls" "crypto/x509" "encoding/base64" - "errors" "fmt" "io" "net" "os" "os/signal" "os/user" - "runtime" "sort" "strconv" "strings" @@ -502,11 +500,6 @@ func Serve(opts *ServeConfig) { } select { case <-ctx.Done(): - // Cancellation. We can stop the server by closing the listener. - // This isn't graceful at all but this is currently only used by - // tests and its our only way to stop. - _ = listener.Close() - // If this is a grpc server, then we also ask the server itself to // end which will kill all connections. There isn't an easy way to do // this for net/rpc currently but net/rpc is more and more unused. @@ -514,6 +507,11 @@ func Serve(opts *ServeConfig) { s.Stop() } + // Cancellation. We can stop the server by closing the listener. + // This isn't graceful at all but this is currently only used by + // tests and its our only way to stop. + _ = listener.Close() + // Wait for the server itself to shut down <-doneCh @@ -525,56 +523,6 @@ func Serve(opts *ServeConfig) { } } -func serverListener(unixSocketCfg UnixSocketConfig) (net.Listener, error) { - if runtime.GOOS == "windows" { - return serverListener_tcp() - } - - return serverListener_unix(unixSocketCfg) -} - -func serverListener_tcp() (net.Listener, error) { - envMinPort := os.Getenv("PLUGIN_MIN_PORT") - envMaxPort := os.Getenv("PLUGIN_MAX_PORT") - - var minPort, maxPort int64 - var err error - - switch { - case len(envMinPort) == 0: - minPort = 0 - default: - minPort, err = strconv.ParseInt(envMinPort, 10, 32) - if err != nil { - return nil, fmt.Errorf("couldn't get value from PLUGIN_MIN_PORT: %v", err) - } - } - - switch { - case len(envMaxPort) == 0: - maxPort = 0 - default: - maxPort, err = strconv.ParseInt(envMaxPort, 10, 32) - if err != nil { - return nil, fmt.Errorf("couldn't get value from PLUGIN_MAX_PORT: %v", err) - } - } - - if minPort > maxPort { - return nil, fmt.Errorf("PLUGIN_MIN_PORT value of %d is greater than PLUGIN_MAX_PORT value of %d", minPort, maxPort) - } - - for port := minPort; port <= maxPort; port++ { - address := fmt.Sprintf("127.0.0.1:%d", port) - listener, err := net.Listen("tcp", address) - if err == nil { - return listener, nil - } - } - - return nil, errors.New("couldn't bind plugin TCP listener") -} - func serverListener_unix(unixSocketCfg UnixSocketConfig) (net.Listener, error) { tf, err := os.CreateTemp(unixSocketCfg.socketDir, "plugin") if err != nil { diff --git a/server_test.go b/server_test.go index 24bb3a3..22a569e 100644 --- a/server_test.go +++ b/server_test.go @@ -9,7 +9,7 @@ import ( "log" "net" "os" - "path" + "path/filepath" "runtime" "strings" "testing" @@ -68,6 +68,21 @@ func TestServer_testMode(t *testing.T) { if err := client.Ping(); err != nil { t.Fatalf("should not err: %s", err) } + // Grab the impl + raw, err := client.Dispense("test") + if err != nil { + t.Fatalf("err should be nil, got %s", err) + } + + tester, ok := raw.(testInterface) + if !ok { + t.Fatalf("bad: %#v", raw) + } + + n := tester.Double(3) + if n != 6 { + t.Fatal("invalid response", n) + } // Kill which should do nothing c.Kill() @@ -309,9 +324,10 @@ func TestServer_testStdLogger(t *testing.T) { func TestUnixSocketDir(t *testing.T) { if runtime.GOOS == "windows" { - t.Skip("go-plugin doesn't support unix sockets on Windows") + if !isSupportUnix() { + t.Skip("go-plugin doesn't support unix sockets on Windows") + } } - tmpDir := t.TempDir() t.Setenv(EnvUnixSocketDir, tmpDir) @@ -344,8 +360,8 @@ func TestUnixSocketDir(t *testing.T) { t.Fatal("should've received reattach") } - actualDir := path.Clean(path.Dir(cfg.Addr.String())) - expectedDir := path.Clean(tmpDir) + actualDir := filepath.Clean(filepath.Dir(cfg.Addr.String())) + expectedDir := filepath.Clean(tmpDir) if actualDir != expectedDir { t.Fatalf("Expected socket in dir: %s, but was in %s", expectedDir, actualDir) } diff --git a/server_unix.go b/server_unix.go new file mode 100644 index 0000000..6f779f6 --- /dev/null +++ b/server_unix.go @@ -0,0 +1,13 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +//go:build !windows +// +build !windows + +package plugin + +var serverListener = serverListener_unix + +func isSupportUnix() bool { + return true +} diff --git a/server_windows.go b/server_windows.go new file mode 100644 index 0000000..2cd12cf --- /dev/null +++ b/server_windows.go @@ -0,0 +1,72 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +//go:build windows +// +build windows + +package plugin + +import ( + "errors" + "fmt" + "net" + "os" + "strconv" + + "golang.org/x/sys/windows" +) + +func serverListener(unixSocketCfg UnixSocketConfig) (net.Listener, error) { + if isSupportUnix() { + unixSocketCfg.Group = "" + return serverListener_unix(unixSocketCfg) + } + return serverListener_tcp() +} + +func serverListener_tcp() (net.Listener, error) { + envMinPort := os.Getenv("PLUGIN_MIN_PORT") + envMaxPort := os.Getenv("PLUGIN_MAX_PORT") + + var minPort, maxPort int64 + var err error + + switch { + case len(envMinPort) == 0: + minPort = 0 + default: + minPort, err = strconv.ParseInt(envMinPort, 10, 32) + if err != nil { + return nil, fmt.Errorf("couldn't get value from PLUGIN_MIN_PORT: %v", err) + } + } + + switch { + case len(envMaxPort) == 0: + maxPort = 0 + default: + maxPort, err = strconv.ParseInt(envMaxPort, 10, 32) + if err != nil { + return nil, fmt.Errorf("couldn't get value from PLUGIN_MAX_PORT: %v", err) + } + } + + if minPort > maxPort { + return nil, fmt.Errorf("PLUGIN_MIN_PORT value of %d is greater than PLUGIN_MAX_PORT value of %d", minPort, maxPort) + } + + for port := minPort; port <= maxPort; port++ { + address := fmt.Sprintf("127.0.0.1:%d", port) + listener, err := net.Listen("tcp", address) + if err == nil { + return listener, nil + } + } + + return nil, errors.New("couldn't bind plugin TCP listener") +} + +func isSupportUnix() bool { + major, _, build := windows.RtlGetNtVersionNumbers() + return major >= 10 && build >= 17063 +}