diff options
author | Marc Pervaz Boocha <mboocha@sudomsg.com> | 2025-08-07 22:51:34 +0530 |
---|---|---|
committer | Marc Pervaz Boocha <mboocha@sudomsg.com> | 2025-08-07 22:51:34 +0530 |
commit | 1326bb4103694d7ceac23b23329997ea2207a3f6 (patch) | |
tree | 72eb0065b597121c4e54518d303f5d15de40336d | |
parent | Fixed missing signals (diff) | |
download | kit-1326bb4103694d7ceac23b23329997ea2207a3f6.tar kit-1326bb4103694d7ceac23b23329997ea2207a3f6.tar.gz kit-1326bb4103694d7ceac23b23329997ea2207a3f6.tar.bz2 kit-1326bb4103694d7ceac23b23329997ea2207a3f6.tar.lz kit-1326bb4103694d7ceac23b23329997ea2207a3f6.tar.xz kit-1326bb4103694d7ceac23b23329997ea2207a3f6.tar.zst kit-1326bb4103694d7ceac23b23329997ea2207a3f6.zip |
Added File Mode to sockets and socket activation
-rw-r--r-- | http/http.go | 85 | ||||
-rw-r--r-- | http/http_test.go | 54 | ||||
-rw-r--r-- | http/server.go | 153 | ||||
-rw-r--r-- | http/server_test.go | 141 | ||||
-rw-r--r-- | http/systemd_linux.go | 61 | ||||
-rw-r--r-- | http/systemd_stub.go | 7 |
6 files changed, 377 insertions, 124 deletions
diff --git a/http/http.go b/http/http.go new file mode 100644 index 0000000..063c1e8 --- /dev/null +++ b/http/http.go @@ -0,0 +1,85 @@ +package http + +import ( + "context" + "fmt" + "go.sudomsg.com/kit/logging" + "golang.org/x/sync/errgroup" + "log/slog" + "net" + "net/http" +) + +// RunHTTPServers runs HTTP servers concurrently on all provided listeners. +// +// The provided handler is used for all servers. +// Servers respond to context cancellation by performing a graceful shutdown with a timeout of 10 seconds. +// +// Logging is performed using the slog.Logger extracted from context. +// Each server logs startup, shutdown, and errors. +// +// This function blocks until all servers have stopped or an error occurs. +func RunHTTPServers(ctx context.Context, lns Listeners, handler http.Handler) error { + g, ctx := errgroup.WithContext(ctx) + + for _, ln := range lns { + g.Go(func() error { + logger, ctx := logging.With(ctx, "address", ln.Addr()) + + srv := &http.Server{ + Addr: ln.Addr().String(), + Handler: handler, + BaseContext: func(l net.Listener) context.Context { + return ctx + }, + ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError), + } + + logger.Log(ctx, slog.LevelInfo, "HTTP server serving") + + if err := httpServeContext(ctx, srv, ln); err != nil { + return fmt.Errorf("HTTP server Serve Error: %w", err) + } + return nil + }) + + } + return g.Wait() +} + +func httpServeContext(ctx context.Context, srv *http.Server, ln net.Listener) error { + logger := logging.FromContext(ctx) + + shutdownErrCh := make(chan error, 1) + + go func() { + <-ctx.Done() + shutdownErrCh <- httpServeShutdown(srv, logger) + }() + serveErr := srv.Serve(ln) + + // Always wait for shutdown result + shutdownErr := <-shutdownErrCh + + // Prioritize Serve error + if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) { + return fmt.Errorf("http serve error: %w", serveErr) + } + if shutdownErr != nil { + return fmt.Errorf("http shutdown error: %w", shutdownErr) + } + return nil +} + +func httpServeShutdown(srv *http.Server, logger *slog.Logger) error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := srv.Shutdown(ctx); err != nil { + logger.Log(ctx, slog.LevelWarn, "HTTP server Shutdown Error", slog.Any("error", err)) + return err + } + + logger.Log(ctx, slog.LevelInfo, "HTTP Server Shutdown Complete") + return nil +} diff --git a/http/http_test.go b/http/http_test.go new file mode 100644 index 0000000..603dce3 --- /dev/null +++ b/http/http_test.go @@ -0,0 +1,54 @@ +package http_test + +import ( + "context" + httpServer "go.sudomsg.com/kit/http" + "io" + "net/http" + "strings" + "testing" + "time" +) + +func TestRunHTTPServers(t *testing.T) { + t.Parallel() + t.Run("basic serve and shutdown", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "hello") + }) + + ln := newListener(t) + addr := ln.Addr().String() + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + // Wait for server to be ready + time.Sleep(200 * time.Millisecond) + + r, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://"+addr, nil) + resp, err := http.DefaultClient.Do(r) + if err != nil { + t.Errorf("http.Do error: %v", err) + return + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if got := strings.TrimSpace(string(body)); got != "hello" { + t.Errorf("unexpected response body: %q", got) + } + + cancel() // shutdown the server + }() + + err := httpServer.RunHTTPServers(ctx, httpServer.Listeners{ln}, handler) + if err != nil { + t.Fatalf("RunHTTPServers failed: %v", err) + } + }) +} + diff --git a/http/server.go b/http/server.go index dc74008..b74e39b 100644 --- a/http/server.go +++ b/http/server.go @@ -12,13 +12,12 @@ import ( "context" "errors" "fmt" - "log/slog" + "io/fs" "net" - "net/http" + "os" + "strings" "sync" - "time" - "go.sudomsg.com/kit/logging" "golang.org/x/sync/errgroup" ) @@ -42,13 +41,71 @@ func (ls Listeners) CloseAll() error { return nil } +type NetType int + +const ( + NetTCP NetType = iota + NetTCP4 + NetTCP6 + NetUnix + NetUnixPacket +) + +func (n NetType) String() string { + switch n { + case NetTCP: + return "tcp" + case NetTCP4: + return "tcp4" + case NetTCP6: + return "tcp6" + case NetUnix: + return "unix" + case NetUnixPacket: + return "unixpacket" + default: + return fmt.Sprintf("NetType(%d)", n) + } +} + +func (n *NetType) Set(s string) error { + switch strings.ToLower(s) { + case "tcp": + *n = NetTCP + case "tcp4": + *n = NetTCP4 + case "tcp6": + *n = NetTCP6 + case "unix": + *n = NetUnix + case "unixpacket": + *n = NetUnixPacket + default: + return fmt.Errorf("invalid NetType %q", s) + } + return nil +} + +func (n *NetType) UnmarshalText(b []byte) error { + return n.Set(string(b)) +} + +func (n NetType) MarshalText() ([]byte, error) { + return n.AppendText(nil) +} + +func (n NetType) AppendText(dst []byte) ([]byte, error) { + return append(dst, n.String()...), nil +} + // ServerConfig defines a single network listener configuration. // // Network is the network type, e.g., "tcp". // Address is the socket address, e.g., ":8080". type ServerConfig struct { - Network string `toml:"network"` - Address string `toml:"address"` + Network NetType + Address string + Mode fs.FileMode } // OpenConfigListeners opens network listeners as specified by the provided ServerConfig slice. @@ -62,16 +119,12 @@ func OpenConfigListners(ctx context.Context, config []ServerConfig) (Listeners, lns := make(Listeners, 0, len(config)) var mu sync.Mutex g, ctx := errgroup.WithContext(ctx) - var lc net.ListenConfig + for _, cfg := range config { g.Go(func() error { - network := cfg.Network - if network == "" { - network = "tcp" - } - ln, err := lc.Listen(ctx, network, cfg.Address) + ln, err := listenConfig(ctx, cfg) if err != nil { - return fmt.Errorf("failed to listen on %s %s: %w", network, cfg.Address, err) + return err } mu.Lock() @@ -87,60 +140,32 @@ func OpenConfigListners(ctx context.Context, config []ServerConfig) (Listeners, return lns, nil } -func OpenListeners(ctx context.Context, config []ServerConfig) (Listeners, error) { - return OpenConfigListners(ctx, config) -} - -// RunHTTPServers runs HTTP servers concurrently on all provided listeners. -// -// The provided handler is used for all servers. -// Servers respond to context cancellation by performing a graceful shutdown with a timeout of 10 seconds. -// -// Logging is performed using the slog.Logger extracted from context. -// Each server logs startup, shutdown, and errors. -// -// This function blocks until all servers have stopped or an error occurs. -func RunHTTPServers(ctx context.Context, lns Listeners, handler http.Handler) error { - g, ctx := errgroup.WithContext(ctx) - - for _, ln := range lns { - g.Go(func() error { - logger, ctx := logging.With(ctx, "address", ln.Addr()) - - srv := &http.Server{ - Addr: ln.Addr().String(), - Handler: handler, - BaseContext: func(l net.Listener) context.Context { - return ctx - }, - ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError), - } - - logger.Log(ctx, slog.LevelInfo, "HTTP server serving") - - if err := httpServeContext(ctx, srv, ln); err != nil && !errors.Is(err, http.ErrServerClosed) { - return fmt.Errorf("HTTP server Serve Error: %w", err) - } - return nil - }) +func listenConfig(ctx context.Context, cfg ServerConfig) (net.Listener, error) { + network := cfg.Network + var lc net.ListenConfig + ln, err := lc.Listen(ctx, network.String(), cfg.Address) + if err != nil { + return nil, fmt.Errorf("failed to listen on %s %s: %w", network, cfg.Address, err) + } + if _, ok := ln.(*net.UnixListener); cfg.Mode != 0 && ok { + if err := os.Chmod(cfg.Address, cfg.Mode); err != nil { + ln.Close() + return nil, fmt.Errorf("chmod failed on %s: %w", cfg.Address, err) + } } - return g.Wait() + return ln, nil } -func httpServeContext(ctx context.Context, srv *http.Server, ln net.Listener) error { - logger := logging.FromContext(ctx) - go func() { - <-ctx.Done() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() +func OpenListeners(ctx context.Context, config []ServerConfig) (Listeners, error) { + lns, err := getSystemdListeners() + if err != nil { + return nil, fmt.Errorf("systemd socket activation failed: %w", err) + } + if lns != nil { + // Systemd is in charge; we don't honor ServerConfig here + return lns, nil + } - err := srv.Shutdown(ctx) - if err != nil { - logger.Log(ctx, slog.LevelWarn, "HTTP server Shutdown Error", slog.Any("error", err)) - } else { - logger.Log(ctx, slog.LevelInfo, "HTTP Server Shutdown Complete") - } - }() - return srv.Serve(ln) + return OpenConfigListners(ctx, config) } diff --git a/http/server_test.go b/http/server_test.go index 7138c73..0f49f01 100644 --- a/http/server_test.go +++ b/http/server_test.go @@ -1,15 +1,11 @@ package http_test import ( - "context" - "io" "net" - "net/http" "strings" "testing" - "time" - httpServer "go.sudomsg.com/kit/http" + "go.sudomsg.com/kit/http" ) // helper to find free TCP ports @@ -29,12 +25,12 @@ func TestOpenConfigListeners(t *testing.T) { t.Parallel() ctx := t.Context() - cfg := []httpServer.ServerConfig{ - {Network: "tcp", Address: "localhost:0"}, - {Network: "tcp", Address: "localhost:0"}, + cfg := []http.ServerConfig{ + {Network: http.NetTCP, Address: "localhost:0"}, + {Network: http.NetTCP, Address: "localhost:0"}, } - lns, err := httpServer.OpenConfigListners(ctx, cfg) + lns, err := http.OpenConfigListners(ctx, cfg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -52,12 +48,12 @@ func TestOpenConfigListeners(t *testing.T) { conflict := newListener(t) defer conflict.Close() - cfg := []httpServer.ServerConfig{ - {Network: "tcp", Address: "localhost:0"}, - {Network: "tcp", Address: conflict.Addr().String()}, // will fail + cfg := []http.ServerConfig{ + {Network: http.NetTCP, Address: "localhost:0"}, + {Network: http.NetTCP, Address: conflict.Addr().String()}, // will fail } - lns, err := httpServer.OpenConfigListners(ctx, cfg) + lns, err := http.OpenConfigListners(ctx, cfg) if err == nil { defer lns.CloseAll() t.Fatal("expected error due to conflict, got nil") @@ -71,7 +67,7 @@ func TestCloseAll(t *testing.T) { ln1 := newListener(t) ln2 := newListener(t) - ls := httpServer.Listeners{ln1, ln2} + ls := http.Listeners{ln1, ln2} err := ls.CloseAll() if err != nil { t.Errorf("unexpected error from CloseAll: %v", err) @@ -85,62 +81,20 @@ func TestCloseAll(t *testing.T) { }) } -func TestRunHTTPServers(t *testing.T) { - t.Parallel() - t.Run("basic serve and shutdown", func(t *testing.T) { - t.Parallel() - ctx := t.Context() - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, "hello") - }) - - ln := newListener(t) - addr := ln.Addr().String() - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - go func() { - // Wait for server to be ready - time.Sleep(200 * time.Millisecond) - - r, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://"+addr, nil) - resp, err := http.DefaultClient.Do(r) - if err != nil { - t.Errorf("http.Do error: %v", err) - return - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if got := strings.TrimSpace(string(body)); got != "hello" { - t.Errorf("unexpected response body: %q", got) - } - - cancel() // shutdown the server - }() - - err := httpServer.RunHTTPServers(ctx, httpServer.Listeners{ln}, handler) - if err != nil { - t.Fatalf("RunHTTPServers failed: %v", err) - } - }) -} - func TestInvalidConfig(t *testing.T) { t.Parallel() tests := []struct { name string - config httpServer.ServerConfig + config http.ServerConfig }{ { name: "invalid network", - config: httpServer.ServerConfig{Network: "invalid", Address: "localhost:0"}, + config: http.ServerConfig{Network: -1, Address: "localhost:0"}, }, { name: "invalid address", - config: httpServer.ServerConfig{Network: "tcp", Address: "::::"}, + config: http.ServerConfig{Network: http.NetTCP, Address: "::::"}, }, } @@ -150,10 +104,77 @@ func TestInvalidConfig(t *testing.T) { ctx := t.Context() - _, err := httpServer.OpenConfigListners(ctx, []httpServer.ServerConfig{tt.config}) + _, err := http.OpenConfigListners(ctx, []http.ServerConfig{tt.config}) if err == nil { t.Fatal("OpenConfigListners() expected error, got nil") } }) } } + +func TestNetType(t *testing.T) { + t.Parallel() + + t.Run("RoundTrip", func(t *testing.T) { + t.Parallel() + + testCases := []http.NetType{ + http.NetTCP, + http.NetTCP4, + http.NetTCP6, + http.NetUnix, + http.NetUnixPacket, + } + + for _, original := range testCases { + t.Run(original.String(), func(t *testing.T) { + t.Run("case same", func(t *testing.T) { + t.Run(original.String(), func(t *testing.T) { + text, err := original.MarshalText() + if err != nil { + t.Fatalf("MarshalText failed: %v", err) + } + + var decoded http.NetType + if err := decoded.UnmarshalText(text); err != nil { + t.Fatalf("UnmarshalText failed: %v", err) + } + + if decoded != original { + t.Fatalf("Round-trip mismatch: got %v, want %v", decoded, original) + } + }) + }) + t.Run("case not same", func(t *testing.T) { + t.Run(original.String(), func(t *testing.T) { + text, err := original.MarshalText() + if err != nil { + t.Fatalf("MarshalText failed: %v", err) + } + + var decoded http.NetType + if err := decoded.UnmarshalText([]byte(strings.ToUpper(string(text)))); err != nil { + t.Fatalf("UnmarshalText failed: %v", err) + } + + if decoded != original { + t.Fatalf("Round-trip mismatch: got %v, want %v", decoded, original) + } + }) + }) + + }) + } + }) + + t.Run("Invalid", func(t *testing.T) { + t.Parallel() + + input := "invalid" + var decoded http.NetType + + if err := decoded.UnmarshalText([]byte(input)); err == nil { + t.Errorf("expected error for input %q, got none", input) + } + }) +} diff --git a/http/systemd_linux.go b/http/systemd_linux.go new file mode 100644 index 0000000..817b421 --- /dev/null +++ b/http/systemd_linux.go @@ -0,0 +1,61 @@ +//go:build linux + +package http + +import ( + "fmt" + + "net" + "os" + "strconv" +) + +func getSystemdListeners() (Listeners, error) { + pidStr := os.Getenv("LISTEN_PID") + fdStr := os.Getenv("LISTEN_FDS") + + if pidStr == "" || fdStr == "" { + return nil, nil // Not running under systemd + } + + pid, err := strconv.Atoi(pidStr) + if err != nil { + // Not our activation — another process might have inherited the env. + return nil, nil + } + if pid != os.Getpid() { + return nil, fmt.Errorf("LISTEN_PID %d does not match current PID %d", pid, os.Getpid()) + } + + defer func() { + _ = os.Unsetenv("LISTEN_PID") + _ = os.Unsetenv("LISTEN_FDS") + }() + + nfds, err := strconv.ParseUint(fdStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid LISTEN_FDS: %w", err) + } + if nfds == 0 { + return nil, nil // Nothing to do + } + + lns := make(Listeners, 0, nfds) + for i := range nfds { + fd := i + 3 + file := os.NewFile(uintptr(fd), fmt.Sprintf("fd_%d", fd)) + if file == nil { + return nil, fmt.Errorf("fd %d was nil", fd) + } + ln, err := net.FileListener(file) + if err != nil { + file.Close() + return nil, fmt.Errorf("failed to create listener from fd %d: %w", fd, err) + } + lns = append(lns, ln) + } + + return lns, nil +} + + diff --git a/http/systemd_stub.go b/http/systemd_stub.go new file mode 100644 index 0000000..00699f3 --- /dev/null +++ b/http/systemd_stub.go @@ -0,0 +1,7 @@ +//go:build !linux + +package http + +func getSystemdListeners() (Listeners, error) { + return nil, nil // No-op on non-Linux +} |