one-api/relay/chat.go

122 lines
2.7 KiB
Go

package relay
import (
"encoding/json"
"errors"
"fmt"
"math"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/common/utils"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
type relayChat struct {
relayBase
chatRequest types.ChatCompletionRequest
}
func NewRelayChat(c *gin.Context) *relayChat {
relay := &relayChat{}
relay.c = c
return relay
}
func (r *relayChat) setRequest() error {
if err := common.UnmarshalBodyReusable(r.c, &r.chatRequest); err != nil {
return err
}
if r.chatRequest.MaxTokens < 0 || r.chatRequest.MaxTokens > math.MaxInt32/2 {
return errors.New("max_tokens is invalid")
}
if r.chatRequest.Tools != nil {
r.c.Set("skip_only_chat", true)
}
if !r.chatRequest.Stream && r.chatRequest.StreamOptions != nil {
return errors.New("The 'stream_options' parameter is only allowed when 'stream' is enabled.")
}
r.originalModel = r.chatRequest.Model
return nil
}
func (r *relayChat) getRequest() interface{} {
return &r.chatRequest
}
func (r *relayChat) getPromptTokens() (int, error) {
return common.CountTokenMessages(r.chatRequest.Messages, r.modelName), nil
}
func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
chatProvider, ok := r.provider.(providersBase.ChatInterface)
if !ok {
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
done = true
return
}
r.chatRequest.Model = r.modelName
if r.chatRequest.Stream {
var response requester.StreamReaderInterface[string]
response, err = chatProvider.CreateChatCompletionStream(&r.chatRequest)
if err != nil {
return
}
doneStr := func() string {
return r.getUsageResponse()
}
err = responseStreamClient(r.c, response, r.cache, doneStr)
} else {
var response *types.ChatCompletionResponse
response, err = chatProvider.CreateChatCompletion(&r.chatRequest)
if err != nil {
return
}
err = responseJsonClient(r.c, response)
if err == nil && response.GetContent() != "" {
r.cache.SetResponse(response)
}
}
if err != nil {
done = true
}
return
}
func (r *relayChat) getUsageResponse() string {
if r.chatRequest.StreamOptions != nil && r.chatRequest.StreamOptions.IncludeUsage {
usageResponse := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
Object: "chat.completion.chunk",
Created: utils.GetTimestamp(),
Model: r.chatRequest.Model,
Choices: []types.ChatCompletionStreamChoice{},
Usage: r.provider.GetUsage(),
}
responseBody, err := json.Marshal(usageResponse)
if err != nil {
return ""
}
return string(responseBody)
}
return ""
}