mods/mods.go
2023-06-05 09:52:14 -05:00

339 lines
8.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"math"
"net/http"
"os"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/mattn/go-isatty"
openai "github.com/sashabaranov/go-openai"
)
const markdownPrefix = "Format the response as Markdown."
type state int
const (
startState state = iota
configLoadedState
completionState
errorState
)
// Mods is the Bubble Tea model that manages reading stdin and querying the
// OpenAI API.
type Mods struct {
Config config
Output string
Input string
Error *modsError
state state
retries int
styles styles
renderer *lipgloss.Renderer
anim tea.Model
width int
height int
}
func newMods(r *lipgloss.Renderer) *Mods {
s := makeStyles(r)
return &Mods{
state: startState,
renderer: r,
styles: s,
}
}
// completionInput is a tea.Msg that wraps the content read from stdin.
type completionInput struct{ content string }
// completionOutput a tea.Msg that wraps the content returned from openai.
type completionOutput struct{ content string }
// modsError is a wrapper around an error that adds additional context.
type modsError struct {
err error
reason string
}
func (m modsError) Error() string {
return m.err.Error()
}
// Init implements tea.Model.
func (m *Mods) Init() tea.Cmd {
return m.loadConfigCmd
}
// Update implements tea.Model.
func (m *Mods) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case config:
m.Config = msg
m.state = configLoadedState
if m.Config.ShowHelp || m.Config.Version || m.Config.Settings {
return m, tea.Quit
}
m.anim = newAnim(m.Config.Fanciness, m.Config.StatusText, m.renderer, m.styles)
return m, tea.Batch(readStdinCmd, m.anim.Init())
case completionInput:
if msg.content == "" && m.Config.Prefix == "" {
return m, tea.Quit
}
if msg.content != "" {
m.Input = msg.content
}
m.state = completionState
return m, m.startCompletionCmd(msg.content)
case completionOutput:
m.Output = msg.content
return m, tea.Quit
case modsError:
m.Error = &msg
m.state = errorState
return m, tea.Quit
case tea.WindowSizeMsg:
m.width, m.height = msg.Width, msg.Height
case tea.KeyMsg:
switch msg.String() {
case "q", "ctrl+c":
return m, tea.Quit
}
}
if m.state == configLoadedState || m.state == completionState {
var cmd tea.Cmd
m.anim, cmd = m.anim.Update(msg)
return m, cmd
}
return m, nil
}
// View implements tea.Model.
func (m *Mods) View() string {
//nolint:exhaustive
switch m.state {
case errorState:
return m.ErrorView()
case completionState:
if !m.Config.Quiet {
return m.anim.View()
}
}
return ""
}
// ErrorView renders the currently set modsError
func (m Mods) ErrorView() string {
const maxWidth = 120
const horizontalPadding = 2
w := m.width - (horizontalPadding * 2)
if w > maxWidth {
w = maxWidth
}
s := m.renderer.NewStyle().Width(w).Padding(0, horizontalPadding)
return fmt.Sprintf(
"\n%s\n\n%s\n\n",
s.Render(m.styles.errorHeader.String(), m.Error.reason),
s.Render(m.styles.errorDetails.Render(m.Error.Error())),
)
}
// FormattedOutput returns the response from OpenAI with the user configured
// prefix and standard in settings.
func (m *Mods) FormattedOutput() string {
prefixFormat := "> %s\n\n---\n\n%s"
stdinFormat := "```\n%s```\n\n---\n\n%s"
out := m.Output
if m.Config.IncludePrompt != 0 {
if m.Config.IncludePrompt < 0 {
out = fmt.Sprintf(stdinFormat, m.Input, out)
}
scanner := bufio.NewScanner(strings.NewReader(m.Input))
i := 0
in := ""
for scanner.Scan() {
if i == m.Config.IncludePrompt {
break
}
in += (scanner.Text() + "\n")
i++
}
out = fmt.Sprintf(stdinFormat, in, out)
}
if m.Config.IncludePromptArgs || m.Config.IncludePrompt != 0 {
out = fmt.Sprintf(prefixFormat, m.Config.Prefix, out)
}
return out
}
func (m *Mods) retry(content string, err modsError) tea.Msg {
m.retries++
if m.retries >= m.Config.MaxRetries {
return err
}
wait := time.Millisecond * 100 * time.Duration(math.Pow(2, float64(m.retries))) //nolint:gomnd
time.Sleep(wait)
return completionInput{content}
}
func (m *Mods) loadConfigCmd() tea.Msg {
cfg, err := newConfig()
if err != nil {
return modsError{err, "There was an error in your config file."}
}
return cfg
}
func (m *Mods) startCompletionCmd(content string) tea.Cmd {
return func() tea.Msg {
var ok bool
var mod Model
var key string
cfg := m.Config
mod, ok = cfg.Models[cfg.Model]
if !ok {
if cfg.API == "" {
return modsError{
reason: "Model " + m.styles.inlineCode.Render(cfg.Model) + " is not in the settings file.",
err: fmt.Errorf("Please specify an API endpoint with %s or configure the model in the settings: %s", m.styles.inlineCode.Render("--api"), m.styles.inlineCode.Render("mods -s")),
}
}
mod.Name = cfg.Model
mod.API = cfg.API
mod.MaxChars = cfg.MaxInputChars
}
if mod.API == "openai" {
key = os.Getenv("OPENAI_API_KEY")
if key == "" {
return modsError{
reason: m.styles.inlineCode.Render("OPENAI_API_KEY") + " environment variabled is required.",
err: fmt.Errorf("You can grab one at %s", m.styles.link.Render("https://platform.openai.com/account/api-keys.")),
}
}
}
ccfg := openai.DefaultConfig(key)
api, ok := cfg.APIs[mod.API]
if !ok {
eps := make([]string, 0)
for k := range cfg.APIs {
eps = append(eps, m.styles.inlineCode.Render(k))
}
return modsError{
reason: fmt.Sprintf("The API endpoint %s is not configured ", m.styles.inlineCode.Render(cfg.API)),
err: fmt.Errorf("Your configured API endpoints are: %s", eps),
}
}
ccfg.BaseURL = api.BaseURL
client := openai.NewClientWithConfig(ccfg)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
prefix := cfg.Prefix
if cfg.Markdown {
prefix = fmt.Sprintf("%s %s", prefix, markdownPrefix)
}
if prefix != "" {
content = strings.TrimSpace(prefix + "\n\n" + content)
}
if !cfg.NoLimit {
if len(content) > mod.MaxChars {
content = content[:mod.MaxChars]
}
}
resp, err := client.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: mod.Name,
Temperature: noOmitFloat(cfg.Temperature),
TopP: noOmitFloat(cfg.TopP),
MaxTokens: cfg.MaxTokens,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: content,
},
},
},
)
ae := &openai.APIError{}
if errors.As(err, &ae) {
switch ae.HTTPStatusCode {
case http.StatusNotFound:
if mod.Fallback != "" {
m.Config.Model = mod.Fallback
return m.retry(content, modsError{err: err, reason: "OpenAI API server error."})
}
return modsError{err: err, reason: fmt.Sprintf("Missing model '%s' for API '%s'", cfg.Model, cfg.API)}
case http.StatusBadRequest:
if ae.Code == "context_length_exceeded" {
pe := modsError{err: err, reason: "Maximum prompt size exceeded."}
if cfg.NoLimit {
return pe
}
return m.retry(content[:len(content)-10], pe)
}
// bad request (do not retry)
return modsError{err: err, reason: "OpenAI API request error."}
case http.StatusUnauthorized:
// invalid auth or key (do not retry)
return modsError{err: err, reason: "Invalid OpenAI API key."}
case http.StatusTooManyRequests:
// rate limiting or engine overload (wait and retry)
return m.retry(content, modsError{err: err, reason: "Youve hit your OpenAI API rate limit."})
case http.StatusInternalServerError:
if mod.API == "openai" {
return m.retry(content, modsError{err: err, reason: "OpenAI API server error."})
}
return modsError{err: err, reason: fmt.Sprintf("Error loading model '%s' for API '%s'", mod.Name, mod.API)}
default:
return m.retry(content, modsError{err: err, reason: "Unknown OpenAI API error."})
}
}
if err != nil {
return modsError{err: err, reason: "There was a problem with the OpenAI API request."}
}
return completionOutput{resp.Choices[0].Message.Content}
}
}
func readStdinCmd() tea.Msg {
if !isatty.IsTerminal(os.Stdin.Fd()) {
reader := bufio.NewReader(os.Stdin)
stdinBytes, err := io.ReadAll(reader)
if err != nil {
return modsError{err, "Unable to read stdin."}
}
return completionInput{string(stdinBytes)}
}
return completionInput{""}
}
// noOmitFloat converts a 0.0 value to a float usable by the OpenAI client
// library, which currently uses Float32 fields in the request struct with the
// omitempty tag. This means we need to use math.SmallestNonzeroFloat32 instead
// of 0.0 so it doesn't get stripped from the request and replaced server side
// with the default values.
// Issue: https://github.com/sashabaranov/go-openai/issues/9
func noOmitFloat(f float32) float32 {
if f == 0.0 {
return math.SmallestNonzeroFloat32
}
return f
}