aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
authorMarc Pervaz Boocha <mboocha@sudomsg.com>2026-02-24 23:43:18 +0530
committerMarc Pervaz Boocha <mboocha@sudomsg.com>2026-02-24 23:43:18 +0530
commit37809b8c855250b931ec592f12fd548ddfa1dabe (patch)
treee9fbb747cc8a29fc031dba3de77919c94edebc4a /net
parentAdded more logging stuff (diff)
downloadkit-37809b8c855250b931ec592f12fd548ddfa1dabe.tar
kit-37809b8c855250b931ec592f12fd548ddfa1dabe.tar.gz
kit-37809b8c855250b931ec592f12fd548ddfa1dabe.tar.bz2
kit-37809b8c855250b931ec592f12fd548ddfa1dabe.tar.lz
kit-37809b8c855250b931ec592f12fd548ddfa1dabe.tar.xz
kit-37809b8c855250b931ec592f12fd548ddfa1dabe.tar.zst
kit-37809b8c855250b931ec592f12fd548ddfa1dabe.zip
Split net packages
Diffstat (limited to 'net')
-rw-r--r--net/server.go175
-rw-r--r--net/server_test.go180
-rw-r--r--net/systemd_linux.go58
-rw-r--r--net/systemd_stub.go7
4 files changed, 420 insertions, 0 deletions
diff --git a/net/server.go b/net/server.go
new file mode 100644
index 0000000..22448b3
--- /dev/null
+++ b/net/server.go
@@ -0,0 +1,175 @@
+// Package http provides utilities for managing HTTP servers and listeners.
+//
+// It supports opening multiple network listeners from configuration,
+// running multiple HTTP servers concurrently with graceful shutdown,
+// and integrates structured logging with context.
+//
+// This package is designed to be used with context cancellation for concurrency control.
+
+package net
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io/fs"
+ "net"
+ "os"
+ "strings"
+ "sync"
+
+ "golang.org/x/sync/errgroup"
+)
+
+// Listeners is a slice of net.Listener interfaces representing multiple network listeners.
+type Listeners []net.Listener
+
+// CloseAll closes all listeners in the slice.
+// It aggregates all errors returned by individual Close calls using errors.Join.
+//
+// If no errors occur, it returns nil.
+func (ls Listeners) CloseAll() error {
+ var errs []error
+ for _, l := range ls {
+ if err := l.Close(); err != nil {
+ errs = append(errs, err)
+ }
+ }
+ if len(errs) > 0 {
+ return errors.Join(errs...)
+ }
+ return nil
+}
+
+type NetType int
+
+const (
+ NetTCP NetType = iota
+ NetTCP4
+ NetTCP6
+ NetUDP
+ NetUDP4
+ NetUDP6
+ NetUnix
+ NetUnixDatagram
+ NetUnixPacket
+)
+
+func (n NetType) String() string {
+ switch n {
+ case NetTCP:
+ return "tcp"
+ case NetTCP4:
+ return "tcp4"
+ case NetTCP6:
+ return "tcp6"
+ case NetUnix:
+ return "unix"
+ case NetUnixPacket:
+ return "unixpacket"
+ default:
+ return fmt.Sprintf("NetType(%d)", n)
+ }
+}
+
+func (n *NetType) Set(s string) error {
+ switch strings.ToLower(s) {
+ case "tcp":
+ *n = NetTCP
+ case "tcp4":
+ *n = NetTCP4
+ case "tcp6":
+ *n = NetTCP6
+ case "unix":
+ *n = NetUnix
+ case "unixpacket":
+ *n = NetUnixPacket
+ default:
+ return fmt.Errorf("invalid NetType %q", s)
+ }
+ return nil
+}
+
+func (n *NetType) UnmarshalText(b []byte) error {
+ return n.Set(string(b))
+}
+
+func (n NetType) MarshalText() ([]byte, error) {
+ return n.AppendText(nil)
+}
+
+func (n NetType) AppendText(dst []byte) ([]byte, error) {
+ return append(dst, n.String()...), nil
+}
+
+// ServerConfig defines a single network listener configuration.
+//
+// Network is the network type, e.g., "tcp".
+// Address is the socket address, e.g., ":8080".
+type ServerConfig struct {
+ Network NetType
+ Address string
+ Mode fs.FileMode
+}
+
+// OpenConfigListeners opens network listeners as specified by the provided ServerConfig slice.
+//
+// It attempts to open all listeners concurrently and returns them if all succeed.
+//
+// If any listener fails to open, it closes all previously opened listeners and returns an error.
+//
+// The context controls cancellation of the opening process.
+func OpenConfigListners(ctx context.Context, config []ServerConfig) (Listeners, error) {
+ lns := make(Listeners, 0, len(config))
+ var mu sync.Mutex
+ g, ctx := errgroup.WithContext(ctx)
+
+ for _, cfg := range config {
+ g.Go(func() error {
+ ln, err := listenConfig(ctx, cfg)
+ if err != nil {
+ return err
+ }
+
+ mu.Lock()
+ lns = append(lns, ln)
+ mu.Unlock()
+ return nil
+ })
+ }
+ if err := g.Wait(); err != nil {
+ _ = lns.CloseAll()
+ return nil, err
+ }
+ return lns, nil
+}
+
+func listenConfig(ctx context.Context, cfg ServerConfig) (net.Listener, error) {
+ network := cfg.Network
+
+ var lc net.ListenConfig
+ ln, err := lc.Listen(ctx, network.String(), cfg.Address)
+ if err != nil {
+ return nil, fmt.Errorf("failed to listen on %s %s: %w", network, cfg.Address, err)
+ }
+ if _, ok := ln.(*net.UnixListener); cfg.Mode != 0 && ok {
+ if err := os.Chmod(cfg.Address, cfg.Mode); err != nil {
+ ln.Close()
+ return nil, fmt.Errorf("chmod failed on %s: %w", cfg.Address, err)
+ }
+ }
+ return ln, nil
+}
+
+func OpenListeners(ctx context.Context, config []ServerConfig) (Listeners, error) {
+ lns, err := getSystemdListeners()
+ if err != nil {
+ return nil, fmt.Errorf("systemd socket activation failed: %w", err)
+ }
+ if lns != nil {
+ // Systemd is in charge; we don't honor ServerConfig here
+ return lns, nil
+ }
+
+ return OpenConfigListners(ctx, config)
+}
diff --git a/net/server_test.go b/net/server_test.go
new file mode 100644
index 0000000..ca4836c
--- /dev/null
+++ b/net/server_test.go
@@ -0,0 +1,180 @@
+package net_test
+
+import (
+ "net"
+ "strings"
+ "testing"
+
+ netkit "go.sudomsg.com/kit/net"
+)
+
+// helper to find free TCP ports
+func newListener(tb testing.TB) net.Listener {
+ tb.Helper()
+
+ ln, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ tb.Fatalf("failed to open listener: %v", err)
+ }
+ return ln
+}
+
+func TestOpenConfigListeners(t *testing.T) {
+ t.Parallel()
+ t.Run("successful config", func(t *testing.T) {
+ t.Parallel()
+ ctx := t.Context()
+
+ cfg := []netkit.ServerConfig{
+ {Network: netkit.NetTCP, Address: "localhost:0"},
+ {Network: netkit.NetTCP, Address: "localhost:0"},
+ }
+
+ lns, err := netkit.OpenConfigListners(ctx, cfg)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ defer lns.CloseAll()
+
+ if got := len(lns); got != len(cfg) {
+ t.Errorf("expected %d listeners, got %d", len(cfg), got)
+ }
+ })
+
+ t.Run("port conflict triggers cleanup", func(t *testing.T) {
+ t.Parallel()
+ ctx := t.Context()
+
+ conflict := newListener(t)
+ defer conflict.Close()
+
+ cfg := []netkit.ServerConfig{
+ {Network: netkit.NetTCP, Address: "localhost:0"},
+ {Network: netkit.NetTCP, Address: conflict.Addr().String()}, // will fail
+ }
+
+ lns, err := netkit.OpenConfigListners(ctx, cfg)
+ if err == nil {
+ defer lns.CloseAll()
+ t.Fatal("expected error due to conflict, got nil")
+ }
+ })
+}
+
+func TestCloseAll(t *testing.T) {
+ t.Run("closes all listeners", func(t *testing.T) {
+ t.Parallel()
+ ln1 := newListener(t)
+ ln2 := newListener(t)
+
+ ls := netkit.Listeners{ln1, ln2}
+ err := ls.CloseAll()
+ if err != nil {
+ t.Errorf("unexpected error from CloseAll: %v", err)
+ }
+
+ for _, ln := range ls {
+ if _, err := ln.Accept(); err == nil {
+ t.Error("expected listener to be closed, but Accept succeeded")
+ }
+ }
+ })
+}
+
+func TestInvalidConfig(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ config netkit.ServerConfig
+ }{
+ {
+ name: "invalid network",
+ config: netkit.ServerConfig{Network: -1, Address: "localhost:0"},
+ },
+ {
+ name: "invalid address",
+ config: netkit.ServerConfig{Network: netkit.NetTCP, Address: "::::"},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ ctx := t.Context()
+
+ _, err := netkit.OpenConfigListners(ctx, []netkit.ServerConfig{tt.config})
+ if err == nil {
+ t.Fatal("OpenConfigListners() expected error, got nil")
+ }
+ })
+ }
+}
+
+func TestNetType(t *testing.T) {
+ t.Parallel()
+
+ t.Run("RoundTrip", func(t *testing.T) {
+ t.Parallel()
+
+ testCases := []netkit.NetType{
+ netkit.NetTCP,
+ netkit.NetTCP4,
+ netkit.NetTCP6,
+ netkit.NetUnix,
+ netkit.NetUnixPacket,
+ }
+
+ for _, original := range testCases {
+ t.Run(original.String(), func(t *testing.T) {
+ t.Run("case same", func(t *testing.T) {
+ t.Run(original.String(), func(t *testing.T) {
+ text, err := original.MarshalText()
+ if err != nil {
+ t.Fatalf("MarshalText failed: %v", err)
+ }
+
+ var decoded netkit.NetType
+ if err := decoded.UnmarshalText(text); err != nil {
+ t.Fatalf("UnmarshalText failed: %v", err)
+ }
+
+ if decoded != original {
+ t.Fatalf("Round-trip mismatch: got %v, want %v", decoded, original)
+ }
+ })
+ })
+ t.Run("case not same", func(t *testing.T) {
+ t.Run(original.String(), func(t *testing.T) {
+ text, err := original.MarshalText()
+ if err != nil {
+ t.Fatalf("MarshalText failed: %v", err)
+ }
+
+ var decoded netkit.NetType
+ if err := decoded.UnmarshalText([]byte(strings.ToUpper(string(text)))); err != nil {
+ t.Fatalf("UnmarshalText failed: %v", err)
+ }
+
+ if decoded != original {
+ t.Fatalf("Round-trip mismatch: got %v, want %v", decoded, original)
+ }
+ })
+ })
+
+ })
+ }
+ })
+
+ t.Run("Invalid", func(t *testing.T) {
+ t.Parallel()
+
+ input := "invalid"
+ var decoded netkit.NetType
+
+ if err := decoded.UnmarshalText([]byte(input)); err == nil {
+ t.Errorf("expected error for input %q, got none", input)
+ }
+ })
+}
diff --git a/net/systemd_linux.go b/net/systemd_linux.go
new file mode 100644
index 0000000..efccd27
--- /dev/null
+++ b/net/systemd_linux.go
@@ -0,0 +1,58 @@
+//go:build linux
+
+package net
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "strconv"
+)
+
+func getSystemdListeners() (Listeners, error) {
+ pidStr := os.Getenv("LISTEN_PID")
+ fdStr := os.Getenv("LISTEN_FDS")
+
+ if pidStr == "" || fdStr == "" {
+ return nil, nil // Not running under systemd
+ }
+
+ pid, err := strconv.Atoi(pidStr)
+ if err != nil {
+ // Not our activation — another process might have inherited the env.
+ return nil, nil
+ }
+ if pid != os.Getpid() {
+ return nil, fmt.Errorf("LISTEN_PID %d does not match current PID %d", pid, os.Getpid())
+ }
+
+ defer func() {
+ _ = os.Unsetenv("LISTEN_PID")
+ _ = os.Unsetenv("LISTEN_FDS")
+ }()
+
+ nfds, err := strconv.ParseUint(fdStr, 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("invalid LISTEN_FDS: %w", err)
+ }
+ if nfds == 0 {
+ return nil, nil // Nothing to do
+ }
+
+ lns := make(Listeners, 0, nfds)
+ for i := range nfds {
+ fd := i + 3
+ file := os.NewFile(uintptr(fd), fmt.Sprintf("fd_%d", fd))
+ if file == nil {
+ return nil, fmt.Errorf("fd %d was nil", fd)
+ }
+ ln, err := net.FileListener(file)
+ if err != nil {
+ file.Close()
+ return nil, fmt.Errorf("failed to create listener from fd %d: %w", fd, err)
+ }
+ lns = append(lns, ln)
+ }
+
+ return lns, nil
+}
diff --git a/net/systemd_stub.go b/net/systemd_stub.go
new file mode 100644
index 0000000..502cf4b
--- /dev/null
+++ b/net/systemd_stub.go
@@ -0,0 +1,7 @@
+//go:build !linux
+
+package net
+
+func getSystemdListeners() (Listeners, error) {
+ return nil, nil // No-op on non-Linux
+}