feat: move to pflag

This commit is contained in:
Toby Padilla 2023-04-17 18:29:15 -05:00
parent 03aba7b9ef
commit 559044d90d
3 changed files with 83 additions and 97 deletions

1
go.mod
View file

@ -10,6 +10,7 @@ require (
github.com/muesli/termenv v0.15.1
github.com/pkg/errors v0.9.1
github.com/sashabaranov/go-openai v1.7.0
github.com/spf13/pflag v1.0.5
)
require (

2
go.sum
View file

@ -51,6 +51,8 @@ github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUc
github.com/sahilm/fuzzy v0.1.0/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
github.com/sashabaranov/go-openai v1.7.0 h1:D1dBXoZhtf/aKNu6WFf0c7Ah2NM30PZ/3Mqly6cZ7fk=
github.com/sashabaranov/go-openai v1.7.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

177
main.go
View file

@ -3,7 +3,6 @@ package main
import (
"bufio"
"context"
"flag"
"fmt"
"io"
"math"
@ -17,56 +16,34 @@ import (
"github.com/muesli/termenv"
"github.com/pkg/errors"
openai "github.com/sashabaranov/go-openai"
)
const (
modelTypeFlagShorthand = "m"
quietFlagShorthand = "q"
markdownFlagShorthand = "md"
temperatureFlagShorthand = "temp"
maxTokensFlagShorthand = "max"
topPFlagShorthand = "top"
typeFlagDescription = "OpenAI model (gpt-3.5-turbo, gpt-4)."
markdownFlagDescription = "Format response as markdown."
quietFlagDescription = "Quiet mode (hide the spinner while loading)."
temperatureFlagDescription = "Temperature (randomness) of results, from 0.0 to 2.0."
maxTokensFlagDescription = "Maximum number of tokens in response."
topPFlagDescription = "TopP, an alternative to temperature that narrows response, from 0.0 to 1.0."
flag "github.com/spf13/pflag"
)
var errorStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("1"))
var codeStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("1")).Background(lipgloss.Color("0")).Padding(0, 1)
var linkStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("10")).Underline(true)
var helpAppStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("208")).Bold(true)
var helpFlagStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#41ffef")).Bold(true)
var helpDescriptionStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("244"))
type Config struct {
Model string
Markdown bool
Quiet bool
MaxTokens int
Temperature float32
TopP float32
Model *string
Markdown *bool
Quiet *bool
MaxTokens *int
Temperature *float32
TopP *float32
}
func printUsage() {
lipgloss.SetColorProfile(termenv.ColorProfile())
appNameStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("208")).
Bold(true)
flagStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#41ffef")).
Bold(true)
descriptionStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("244"))
fmt.Printf("Usage: %s [OPTIONS] [PREFIX TERM]\n", appNameStyle.Render(os.Args[0]))
fmt.Println()
fmt.Println("Options:")
fmt.Printf(" %s\t%s\n", flagStyle.Render("-"+modelTypeFlagShorthand), descriptionStyle.Render(typeFlagDescription))
fmt.Printf(" %s\t%s\n", flagStyle.Render("-"+quietFlagShorthand), descriptionStyle.Render(quietFlagDescription))
fmt.Printf(" %s\t%s\n", flagStyle.Render("-"+markdownFlagShorthand), descriptionStyle.Render(markdownFlagDescription))
fmt.Printf(" %s\t%s\n", flagStyle.Render("-"+temperatureFlagShorthand), descriptionStyle.Render(temperatureFlagDescription))
fmt.Printf(" %s\t%s\n", flagStyle.Render("-"+topPFlagShorthand), descriptionStyle.Render(topPFlagDescription))
fmt.Printf(" %s\t%s\n", flagStyle.Render("-"+maxTokensFlagShorthand), descriptionStyle.Render(maxTokensFlagDescription))
func newConfig() Config {
return Config{
Model: flag.StringP("model", "m", "gpt-4", "OpenAI model (gpt-3.5-turbo, gpt-4)."),
Markdown: flag.BoolP("format", "f", false, "Format response as markdown."),
Quiet: flag.BoolP("quiet", "q", false, "Quiet mode (hide the spinner while loading)."),
MaxTokens: flag.Int("max", 0, "Maximum number of tokens in response."),
Temperature: flag.Float32("temp", 1.0, "Temperature (randomness) of results, from 0.0 to 2.0."),
TopP: flag.Float32("top", 1.0, "TopP, an alternative to temperature that narrows response, from 0.0 to 1.0."),
}
}
func readStdinContent() string {
@ -74,29 +51,47 @@ func readStdinContent() string {
reader := bufio.NewReader(os.Stdin)
stdinBytes, err := io.ReadAll(reader)
if err != nil {
fmt.Println()
fmt.Println(errorStyle.Render(" Unable to read stdin."))
fmt.Println()
fmt.Println(" " + errorStyle.Render(err.Error()))
fmt.Println()
os.Exit(1)
handleError(err, "Unable to read stdin.")
}
return string(stdinBytes)
}
return ""
}
// flagToFloat converts a flag value to a float usable by the OpenAI client
// noOmitFloat converts a 0.0 value to a float usable by the OpenAI client
// library, which currently uses Float32 fields in the request struct with the
// omitempty tag. This means we need to use math.SmallestNonzeroFloat32 instead
// of 0.0 so it doesn't get stripped from the request and replaced server side
// with the default values.
// Issue: https://github.com/sashabaranov/go-openai/issues/9
func flagToFloat(f *float64) float32 {
if *f == 0.0 {
func noOmitFloat(f float32) float32 {
if f == 0.0 {
return math.SmallestNonzeroFloat32
}
return float32(*f)
return f
}
func usage() {
lipgloss.SetColorProfile(termenv.ColorProfile())
fmt.Printf("Usage: %s [OPTIONS] [PREFIX TERM]\n", helpAppStyle.Render(os.Args[0]))
fmt.Println()
fmt.Println("Options:")
flag.VisitAll(func(f *flag.Flag) {
if f.Shorthand == "" {
fmt.Printf(
" %-38s %s\n",
helpFlagStyle.Render("--"+f.Name),
helpDescriptionStyle.Render(f.Usage),
)
} else {
fmt.Printf(
" %s, %-34s %s\n",
helpFlagStyle.Render("-"+f.Shorthand),
helpFlagStyle.Render("--"+f.Name),
helpDescriptionStyle.Render(f.Usage),
)
}
})
}
func createClient(apiKey string) *openai.Client {
@ -115,10 +110,10 @@ func startChatCompletion(client openai.Client, config Config, content string) (s
resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: config.Model,
Temperature: config.Temperature,
TopP: config.TopP,
MaxTokens: config.MaxTokens,
Model: *config.Model,
Temperature: noOmitFloat(*config.Temperature),
TopP: noOmitFloat(*config.TopP),
MaxTokens: *config.MaxTokens,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
@ -133,35 +128,28 @@ func startChatCompletion(client openai.Client, config Config, content string) (s
return resp.Choices[0].Message.Content, nil
}
func configFromFlags() Config {
modelTypeFlag := flag.String(modelTypeFlagShorthand, "gpt-4", typeFlagDescription)
markdownFlag := flag.Bool(markdownFlagShorthand, false, markdownFlagDescription)
quietFlag := flag.Bool(quietFlagShorthand, false, quietFlagDescription)
temperatureFlag := flag.Float64(temperatureFlagShorthand, 1.0, temperatureFlagDescription)
maxTokenFlag := flag.Int(maxTokensFlagShorthand, 0, maxTokensFlagDescription)
topPFlag := flag.Float64(topPFlagShorthand, 1.0, topPFlagDescription)
flag.Usage = printUsage
flag.Parse()
return Config{
Model: *modelTypeFlag,
Quiet: *quietFlag,
MaxTokens: *maxTokenFlag,
Markdown: *markdownFlag,
Temperature: flagToFloat(temperatureFlag),
TopP: flagToFloat(topPFlag),
}
func handleError(err error, reason string) {
fmt.Println()
fmt.Println(errorStyle.Render(" Error: %s", reason))
fmt.Println()
fmt.Println(" " + errorStyle.Render(err.Error()))
fmt.Println()
os.Exit(1)
}
func main() {
config := configFromFlags()
flag.Usage = usage
flag.CommandLine.SortFlags = false
config := newConfig()
flag.Parse()
client := createClient(os.Getenv("OPENAI_API_KEY"))
content := readStdinContent()
prefix := strings.Join(flag.Args(), " ")
if prefix == "" && content == "" {
printUsage()
flag.Usage()
os.Exit(0)
}
if config.Markdown {
if *config.Markdown {
prefix = fmt.Sprintf("%s Format output as Markdown.", prefix)
}
if prefix != "" {
@ -169,35 +157,30 @@ func main() {
}
var p *tea.Program
if !config.Quiet {
var output string
var err error
if !*config.Quiet {
lipgloss.SetColorProfile(termenv.NewOutput(os.Stderr).ColorProfile())
spinner := spinner.New(spinner.WithSpinner(spinner.Dot), spinner.WithStyle(spinnerStyle))
p = tea.NewProgram(Model{spinner: spinner}, tea.WithOutput(os.Stderr))
}
if !config.Quiet {
go func() {
output, err := startChatCompletion(*client, config, content)
p.Send(quitMsg{})
output, err = startChatCompletion(*client, config, content)
p.Quit()
if err != nil {
fmt.Println()
fmt.Println(errorStyle.Render(" Error: Unable to generate response."))
fmt.Println()
fmt.Println(" " + errorStyle.Render(err.Error()))
fmt.Println()
os.Exit(1)
handleError(err, "There was a problem with the OpenAI API.")
}
fmt.Println(output)
}()
} else {
output, err := startChatCompletion(*client, config, content)
if err != nil {
fmt.Println(err.Error())
}
fmt.Println(output)
}
if !config.Quiet {
_, _ = p.Run()
_, err = p.Run()
if err != nil {
handleError(err, "Can't run the Bubble Tea program.")
}
} else {
output, err = startChatCompletion(*client, config, content)
if err != nil {
handleError(err, "There was a problem with the OpenAI API.")
}
}
fmt.Println(output)
}