From 30aec0042088cb5b89624f5dcb026bfc5848cf9c Mon Sep 17 00:00:00 2001 From: Marc Pervaz Boocha Date: Sat, 4 Oct 2025 20:24:23 +0530 Subject: Added subcommand support --- runner/run.go | 145 ++++++++++++++++++++++++++++++++++++++--------- runner/run_test.go | 163 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 282 insertions(+), 26 deletions(-) create mode 100644 runner/run_test.go (limited to 'runner') diff --git a/runner/run.go b/runner/run.go index 04f1907..893ee7d 100644 --- a/runner/run.go +++ b/runner/run.go @@ -2,7 +2,9 @@ package runner import ( "context" + "errors" "flag" + "fmt" "log/slog" "os" "os/signal" @@ -11,49 +13,140 @@ import ( "go.sudomsg.com/kit/logging" ) -// LoadWithArgs initializes context, signal handling, flag parsing, and runs the given application callback. -// -// - args: Command-line arguments (excluding/exact os.Args). -// - run: Callback function (your main app logic). Receives a context (with signal cancellation), -// a FlagSet (ready for use but not yet parsed), and trailing arguments. -// -// Handles panics and logs via logging.RecoverAndLog. -// Logs and exits on non-nil errors from the run callback. -func LoadWithArgs(args []string, run func(ctx context.Context, fs *flag.FlagSet, args []string) error) { +const ( + ExitSuccess = 0 + ExitFailure = 1 + ExitPanic = 2 + ExitUsage = 3 +) + +type UsageError struct { + Err error +} + +func (e UsageError) Error() string { + return e.Err.Error() +} + +func (e UsageError) Unwrap() error { + return e.Err +} + +func RunWithArgs(args []string, cmd Command) { ctx := context.Background() - ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) + ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) defer stop() defer func() { if err := recover(); err != nil { logging.RecoverAndLog(ctx, "Panicked", err) - os.Exit(1) + os.Exit(ExitPanic) } }() - name := "app" + if err := runCmd(ctx, cmd, args); err != nil { + var ue UsageError + if errors.As(err, &ue) { + os.Exit(ExitUsage) + } + slog.Log(ctx, slog.LevelError, "Program Terminated", "error", err) + os.Exit(ExitFailure) + } +} + +func Run(cmd Command) { + RunWithArgs(os.Args, cmd) +} + +type Command struct { + Name string + Description string + Run func(ctx context.Context, cs *CommandSet, args []string) error +} + +func runCmd(ctx context.Context, cmd Command, args []string) error { + name := "command" if len(args) > 0 && args[0] != "" { name = args[0] args = args[1:] } + if cmd.Name != "" { + name = cmd.Name + } + + fs := flag.NewFlagSet(name, flag.ContinueOnError) + + c := &CommandSet{ + FlagSet: fs, + description: cmd.Description, + } + + c.Usage = func() { + fmt.Fprintf(fs.Output(), "Usage of %s:\n", fs.Name()) + c.PrintDefaults() + } + + c.AddSubcommand("help", "Print out the help", func(ctx context.Context, cs *CommandSet, args []string) error { + c.Usage() + return nil + }) + + return cmd.Run(ctx, c, args) +} + +type CommandSet struct { + *flag.FlagSet + cmds []Command + description string +} - fs := flag.NewFlagSet(name, flag.ExitOnError) +var ErrInvalidSubcommand = errors.New("invalid subcommand") - if err := run(ctx, fs, args); err != nil { - slog.ErrorContext(ctx, "Program Terminated", "error", err) - os.Exit(1) +func (c *CommandSet) Run(ctx context.Context, args []string) error { + if err := c.Parse(args); err != nil { + return err + } + + args = c.Args() + if len(args) == 0 { + fmt.Fprintf(c.Output(), "Missing Subcommand") + c.Usage() + return UsageError{Err: ErrInvalidSubcommand} + } + + for _, sc := range c.cmds { + if sc.Name == args[0] { + return runCmd(ctx, sc, args) + } + } + + fmt.Fprintf(c.Output(), "Invalid Subcommand: %s", args[0]) + c.Usage() + return UsageError{Err: fmt.Errorf("invalid subcommand: %s", args[0])} + +} + +func (c *CommandSet) AddSubcommand(name string, usage string, run func(ctx context.Context, cs *CommandSet, args []string) error) { + c.cmds = append(c.cmds, Command{Name: name, Description: usage, Run: run}) +} + +func (c *CommandSet) PrintDefaults() { + c.FlagSet.PrintDefaults() + if c.description != "" { + fmt.Fprintf(c.Output(), "\n%s\n", c.description) + } + if len(c.cmds) > 0 { + fmt.Fprintf(c.Output(), "\nSubcommands:\n") + for _, sc := range c.cmds { + fmt.Fprintf(c.Output(), " %s\t%s\n", sc.Name, sc.Description) + } } } -// LoadWithArgs initializes context, signal handling, flag parsing, and runs the given application callback. -// -// - args: Command-line arguments (excluding/exact os.Args). -// - run: Callback function (your main app logic). Receives a context (with signal cancellation), -// a FlagSet (ready for use but not yet parsed), and trailing arguments. -// -// Handles panics and logs via logging.RecoverAndLog. -// Logs and exits on non-nil errors from the run callback. -func Load(run func(ctx context.Context, fs *flag.FlagSet, args []string) error) { - LoadWithArgs(os.Args, run) +func (c *CommandSet) Parse(args []string) error { + if err := c.FlagSet.Parse(args); err != nil { + return UsageError{Err: err} + } + return nil } diff --git a/runner/run_test.go b/runner/run_test.go new file mode 100644 index 0000000..d6c1500 --- /dev/null +++ b/runner/run_test.go @@ -0,0 +1,163 @@ +package runner_test + +import ( + "bytes" + "context" + "errors" + "flag" + "io" + "reflect" + "strings" + "testing" + + "go.sudomsg.com/kit/runner" +) + +func TestCommandSet_Run(t *testing.T) { + t.Run("Dispatch", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + var called bool + cs := &runner.CommandSet{ + FlagSet: flag.NewFlagSet("test", flag.ContinueOnError), + } + cs.SetOutput(io.Discard) + + cs.AddSubcommand("foo", "foo command", func(ctx context.Context, cs *runner.CommandSet, args []string) error { + called = true + if !reflect.DeepEqual(args, []string{"bar", "baz"}) { + t.Fatalf("callback got args %v, want [bar baz]", args) + } + if gotName := cs.Name(); gotName != "foo" { + t.Fatalf("sub FlagSet name = %q, want %q", gotName, "foo") + } + + return nil + }) + + err := cs.Run(ctx, []string{"foo", "bar", "baz"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Fatal("expected foo subcommand to be called") + } + }) + + t.Run("Recursive Dispatch", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + var called bool + cs := &runner.CommandSet{ + FlagSet: flag.NewFlagSet("test", flag.ContinueOnError), + } + cs.SetOutput(io.Discard) + + cs.AddSubcommand("foo", "foo command", func(ctx context.Context, cs *runner.CommandSet, args []string) error { + cs.AddSubcommand("bar", "bar command", func(ctx context.Context, cs *runner.CommandSet, args []string) error { + called = true + if !reflect.DeepEqual(args, []string{"baz"}) { + t.Fatalf("callback got args %v, want [bar baz]", args) + } + if gotName := cs.Name(); gotName != "bar" { + t.Fatalf("sub FlagSet name = %q, want %q", gotName, "foo") + } + return nil + }) + return cs.Run(ctx, args) + }) + + err := cs.Run(ctx, []string{"foo", "bar", "baz"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Fatal("expected foo subcommand to be called") + } + + }) + + t.Run("No Subcommand", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + var buf bytes.Buffer + + cs := &runner.CommandSet{ + FlagSet: flag.NewFlagSet("test", flag.ContinueOnError), + } + cs.SetOutput(&buf) + + err := cs.Run(ctx, []string{}) + if err == nil { + t.Fatal("expected error, got nil") + } + var ue runner.UsageError + if !errors.As(err, &ue) { + t.Fatalf("expected UsageError, got %T", err) + } + if !strings.Contains(buf.String(), "Missing Subcommand") { + t.Errorf("expected help output to mention missing subcommand, got %q", buf.String()) + } + }) + + t.Run("Invalid Subcommand", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + var buf bytes.Buffer + + cs := &runner.CommandSet{ + FlagSet: flag.NewFlagSet("test", flag.ContinueOnError), + } + cs.SetOutput(&buf) + + cs.AddSubcommand("foo", "bar", func(ctx context.Context, cs *runner.CommandSet, args []string) error { + t.Fatalf("foo command called") + return nil + }) + + err := cs.Run(ctx, []string{"nope"}) + if err == nil { + t.Fatal("expected error, got nil") + } + + var ue runner.UsageError + if !errors.As(err, &ue) { + t.Fatalf("expected UsageError, got %T", err) + } + if !strings.Contains(buf.String(), "Invalid Subcommand") { + t.Errorf("expected help output to mention invalid subcommand, got %q", buf.String()) + } + }) + + t.Run("Invalid Flags", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + var buf bytes.Buffer + + cs := &runner.CommandSet{ + FlagSet: flag.NewFlagSet("test", flag.ContinueOnError), + } + cs.SetOutput(&buf) + + cs.AddSubcommand("foo", "foo command", func(ctx context.Context, cs *runner.CommandSet, args []string) error { + return cs.Run(ctx, args) + }) + + err := cs.Run(ctx, []string{"foo", "-bar"}) + if err == nil { + t.Fatal("expected error, got nil") + } + + var ue runner.UsageError + if !errors.As(err, &ue) { + t.Fatalf("expected UsageError, got %T", err) + } + }) +} -- cgit v1.2.3-70-g09d2