aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMarc Pervaz Boocha <mboocha@sudomsg.com>2025-08-07 22:51:34 +0530
committerMarc Pervaz Boocha <mboocha@sudomsg.com>2025-08-07 22:51:34 +0530
commit1326bb4103694d7ceac23b23329997ea2207a3f6 (patch)
tree72eb0065b597121c4e54518d303f5d15de40336d
parentFixed missing signals (diff)
downloadkit-1326bb4103694d7ceac23b23329997ea2207a3f6.tar
kit-1326bb4103694d7ceac23b23329997ea2207a3f6.tar.gz
kit-1326bb4103694d7ceac23b23329997ea2207a3f6.tar.bz2
kit-1326bb4103694d7ceac23b23329997ea2207a3f6.tar.lz
kit-1326bb4103694d7ceac23b23329997ea2207a3f6.tar.xz
kit-1326bb4103694d7ceac23b23329997ea2207a3f6.tar.zst
kit-1326bb4103694d7ceac23b23329997ea2207a3f6.zip
Added File Mode to sockets and socket activation
-rw-r--r--http/http.go85
-rw-r--r--http/http_test.go54
-rw-r--r--http/server.go153
-rw-r--r--http/server_test.go141
-rw-r--r--http/systemd_linux.go61
-rw-r--r--http/systemd_stub.go7
6 files changed, 377 insertions, 124 deletions
diff --git a/http/http.go b/http/http.go
new file mode 100644
index 0000000..063c1e8
--- /dev/null
+++ b/http/http.go
@@ -0,0 +1,85 @@
+package http
+
+import (
+ "context"
+ "fmt"
+ "go.sudomsg.com/kit/logging"
+ "golang.org/x/sync/errgroup"
+ "log/slog"
+ "net"
+ "net/http"
+)
+
+// RunHTTPServers runs HTTP servers concurrently on all provided listeners.
+//
+// The provided handler is used for all servers.
+// Servers respond to context cancellation by performing a graceful shutdown with a timeout of 10 seconds.
+//
+// Logging is performed using the slog.Logger extracted from context.
+// Each server logs startup, shutdown, and errors.
+//
+// This function blocks until all servers have stopped or an error occurs.
+func RunHTTPServers(ctx context.Context, lns Listeners, handler http.Handler) error {
+ g, ctx := errgroup.WithContext(ctx)
+
+ for _, ln := range lns {
+ g.Go(func() error {
+ logger, ctx := logging.With(ctx, "address", ln.Addr())
+
+ srv := &http.Server{
+ Addr: ln.Addr().String(),
+ Handler: handler,
+ BaseContext: func(l net.Listener) context.Context {
+ return ctx
+ },
+ ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
+ }
+
+ logger.Log(ctx, slog.LevelInfo, "HTTP server serving")
+
+ if err := httpServeContext(ctx, srv, ln); err != nil {
+ return fmt.Errorf("HTTP server Serve Error: %w", err)
+ }
+ return nil
+ })
+
+ }
+ return g.Wait()
+}
+
+func httpServeContext(ctx context.Context, srv *http.Server, ln net.Listener) error {
+ logger := logging.FromContext(ctx)
+
+ shutdownErrCh := make(chan error, 1)
+
+ go func() {
+ <-ctx.Done()
+ shutdownErrCh <- httpServeShutdown(srv, logger)
+ }()
+ serveErr := srv.Serve(ln)
+
+ // Always wait for shutdown result
+ shutdownErr := <-shutdownErrCh
+
+ // Prioritize Serve error
+ if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
+ return fmt.Errorf("http serve error: %w", serveErr)
+ }
+ if shutdownErr != nil {
+ return fmt.Errorf("http shutdown error: %w", shutdownErr)
+ }
+ return nil
+}
+
+func httpServeShutdown(srv *http.Server, logger *slog.Logger) error {
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ if err := srv.Shutdown(ctx); err != nil {
+ logger.Log(ctx, slog.LevelWarn, "HTTP server Shutdown Error", slog.Any("error", err))
+ return err
+ }
+
+ logger.Log(ctx, slog.LevelInfo, "HTTP Server Shutdown Complete")
+ return nil
+}
diff --git a/http/http_test.go b/http/http_test.go
new file mode 100644
index 0000000..603dce3
--- /dev/null
+++ b/http/http_test.go
@@ -0,0 +1,54 @@
+package http_test
+
+import (
+ "context"
+ httpServer "go.sudomsg.com/kit/http"
+ "io"
+ "net/http"
+ "strings"
+ "testing"
+ "time"
+)
+
+func TestRunHTTPServers(t *testing.T) {
+ t.Parallel()
+ t.Run("basic serve and shutdown", func(t *testing.T) {
+ t.Parallel()
+ ctx := t.Context()
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, "hello")
+ })
+
+ ln := newListener(t)
+ addr := ln.Addr().String()
+
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ go func() {
+ // Wait for server to be ready
+ time.Sleep(200 * time.Millisecond)
+
+ r, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://"+addr, nil)
+ resp, err := http.DefaultClient.Do(r)
+ if err != nil {
+ t.Errorf("http.Do error: %v", err)
+ return
+ }
+ defer resp.Body.Close()
+
+ body, _ := io.ReadAll(resp.Body)
+ if got := strings.TrimSpace(string(body)); got != "hello" {
+ t.Errorf("unexpected response body: %q", got)
+ }
+
+ cancel() // shutdown the server
+ }()
+
+ err := httpServer.RunHTTPServers(ctx, httpServer.Listeners{ln}, handler)
+ if err != nil {
+ t.Fatalf("RunHTTPServers failed: %v", err)
+ }
+ })
+}
+
diff --git a/http/server.go b/http/server.go
index dc74008..b74e39b 100644
--- a/http/server.go
+++ b/http/server.go
@@ -12,13 +12,12 @@ import (
"context"
"errors"
"fmt"
- "log/slog"
+ "io/fs"
"net"
- "net/http"
+ "os"
+ "strings"
"sync"
- "time"
- "go.sudomsg.com/kit/logging"
"golang.org/x/sync/errgroup"
)
@@ -42,13 +41,71 @@ func (ls Listeners) CloseAll() error {
return nil
}
+type NetType int
+
+const (
+ NetTCP NetType = iota
+ NetTCP4
+ NetTCP6
+ NetUnix
+ NetUnixPacket
+)
+
+func (n NetType) String() string {
+ switch n {
+ case NetTCP:
+ return "tcp"
+ case NetTCP4:
+ return "tcp4"
+ case NetTCP6:
+ return "tcp6"
+ case NetUnix:
+ return "unix"
+ case NetUnixPacket:
+ return "unixpacket"
+ default:
+ return fmt.Sprintf("NetType(%d)", n)
+ }
+}
+
+func (n *NetType) Set(s string) error {
+ switch strings.ToLower(s) {
+ case "tcp":
+ *n = NetTCP
+ case "tcp4":
+ *n = NetTCP4
+ case "tcp6":
+ *n = NetTCP6
+ case "unix":
+ *n = NetUnix
+ case "unixpacket":
+ *n = NetUnixPacket
+ default:
+ return fmt.Errorf("invalid NetType %q", s)
+ }
+ return nil
+}
+
+func (n *NetType) UnmarshalText(b []byte) error {
+ return n.Set(string(b))
+}
+
+func (n NetType) MarshalText() ([]byte, error) {
+ return n.AppendText(nil)
+}
+
+func (n NetType) AppendText(dst []byte) ([]byte, error) {
+ return append(dst, n.String()...), nil
+}
+
// ServerConfig defines a single network listener configuration.
//
// Network is the network type, e.g., "tcp".
// Address is the socket address, e.g., ":8080".
type ServerConfig struct {
- Network string `toml:"network"`
- Address string `toml:"address"`
+ Network NetType
+ Address string
+ Mode fs.FileMode
}
// OpenConfigListeners opens network listeners as specified by the provided ServerConfig slice.
@@ -62,16 +119,12 @@ func OpenConfigListners(ctx context.Context, config []ServerConfig) (Listeners,
lns := make(Listeners, 0, len(config))
var mu sync.Mutex
g, ctx := errgroup.WithContext(ctx)
- var lc net.ListenConfig
+
for _, cfg := range config {
g.Go(func() error {
- network := cfg.Network
- if network == "" {
- network = "tcp"
- }
- ln, err := lc.Listen(ctx, network, cfg.Address)
+ ln, err := listenConfig(ctx, cfg)
if err != nil {
- return fmt.Errorf("failed to listen on %s %s: %w", network, cfg.Address, err)
+ return err
}
mu.Lock()
@@ -87,60 +140,32 @@ func OpenConfigListners(ctx context.Context, config []ServerConfig) (Listeners,
return lns, nil
}
-func OpenListeners(ctx context.Context, config []ServerConfig) (Listeners, error) {
- return OpenConfigListners(ctx, config)
-}
-
-// RunHTTPServers runs HTTP servers concurrently on all provided listeners.
-//
-// The provided handler is used for all servers.
-// Servers respond to context cancellation by performing a graceful shutdown with a timeout of 10 seconds.
-//
-// Logging is performed using the slog.Logger extracted from context.
-// Each server logs startup, shutdown, and errors.
-//
-// This function blocks until all servers have stopped or an error occurs.
-func RunHTTPServers(ctx context.Context, lns Listeners, handler http.Handler) error {
- g, ctx := errgroup.WithContext(ctx)
-
- for _, ln := range lns {
- g.Go(func() error {
- logger, ctx := logging.With(ctx, "address", ln.Addr())
-
- srv := &http.Server{
- Addr: ln.Addr().String(),
- Handler: handler,
- BaseContext: func(l net.Listener) context.Context {
- return ctx
- },
- ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
- }
-
- logger.Log(ctx, slog.LevelInfo, "HTTP server serving")
-
- if err := httpServeContext(ctx, srv, ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
- return fmt.Errorf("HTTP server Serve Error: %w", err)
- }
- return nil
- })
+func listenConfig(ctx context.Context, cfg ServerConfig) (net.Listener, error) {
+ network := cfg.Network
+ var lc net.ListenConfig
+ ln, err := lc.Listen(ctx, network.String(), cfg.Address)
+ if err != nil {
+ return nil, fmt.Errorf("failed to listen on %s %s: %w", network, cfg.Address, err)
+ }
+ if _, ok := ln.(*net.UnixListener); cfg.Mode != 0 && ok {
+ if err := os.Chmod(cfg.Address, cfg.Mode); err != nil {
+ ln.Close()
+ return nil, fmt.Errorf("chmod failed on %s: %w", cfg.Address, err)
+ }
}
- return g.Wait()
+ return ln, nil
}
-func httpServeContext(ctx context.Context, srv *http.Server, ln net.Listener) error {
- logger := logging.FromContext(ctx)
- go func() {
- <-ctx.Done()
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
+func OpenListeners(ctx context.Context, config []ServerConfig) (Listeners, error) {
+ lns, err := getSystemdListeners()
+ if err != nil {
+ return nil, fmt.Errorf("systemd socket activation failed: %w", err)
+ }
+ if lns != nil {
+ // Systemd is in charge; we don't honor ServerConfig here
+ return lns, nil
+ }
- err := srv.Shutdown(ctx)
- if err != nil {
- logger.Log(ctx, slog.LevelWarn, "HTTP server Shutdown Error", slog.Any("error", err))
- } else {
- logger.Log(ctx, slog.LevelInfo, "HTTP Server Shutdown Complete")
- }
- }()
- return srv.Serve(ln)
+ return OpenConfigListners(ctx, config)
}
diff --git a/http/server_test.go b/http/server_test.go
index 7138c73..0f49f01 100644
--- a/http/server_test.go
+++ b/http/server_test.go
@@ -1,15 +1,11 @@
package http_test
import (
- "context"
- "io"
"net"
- "net/http"
"strings"
"testing"
- "time"
- httpServer "go.sudomsg.com/kit/http"
+ "go.sudomsg.com/kit/http"
)
// helper to find free TCP ports
@@ -29,12 +25,12 @@ func TestOpenConfigListeners(t *testing.T) {
t.Parallel()
ctx := t.Context()
- cfg := []httpServer.ServerConfig{
- {Network: "tcp", Address: "localhost:0"},
- {Network: "tcp", Address: "localhost:0"},
+ cfg := []http.ServerConfig{
+ {Network: http.NetTCP, Address: "localhost:0"},
+ {Network: http.NetTCP, Address: "localhost:0"},
}
- lns, err := httpServer.OpenConfigListners(ctx, cfg)
+ lns, err := http.OpenConfigListners(ctx, cfg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -52,12 +48,12 @@ func TestOpenConfigListeners(t *testing.T) {
conflict := newListener(t)
defer conflict.Close()
- cfg := []httpServer.ServerConfig{
- {Network: "tcp", Address: "localhost:0"},
- {Network: "tcp", Address: conflict.Addr().String()}, // will fail
+ cfg := []http.ServerConfig{
+ {Network: http.NetTCP, Address: "localhost:0"},
+ {Network: http.NetTCP, Address: conflict.Addr().String()}, // will fail
}
- lns, err := httpServer.OpenConfigListners(ctx, cfg)
+ lns, err := http.OpenConfigListners(ctx, cfg)
if err == nil {
defer lns.CloseAll()
t.Fatal("expected error due to conflict, got nil")
@@ -71,7 +67,7 @@ func TestCloseAll(t *testing.T) {
ln1 := newListener(t)
ln2 := newListener(t)
- ls := httpServer.Listeners{ln1, ln2}
+ ls := http.Listeners{ln1, ln2}
err := ls.CloseAll()
if err != nil {
t.Errorf("unexpected error from CloseAll: %v", err)
@@ -85,62 +81,20 @@ func TestCloseAll(t *testing.T) {
})
}
-func TestRunHTTPServers(t *testing.T) {
- t.Parallel()
- t.Run("basic serve and shutdown", func(t *testing.T) {
- t.Parallel()
- ctx := t.Context()
- handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- io.WriteString(w, "hello")
- })
-
- ln := newListener(t)
- addr := ln.Addr().String()
-
- ctx, cancel := context.WithCancel(ctx)
- defer cancel()
-
- go func() {
- // Wait for server to be ready
- time.Sleep(200 * time.Millisecond)
-
- r, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://"+addr, nil)
- resp, err := http.DefaultClient.Do(r)
- if err != nil {
- t.Errorf("http.Do error: %v", err)
- return
- }
- defer resp.Body.Close()
-
- body, _ := io.ReadAll(resp.Body)
- if got := strings.TrimSpace(string(body)); got != "hello" {
- t.Errorf("unexpected response body: %q", got)
- }
-
- cancel() // shutdown the server
- }()
-
- err := httpServer.RunHTTPServers(ctx, httpServer.Listeners{ln}, handler)
- if err != nil {
- t.Fatalf("RunHTTPServers failed: %v", err)
- }
- })
-}
-
func TestInvalidConfig(t *testing.T) {
t.Parallel()
tests := []struct {
name string
- config httpServer.ServerConfig
+ config http.ServerConfig
}{
{
name: "invalid network",
- config: httpServer.ServerConfig{Network: "invalid", Address: "localhost:0"},
+ config: http.ServerConfig{Network: -1, Address: "localhost:0"},
},
{
name: "invalid address",
- config: httpServer.ServerConfig{Network: "tcp", Address: "::::"},
+ config: http.ServerConfig{Network: http.NetTCP, Address: "::::"},
},
}
@@ -150,10 +104,77 @@ func TestInvalidConfig(t *testing.T) {
ctx := t.Context()
- _, err := httpServer.OpenConfigListners(ctx, []httpServer.ServerConfig{tt.config})
+ _, err := http.OpenConfigListners(ctx, []http.ServerConfig{tt.config})
if err == nil {
t.Fatal("OpenConfigListners() expected error, got nil")
}
})
}
}
+
+func TestNetType(t *testing.T) {
+ t.Parallel()
+
+ t.Run("RoundTrip", func(t *testing.T) {
+ t.Parallel()
+
+ testCases := []http.NetType{
+ http.NetTCP,
+ http.NetTCP4,
+ http.NetTCP6,
+ http.NetUnix,
+ http.NetUnixPacket,
+ }
+
+ for _, original := range testCases {
+ t.Run(original.String(), func(t *testing.T) {
+ t.Run("case same", func(t *testing.T) {
+ t.Run(original.String(), func(t *testing.T) {
+ text, err := original.MarshalText()
+ if err != nil {
+ t.Fatalf("MarshalText failed: %v", err)
+ }
+
+ var decoded http.NetType
+ if err := decoded.UnmarshalText(text); err != nil {
+ t.Fatalf("UnmarshalText failed: %v", err)
+ }
+
+ if decoded != original {
+ t.Fatalf("Round-trip mismatch: got %v, want %v", decoded, original)
+ }
+ })
+ })
+ t.Run("case not same", func(t *testing.T) {
+ t.Run(original.String(), func(t *testing.T) {
+ text, err := original.MarshalText()
+ if err != nil {
+ t.Fatalf("MarshalText failed: %v", err)
+ }
+
+ var decoded http.NetType
+ if err := decoded.UnmarshalText([]byte(strings.ToUpper(string(text)))); err != nil {
+ t.Fatalf("UnmarshalText failed: %v", err)
+ }
+
+ if decoded != original {
+ t.Fatalf("Round-trip mismatch: got %v, want %v", decoded, original)
+ }
+ })
+ })
+
+ })
+ }
+ })
+
+ t.Run("Invalid", func(t *testing.T) {
+ t.Parallel()
+
+ input := "invalid"
+ var decoded http.NetType
+
+ if err := decoded.UnmarshalText([]byte(input)); err == nil {
+ t.Errorf("expected error for input %q, got none", input)
+ }
+ })
+}
diff --git a/http/systemd_linux.go b/http/systemd_linux.go
new file mode 100644
index 0000000..817b421
--- /dev/null
+++ b/http/systemd_linux.go
@@ -0,0 +1,61 @@
+//go:build linux
+
+package http
+
+import (
+ "fmt"
+
+ "net"
+ "os"
+ "strconv"
+)
+
+func getSystemdListeners() (Listeners, error) {
+ pidStr := os.Getenv("LISTEN_PID")
+ fdStr := os.Getenv("LISTEN_FDS")
+
+ if pidStr == "" || fdStr == "" {
+ return nil, nil // Not running under systemd
+ }
+
+ pid, err := strconv.Atoi(pidStr)
+ if err != nil {
+ // Not our activation — another process might have inherited the env.
+ return nil, nil
+ }
+ if pid != os.Getpid() {
+ return nil, fmt.Errorf("LISTEN_PID %d does not match current PID %d", pid, os.Getpid())
+ }
+
+ defer func() {
+ _ = os.Unsetenv("LISTEN_PID")
+ _ = os.Unsetenv("LISTEN_FDS")
+ }()
+
+ nfds, err := strconv.ParseUint(fdStr, 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("invalid LISTEN_FDS: %w", err)
+ }
+ if nfds == 0 {
+ return nil, nil // Nothing to do
+ }
+
+ lns := make(Listeners, 0, nfds)
+ for i := range nfds {
+ fd := i + 3
+ file := os.NewFile(uintptr(fd), fmt.Sprintf("fd_%d", fd))
+ if file == nil {
+ return nil, fmt.Errorf("fd %d was nil", fd)
+ }
+ ln, err := net.FileListener(file)
+ if err != nil {
+ file.Close()
+ return nil, fmt.Errorf("failed to create listener from fd %d: %w", fd, err)
+ }
+ lns = append(lns, ln)
+ }
+
+ return lns, nil
+}
+
+
diff --git a/http/systemd_stub.go b/http/systemd_stub.go
new file mode 100644
index 0000000..00699f3
--- /dev/null
+++ b/http/systemd_stub.go
@@ -0,0 +1,7 @@
+//go:build !linux
+
+package http
+
+func getSystemdListeners() (Listeners, error) {
+ return nil, nil // No-op on non-Linux
+}