mirror of
https://github.com/charmbracelet/mods
synced 2024-10-18 15:22:17 +00:00
feat: move to pflag
This commit is contained in:
parent
03aba7b9ef
commit
559044d90d
1
go.mod
1
go.mod
|
@ -10,6 +10,7 @@ require (
|
|||
github.com/muesli/termenv v0.15.1
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/sashabaranov/go-openai v1.7.0
|
||||
github.com/spf13/pflag v1.0.5
|
||||
)
|
||||
|
||||
require (
|
||||
|
|
2
go.sum
2
go.sum
|
@ -51,6 +51,8 @@ github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUc
|
|||
github.com/sahilm/fuzzy v0.1.0/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
|
||||
github.com/sashabaranov/go-openai v1.7.0 h1:D1dBXoZhtf/aKNu6WFf0c7Ah2NM30PZ/3Mqly6cZ7fk=
|
||||
github.com/sashabaranov/go-openai v1.7.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
|
|
177
main.go
177
main.go
|
@ -3,7 +3,6 @@ package main
|
|||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
|
@ -17,56 +16,34 @@ import (
|
|||
"github.com/muesli/termenv"
|
||||
"github.com/pkg/errors"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
const (
|
||||
modelTypeFlagShorthand = "m"
|
||||
quietFlagShorthand = "q"
|
||||
markdownFlagShorthand = "md"
|
||||
temperatureFlagShorthand = "temp"
|
||||
maxTokensFlagShorthand = "max"
|
||||
topPFlagShorthand = "top"
|
||||
typeFlagDescription = "OpenAI model (gpt-3.5-turbo, gpt-4)."
|
||||
markdownFlagDescription = "Format response as markdown."
|
||||
quietFlagDescription = "Quiet mode (hide the spinner while loading)."
|
||||
temperatureFlagDescription = "Temperature (randomness) of results, from 0.0 to 2.0."
|
||||
maxTokensFlagDescription = "Maximum number of tokens in response."
|
||||
topPFlagDescription = "TopP, an alternative to temperature that narrows response, from 0.0 to 1.0."
|
||||
flag "github.com/spf13/pflag"
|
||||
)
|
||||
|
||||
var errorStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("1"))
|
||||
var codeStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("1")).Background(lipgloss.Color("0")).Padding(0, 1)
|
||||
var linkStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("10")).Underline(true)
|
||||
var helpAppStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("208")).Bold(true)
|
||||
var helpFlagStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#41ffef")).Bold(true)
|
||||
var helpDescriptionStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("244"))
|
||||
|
||||
type Config struct {
|
||||
Model string
|
||||
Markdown bool
|
||||
Quiet bool
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
TopP float32
|
||||
Model *string
|
||||
Markdown *bool
|
||||
Quiet *bool
|
||||
MaxTokens *int
|
||||
Temperature *float32
|
||||
TopP *float32
|
||||
}
|
||||
|
||||
func printUsage() {
|
||||
lipgloss.SetColorProfile(termenv.ColorProfile())
|
||||
appNameStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("208")).
|
||||
Bold(true)
|
||||
flagStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#41ffef")).
|
||||
Bold(true)
|
||||
descriptionStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("244"))
|
||||
|
||||
fmt.Printf("Usage: %s [OPTIONS] [PREFIX TERM]\n", appNameStyle.Render(os.Args[0]))
|
||||
fmt.Println()
|
||||
fmt.Println("Options:")
|
||||
fmt.Printf(" %s\t%s\n", flagStyle.Render("-"+modelTypeFlagShorthand), descriptionStyle.Render(typeFlagDescription))
|
||||
fmt.Printf(" %s\t%s\n", flagStyle.Render("-"+quietFlagShorthand), descriptionStyle.Render(quietFlagDescription))
|
||||
fmt.Printf(" %s\t%s\n", flagStyle.Render("-"+markdownFlagShorthand), descriptionStyle.Render(markdownFlagDescription))
|
||||
fmt.Printf(" %s\t%s\n", flagStyle.Render("-"+temperatureFlagShorthand), descriptionStyle.Render(temperatureFlagDescription))
|
||||
fmt.Printf(" %s\t%s\n", flagStyle.Render("-"+topPFlagShorthand), descriptionStyle.Render(topPFlagDescription))
|
||||
fmt.Printf(" %s\t%s\n", flagStyle.Render("-"+maxTokensFlagShorthand), descriptionStyle.Render(maxTokensFlagDescription))
|
||||
func newConfig() Config {
|
||||
return Config{
|
||||
Model: flag.StringP("model", "m", "gpt-4", "OpenAI model (gpt-3.5-turbo, gpt-4)."),
|
||||
Markdown: flag.BoolP("format", "f", false, "Format response as markdown."),
|
||||
Quiet: flag.BoolP("quiet", "q", false, "Quiet mode (hide the spinner while loading)."),
|
||||
MaxTokens: flag.Int("max", 0, "Maximum number of tokens in response."),
|
||||
Temperature: flag.Float32("temp", 1.0, "Temperature (randomness) of results, from 0.0 to 2.0."),
|
||||
TopP: flag.Float32("top", 1.0, "TopP, an alternative to temperature that narrows response, from 0.0 to 1.0."),
|
||||
}
|
||||
}
|
||||
|
||||
func readStdinContent() string {
|
||||
|
@ -74,29 +51,47 @@ func readStdinContent() string {
|
|||
reader := bufio.NewReader(os.Stdin)
|
||||
stdinBytes, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
fmt.Println()
|
||||
fmt.Println(errorStyle.Render(" Unable to read stdin."))
|
||||
fmt.Println()
|
||||
fmt.Println(" " + errorStyle.Render(err.Error()))
|
||||
fmt.Println()
|
||||
os.Exit(1)
|
||||
handleError(err, "Unable to read stdin.")
|
||||
}
|
||||
return string(stdinBytes)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// flagToFloat converts a flag value to a float usable by the OpenAI client
|
||||
// noOmitFloat converts a 0.0 value to a float usable by the OpenAI client
|
||||
// library, which currently uses Float32 fields in the request struct with the
|
||||
// omitempty tag. This means we need to use math.SmallestNonzeroFloat32 instead
|
||||
// of 0.0 so it doesn't get stripped from the request and replaced server side
|
||||
// with the default values.
|
||||
// Issue: https://github.com/sashabaranov/go-openai/issues/9
|
||||
func flagToFloat(f *float64) float32 {
|
||||
if *f == 0.0 {
|
||||
func noOmitFloat(f float32) float32 {
|
||||
if f == 0.0 {
|
||||
return math.SmallestNonzeroFloat32
|
||||
}
|
||||
return float32(*f)
|
||||
return f
|
||||
}
|
||||
|
||||
func usage() {
|
||||
lipgloss.SetColorProfile(termenv.ColorProfile())
|
||||
fmt.Printf("Usage: %s [OPTIONS] [PREFIX TERM]\n", helpAppStyle.Render(os.Args[0]))
|
||||
fmt.Println()
|
||||
fmt.Println("Options:")
|
||||
flag.VisitAll(func(f *flag.Flag) {
|
||||
if f.Shorthand == "" {
|
||||
fmt.Printf(
|
||||
" %-38s %s\n",
|
||||
helpFlagStyle.Render("--"+f.Name),
|
||||
helpDescriptionStyle.Render(f.Usage),
|
||||
)
|
||||
} else {
|
||||
fmt.Printf(
|
||||
" %s, %-34s %s\n",
|
||||
helpFlagStyle.Render("-"+f.Shorthand),
|
||||
helpFlagStyle.Render("--"+f.Name),
|
||||
helpDescriptionStyle.Render(f.Usage),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func createClient(apiKey string) *openai.Client {
|
||||
|
@ -115,10 +110,10 @@ func startChatCompletion(client openai.Client, config Config, content string) (s
|
|||
resp, err := client.CreateChatCompletion(
|
||||
context.Background(),
|
||||
openai.ChatCompletionRequest{
|
||||
Model: config.Model,
|
||||
Temperature: config.Temperature,
|
||||
TopP: config.TopP,
|
||||
MaxTokens: config.MaxTokens,
|
||||
Model: *config.Model,
|
||||
Temperature: noOmitFloat(*config.Temperature),
|
||||
TopP: noOmitFloat(*config.TopP),
|
||||
MaxTokens: *config.MaxTokens,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
|
@ -133,35 +128,28 @@ func startChatCompletion(client openai.Client, config Config, content string) (s
|
|||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
func configFromFlags() Config {
|
||||
modelTypeFlag := flag.String(modelTypeFlagShorthand, "gpt-4", typeFlagDescription)
|
||||
markdownFlag := flag.Bool(markdownFlagShorthand, false, markdownFlagDescription)
|
||||
quietFlag := flag.Bool(quietFlagShorthand, false, quietFlagDescription)
|
||||
temperatureFlag := flag.Float64(temperatureFlagShorthand, 1.0, temperatureFlagDescription)
|
||||
maxTokenFlag := flag.Int(maxTokensFlagShorthand, 0, maxTokensFlagDescription)
|
||||
topPFlag := flag.Float64(topPFlagShorthand, 1.0, topPFlagDescription)
|
||||
flag.Usage = printUsage
|
||||
flag.Parse()
|
||||
return Config{
|
||||
Model: *modelTypeFlag,
|
||||
Quiet: *quietFlag,
|
||||
MaxTokens: *maxTokenFlag,
|
||||
Markdown: *markdownFlag,
|
||||
Temperature: flagToFloat(temperatureFlag),
|
||||
TopP: flagToFloat(topPFlag),
|
||||
}
|
||||
func handleError(err error, reason string) {
|
||||
fmt.Println()
|
||||
fmt.Println(errorStyle.Render(" Error: %s", reason))
|
||||
fmt.Println()
|
||||
fmt.Println(" " + errorStyle.Render(err.Error()))
|
||||
fmt.Println()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func main() {
|
||||
config := configFromFlags()
|
||||
flag.Usage = usage
|
||||
flag.CommandLine.SortFlags = false
|
||||
config := newConfig()
|
||||
flag.Parse()
|
||||
client := createClient(os.Getenv("OPENAI_API_KEY"))
|
||||
content := readStdinContent()
|
||||
prefix := strings.Join(flag.Args(), " ")
|
||||
if prefix == "" && content == "" {
|
||||
printUsage()
|
||||
flag.Usage()
|
||||
os.Exit(0)
|
||||
}
|
||||
if config.Markdown {
|
||||
if *config.Markdown {
|
||||
prefix = fmt.Sprintf("%s Format output as Markdown.", prefix)
|
||||
}
|
||||
if prefix != "" {
|
||||
|
@ -169,35 +157,30 @@ func main() {
|
|||
}
|
||||
|
||||
var p *tea.Program
|
||||
if !config.Quiet {
|
||||
var output string
|
||||
var err error
|
||||
if !*config.Quiet {
|
||||
lipgloss.SetColorProfile(termenv.NewOutput(os.Stderr).ColorProfile())
|
||||
spinner := spinner.New(spinner.WithSpinner(spinner.Dot), spinner.WithStyle(spinnerStyle))
|
||||
p = tea.NewProgram(Model{spinner: spinner}, tea.WithOutput(os.Stderr))
|
||||
}
|
||||
|
||||
if !config.Quiet {
|
||||
go func() {
|
||||
output, err := startChatCompletion(*client, config, content)
|
||||
p.Send(quitMsg{})
|
||||
output, err = startChatCompletion(*client, config, content)
|
||||
p.Quit()
|
||||
if err != nil {
|
||||
fmt.Println()
|
||||
fmt.Println(errorStyle.Render(" Error: Unable to generate response."))
|
||||
fmt.Println()
|
||||
fmt.Println(" " + errorStyle.Render(err.Error()))
|
||||
fmt.Println()
|
||||
os.Exit(1)
|
||||
handleError(err, "There was a problem with the OpenAI API.")
|
||||
}
|
||||
fmt.Println(output)
|
||||
}()
|
||||
} else {
|
||||
output, err := startChatCompletion(*client, config, content)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
}
|
||||
fmt.Println(output)
|
||||
}
|
||||
|
||||
if !config.Quiet {
|
||||
_, _ = p.Run()
|
||||
_, err = p.Run()
|
||||
if err != nil {
|
||||
handleError(err, "Can't run the Bubble Tea program.")
|
||||
}
|
||||
} else {
|
||||
output, err = startChatCompletion(*client, config, content)
|
||||
if err != nil {
|
||||
handleError(err, "There was a problem with the OpenAI API.")
|
||||
}
|
||||
}
|
||||
fmt.Println(output)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue