feat: add support for Cohere (#294)

This commit is contained in:
Maxime Brunet 2024-07-29 19:39:10 +00:00 committed by GitHub
parent ff9a598b20
commit fe8a551e66
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 248 additions and 1 deletions

View file

@ -17,10 +17,11 @@ tool to add a sprinkle of AI in your command line and make your pipelines
artificially intelligent.
It works great with LLMs running locally through [LocalAI](LocalAI). You can
also use [OpenAI](OpenAI), [Groq](Groq), or [Azure OpenAI](AzureOpenAI).
also use [OpenAI](OpenAI), [Cohere](Cohere), [Groq](Groq), or [Azure OpenAI](AzureOpenAI).
[LocalAI]: https://github.com/go-skynet/LocalAI
[OpenAI]: https://platform.openai.com/account/api-keys
[Cohere]: https://dashboard.cohere.com/api-keys
[Groq]: https://console.groq.com/keys
[AzureOpenAI]: https://azure.microsoft.com/en-us/products/cognitive-services/openai-service
@ -196,6 +197,13 @@ can grab it the [OpenAI website](https://platform.openai.com/account/api-keys).
Alternatively, set the [`AZURE_OPENAI_KEY`] environment variable to use Azure
OpenAI. Grab a key from [Azure](https://azure.microsoft.com/en-us/products/cognitive-services/openai-service).
### Cohere
Cohere provides enterprise optimized models.
Set the `COHERE_API_KEY` environment variable. If you don't have one yet, you can
get it from the [Cohere dashboard](https://dashboard.cohere.com/api-keys).
### Local AI
Local AI allows you to run models locally. Mods works with the GPT4ALL-J model

153
cohere.go Normal file
View file

@ -0,0 +1,153 @@
package main
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
cohere "github.com/cohere-ai/cohere-go/v2"
"github.com/cohere-ai/cohere-go/v2/client"
"github.com/cohere-ai/cohere-go/v2/core"
coherecore "github.com/cohere-ai/cohere-go/v2/core"
"github.com/cohere-ai/cohere-go/v2/option"
openai "github.com/sashabaranov/go-openai"
)
// CohereClientConfig represents the configuration for the Cohere API client.
type CohereClientConfig struct {
AuthToken string
BaseURL string
HTTPClient *http.Client
EmptyMessagesLimit uint
}
// DefaultCohereConfig returns the default configuration for the Cohere API client.
func DefaultCohereConfig(authToken string) CohereClientConfig {
return CohereClientConfig{
AuthToken: authToken,
BaseURL: "",
HTTPClient: &http.Client{},
}
}
// CohereClient is a client for the Cohere API.
type CohereClient struct {
*client.Client
}
// NewCohereClient creates a new [client.Client] with the given configuration.
func NewCohereClientWithConfig(config CohereClientConfig) *CohereClient {
opts := []option.RequestOption{
client.WithToken(config.AuthToken),
client.WithHTTPClient(config.HTTPClient),
}
if config.BaseURL != "" {
opts = append(opts, client.WithBaseURL(config.BaseURL))
}
return &CohereClient{
Client: client.NewClient(opts...),
}
}
// CohereChatCompletionStream represents a stream for chat completion.
type CohereChatCompletionStream struct {
*cohereStreamReader
}
type cohereStreamReader struct {
*core.Stream[cohere.StreamedChatResponse]
}
// Recv reads the next response from the stream.
func (stream *cohereStreamReader) Recv() (response openai.ChatCompletionStreamResponse, err error) {
return stream.processMessages()
}
// Close closes the stream.
func (stream *cohereStreamReader) Close() error {
return stream.Stream.Close()
}
func (stream *cohereStreamReader) processMessages() (openai.ChatCompletionStreamResponse, error) {
for {
message, err := stream.Stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return *new(openai.ChatCompletionStreamResponse), io.EOF
}
return *new(openai.ChatCompletionStreamResponse), err
}
if message.EventType != "text-generation" {
continue
}
// NOTE: Leverage the existing logic based on OpenAI ChatCompletionStreamResponse by
// converting the Cohere events into them.
response := openai.ChatCompletionStreamResponse{
Choices: []openai.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{
Content: message.TextGeneration.Text,
Role: "assistant",
},
},
},
}
return response, nil
}
}
// CreateChatCompletionStream — API call to create a chat completion w/ streaming
// support.
func (c *CohereClient) CreateChatCompletionStream(
ctx context.Context,
request *cohere.ChatStreamRequest,
) (stream *CohereChatCompletionStream, err error) {
resp, err := c.ChatStream(ctx, request)
if err != nil {
return
}
stream = &CohereChatCompletionStream{
cohereStreamReader: &cohereStreamReader{
Stream: resp,
},
}
return
}
// CohereToOpenAIAPIError attempts to convert a Cohere API error into
// an OpenAI API error to later reuse the existing error handling logic.
func CohereToOpenAIAPIError(err error) error {
ce := &coherecore.APIError{}
if !errors.As(err, &ce) {
return err
}
unwrapped := ce.Unwrap()
if unwrapped == nil {
unwrapped = err
}
var message string
var body map[string]interface{}
if err := json.Unmarshal([]byte(unwrapped.Error()), &body); err == nil {
message, _ = body["message"].(string)
}
if message == "" {
message = unwrapped.Error()
}
return &openai.APIError{
HTTPStatusCode: ce.StatusCode,
Message: message,
}
}

View file

@ -93,6 +93,13 @@ apis:
claude-3-opus-20240229:
aliases: ["claude3-opus", "opus"]
max-input-chars: 680000
cohere:
base-url: https://api.cohere.com/v1
models:
command-r-plus:
max-input-chars: 128000
command-r:
max-input-chars: 128000
ollama:
base-url: http://localhost:11434/api
models: # https://ollama.com/library

1
go.mod
View file

@ -16,6 +16,7 @@ require (
github.com/charmbracelet/x/editor v0.0.0-20231116172829-450eedbca1ab
github.com/charmbracelet/x/exp/ordered v0.0.0-20231010190216-1cb11efc897d
github.com/charmbracelet/x/exp/strings v0.0.0-20240524151031-ff83003bf67a
github.com/cohere-ai/cohere-go/v2 v2.8.2
github.com/jmoiron/sqlx v1.3.5
github.com/lucasb-eyer/go-colorful v1.2.0
github.com/mattn/go-isatty v0.0.20

2
go.sum
View file

@ -42,6 +42,8 @@ github.com/charmbracelet/x/term v0.1.1 h1:3cosVAiPOig+EV4X9U+3LDgtwwAoEzJjNdwbXD
github.com/charmbracelet/x/term v0.1.1/go.mod h1:wB1fHt5ECsu3mXYusyzcngVWWlu1KKUmmLhfgr/Flxw=
github.com/charmbracelet/x/windows v0.1.2 h1:Iumiwq2G+BRmgoayww/qfcvof7W/3uLoelhxojXlRWg=
github.com/charmbracelet/x/windows v0.1.2/go.mod h1:GLEO/l+lizvFDBPLIOk+49gdX49L9YWMB5t+DZd0jkQ=
github.com/cohere-ai/cohere-go/v2 v2.8.2 h1:NtxtcqkJ3ZBj8DFgk/4hpOrGK7CGnllGNpQn1bkaqQs=
github.com/cohere-ai/cohere-go/v2 v2.8.2/go.mod h1:dlDCT66i8BqZDuuskFvYzsrc+O0M4l5J9Ibckoflvt4=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

13
mods.go
View file

@ -262,6 +262,7 @@ func (m *Mods) startCompletionCmd(content string) tea.Cmd {
var api API
var ccfg openai.ClientConfig
var accfg AnthropicClientConfig
var cccfg CohereClientConfig
var occfg OllamaClientConfig
cfg := m.Config
@ -325,6 +326,15 @@ func (m *Mods) startCompletionCmd(content string) tea.Cmd {
if api.Version != "" {
accfg.Version = AnthropicAPIVersion(api.Version)
}
case "cohere":
key, err := m.ensureKey(api, "COHERE_API_KEY", "https://dashboard.cohere.com/api-keys")
if err != nil {
return err
}
cccfg = DefaultCohereConfig(key)
if api.BaseURL != "" {
ccfg.BaseURL = api.BaseURL
}
case "azure", "azure-ad":
key, err := m.ensureKey(api, "AZURE_OPENAI_KEY", "https://aka.ms/oai/access")
if err != nil {
@ -353,6 +363,7 @@ func (m *Mods) startCompletionCmd(content string) tea.Cmd {
httpClient := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)}}
ccfg.HTTPClient = httpClient
accfg.HTTPClient = httpClient
cccfg.HTTPClient = httpClient
occfg.HTTPClient = httpClient
}
@ -363,6 +374,8 @@ func (m *Mods) startCompletionCmd(content string) tea.Cmd {
switch mod.API {
case "anthropic":
return m.createAnthropicStream(content, accfg, mod)
case "cohere":
return m.createCohereStream(content, cccfg, mod)
case "ollama":
return m.createOllamaStream(content, occfg, mod)
default:

View file

@ -6,6 +6,7 @@ import (
"strings"
tea "github.com/charmbracelet/bubbletea"
cohere "github.com/cohere-ai/cohere-go/v2"
openai "github.com/sashabaranov/go-openai"
)
@ -126,6 +127,68 @@ func (m *Mods) createAnthropicStream(content string, accfg AnthropicClientConfig
return m.receiveCompletionStreamCmd(completionOutput{stream: stream})()
}
func (m *Mods) createCohereStream(content string, cccfg CohereClientConfig, mod Model) tea.Msg {
cfg := m.Config
client := NewCohereClientWithConfig(cccfg)
ctx, cancel := context.WithCancel(context.Background())
m.cancelRequest = cancel
if err := m.setupStreamContext(content, mod); err != nil {
return err
}
var messages []*cohere.Message
for _, message := range m.messages {
switch message.Role {
case openai.ChatMessageRoleSystem:
// For system, it is recommended to use the `preamble` field
// rather than a "SYSTEM" role message
m.system += message.Content + "\n"
case openai.ChatMessageRoleAssistant:
messages = append(messages, &cohere.Message{
Role: "CHATBOT",
Chatbot: &cohere.ChatMessage{
Message: message.Content,
},
})
case openai.ChatMessageRoleUser:
messages = append(messages, &cohere.Message{
Role: "USER",
User: &cohere.ChatMessage{
Message: message.Content,
},
})
}
}
var history []*cohere.Message
if len(messages) > 1 {
history = messages[:len(messages)-1]
}
req := &cohere.ChatStreamRequest{
Model: cohere.String(mod.Name),
ChatHistory: history,
Message: messages[len(messages)-1].User.Message,
Preamble: cohere.String(m.system),
Temperature: cohere.Float64(float64(cfg.Temperature)),
P: cohere.Float64(float64(cfg.TopP)),
StopSequences: cfg.Stop,
}
if cfg.MaxTokens > 0 {
req.MaxTokens = cohere.Int(cfg.MaxTokens)
}
stream, err := client.CreateChatCompletionStream(ctx, req)
if err != nil {
return m.handleRequestError(CohereToOpenAIAPIError(err), mod, content)
}
return m.receiveCompletionStreamCmd(completionOutput{stream: stream})()
}
func (m *Mods) setupStreamContext(content string, mod Model) error {
cfg := m.Config
m.messages = []openai.ChatCompletionMessage{}