mirror of
https://github.com/charmbracelet/mods
synced 2024-10-18 07:12:18 +00:00
e9af3b613e
* feat(stream): Filter out System role messages for Anthropic API requests * feat(api): add support for Google API and Gemini models * refactor(google): simplify variable declarations in google.go * feat: add TopK parameter for token generation control * fix(formatter): remove redundant default FormatText assignment This was breaking the usage of the `--format` and `--format-as` flags for every value other than `markdown`. * chore(core): initialize FormatText config with default value if nil --------- Co-authored-by: guzmonne <guzmonne@hotmail.com>
288 lines
9 KiB
Go
288 lines
9 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
|
|
openai "github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
// OllamaClientConfig represents the configuration for the Ollama API client.
|
|
type OllamaClientConfig struct {
|
|
BaseURL string
|
|
HTTPClient *http.Client
|
|
EmptyMessagesLimit uint
|
|
}
|
|
|
|
// DefaultOllamaConfig returns the default configuration for the Ollama API client.
|
|
func DefaultOllamaConfig() OllamaClientConfig {
|
|
return OllamaClientConfig{
|
|
BaseURL: "http://localhost:11434/api",
|
|
HTTPClient: &http.Client{},
|
|
EmptyMessagesLimit: defaultEmptyMessagesLimit,
|
|
}
|
|
}
|
|
|
|
// OllamaMessageCompletionRequestOptions represents the valid parameters and values options for the request.
|
|
type OllamaMessageCompletionRequestOptions struct {
|
|
Mirostat int `json:"mirostat,omitempty"`
|
|
MirostatEta int `json:"mirostat_eta,omitempty"`
|
|
MirostatTau int `json:"mirostat_tau,omitempty"`
|
|
NumCtx int `json:"num_ctx,omitempty"`
|
|
RepeatLastN int `json:"repeat_last_n,omitempty"`
|
|
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
|
|
Temperature float32 `json:"temperature,omitempty"`
|
|
Seed int `json:"seed,omitempty"`
|
|
Stop string `json:"stop,omitempty"`
|
|
TfsZ float32 `json:"tfs_z,omitempty"`
|
|
NumPredict int `json:"num_predict,omitempty"`
|
|
TopP float32 `json:"top_p,omitempty"`
|
|
TopK int `json:"top_k,omitempty"`
|
|
}
|
|
|
|
// OllamaMessageCompletionRequest represents the request body for the generate completion API.
|
|
type OllamaMessageCompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []openai.ChatCompletionMessage `json:"messages"`
|
|
Options OllamaMessageCompletionRequestOptions `json:"options,omitempty"`
|
|
Stream bool `json:"stream,omitempty"`
|
|
KeepAlive string `json:"keep_alive,omitempty"`
|
|
}
|
|
|
|
// OllamaRequestBuilder is an interface for building HTTP requests for the Ollama API.
|
|
type OllamaRequestBuilder interface {
|
|
Build(ctx context.Context, method, url string, body any, header http.Header) (*http.Request, error)
|
|
}
|
|
|
|
// NewOllamaRequestBuilder creates a new HTTPRequestBuilder.
|
|
func NewOllamaRequestBuilder() *HTTPRequestBuilder {
|
|
return &HTTPRequestBuilder{
|
|
marshaller: &JSONMarshaller{},
|
|
}
|
|
}
|
|
|
|
// OllamaClient is a client for the Ollama API.
|
|
type OllamaClient struct {
|
|
config OllamaClientConfig
|
|
|
|
requestBuilder OllamaRequestBuilder
|
|
}
|
|
|
|
// NewOllamaClient creates a new OllamaClient with the given configuration.
|
|
func NewOllamaClientWithConfig(config OllamaClientConfig) *OllamaClient {
|
|
return &OllamaClient{
|
|
config: config,
|
|
requestBuilder: NewOllamaRequestBuilder(),
|
|
}
|
|
}
|
|
|
|
const ollamaChatCompletionsSuffix = "/chat"
|
|
|
|
func (c *OllamaClient) newRequest(ctx context.Context, method, url string, setters ...requestOption) (*http.Request, error) {
|
|
// Default Options
|
|
args := &requestOptions{
|
|
body: nil,
|
|
header: make(http.Header),
|
|
}
|
|
for _, setter := range setters {
|
|
setter(args)
|
|
}
|
|
req, err := c.requestBuilder.Build(ctx, method, url, args.body, args.header)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("OllamaClient.newRequest: %w", err)
|
|
}
|
|
return req, nil
|
|
}
|
|
|
|
func (c *OllamaClient) handleErrorResp(resp *http.Response) error {
|
|
// Print the response text
|
|
var errRes openai.ErrorResponse
|
|
err := json.NewDecoder(resp.Body).Decode(&errRes)
|
|
if err != nil || errRes.Error == nil {
|
|
reqErr := &openai.RequestError{
|
|
HTTPStatusCode: resp.StatusCode,
|
|
Err: err,
|
|
}
|
|
if errRes.Error != nil {
|
|
reqErr.Err = errRes.Error
|
|
}
|
|
return reqErr
|
|
}
|
|
|
|
errRes.Error.HTTPStatusCode = resp.StatusCode
|
|
return errRes.Error
|
|
}
|
|
|
|
// OllamaMessageUsage represents the usage of an Ollama message.
|
|
type OllamaMessageUsage struct {
|
|
InputTokens int `json:"input_tokens,omitempty"`
|
|
OutputTokens int `json:"output_tokens,omitempty"`
|
|
}
|
|
|
|
// OllamaMessage represents an Ollama message.
|
|
type OllamaMessage struct {
|
|
Usage *OllamaMessageUsage `json:"usage,omitempty"`
|
|
StopReason *string `json:"stop_reason,omitempty"`
|
|
StopSequence *string `json:"stop_sequence,omitempty"`
|
|
ID string `json:"id,omitempty"`
|
|
Type string `json:"type"`
|
|
Role string `json:"role,omitempty"`
|
|
Model string `json:"model,omitempty"`
|
|
Content []string `json:"content,omitempty"`
|
|
}
|
|
|
|
// OllamaMessageContentBlock represents a content block in an Ollama message.
|
|
type OllamaMessageContentBlock struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text,omitempty"`
|
|
}
|
|
|
|
// OllamaMessageTextDelta represents a text delta in an Ollama message.
|
|
type OllamaMessageTextDelta struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text,omitempty"`
|
|
}
|
|
|
|
// OllamaMessageCompletionRequest represents the response body for the generate completion API.
|
|
type OllamaCompletionMessageResponse struct {
|
|
Model string `json:"model"`
|
|
CreatedAt string `json:"created_at"`
|
|
Message openai.ChatCompletionMessage `json:"message"`
|
|
Done bool `json:"done"`
|
|
TotalDuration int `json:"total_duration"`
|
|
LoadDuration int `json:"load_duration"`
|
|
PromptEvalCount int `json:"prompt_eval_count"`
|
|
PromptEvalDuration int `json:"prompt_eval_duration"`
|
|
EvalCount int `json:"eval_count"`
|
|
EvalDuration int `json:"eval_duration"`
|
|
}
|
|
|
|
// ChatCompletionStream represents a stream for chat completion.
|
|
type OllamaChatCompletionStream struct {
|
|
*ollamaStreamReader
|
|
}
|
|
|
|
type ollamaStreamReader struct {
|
|
emptyMessagesLimit uint
|
|
isFinished bool
|
|
|
|
reader *bufio.Reader
|
|
response *http.Response
|
|
errAccumulator ErrorAccumulator
|
|
unmarshaler Unmarshaler
|
|
|
|
httpHeader
|
|
}
|
|
|
|
// Recv reads the next response from the stream.
|
|
func (stream *ollamaStreamReader) Recv() (response openai.ChatCompletionStreamResponse, err error) {
|
|
if stream.isFinished {
|
|
err = io.EOF
|
|
return
|
|
}
|
|
|
|
response, err = stream.processLines()
|
|
return
|
|
}
|
|
|
|
// Close closes the stream.
|
|
func (stream *ollamaStreamReader) Close() error {
|
|
return stream.response.Body.Close() //nolint:wrapcheck
|
|
}
|
|
|
|
//nolint:gocognit
|
|
func (stream *ollamaStreamReader) processLines() (openai.ChatCompletionStreamResponse, error) {
|
|
for {
|
|
rawLine, readErr := stream.reader.ReadBytes('\n')
|
|
|
|
if readErr != nil {
|
|
return *new(openai.ChatCompletionStreamResponse), fmt.Errorf("ollamaStreamReader.processLines: %w", readErr)
|
|
}
|
|
|
|
noSpaceLine := bytes.TrimSpace(rawLine)
|
|
|
|
var chunk OllamaCompletionMessageResponse
|
|
unmarshalErr := stream.unmarshaler.Unmarshal(noSpaceLine, &chunk)
|
|
if unmarshalErr != nil {
|
|
return openai.ChatCompletionStreamResponse{}, fmt.Errorf("ollamaStreamReader.processLines: %w", unmarshalErr)
|
|
}
|
|
|
|
if chunk.Done {
|
|
return openai.ChatCompletionStreamResponse{}, nil
|
|
}
|
|
|
|
if chunk.Message.Content == "" {
|
|
continue
|
|
}
|
|
|
|
// NOTE: Leverage the existing logic based on OpenAI ChatCompletionStreamResponse by
|
|
// converting the Ollama events into them.
|
|
response := openai.ChatCompletionStreamResponse{
|
|
Choices: []openai.ChatCompletionStreamChoice{
|
|
{
|
|
Index: 0,
|
|
Delta: openai.ChatCompletionStreamChoiceDelta{
|
|
Content: chunk.Message.Content,
|
|
Role: "assistant",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
}
|
|
|
|
func ollamaSendRequestStream(client *OllamaClient, req *http.Request) (*ollamaStreamReader, error) {
|
|
req.Header.Set("content-type", "application/json")
|
|
|
|
resp, err := client.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
|
|
if err != nil {
|
|
return new(ollamaStreamReader), err
|
|
}
|
|
if isFailureStatusCode(resp) {
|
|
return new(ollamaStreamReader), client.handleErrorResp(resp)
|
|
}
|
|
|
|
return &ollamaStreamReader{
|
|
emptyMessagesLimit: client.config.EmptyMessagesLimit,
|
|
reader: bufio.NewReader(resp.Body),
|
|
response: resp,
|
|
errAccumulator: NewErrorAccumulator(),
|
|
unmarshaler: &JSONUnmarshaler{},
|
|
httpHeader: httpHeader(resp.Header),
|
|
}, nil
|
|
}
|
|
|
|
// CreateChatCompletionStream — API call to create a generate completion w/ streaming
|
|
// support. It sets whether to stream back partial progress. If set, tokens will be
|
|
// sent as data-only server-sent events as they become available, with the
|
|
// stream terminated by a data: [DONE] message.
|
|
func (c *OllamaClient) CreateChatCompletionStream(
|
|
ctx context.Context,
|
|
request OllamaMessageCompletionRequest,
|
|
) (stream *OllamaChatCompletionStream, err error) {
|
|
urlSuffix := ollamaChatCompletionsSuffix
|
|
|
|
request.Stream = true
|
|
req, err := c.newRequest(ctx, http.MethodPost, c.config.BaseURL+urlSuffix, withBody(request))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := ollamaSendRequestStream(c, req)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
stream = &OllamaChatCompletionStream{
|
|
ollamaStreamReader: resp,
|
|
}
|
|
return
|
|
}
|