From 559044d90d25a641ce6b28e54a22975f0fb451f0 Mon Sep 17 00:00:00 2001 From: Toby Padilla Date: Mon, 17 Apr 2023 18:29:15 -0500 Subject: [PATCH] feat: move to pflag --- go.mod | 1 + go.sum | 2 + main.go | 177 +++++++++++++++++++++++++------------------------------- 3 files changed, 83 insertions(+), 97 deletions(-) diff --git a/go.mod b/go.mod index 4ec6d59..2867b98 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/muesli/termenv v0.15.1 github.com/pkg/errors v0.9.1 github.com/sashabaranov/go-openai v1.7.0 + github.com/spf13/pflag v1.0.5 ) require ( diff --git a/go.sum b/go.sum index 3b18677..813a625 100644 --- a/go.sum +++ b/go.sum @@ -51,6 +51,8 @@ github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUc github.com/sahilm/fuzzy v0.1.0/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= github.com/sashabaranov/go-openai v1.7.0 h1:D1dBXoZhtf/aKNu6WFf0c7Ah2NM30PZ/3Mqly6cZ7fk= github.com/sashabaranov/go-openai v1.7.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/main.go b/main.go index ba0e37d..fc75966 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,6 @@ package main import ( "bufio" "context" - "flag" "fmt" "io" "math" @@ -17,56 +16,34 @@ import ( "github.com/muesli/termenv" "github.com/pkg/errors" openai "github.com/sashabaranov/go-openai" -) - -const ( - modelTypeFlagShorthand = "m" - quietFlagShorthand = "q" - markdownFlagShorthand = "md" - temperatureFlagShorthand = "temp" - maxTokensFlagShorthand = "max" - topPFlagShorthand = "top" - typeFlagDescription = "OpenAI model (gpt-3.5-turbo, gpt-4)." - markdownFlagDescription = "Format response as markdown." - quietFlagDescription = "Quiet mode (hide the spinner while loading)." - temperatureFlagDescription = "Temperature (randomness) of results, from 0.0 to 2.0." - maxTokensFlagDescription = "Maximum number of tokens in response." - topPFlagDescription = "TopP, an alternative to temperature that narrows response, from 0.0 to 1.0." + flag "github.com/spf13/pflag" ) var errorStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("1")) var codeStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("1")).Background(lipgloss.Color("0")).Padding(0, 1) var linkStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("10")).Underline(true) +var helpAppStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("208")).Bold(true) +var helpFlagStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#41ffef")).Bold(true) +var helpDescriptionStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("244")) type Config struct { - Model string - Markdown bool - Quiet bool - MaxTokens int - Temperature float32 - TopP float32 + Model *string + Markdown *bool + Quiet *bool + MaxTokens *int + Temperature *float32 + TopP *float32 } -func printUsage() { - lipgloss.SetColorProfile(termenv.ColorProfile()) - appNameStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("208")). - Bold(true) - flagStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#41ffef")). - Bold(true) - descriptionStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("244")) - - fmt.Printf("Usage: %s [OPTIONS] [PREFIX TERM]\n", appNameStyle.Render(os.Args[0])) - fmt.Println() - fmt.Println("Options:") - fmt.Printf(" %s\t%s\n", flagStyle.Render("-"+modelTypeFlagShorthand), descriptionStyle.Render(typeFlagDescription)) - fmt.Printf(" %s\t%s\n", flagStyle.Render("-"+quietFlagShorthand), descriptionStyle.Render(quietFlagDescription)) - fmt.Printf(" %s\t%s\n", flagStyle.Render("-"+markdownFlagShorthand), descriptionStyle.Render(markdownFlagDescription)) - fmt.Printf(" %s\t%s\n", flagStyle.Render("-"+temperatureFlagShorthand), descriptionStyle.Render(temperatureFlagDescription)) - fmt.Printf(" %s\t%s\n", flagStyle.Render("-"+topPFlagShorthand), descriptionStyle.Render(topPFlagDescription)) - fmt.Printf(" %s\t%s\n", flagStyle.Render("-"+maxTokensFlagShorthand), descriptionStyle.Render(maxTokensFlagDescription)) +func newConfig() Config { + return Config{ + Model: flag.StringP("model", "m", "gpt-4", "OpenAI model (gpt-3.5-turbo, gpt-4)."), + Markdown: flag.BoolP("format", "f", false, "Format response as markdown."), + Quiet: flag.BoolP("quiet", "q", false, "Quiet mode (hide the spinner while loading)."), + MaxTokens: flag.Int("max", 0, "Maximum number of tokens in response."), + Temperature: flag.Float32("temp", 1.0, "Temperature (randomness) of results, from 0.0 to 2.0."), + TopP: flag.Float32("top", 1.0, "TopP, an alternative to temperature that narrows response, from 0.0 to 1.0."), + } } func readStdinContent() string { @@ -74,29 +51,47 @@ func readStdinContent() string { reader := bufio.NewReader(os.Stdin) stdinBytes, err := io.ReadAll(reader) if err != nil { - fmt.Println() - fmt.Println(errorStyle.Render(" Unable to read stdin.")) - fmt.Println() - fmt.Println(" " + errorStyle.Render(err.Error())) - fmt.Println() - os.Exit(1) + handleError(err, "Unable to read stdin.") } return string(stdinBytes) } return "" } -// flagToFloat converts a flag value to a float usable by the OpenAI client +// noOmitFloat converts a 0.0 value to a float usable by the OpenAI client // library, which currently uses Float32 fields in the request struct with the // omitempty tag. This means we need to use math.SmallestNonzeroFloat32 instead // of 0.0 so it doesn't get stripped from the request and replaced server side // with the default values. // Issue: https://github.com/sashabaranov/go-openai/issues/9 -func flagToFloat(f *float64) float32 { - if *f == 0.0 { +func noOmitFloat(f float32) float32 { + if f == 0.0 { return math.SmallestNonzeroFloat32 } - return float32(*f) + return f +} + +func usage() { + lipgloss.SetColorProfile(termenv.ColorProfile()) + fmt.Printf("Usage: %s [OPTIONS] [PREFIX TERM]\n", helpAppStyle.Render(os.Args[0])) + fmt.Println() + fmt.Println("Options:") + flag.VisitAll(func(f *flag.Flag) { + if f.Shorthand == "" { + fmt.Printf( + " %-38s %s\n", + helpFlagStyle.Render("--"+f.Name), + helpDescriptionStyle.Render(f.Usage), + ) + } else { + fmt.Printf( + " %s, %-34s %s\n", + helpFlagStyle.Render("-"+f.Shorthand), + helpFlagStyle.Render("--"+f.Name), + helpDescriptionStyle.Render(f.Usage), + ) + } + }) } func createClient(apiKey string) *openai.Client { @@ -115,10 +110,10 @@ func startChatCompletion(client openai.Client, config Config, content string) (s resp, err := client.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ - Model: config.Model, - Temperature: config.Temperature, - TopP: config.TopP, - MaxTokens: config.MaxTokens, + Model: *config.Model, + Temperature: noOmitFloat(*config.Temperature), + TopP: noOmitFloat(*config.TopP), + MaxTokens: *config.MaxTokens, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -133,35 +128,28 @@ func startChatCompletion(client openai.Client, config Config, content string) (s return resp.Choices[0].Message.Content, nil } -func configFromFlags() Config { - modelTypeFlag := flag.String(modelTypeFlagShorthand, "gpt-4", typeFlagDescription) - markdownFlag := flag.Bool(markdownFlagShorthand, false, markdownFlagDescription) - quietFlag := flag.Bool(quietFlagShorthand, false, quietFlagDescription) - temperatureFlag := flag.Float64(temperatureFlagShorthand, 1.0, temperatureFlagDescription) - maxTokenFlag := flag.Int(maxTokensFlagShorthand, 0, maxTokensFlagDescription) - topPFlag := flag.Float64(topPFlagShorthand, 1.0, topPFlagDescription) - flag.Usage = printUsage - flag.Parse() - return Config{ - Model: *modelTypeFlag, - Quiet: *quietFlag, - MaxTokens: *maxTokenFlag, - Markdown: *markdownFlag, - Temperature: flagToFloat(temperatureFlag), - TopP: flagToFloat(topPFlag), - } +func handleError(err error, reason string) { + fmt.Println() + fmt.Println(errorStyle.Render(" Error: %s", reason)) + fmt.Println() + fmt.Println(" " + errorStyle.Render(err.Error())) + fmt.Println() + os.Exit(1) } func main() { - config := configFromFlags() + flag.Usage = usage + flag.CommandLine.SortFlags = false + config := newConfig() + flag.Parse() client := createClient(os.Getenv("OPENAI_API_KEY")) content := readStdinContent() prefix := strings.Join(flag.Args(), " ") if prefix == "" && content == "" { - printUsage() + flag.Usage() os.Exit(0) } - if config.Markdown { + if *config.Markdown { prefix = fmt.Sprintf("%s Format output as Markdown.", prefix) } if prefix != "" { @@ -169,35 +157,30 @@ func main() { } var p *tea.Program - if !config.Quiet { + var output string + var err error + if !*config.Quiet { lipgloss.SetColorProfile(termenv.NewOutput(os.Stderr).ColorProfile()) spinner := spinner.New(spinner.WithSpinner(spinner.Dot), spinner.WithStyle(spinnerStyle)) p = tea.NewProgram(Model{spinner: spinner}, tea.WithOutput(os.Stderr)) - } - if !config.Quiet { go func() { - output, err := startChatCompletion(*client, config, content) - p.Send(quitMsg{}) + output, err = startChatCompletion(*client, config, content) + p.Quit() if err != nil { - fmt.Println() - fmt.Println(errorStyle.Render(" Error: Unable to generate response.")) - fmt.Println() - fmt.Println(" " + errorStyle.Render(err.Error())) - fmt.Println() - os.Exit(1) + handleError(err, "There was a problem with the OpenAI API.") } - fmt.Println(output) }() - } else { - output, err := startChatCompletion(*client, config, content) - if err != nil { - fmt.Println(err.Error()) - } - fmt.Println(output) - } - if !config.Quiet { - _, _ = p.Run() + _, err = p.Run() + if err != nil { + handleError(err, "Can't run the Bubble Tea program.") + } + } else { + output, err = startChatCompletion(*client, config, content) + if err != nil { + handleError(err, "There was a problem with the OpenAI API.") + } } + fmt.Println(output) }