diff --git a/config.go b/config.go index e5f5ce9..3802c71 100644 --- a/config.go +++ b/config.go @@ -17,37 +17,34 @@ import ( const configTemplate = ` # {{ index .Help "apis" }} +# LocalAI setup instructions: https://github.com/go-skynet/LocalAI#example-use-gpt4all-j-model apis: - openai: https://api.openai.com/v1 - # LocalAI setup instructions: https://github.com/go-skynet/LocalAI#example-use-gpt4all-j-model - localai: http://localhost:8080 -# {{ index .Help "models" }} -models: - gpt-4: - aliases: ["4"] - max-input-chars: 24500 - api: openai - fallback: gpt-3.5-turbo - gpt-4-32k: - aliases: ["32k"] - max-input-chars: 98000 - api: openai - fallback: gpt-4 - gpt-3.5-turbo: - aliases: ["35t"] - max-input-chars: 12250 - api: openai - fallback: gpt-3.5 - gpt-3.5: - aliases: ["35"] - max-input-chars: 12250 - api: openai - fallback: - ggml-gpt4all-j: - aliases: ["local", "4all"] - max-input-chars: 12250 - api: localai - fallback: + openai: + base-url: https://api.openai.com/v1 + models: + gpt-4: + aliases: ["4"] + max-input-chars: 24500 + fallback: gpt-3.5-turbo + gpt-4-32k: + aliases: ["32k"] + max-input-chars: 98000 + fallback: gpt-4 + gpt-3.5-turbo: + aliases: ["35t"] + max-input-chars: 12250 + fallback: gpt-3.5 + gpt-3.5: + aliases: ["35"] + max-input-chars: 12250 + fallback: + localai: + base-url: http://localhost:8080 + models: + ggml-gpt4all-j: + aliases: ["local", "4all"] + max-input-chars: 12250 + fallback: # {{ index .Help "model" }} default-model: gpt-4 # {{ index .Help "max-input-chars" }} @@ -77,22 +74,22 @@ status-text: Generating ` type config struct { - API string `yaml:"api" env:"API"` - APIs map[string]string `yaml:"apis"` - Model string `yaml:"default-model" env:"MODEL"` - Models map[string]Model `yaml:"models"` - Markdown bool `yaml:"format" env:"FORMAT"` - Quiet bool `yaml:"quiet" env:"QUIET"` - MaxTokens int `yaml:"max-tokens" env:"MAX_TOKENS"` - MaxInputChars int `yaml:"max-input-chars" env:"MAX_INPUT_CHARS"` - Temperature float32 `yaml:"temp" env:"TEMP"` - TopP float32 `yaml:"topp" env:"TOPP"` - NoLimit bool `yaml:"no-limit" env:"NO_LIMIT"` - IncludePromptArgs bool `yaml:"include-prompt-args" env:"INCLUDE_PROMPT_ARGS"` - IncludePrompt int `yaml:"include-prompt" env:"INCLUDE_PROMPT"` - MaxRetries int `yaml:"max-retries" env:"MAX_RETRIES"` - Fanciness uint `yaml:"fanciness" env:"FANCINESS"` - StatusText string `yaml:"status-text" env:"STATUS_TEXT"` + APIs map[string]API `yaml:"apis"` + Model string `yaml:"default-model" env:"MODEL"` + Markdown bool `yaml:"format" env:"FORMAT"` + Quiet bool `yaml:"quiet" env:"QUIET"` + MaxTokens int `yaml:"max-tokens" env:"MAX_TOKENS"` + MaxInputChars int `yaml:"max-input-chars" env:"MAX_INPUT_CHARS"` + Temperature float32 `yaml:"temp" env:"TEMP"` + TopP float32 `yaml:"topp" env:"TOPP"` + NoLimit bool `yaml:"no-limit" env:"NO_LIMIT"` + IncludePromptArgs bool `yaml:"include-prompt-args" env:"INCLUDE_PROMPT_ARGS"` + IncludePrompt int `yaml:"include-prompt" env:"INCLUDE_PROMPT"` + MaxRetries int `yaml:"max-retries" env:"MAX_RETRIES"` + Fanciness uint `yaml:"fanciness" env:"FANCINESS"` + StatusText string `yaml:"status-text" env:"STATUS_TEXT"` + API string + Models map[string]Model ShowHelp bool Prefix string Version bool @@ -105,10 +102,9 @@ func newConfig() (config, error) { var content []byte help := map[string]string{ - "api": "Default OpenAI compatible REST API (openai, localai).", + "api": "OpenAI compatible REST API (openai, localai).", "apis": "Aliases and endpoints for OpenAI compatible REST API.", "model": "Default model (gpt-3.5-turbo, gpt-4, ggml-gpt4all-j...).", - "models": "Model details and aliases.", "max-input-chars": "Default character limit on input to model.", "format": "Format response as markdown.", "prompt": "Include the prompt from the arguments and stdin, truncate stdin to specified number of lines.", @@ -168,13 +164,15 @@ func newConfig() (config, error) { return c, err } - // Set model aliases ms := make(map[string]Model) - for k, m := range c.Models { - m.Name = k - ms[k] = m - for _, am := range m.Aliases { - ms[am] = m + for ak, av := range c.APIs { + for mk, mv := range av.Models { + mv.Name = mk + mv.API = ak + ms[mk] = mv + for _, a := range mv.Aliases { + ms[a] = mv + } } } c.Models = ms diff --git a/model.go b/model.go index e0f18b7..2c7a993 100644 --- a/model.go +++ b/model.go @@ -3,8 +3,14 @@ package main // Model represents the LLM model used in the API call. type Model struct { Name string + API string MaxChars int `yaml:"max-input-chars"` Aliases []string `yaml:"aliases"` - API string `yaml:"api"` Fallback string `yaml:"fallback"` } + +// API represents an API endpoint and its models. +type API struct { + BaseURL string `yaml:"base-url"` + Models map[string]Model `yaml:"models"` +} diff --git a/mods.go b/mods.go index cc913af..a42426d 100644 --- a/mods.go +++ b/mods.go @@ -226,7 +226,7 @@ func (m *Mods) startCompletionCmd(content string) tea.Cmd { } } ccfg := openai.DefaultConfig(key) - ccfg.BaseURL, ok = cfg.APIs[mod.API] + api, ok := cfg.APIs[mod.API] if !ok { eps := make([]string, 0) for k := range cfg.APIs { @@ -237,6 +237,7 @@ func (m *Mods) startCompletionCmd(content string) tea.Cmd { err: fmt.Errorf("Your configured API endpoints are: %s", eps), } } + ccfg.BaseURL = api.BaseURL client := openai.NewClientWithConfig(ccfg) ctx, cancel := context.WithCancel(context.Background()) defer cancel()