package net_test import ( "net" "strings" "testing" netkit "go.sudomsg.com/kit/net" ) // helper to find free TCP ports func newListener(tb testing.TB) net.Listener { tb.Helper() ln, err := net.Listen("tcp", "localhost:0") if err != nil { tb.Fatalf("failed to open listener: %v", err) } return ln } func TestOpenConfigListeners(t *testing.T) { t.Parallel() t.Run("successful config", func(t *testing.T) { t.Parallel() ctx := t.Context() cfg := []netkit.ServerConfig{ {Network: netkit.NetTCP, Address: "localhost:0"}, {Network: netkit.NetTCP, Address: "localhost:0"}, } lns, err := netkit.OpenConfigListeners(ctx, cfg) if err != nil { t.Fatalf("unexpected error: %v", err) } defer lns.CloseAll() if got := len(lns); got != len(cfg) { t.Errorf("expected %d listeners, got %d", len(cfg), got) } }) t.Run("port conflict triggers cleanup", func(t *testing.T) { t.Parallel() ctx := t.Context() conflict := newListener(t) defer conflict.Close() cfg := []netkit.ServerConfig{ {Network: netkit.NetTCP, Address: "localhost:0"}, {Network: netkit.NetTCP, Address: conflict.Addr().String()}, // will fail } lns, err := netkit.OpenConfigListeners(ctx, cfg) if err == nil { defer lns.CloseAll() t.Fatal("expected error due to conflict, got nil") } }) } func TestCloseAll(t *testing.T) { t.Run("closes all listeners", func(t *testing.T) { t.Parallel() ln1 := newListener(t) ln2 := newListener(t) ls := netkit.Listeners{ln1, ln2} err := ls.CloseAll() if err != nil { t.Errorf("unexpected error from CloseAll: %v", err) } for _, ln := range ls { if _, err := ln.Accept(); err == nil { t.Error("expected listener to be closed, but Accept succeeded") } } }) } func TestInvalidConfig(t *testing.T) { t.Parallel() tests := []struct { name string config netkit.ServerConfig }{ { name: "invalid network", config: netkit.ServerConfig{Network: -1, Address: "localhost:0"}, }, { name: "invalid address", config: netkit.ServerConfig{Network: netkit.NetTCP, Address: "::::"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() ctx := t.Context() _, err := netkit.OpenConfigListeners(ctx, []netkit.ServerConfig{tt.config}) if err == nil { t.Fatal("OpenConfigListners() expected error, got nil") } }) } } func TestNetType(t *testing.T) { t.Parallel() t.Run("RoundTrip", func(t *testing.T) { t.Parallel() testCases := []netkit.NetType{ netkit.NetTCP, netkit.NetTCP4, netkit.NetTCP6, netkit.NetUnix, netkit.NetUnixPacket, } for _, original := range testCases { t.Run(original.String(), func(t *testing.T) { t.Run("case same", func(t *testing.T) { t.Run(original.String(), func(t *testing.T) { text, err := original.MarshalText() if err != nil { t.Fatalf("MarshalText failed: %v", err) } var decoded netkit.NetType if err := decoded.UnmarshalText(text); err != nil { t.Fatalf("UnmarshalText failed: %v", err) } if decoded != original { t.Fatalf("Round-trip mismatch: got %v, want %v", decoded, original) } }) }) t.Run("case not same", func(t *testing.T) { t.Run(original.String(), func(t *testing.T) { text, err := original.MarshalText() if err != nil { t.Fatalf("MarshalText failed: %v", err) } var decoded netkit.NetType if err := decoded.UnmarshalText([]byte(strings.ToUpper(string(text)))); err != nil { t.Fatalf("UnmarshalText failed: %v", err) } if decoded != original { t.Fatalf("Round-trip mismatch: got %v, want %v", decoded, original) } }) }) }) } }) t.Run("Invalid", func(t *testing.T) { t.Parallel() input := "invalid" var decoded netkit.NetType if err := decoded.UnmarshalText([]byte(input)); err == nil { t.Errorf("expected error for input %q, got none", input) } }) }