mirror of
https://github.com/charmbracelet/mods
synced 2024-10-18 23:32:17 +00:00
339 lines
8.8 KiB
Go
339 lines
8.8 KiB
Go
package main
|
||
|
||
import (
|
||
"bufio"
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"math"
|
||
"net/http"
|
||
"os"
|
||
"strings"
|
||
"time"
|
||
|
||
tea "github.com/charmbracelet/bubbletea"
|
||
"github.com/charmbracelet/lipgloss"
|
||
"github.com/mattn/go-isatty"
|
||
openai "github.com/sashabaranov/go-openai"
|
||
)
|
||
|
||
const markdownPrefix = "Format the response as Markdown."
|
||
|
||
type state int
|
||
|
||
const (
|
||
startState state = iota
|
||
configLoadedState
|
||
completionState
|
||
errorState
|
||
)
|
||
|
||
// Mods is the Bubble Tea model that manages reading stdin and querying the
|
||
// OpenAI API.
|
||
type Mods struct {
|
||
Config config
|
||
Output string
|
||
Input string
|
||
Error *modsError
|
||
state state
|
||
retries int
|
||
styles styles
|
||
renderer *lipgloss.Renderer
|
||
anim tea.Model
|
||
width int
|
||
height int
|
||
}
|
||
|
||
func newMods(r *lipgloss.Renderer) *Mods {
|
||
s := makeStyles(r)
|
||
return &Mods{
|
||
state: startState,
|
||
renderer: r,
|
||
styles: s,
|
||
}
|
||
}
|
||
|
||
// completionInput is a tea.Msg that wraps the content read from stdin.
|
||
type completionInput struct{ content string }
|
||
|
||
// completionOutput a tea.Msg that wraps the content returned from openai.
|
||
type completionOutput struct{ content string }
|
||
|
||
// modsError is a wrapper around an error that adds additional context.
|
||
type modsError struct {
|
||
err error
|
||
reason string
|
||
}
|
||
|
||
func (m modsError) Error() string {
|
||
return m.err.Error()
|
||
}
|
||
|
||
// Init implements tea.Model.
|
||
func (m *Mods) Init() tea.Cmd {
|
||
return m.loadConfigCmd
|
||
}
|
||
|
||
// Update implements tea.Model.
|
||
func (m *Mods) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||
switch msg := msg.(type) {
|
||
case config:
|
||
m.Config = msg
|
||
m.state = configLoadedState
|
||
if m.Config.ShowHelp || m.Config.Version || m.Config.Settings {
|
||
return m, tea.Quit
|
||
}
|
||
m.anim = newAnim(m.Config.Fanciness, m.Config.StatusText, m.renderer, m.styles)
|
||
return m, tea.Batch(readStdinCmd, m.anim.Init())
|
||
case completionInput:
|
||
if msg.content == "" && m.Config.Prefix == "" {
|
||
return m, tea.Quit
|
||
}
|
||
if msg.content != "" {
|
||
m.Input = msg.content
|
||
}
|
||
m.state = completionState
|
||
return m, m.startCompletionCmd(msg.content)
|
||
case completionOutput:
|
||
m.Output = msg.content
|
||
return m, tea.Quit
|
||
case modsError:
|
||
m.Error = &msg
|
||
m.state = errorState
|
||
return m, tea.Quit
|
||
case tea.WindowSizeMsg:
|
||
m.width, m.height = msg.Width, msg.Height
|
||
case tea.KeyMsg:
|
||
switch msg.String() {
|
||
case "q", "ctrl+c":
|
||
return m, tea.Quit
|
||
}
|
||
}
|
||
if m.state == configLoadedState || m.state == completionState {
|
||
var cmd tea.Cmd
|
||
m.anim, cmd = m.anim.Update(msg)
|
||
return m, cmd
|
||
}
|
||
return m, nil
|
||
}
|
||
|
||
// View implements tea.Model.
|
||
func (m *Mods) View() string {
|
||
//nolint:exhaustive
|
||
switch m.state {
|
||
case errorState:
|
||
return m.ErrorView()
|
||
case completionState:
|
||
if !m.Config.Quiet {
|
||
return m.anim.View()
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// ErrorView renders the currently set modsError
|
||
func (m Mods) ErrorView() string {
|
||
const maxWidth = 120
|
||
const horizontalPadding = 2
|
||
w := m.width - (horizontalPadding * 2)
|
||
if w > maxWidth {
|
||
w = maxWidth
|
||
}
|
||
s := m.renderer.NewStyle().Width(w).Padding(0, horizontalPadding)
|
||
return fmt.Sprintf(
|
||
"\n%s\n\n%s\n\n",
|
||
s.Render(m.styles.errorHeader.String(), m.Error.reason),
|
||
s.Render(m.styles.errorDetails.Render(m.Error.Error())),
|
||
)
|
||
}
|
||
|
||
// FormattedOutput returns the response from OpenAI with the user configured
|
||
// prefix and standard in settings.
|
||
func (m *Mods) FormattedOutput() string {
|
||
prefixFormat := "> %s\n\n---\n\n%s"
|
||
stdinFormat := "```\n%s```\n\n---\n\n%s"
|
||
out := m.Output
|
||
|
||
if m.Config.IncludePrompt != 0 {
|
||
if m.Config.IncludePrompt < 0 {
|
||
out = fmt.Sprintf(stdinFormat, m.Input, out)
|
||
}
|
||
scanner := bufio.NewScanner(strings.NewReader(m.Input))
|
||
i := 0
|
||
in := ""
|
||
for scanner.Scan() {
|
||
if i == m.Config.IncludePrompt {
|
||
break
|
||
}
|
||
in += (scanner.Text() + "\n")
|
||
i++
|
||
}
|
||
out = fmt.Sprintf(stdinFormat, in, out)
|
||
}
|
||
|
||
if m.Config.IncludePromptArgs || m.Config.IncludePrompt != 0 {
|
||
out = fmt.Sprintf(prefixFormat, m.Config.Prefix, out)
|
||
}
|
||
|
||
return out
|
||
}
|
||
|
||
func (m *Mods) retry(content string, err modsError) tea.Msg {
|
||
m.retries++
|
||
if m.retries >= m.Config.MaxRetries {
|
||
return err
|
||
}
|
||
wait := time.Millisecond * 100 * time.Duration(math.Pow(2, float64(m.retries))) //nolint:gomnd
|
||
time.Sleep(wait)
|
||
return completionInput{content}
|
||
}
|
||
|
||
func (m *Mods) loadConfigCmd() tea.Msg {
|
||
cfg, err := newConfig()
|
||
if err != nil {
|
||
return modsError{err, "There was an error in your config file."}
|
||
}
|
||
return cfg
|
||
}
|
||
|
||
func (m *Mods) startCompletionCmd(content string) tea.Cmd {
|
||
return func() tea.Msg {
|
||
var ok bool
|
||
var mod Model
|
||
var key string
|
||
cfg := m.Config
|
||
mod, ok = cfg.Models[cfg.Model]
|
||
if !ok {
|
||
if cfg.API == "" {
|
||
return modsError{
|
||
reason: "Model " + m.styles.inlineCode.Render(cfg.Model) + " is not in the settings file.",
|
||
err: fmt.Errorf("Please specify an API endpoint with %s or configure the model in the settings: %s", m.styles.inlineCode.Render("--api"), m.styles.inlineCode.Render("mods -s")),
|
||
}
|
||
}
|
||
mod.Name = cfg.Model
|
||
mod.API = cfg.API
|
||
mod.MaxChars = cfg.MaxInputChars
|
||
}
|
||
|
||
if mod.API == "openai" {
|
||
key = os.Getenv("OPENAI_API_KEY")
|
||
if key == "" {
|
||
return modsError{
|
||
reason: m.styles.inlineCode.Render("OPENAI_API_KEY") + " environment variabled is required.",
|
||
err: fmt.Errorf("You can grab one at %s", m.styles.link.Render("https://platform.openai.com/account/api-keys.")),
|
||
}
|
||
}
|
||
}
|
||
ccfg := openai.DefaultConfig(key)
|
||
api, ok := cfg.APIs[mod.API]
|
||
if !ok {
|
||
eps := make([]string, 0)
|
||
for k := range cfg.APIs {
|
||
eps = append(eps, m.styles.inlineCode.Render(k))
|
||
}
|
||
return modsError{
|
||
reason: fmt.Sprintf("The API endpoint %s is not configured ", m.styles.inlineCode.Render(cfg.API)),
|
||
err: fmt.Errorf("Your configured API endpoints are: %s", eps),
|
||
}
|
||
}
|
||
ccfg.BaseURL = api.BaseURL
|
||
client := openai.NewClientWithConfig(ccfg)
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
defer cancel()
|
||
prefix := cfg.Prefix
|
||
if cfg.Markdown {
|
||
prefix = fmt.Sprintf("%s %s", prefix, markdownPrefix)
|
||
}
|
||
if prefix != "" {
|
||
content = strings.TrimSpace(prefix + "\n\n" + content)
|
||
}
|
||
|
||
if !cfg.NoLimit {
|
||
if len(content) > mod.MaxChars {
|
||
content = content[:mod.MaxChars]
|
||
}
|
||
}
|
||
|
||
resp, err := client.CreateChatCompletion(
|
||
ctx,
|
||
openai.ChatCompletionRequest{
|
||
Model: mod.Name,
|
||
Temperature: noOmitFloat(cfg.Temperature),
|
||
TopP: noOmitFloat(cfg.TopP),
|
||
MaxTokens: cfg.MaxTokens,
|
||
Messages: []openai.ChatCompletionMessage{
|
||
{
|
||
Role: openai.ChatMessageRoleUser,
|
||
Content: content,
|
||
},
|
||
},
|
||
},
|
||
)
|
||
ae := &openai.APIError{}
|
||
if errors.As(err, &ae) {
|
||
switch ae.HTTPStatusCode {
|
||
case http.StatusNotFound:
|
||
if mod.Fallback != "" {
|
||
m.Config.Model = mod.Fallback
|
||
return m.retry(content, modsError{err: err, reason: "OpenAI API server error."})
|
||
}
|
||
return modsError{err: err, reason: fmt.Sprintf("Missing model '%s' for API '%s'", cfg.Model, cfg.API)}
|
||
case http.StatusBadRequest:
|
||
if ae.Code == "context_length_exceeded" {
|
||
pe := modsError{err: err, reason: "Maximum prompt size exceeded."}
|
||
if cfg.NoLimit {
|
||
return pe
|
||
}
|
||
return m.retry(content[:len(content)-10], pe)
|
||
}
|
||
// bad request (do not retry)
|
||
return modsError{err: err, reason: "OpenAI API request error."}
|
||
case http.StatusUnauthorized:
|
||
// invalid auth or key (do not retry)
|
||
return modsError{err: err, reason: "Invalid OpenAI API key."}
|
||
case http.StatusTooManyRequests:
|
||
// rate limiting or engine overload (wait and retry)
|
||
return m.retry(content, modsError{err: err, reason: "You’ve hit your OpenAI API rate limit."})
|
||
case http.StatusInternalServerError:
|
||
if mod.API == "openai" {
|
||
return m.retry(content, modsError{err: err, reason: "OpenAI API server error."})
|
||
}
|
||
return modsError{err: err, reason: fmt.Sprintf("Error loading model '%s' for API '%s'", mod.Name, mod.API)}
|
||
default:
|
||
return m.retry(content, modsError{err: err, reason: "Unknown OpenAI API error."})
|
||
}
|
||
}
|
||
|
||
if err != nil {
|
||
return modsError{err: err, reason: "There was a problem with the OpenAI API request."}
|
||
}
|
||
return completionOutput{resp.Choices[0].Message.Content}
|
||
}
|
||
}
|
||
|
||
func readStdinCmd() tea.Msg {
|
||
if !isatty.IsTerminal(os.Stdin.Fd()) {
|
||
reader := bufio.NewReader(os.Stdin)
|
||
stdinBytes, err := io.ReadAll(reader)
|
||
if err != nil {
|
||
return modsError{err, "Unable to read stdin."}
|
||
}
|
||
return completionInput{string(stdinBytes)}
|
||
}
|
||
return completionInput{""}
|
||
}
|
||
|
||
// noOmitFloat converts a 0.0 value to a float usable by the OpenAI client
|
||
// library, which currently uses Float32 fields in the request struct with the
|
||
// omitempty tag. This means we need to use math.SmallestNonzeroFloat32 instead
|
||
// of 0.0 so it doesn't get stripped from the request and replaced server side
|
||
// with the default values.
|
||
// Issue: https://github.com/sashabaranov/go-openai/issues/9
|
||
func noOmitFloat(f float32) float32 {
|
||
if f == 0.0 {
|
||
return math.SmallestNonzeroFloat32
|
||
}
|
||
return f
|
||
}
|