mirror of
https://github.com/charmbracelet/mods
synced 2024-10-18 15:22:17 +00:00
feat: add support for Cohere (#294)
This commit is contained in:
parent
ff9a598b20
commit
fe8a551e66
10
README.md
10
README.md
|
@ -17,10 +17,11 @@ tool to add a sprinkle of AI in your command line and make your pipelines
|
|||
artificially intelligent.
|
||||
|
||||
It works great with LLMs running locally through [LocalAI](LocalAI). You can
|
||||
also use [OpenAI](OpenAI), [Groq](Groq), or [Azure OpenAI](AzureOpenAI).
|
||||
also use [OpenAI](OpenAI), [Cohere](Cohere), [Groq](Groq), or [Azure OpenAI](AzureOpenAI).
|
||||
|
||||
[LocalAI]: https://github.com/go-skynet/LocalAI
|
||||
[OpenAI]: https://platform.openai.com/account/api-keys
|
||||
[Cohere]: https://dashboard.cohere.com/api-keys
|
||||
[Groq]: https://console.groq.com/keys
|
||||
[AzureOpenAI]: https://azure.microsoft.com/en-us/products/cognitive-services/openai-service
|
||||
|
||||
|
@ -196,6 +197,13 @@ can grab it the [OpenAI website](https://platform.openai.com/account/api-keys).
|
|||
Alternatively, set the [`AZURE_OPENAI_KEY`] environment variable to use Azure
|
||||
OpenAI. Grab a key from [Azure](https://azure.microsoft.com/en-us/products/cognitive-services/openai-service).
|
||||
|
||||
### Cohere
|
||||
|
||||
Cohere provides enterprise optimized models.
|
||||
|
||||
Set the `COHERE_API_KEY` environment variable. If you don't have one yet, you can
|
||||
get it from the [Cohere dashboard](https://dashboard.cohere.com/api-keys).
|
||||
|
||||
### Local AI
|
||||
|
||||
Local AI allows you to run models locally. Mods works with the GPT4ALL-J model
|
||||
|
|
153
cohere.go
Normal file
153
cohere.go
Normal file
|
@ -0,0 +1,153 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
cohere "github.com/cohere-ai/cohere-go/v2"
|
||||
"github.com/cohere-ai/cohere-go/v2/client"
|
||||
"github.com/cohere-ai/cohere-go/v2/core"
|
||||
coherecore "github.com/cohere-ai/cohere-go/v2/core"
|
||||
"github.com/cohere-ai/cohere-go/v2/option"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// CohereClientConfig represents the configuration for the Cohere API client.
|
||||
type CohereClientConfig struct {
|
||||
AuthToken string
|
||||
BaseURL string
|
||||
HTTPClient *http.Client
|
||||
EmptyMessagesLimit uint
|
||||
}
|
||||
|
||||
// DefaultCohereConfig returns the default configuration for the Cohere API client.
|
||||
func DefaultCohereConfig(authToken string) CohereClientConfig {
|
||||
return CohereClientConfig{
|
||||
AuthToken: authToken,
|
||||
BaseURL: "",
|
||||
HTTPClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
// CohereClient is a client for the Cohere API.
|
||||
type CohereClient struct {
|
||||
*client.Client
|
||||
}
|
||||
|
||||
// NewCohereClient creates a new [client.Client] with the given configuration.
|
||||
func NewCohereClientWithConfig(config CohereClientConfig) *CohereClient {
|
||||
opts := []option.RequestOption{
|
||||
client.WithToken(config.AuthToken),
|
||||
client.WithHTTPClient(config.HTTPClient),
|
||||
}
|
||||
|
||||
if config.BaseURL != "" {
|
||||
opts = append(opts, client.WithBaseURL(config.BaseURL))
|
||||
}
|
||||
|
||||
return &CohereClient{
|
||||
Client: client.NewClient(opts...),
|
||||
}
|
||||
}
|
||||
|
||||
// CohereChatCompletionStream represents a stream for chat completion.
|
||||
type CohereChatCompletionStream struct {
|
||||
*cohereStreamReader
|
||||
}
|
||||
|
||||
type cohereStreamReader struct {
|
||||
*core.Stream[cohere.StreamedChatResponse]
|
||||
}
|
||||
|
||||
// Recv reads the next response from the stream.
|
||||
func (stream *cohereStreamReader) Recv() (response openai.ChatCompletionStreamResponse, err error) {
|
||||
return stream.processMessages()
|
||||
}
|
||||
|
||||
// Close closes the stream.
|
||||
func (stream *cohereStreamReader) Close() error {
|
||||
return stream.Stream.Close()
|
||||
}
|
||||
|
||||
func (stream *cohereStreamReader) processMessages() (openai.ChatCompletionStreamResponse, error) {
|
||||
for {
|
||||
message, err := stream.Stream.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return *new(openai.ChatCompletionStreamResponse), io.EOF
|
||||
}
|
||||
return *new(openai.ChatCompletionStreamResponse), err
|
||||
}
|
||||
|
||||
if message.EventType != "text-generation" {
|
||||
continue
|
||||
}
|
||||
|
||||
// NOTE: Leverage the existing logic based on OpenAI ChatCompletionStreamResponse by
|
||||
// converting the Cohere events into them.
|
||||
response := openai.ChatCompletionStreamResponse{
|
||||
Choices: []openai.ChatCompletionStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: openai.ChatCompletionStreamChoiceDelta{
|
||||
Content: message.TextGeneration.Text,
|
||||
Role: "assistant",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
|
||||
// CreateChatCompletionStream — API call to create a chat completion w/ streaming
|
||||
// support.
|
||||
func (c *CohereClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
request *cohere.ChatStreamRequest,
|
||||
) (stream *CohereChatCompletionStream, err error) {
|
||||
resp, err := c.ChatStream(ctx, request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
stream = &CohereChatCompletionStream{
|
||||
cohereStreamReader: &cohereStreamReader{
|
||||
Stream: resp,
|
||||
},
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CohereToOpenAIAPIError attempts to convert a Cohere API error into
|
||||
// an OpenAI API error to later reuse the existing error handling logic.
|
||||
func CohereToOpenAIAPIError(err error) error {
|
||||
ce := &coherecore.APIError{}
|
||||
if !errors.As(err, &ce) {
|
||||
return err
|
||||
}
|
||||
|
||||
unwrapped := ce.Unwrap()
|
||||
if unwrapped == nil {
|
||||
unwrapped = err
|
||||
}
|
||||
|
||||
var message string
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(unwrapped.Error()), &body); err == nil {
|
||||
message, _ = body["message"].(string)
|
||||
}
|
||||
|
||||
if message == "" {
|
||||
message = unwrapped.Error()
|
||||
}
|
||||
|
||||
return &openai.APIError{
|
||||
HTTPStatusCode: ce.StatusCode,
|
||||
Message: message,
|
||||
}
|
||||
}
|
|
@ -93,6 +93,13 @@ apis:
|
|||
claude-3-opus-20240229:
|
||||
aliases: ["claude3-opus", "opus"]
|
||||
max-input-chars: 680000
|
||||
cohere:
|
||||
base-url: https://api.cohere.com/v1
|
||||
models:
|
||||
command-r-plus:
|
||||
max-input-chars: 128000
|
||||
command-r:
|
||||
max-input-chars: 128000
|
||||
ollama:
|
||||
base-url: http://localhost:11434/api
|
||||
models: # https://ollama.com/library
|
||||
|
|
1
go.mod
1
go.mod
|
@ -16,6 +16,7 @@ require (
|
|||
github.com/charmbracelet/x/editor v0.0.0-20231116172829-450eedbca1ab
|
||||
github.com/charmbracelet/x/exp/ordered v0.0.0-20231010190216-1cb11efc897d
|
||||
github.com/charmbracelet/x/exp/strings v0.0.0-20240524151031-ff83003bf67a
|
||||
github.com/cohere-ai/cohere-go/v2 v2.8.2
|
||||
github.com/jmoiron/sqlx v1.3.5
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0
|
||||
github.com/mattn/go-isatty v0.0.20
|
||||
|
|
2
go.sum
2
go.sum
|
@ -42,6 +42,8 @@ github.com/charmbracelet/x/term v0.1.1 h1:3cosVAiPOig+EV4X9U+3LDgtwwAoEzJjNdwbXD
|
|||
github.com/charmbracelet/x/term v0.1.1/go.mod h1:wB1fHt5ECsu3mXYusyzcngVWWlu1KKUmmLhfgr/Flxw=
|
||||
github.com/charmbracelet/x/windows v0.1.2 h1:Iumiwq2G+BRmgoayww/qfcvof7W/3uLoelhxojXlRWg=
|
||||
github.com/charmbracelet/x/windows v0.1.2/go.mod h1:GLEO/l+lizvFDBPLIOk+49gdX49L9YWMB5t+DZd0jkQ=
|
||||
github.com/cohere-ai/cohere-go/v2 v2.8.2 h1:NtxtcqkJ3ZBj8DFgk/4hpOrGK7CGnllGNpQn1bkaqQs=
|
||||
github.com/cohere-ai/cohere-go/v2 v2.8.2/go.mod h1:dlDCT66i8BqZDuuskFvYzsrc+O0M4l5J9Ibckoflvt4=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
|
|
13
mods.go
13
mods.go
|
@ -262,6 +262,7 @@ func (m *Mods) startCompletionCmd(content string) tea.Cmd {
|
|||
var api API
|
||||
var ccfg openai.ClientConfig
|
||||
var accfg AnthropicClientConfig
|
||||
var cccfg CohereClientConfig
|
||||
var occfg OllamaClientConfig
|
||||
|
||||
cfg := m.Config
|
||||
|
@ -325,6 +326,15 @@ func (m *Mods) startCompletionCmd(content string) tea.Cmd {
|
|||
if api.Version != "" {
|
||||
accfg.Version = AnthropicAPIVersion(api.Version)
|
||||
}
|
||||
case "cohere":
|
||||
key, err := m.ensureKey(api, "COHERE_API_KEY", "https://dashboard.cohere.com/api-keys")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cccfg = DefaultCohereConfig(key)
|
||||
if api.BaseURL != "" {
|
||||
ccfg.BaseURL = api.BaseURL
|
||||
}
|
||||
case "azure", "azure-ad":
|
||||
key, err := m.ensureKey(api, "AZURE_OPENAI_KEY", "https://aka.ms/oai/access")
|
||||
if err != nil {
|
||||
|
@ -353,6 +363,7 @@ func (m *Mods) startCompletionCmd(content string) tea.Cmd {
|
|||
httpClient := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)}}
|
||||
ccfg.HTTPClient = httpClient
|
||||
accfg.HTTPClient = httpClient
|
||||
cccfg.HTTPClient = httpClient
|
||||
occfg.HTTPClient = httpClient
|
||||
}
|
||||
|
||||
|
@ -363,6 +374,8 @@ func (m *Mods) startCompletionCmd(content string) tea.Cmd {
|
|||
switch mod.API {
|
||||
case "anthropic":
|
||||
return m.createAnthropicStream(content, accfg, mod)
|
||||
case "cohere":
|
||||
return m.createCohereStream(content, cccfg, mod)
|
||||
case "ollama":
|
||||
return m.createOllamaStream(content, occfg, mod)
|
||||
default:
|
||||
|
|
63
stream.go
63
stream.go
|
@ -6,6 +6,7 @@ import (
|
|||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
cohere "github.com/cohere-ai/cohere-go/v2"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
|
@ -126,6 +127,68 @@ func (m *Mods) createAnthropicStream(content string, accfg AnthropicClientConfig
|
|||
return m.receiveCompletionStreamCmd(completionOutput{stream: stream})()
|
||||
}
|
||||
|
||||
func (m *Mods) createCohereStream(content string, cccfg CohereClientConfig, mod Model) tea.Msg {
|
||||
cfg := m.Config
|
||||
|
||||
client := NewCohereClientWithConfig(cccfg)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancelRequest = cancel
|
||||
|
||||
if err := m.setupStreamContext(content, mod); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var messages []*cohere.Message
|
||||
for _, message := range m.messages {
|
||||
switch message.Role {
|
||||
case openai.ChatMessageRoleSystem:
|
||||
// For system, it is recommended to use the `preamble` field
|
||||
// rather than a "SYSTEM" role message
|
||||
m.system += message.Content + "\n"
|
||||
case openai.ChatMessageRoleAssistant:
|
||||
messages = append(messages, &cohere.Message{
|
||||
Role: "CHATBOT",
|
||||
Chatbot: &cohere.ChatMessage{
|
||||
Message: message.Content,
|
||||
},
|
||||
})
|
||||
case openai.ChatMessageRoleUser:
|
||||
messages = append(messages, &cohere.Message{
|
||||
Role: "USER",
|
||||
User: &cohere.ChatMessage{
|
||||
Message: message.Content,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var history []*cohere.Message
|
||||
if len(messages) > 1 {
|
||||
history = messages[:len(messages)-1]
|
||||
}
|
||||
|
||||
req := &cohere.ChatStreamRequest{
|
||||
Model: cohere.String(mod.Name),
|
||||
ChatHistory: history,
|
||||
Message: messages[len(messages)-1].User.Message,
|
||||
Preamble: cohere.String(m.system),
|
||||
Temperature: cohere.Float64(float64(cfg.Temperature)),
|
||||
P: cohere.Float64(float64(cfg.TopP)),
|
||||
StopSequences: cfg.Stop,
|
||||
}
|
||||
|
||||
if cfg.MaxTokens > 0 {
|
||||
req.MaxTokens = cohere.Int(cfg.MaxTokens)
|
||||
}
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(ctx, req)
|
||||
if err != nil {
|
||||
return m.handleRequestError(CohereToOpenAIAPIError(err), mod, content)
|
||||
}
|
||||
|
||||
return m.receiveCompletionStreamCmd(completionOutput{stream: stream})()
|
||||
}
|
||||
|
||||
func (m *Mods) setupStreamContext(content string, mod Model) error {
|
||||
cfg := m.Config
|
||||
m.messages = []openai.ChatCompletionMessage{}
|
||||
|
|
Loading…
Reference in a new issue