diff options
Diffstat (limited to '')
| -rw-r--r-- | runner/run_test.go | 163 |
1 files changed, 163 insertions, 0 deletions
diff --git a/runner/run_test.go b/runner/run_test.go new file mode 100644 index 0000000..d6c1500 --- /dev/null +++ b/runner/run_test.go @@ -0,0 +1,163 @@ +package runner_test + +import ( + "bytes" + "context" + "errors" + "flag" + "io" + "reflect" + "strings" + "testing" + + "go.sudomsg.com/kit/runner" +) + +func TestCommandSet_Run(t *testing.T) { + t.Run("Dispatch", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + var called bool + cs := &runner.CommandSet{ + FlagSet: flag.NewFlagSet("test", flag.ContinueOnError), + } + cs.SetOutput(io.Discard) + + cs.AddSubcommand("foo", "foo command", func(ctx context.Context, cs *runner.CommandSet, args []string) error { + called = true + if !reflect.DeepEqual(args, []string{"bar", "baz"}) { + t.Fatalf("callback got args %v, want [bar baz]", args) + } + if gotName := cs.Name(); gotName != "foo" { + t.Fatalf("sub FlagSet name = %q, want %q", gotName, "foo") + } + + return nil + }) + + err := cs.Run(ctx, []string{"foo", "bar", "baz"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Fatal("expected foo subcommand to be called") + } + }) + + t.Run("Recursive Dispatch", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + var called bool + cs := &runner.CommandSet{ + FlagSet: flag.NewFlagSet("test", flag.ContinueOnError), + } + cs.SetOutput(io.Discard) + + cs.AddSubcommand("foo", "foo command", func(ctx context.Context, cs *runner.CommandSet, args []string) error { + cs.AddSubcommand("bar", "bar command", func(ctx context.Context, cs *runner.CommandSet, args []string) error { + called = true + if !reflect.DeepEqual(args, []string{"baz"}) { + t.Fatalf("callback got args %v, want [bar baz]", args) + } + if gotName := cs.Name(); gotName != "bar" { + t.Fatalf("sub FlagSet name = %q, want %q", gotName, "foo") + } + return nil + }) + return cs.Run(ctx, args) + }) + + err := cs.Run(ctx, []string{"foo", "bar", "baz"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Fatal("expected foo subcommand to be called") + } + + }) + + t.Run("No Subcommand", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + var buf bytes.Buffer + + cs := &runner.CommandSet{ + FlagSet: flag.NewFlagSet("test", flag.ContinueOnError), + } + cs.SetOutput(&buf) + + err := cs.Run(ctx, []string{}) + if err == nil { + t.Fatal("expected error, got nil") + } + var ue runner.UsageError + if !errors.As(err, &ue) { + t.Fatalf("expected UsageError, got %T", err) + } + if !strings.Contains(buf.String(), "Missing Subcommand") { + t.Errorf("expected help output to mention missing subcommand, got %q", buf.String()) + } + }) + + t.Run("Invalid Subcommand", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + var buf bytes.Buffer + + cs := &runner.CommandSet{ + FlagSet: flag.NewFlagSet("test", flag.ContinueOnError), + } + cs.SetOutput(&buf) + + cs.AddSubcommand("foo", "bar", func(ctx context.Context, cs *runner.CommandSet, args []string) error { + t.Fatalf("foo command called") + return nil + }) + + err := cs.Run(ctx, []string{"nope"}) + if err == nil { + t.Fatal("expected error, got nil") + } + + var ue runner.UsageError + if !errors.As(err, &ue) { + t.Fatalf("expected UsageError, got %T", err) + } + if !strings.Contains(buf.String(), "Invalid Subcommand") { + t.Errorf("expected help output to mention invalid subcommand, got %q", buf.String()) + } + }) + + t.Run("Invalid Flags", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + var buf bytes.Buffer + + cs := &runner.CommandSet{ + FlagSet: flag.NewFlagSet("test", flag.ContinueOnError), + } + cs.SetOutput(&buf) + + cs.AddSubcommand("foo", "foo command", func(ctx context.Context, cs *runner.CommandSet, args []string) error { + return cs.Run(ctx, args) + }) + + err := cs.Run(ctx, []string{"foo", "-bar"}) + if err == nil { + t.Fatal("expected error, got nil") + } + + var ue runner.UsageError + if !errors.As(err, &ue) { + t.Fatalf("expected UsageError, got %T", err) + } + }) +} |
