From 37809b8c855250b931ec592f12fd548ddfa1dabe Mon Sep 17 00:00:00 2001 From: Marc Pervaz Boocha Date: Tue, 24 Feb 2026 23:43:18 +0530 Subject: Split net packages --- go.mod | 5 +- go.sum | 2 - http/http.go | 3 +- http/http_test.go | 18 ++++- http/server.go | 171 ----------------------------------------------- http/server_test.go | 180 -------------------------------------------------- http/systemd_linux.go | 16 ----- http/systemd_stub.go | 7 -- net/server.go | 175 ++++++++++++++++++++++++++++++++++++++++++++++++ net/server_test.go | 180 ++++++++++++++++++++++++++++++++++++++++++++++++++ net/systemd_linux.go | 58 ++++++++++++++++ net/systemd_stub.go | 7 ++ 12 files changed, 439 insertions(+), 383 deletions(-) delete mode 100644 http/server.go delete mode 100644 http/server_test.go delete mode 100644 http/systemd_linux.go delete mode 100644 http/systemd_stub.go create mode 100644 net/server.go create mode 100644 net/server_test.go create mode 100644 net/systemd_linux.go create mode 100644 net/systemd_stub.go diff --git a/go.mod b/go.mod index b775451..0cc7e2a 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,4 @@ module go.sudomsg.com/kit go 1.26.0 -require ( - github.com/coreos/go-systemd/v22 v22.7.0 - golang.org/x/sync v0.19.0 -) +require golang.org/x/sync v0.19.0 diff --git a/go.sum b/go.sum index 30ce3a9..159532a 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,2 @@ -github.com/coreos/go-systemd/v22 v22.7.0 h1:LAEzFkke61DFROc7zNLX/WA2i5J8gYqe0rSj9KI28KA= -github.com/coreos/go-systemd/v22 v22.7.0/go.mod h1:xNUYtjHu2EDXbsxz1i41wouACIwT7Ybq9o0BQhMwD0w= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= diff --git a/http/http.go b/http/http.go index 71075a7..42d8f4f 100644 --- a/http/http.go +++ b/http/http.go @@ -10,6 +10,7 @@ import ( "time" "go.sudomsg.com/kit/logging" + netkit "go.sudomsg.com/kit/net" "golang.org/x/sync/errgroup" ) @@ -22,7 +23,7 @@ import ( // Each server logs startup, shutdown, and errors. // // This function blocks until all servers have stopped or an error occurs. -func RunHTTPServers(ctx context.Context, lns Listeners, handler http.Handler) error { +func RunHTTPServers(ctx context.Context, lns netkit.Listeners, handler http.Handler) error { g, ctx := errgroup.WithContext(ctx) for _, ln := range lns { diff --git a/http/http_test.go b/http/http_test.go index d7eb3c4..dce9092 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -2,14 +2,28 @@ package http_test import ( "context" - httpServer "go.sudomsg.com/kit/http" "io" + "net" "net/http" "strings" "testing" "time" + + httpServer "go.sudomsg.com/kit/http" + netkit "go.sudomsg.com/kit/net" ) +func newListener(tb testing.TB) net.Listener { + tb.Helper() + + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + tb.Fatalf("failed to open listener: %v", err) + } + return ln +} + + func TestRunHTTPServers(t *testing.T) { t.Parallel() t.Run("basic serve and shutdown", func(t *testing.T) { @@ -45,7 +59,7 @@ func TestRunHTTPServers(t *testing.T) { cancel() // shutdown the server }() - err := httpServer.RunHTTPServers(ctx, httpServer.Listeners{ln}, handler) + err := httpServer.RunHTTPServers(ctx, netkit.Listeners{ln}, handler) if err != nil { t.Fatalf("RunHTTPServers failed: %v", err) } diff --git a/http/server.go b/http/server.go deleted file mode 100644 index b74e39b..0000000 --- a/http/server.go +++ /dev/null @@ -1,171 +0,0 @@ -// Package http provides utilities for managing HTTP servers and listeners. -// -// It supports opening multiple network listeners from configuration, -// running multiple HTTP servers concurrently with graceful shutdown, -// and integrates structured logging with context. -// -// This package is designed to be used with context cancellation for concurrency control. - -package http - -import ( - "context" - "errors" - "fmt" - "io/fs" - "net" - "os" - "strings" - "sync" - - "golang.org/x/sync/errgroup" -) - -// Listeners is a slice of net.Listener interfaces representing multiple network listeners. -type Listeners []net.Listener - -// CloseAll closes all listeners in the slice. -// It aggregates all errors returned by individual Close calls using errors.Join. -// -// If no errors occur, it returns nil. -func (ls Listeners) CloseAll() error { - var errs []error - for _, l := range ls { - if err := l.Close(); err != nil { - errs = append(errs, err) - } - } - if len(errs) > 0 { - return errors.Join(errs...) - } - return nil -} - -type NetType int - -const ( - NetTCP NetType = iota - NetTCP4 - NetTCP6 - NetUnix - NetUnixPacket -) - -func (n NetType) String() string { - switch n { - case NetTCP: - return "tcp" - case NetTCP4: - return "tcp4" - case NetTCP6: - return "tcp6" - case NetUnix: - return "unix" - case NetUnixPacket: - return "unixpacket" - default: - return fmt.Sprintf("NetType(%d)", n) - } -} - -func (n *NetType) Set(s string) error { - switch strings.ToLower(s) { - case "tcp": - *n = NetTCP - case "tcp4": - *n = NetTCP4 - case "tcp6": - *n = NetTCP6 - case "unix": - *n = NetUnix - case "unixpacket": - *n = NetUnixPacket - default: - return fmt.Errorf("invalid NetType %q", s) - } - return nil -} - -func (n *NetType) UnmarshalText(b []byte) error { - return n.Set(string(b)) -} - -func (n NetType) MarshalText() ([]byte, error) { - return n.AppendText(nil) -} - -func (n NetType) AppendText(dst []byte) ([]byte, error) { - return append(dst, n.String()...), nil -} - -// ServerConfig defines a single network listener configuration. -// -// Network is the network type, e.g., "tcp". -// Address is the socket address, e.g., ":8080". -type ServerConfig struct { - Network NetType - Address string - Mode fs.FileMode -} - -// OpenConfigListeners opens network listeners as specified by the provided ServerConfig slice. -// -// It attempts to open all listeners concurrently and returns them if all succeed. -// -// If any listener fails to open, it closes all previously opened listeners and returns an error. -// -// The context controls cancellation of the opening process. -func OpenConfigListners(ctx context.Context, config []ServerConfig) (Listeners, error) { - lns := make(Listeners, 0, len(config)) - var mu sync.Mutex - g, ctx := errgroup.WithContext(ctx) - - for _, cfg := range config { - g.Go(func() error { - ln, err := listenConfig(ctx, cfg) - if err != nil { - return err - } - - mu.Lock() - lns = append(lns, ln) - mu.Unlock() - return nil - }) - } - if err := g.Wait(); err != nil { - _ = lns.CloseAll() - return nil, err - } - return lns, nil -} - -func listenConfig(ctx context.Context, cfg ServerConfig) (net.Listener, error) { - network := cfg.Network - - var lc net.ListenConfig - ln, err := lc.Listen(ctx, network.String(), cfg.Address) - if err != nil { - return nil, fmt.Errorf("failed to listen on %s %s: %w", network, cfg.Address, err) - } - if _, ok := ln.(*net.UnixListener); cfg.Mode != 0 && ok { - if err := os.Chmod(cfg.Address, cfg.Mode); err != nil { - ln.Close() - return nil, fmt.Errorf("chmod failed on %s: %w", cfg.Address, err) - } - } - return ln, nil -} - -func OpenListeners(ctx context.Context, config []ServerConfig) (Listeners, error) { - lns, err := getSystemdListeners() - if err != nil { - return nil, fmt.Errorf("systemd socket activation failed: %w", err) - } - if lns != nil { - // Systemd is in charge; we don't honor ServerConfig here - return lns, nil - } - - return OpenConfigListners(ctx, config) -} diff --git a/http/server_test.go b/http/server_test.go deleted file mode 100644 index 0f49f01..0000000 --- a/http/server_test.go +++ /dev/null @@ -1,180 +0,0 @@ -package http_test - -import ( - "net" - "strings" - "testing" - - "go.sudomsg.com/kit/http" -) - -// helper to find free TCP ports -func newListener(tb testing.TB) net.Listener { - tb.Helper() - - ln, err := net.Listen("tcp", "localhost:0") - if err != nil { - tb.Fatalf("failed to open listener: %v", err) - } - return ln -} - -func TestOpenConfigListeners(t *testing.T) { - t.Parallel() - t.Run("successful config", func(t *testing.T) { - t.Parallel() - ctx := t.Context() - - cfg := []http.ServerConfig{ - {Network: http.NetTCP, Address: "localhost:0"}, - {Network: http.NetTCP, Address: "localhost:0"}, - } - - lns, err := http.OpenConfigListners(ctx, cfg) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - defer lns.CloseAll() - - if got := len(lns); got != len(cfg) { - t.Errorf("expected %d listeners, got %d", len(cfg), got) - } - }) - - t.Run("port conflict triggers cleanup", func(t *testing.T) { - t.Parallel() - ctx := t.Context() - - conflict := newListener(t) - defer conflict.Close() - - cfg := []http.ServerConfig{ - {Network: http.NetTCP, Address: "localhost:0"}, - {Network: http.NetTCP, Address: conflict.Addr().String()}, // will fail - } - - lns, err := http.OpenConfigListners(ctx, cfg) - if err == nil { - defer lns.CloseAll() - t.Fatal("expected error due to conflict, got nil") - } - }) -} - -func TestCloseAll(t *testing.T) { - t.Run("closes all listeners", func(t *testing.T) { - t.Parallel() - ln1 := newListener(t) - ln2 := newListener(t) - - ls := http.Listeners{ln1, ln2} - err := ls.CloseAll() - if err != nil { - t.Errorf("unexpected error from CloseAll: %v", err) - } - - for _, ln := range ls { - if _, err := ln.Accept(); err == nil { - t.Error("expected listener to be closed, but Accept succeeded") - } - } - }) -} - -func TestInvalidConfig(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - config http.ServerConfig - }{ - { - name: "invalid network", - config: http.ServerConfig{Network: -1, Address: "localhost:0"}, - }, - { - name: "invalid address", - config: http.ServerConfig{Network: http.NetTCP, Address: "::::"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctx := t.Context() - - _, err := http.OpenConfigListners(ctx, []http.ServerConfig{tt.config}) - if err == nil { - t.Fatal("OpenConfigListners() expected error, got nil") - } - }) - } -} - -func TestNetType(t *testing.T) { - t.Parallel() - - t.Run("RoundTrip", func(t *testing.T) { - t.Parallel() - - testCases := []http.NetType{ - http.NetTCP, - http.NetTCP4, - http.NetTCP6, - http.NetUnix, - http.NetUnixPacket, - } - - for _, original := range testCases { - t.Run(original.String(), func(t *testing.T) { - t.Run("case same", func(t *testing.T) { - t.Run(original.String(), func(t *testing.T) { - text, err := original.MarshalText() - if err != nil { - t.Fatalf("MarshalText failed: %v", err) - } - - var decoded http.NetType - if err := decoded.UnmarshalText(text); err != nil { - t.Fatalf("UnmarshalText failed: %v", err) - } - - if decoded != original { - t.Fatalf("Round-trip mismatch: got %v, want %v", decoded, original) - } - }) - }) - t.Run("case not same", func(t *testing.T) { - t.Run(original.String(), func(t *testing.T) { - text, err := original.MarshalText() - if err != nil { - t.Fatalf("MarshalText failed: %v", err) - } - - var decoded http.NetType - if err := decoded.UnmarshalText([]byte(strings.ToUpper(string(text)))); err != nil { - t.Fatalf("UnmarshalText failed: %v", err) - } - - if decoded != original { - t.Fatalf("Round-trip mismatch: got %v, want %v", decoded, original) - } - }) - }) - - }) - } - }) - - t.Run("Invalid", func(t *testing.T) { - t.Parallel() - - input := "invalid" - var decoded http.NetType - - if err := decoded.UnmarshalText([]byte(input)); err == nil { - t.Errorf("expected error for input %q, got none", input) - } - }) -} diff --git a/http/systemd_linux.go b/http/systemd_linux.go deleted file mode 100644 index bf1e6a1..0000000 --- a/http/systemd_linux.go +++ /dev/null @@ -1,16 +0,0 @@ -//go:build linux - -package http - -import ( - "github.com/coreos/go-systemd/v22/activation" -) - -func getSystemdListeners() (Listeners, error) { - lns, err := activation.Listeners() - if err != nil { - return nil, err - } - - return lns, nil -} diff --git a/http/systemd_stub.go b/http/systemd_stub.go deleted file mode 100644 index 00699f3..0000000 --- a/http/systemd_stub.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !linux - -package http - -func getSystemdListeners() (Listeners, error) { - return nil, nil // No-op on non-Linux -} diff --git a/net/server.go b/net/server.go new file mode 100644 index 0000000..22448b3 --- /dev/null +++ b/net/server.go @@ -0,0 +1,175 @@ +// Package http provides utilities for managing HTTP servers and listeners. +// +// It supports opening multiple network listeners from configuration, +// running multiple HTTP servers concurrently with graceful shutdown, +// and integrates structured logging with context. +// +// This package is designed to be used with context cancellation for concurrency control. + +package net + +import ( + "context" + "errors" + "fmt" + "io/fs" + "net" + "os" + "strings" + "sync" + + "golang.org/x/sync/errgroup" +) + +// Listeners is a slice of net.Listener interfaces representing multiple network listeners. +type Listeners []net.Listener + +// CloseAll closes all listeners in the slice. +// It aggregates all errors returned by individual Close calls using errors.Join. +// +// If no errors occur, it returns nil. +func (ls Listeners) CloseAll() error { + var errs []error + for _, l := range ls { + if err := l.Close(); err != nil { + errs = append(errs, err) + } + } + if len(errs) > 0 { + return errors.Join(errs...) + } + return nil +} + +type NetType int + +const ( + NetTCP NetType = iota + NetTCP4 + NetTCP6 + NetUDP + NetUDP4 + NetUDP6 + NetUnix + NetUnixDatagram + NetUnixPacket +) + +func (n NetType) String() string { + switch n { + case NetTCP: + return "tcp" + case NetTCP4: + return "tcp4" + case NetTCP6: + return "tcp6" + case NetUnix: + return "unix" + case NetUnixPacket: + return "unixpacket" + default: + return fmt.Sprintf("NetType(%d)", n) + } +} + +func (n *NetType) Set(s string) error { + switch strings.ToLower(s) { + case "tcp": + *n = NetTCP + case "tcp4": + *n = NetTCP4 + case "tcp6": + *n = NetTCP6 + case "unix": + *n = NetUnix + case "unixpacket": + *n = NetUnixPacket + default: + return fmt.Errorf("invalid NetType %q", s) + } + return nil +} + +func (n *NetType) UnmarshalText(b []byte) error { + return n.Set(string(b)) +} + +func (n NetType) MarshalText() ([]byte, error) { + return n.AppendText(nil) +} + +func (n NetType) AppendText(dst []byte) ([]byte, error) { + return append(dst, n.String()...), nil +} + +// ServerConfig defines a single network listener configuration. +// +// Network is the network type, e.g., "tcp". +// Address is the socket address, e.g., ":8080". +type ServerConfig struct { + Network NetType + Address string + Mode fs.FileMode +} + +// OpenConfigListeners opens network listeners as specified by the provided ServerConfig slice. +// +// It attempts to open all listeners concurrently and returns them if all succeed. +// +// If any listener fails to open, it closes all previously opened listeners and returns an error. +// +// The context controls cancellation of the opening process. +func OpenConfigListners(ctx context.Context, config []ServerConfig) (Listeners, error) { + lns := make(Listeners, 0, len(config)) + var mu sync.Mutex + g, ctx := errgroup.WithContext(ctx) + + for _, cfg := range config { + g.Go(func() error { + ln, err := listenConfig(ctx, cfg) + if err != nil { + return err + } + + mu.Lock() + lns = append(lns, ln) + mu.Unlock() + return nil + }) + } + if err := g.Wait(); err != nil { + _ = lns.CloseAll() + return nil, err + } + return lns, nil +} + +func listenConfig(ctx context.Context, cfg ServerConfig) (net.Listener, error) { + network := cfg.Network + + var lc net.ListenConfig + ln, err := lc.Listen(ctx, network.String(), cfg.Address) + if err != nil { + return nil, fmt.Errorf("failed to listen on %s %s: %w", network, cfg.Address, err) + } + if _, ok := ln.(*net.UnixListener); cfg.Mode != 0 && ok { + if err := os.Chmod(cfg.Address, cfg.Mode); err != nil { + ln.Close() + return nil, fmt.Errorf("chmod failed on %s: %w", cfg.Address, err) + } + } + return ln, nil +} + +func OpenListeners(ctx context.Context, config []ServerConfig) (Listeners, error) { + lns, err := getSystemdListeners() + if err != nil { + return nil, fmt.Errorf("systemd socket activation failed: %w", err) + } + if lns != nil { + // Systemd is in charge; we don't honor ServerConfig here + return lns, nil + } + + return OpenConfigListners(ctx, config) +} diff --git a/net/server_test.go b/net/server_test.go new file mode 100644 index 0000000..ca4836c --- /dev/null +++ b/net/server_test.go @@ -0,0 +1,180 @@ +package net_test + +import ( + "net" + "strings" + "testing" + + netkit "go.sudomsg.com/kit/net" +) + +// helper to find free TCP ports +func newListener(tb testing.TB) net.Listener { + tb.Helper() + + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + tb.Fatalf("failed to open listener: %v", err) + } + return ln +} + +func TestOpenConfigListeners(t *testing.T) { + t.Parallel() + t.Run("successful config", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + + cfg := []netkit.ServerConfig{ + {Network: netkit.NetTCP, Address: "localhost:0"}, + {Network: netkit.NetTCP, Address: "localhost:0"}, + } + + lns, err := netkit.OpenConfigListners(ctx, cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer lns.CloseAll() + + if got := len(lns); got != len(cfg) { + t.Errorf("expected %d listeners, got %d", len(cfg), got) + } + }) + + t.Run("port conflict triggers cleanup", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + + conflict := newListener(t) + defer conflict.Close() + + cfg := []netkit.ServerConfig{ + {Network: netkit.NetTCP, Address: "localhost:0"}, + {Network: netkit.NetTCP, Address: conflict.Addr().String()}, // will fail + } + + lns, err := netkit.OpenConfigListners(ctx, cfg) + if err == nil { + defer lns.CloseAll() + t.Fatal("expected error due to conflict, got nil") + } + }) +} + +func TestCloseAll(t *testing.T) { + t.Run("closes all listeners", func(t *testing.T) { + t.Parallel() + ln1 := newListener(t) + ln2 := newListener(t) + + ls := netkit.Listeners{ln1, ln2} + err := ls.CloseAll() + if err != nil { + t.Errorf("unexpected error from CloseAll: %v", err) + } + + for _, ln := range ls { + if _, err := ln.Accept(); err == nil { + t.Error("expected listener to be closed, but Accept succeeded") + } + } + }) +} + +func TestInvalidConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config netkit.ServerConfig + }{ + { + name: "invalid network", + config: netkit.ServerConfig{Network: -1, Address: "localhost:0"}, + }, + { + name: "invalid address", + config: netkit.ServerConfig{Network: netkit.NetTCP, Address: "::::"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + _, err := netkit.OpenConfigListners(ctx, []netkit.ServerConfig{tt.config}) + if err == nil { + t.Fatal("OpenConfigListners() expected error, got nil") + } + }) + } +} + +func TestNetType(t *testing.T) { + t.Parallel() + + t.Run("RoundTrip", func(t *testing.T) { + t.Parallel() + + testCases := []netkit.NetType{ + netkit.NetTCP, + netkit.NetTCP4, + netkit.NetTCP6, + netkit.NetUnix, + netkit.NetUnixPacket, + } + + for _, original := range testCases { + t.Run(original.String(), func(t *testing.T) { + t.Run("case same", func(t *testing.T) { + t.Run(original.String(), func(t *testing.T) { + text, err := original.MarshalText() + if err != nil { + t.Fatalf("MarshalText failed: %v", err) + } + + var decoded netkit.NetType + if err := decoded.UnmarshalText(text); err != nil { + t.Fatalf("UnmarshalText failed: %v", err) + } + + if decoded != original { + t.Fatalf("Round-trip mismatch: got %v, want %v", decoded, original) + } + }) + }) + t.Run("case not same", func(t *testing.T) { + t.Run(original.String(), func(t *testing.T) { + text, err := original.MarshalText() + if err != nil { + t.Fatalf("MarshalText failed: %v", err) + } + + var decoded netkit.NetType + if err := decoded.UnmarshalText([]byte(strings.ToUpper(string(text)))); err != nil { + t.Fatalf("UnmarshalText failed: %v", err) + } + + if decoded != original { + t.Fatalf("Round-trip mismatch: got %v, want %v", decoded, original) + } + }) + }) + + }) + } + }) + + t.Run("Invalid", func(t *testing.T) { + t.Parallel() + + input := "invalid" + var decoded netkit.NetType + + if err := decoded.UnmarshalText([]byte(input)); err == nil { + t.Errorf("expected error for input %q, got none", input) + } + }) +} diff --git a/net/systemd_linux.go b/net/systemd_linux.go new file mode 100644 index 0000000..efccd27 --- /dev/null +++ b/net/systemd_linux.go @@ -0,0 +1,58 @@ +//go:build linux + +package net + +import ( + "fmt" + "net" + "os" + "strconv" +) + +func getSystemdListeners() (Listeners, error) { + pidStr := os.Getenv("LISTEN_PID") + fdStr := os.Getenv("LISTEN_FDS") + + if pidStr == "" || fdStr == "" { + return nil, nil // Not running under systemd + } + + pid, err := strconv.Atoi(pidStr) + if err != nil { + // Not our activation — another process might have inherited the env. + return nil, nil + } + if pid != os.Getpid() { + return nil, fmt.Errorf("LISTEN_PID %d does not match current PID %d", pid, os.Getpid()) + } + + defer func() { + _ = os.Unsetenv("LISTEN_PID") + _ = os.Unsetenv("LISTEN_FDS") + }() + + nfds, err := strconv.ParseUint(fdStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid LISTEN_FDS: %w", err) + } + if nfds == 0 { + return nil, nil // Nothing to do + } + + lns := make(Listeners, 0, nfds) + for i := range nfds { + fd := i + 3 + file := os.NewFile(uintptr(fd), fmt.Sprintf("fd_%d", fd)) + if file == nil { + return nil, fmt.Errorf("fd %d was nil", fd) + } + ln, err := net.FileListener(file) + if err != nil { + file.Close() + return nil, fmt.Errorf("failed to create listener from fd %d: %w", fd, err) + } + lns = append(lns, ln) + } + + return lns, nil +} diff --git a/net/systemd_stub.go b/net/systemd_stub.go new file mode 100644 index 0000000..502cf4b --- /dev/null +++ b/net/systemd_stub.go @@ -0,0 +1,7 @@ +//go:build !linux + +package net + +func getSystemdListeners() (Listeners, error) { + return nil, nil // No-op on non-Linux +} -- cgit v1.3