feat: 支持CozeV3

This commit is contained in:
suziheng
2025-01-22 16:43:40 +08:00
parent 6eb4e788c7
commit 533f9853ac
13 changed files with 359 additions and 15 deletions

View File

@@ -64,6 +64,9 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
return &proxy.Adaptor{}
case apitype.Replicate:
return &replicate.Adaptor{}
case apitype.CozeV3:
return &coze.AdaptorV3{}
}
return nil
}

View File

@@ -0,0 +1,75 @@
package coze
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
type AdaptorV3 struct {
meta *meta.Meta
}
func (a *AdaptorV3) Init(meta *meta.Meta) {
a.meta = meta
}
func (a *AdaptorV3) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/v3/chat", meta.BaseURL), nil
}
func (a *AdaptorV3) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *AdaptorV3) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
request.User = a.meta.Config.UserID
return ConvertRequest(*request), nil
}
func (a *AdaptorV3) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *AdaptorV3) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *AdaptorV3) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
var responseText *string
if meta.IsStream {
err, responseText = V3StreamHandler(c, resp)
} else {
err, responseText = V3Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
if responseText != nil {
usage = openai.ResponseText2Usage(*responseText, meta.ActualModelName, meta.PromptTokens)
} else {
usage = &model.Usage{}
}
usage.PromptTokens = meta.PromptTokens
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return
}
func (a *AdaptorV3) GetModelList() []string {
return ModelList
}
func (a *AdaptorV3) GetChannelName() string {
return "CozeV3"
}

View File

@@ -1,6 +1,9 @@
package coze
import "github.com/songquanpeng/one-api/relay/adaptor/coze/constant/event"
import (
"github.com/songquanpeng/one-api/relay/adaptor/coze/constant/event"
"strings"
)
func event2StopReason(e *string) string {
if e == nil || *e == event.Message {
@@ -8,3 +11,16 @@ func event2StopReason(e *string) string {
}
return "stop"
}
func splitOnDoubleNewline(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
}

View File

@@ -4,19 +4,18 @@ import (
"bufio"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
)
// https://www.coze.com/open
@@ -45,12 +44,12 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
}
for i, message := range textRequest.Messages {
if i == len(textRequest.Messages)-1 {
cozeRequest.Query = message.StringContent()
cozeRequest.Query = message.CozeV3StringContent()
continue
}
cozeMessage := Message{
Role: message.Role,
Content: message.StringContent(),
Content: message.CozeV3StringContent(),
}
cozeRequest.ChatHistory = append(cozeRequest.ChatHistory, cozeMessage)
}
@@ -80,6 +79,28 @@ func StreamResponseCoze2OpenAI(cozeResponse *StreamResponse) (*openai.ChatComple
return &openaiResponse, response
}
func V3StreamResponseCoze2OpenAI(cozeResponse *V3StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
var response *Response
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Role = cozeResponse.Role
choice.Delta.Content = cozeResponse.Content
var openaiResponse openai.ChatCompletionsStreamResponse
openaiResponse.Object = "chat.completion.chunk"
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
openaiResponse.Id = cozeResponse.ConversationId
if cozeResponse.Usage.TokenCount > 0 {
openaiResponse.Usage = &model.Usage{
PromptTokens: cozeResponse.Usage.InputCount,
CompletionTokens: cozeResponse.Usage.OutputCount,
TotalTokens: cozeResponse.Usage.TokenCount,
}
}
return &openaiResponse, response
}
func ResponseCoze2OpenAI(cozeResponse *Response) *openai.TextResponse {
var responseText string
for _, message := range cozeResponse.Messages {
@@ -107,6 +128,26 @@ func ResponseCoze2OpenAI(cozeResponse *Response) *openai.TextResponse {
return &fullTextResponse
}
func V3ResponseCoze2OpenAI(cozeResponse *V3Response) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: cozeResponse.Data.Content,
Name: nil,
},
FinishReason: "stop",
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", cozeResponse.Data.ConversationId),
Model: "coze-bot",
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *string) {
var responseText string
createdTime := helper.GetTimestamp()
@@ -162,6 +203,63 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
return nil, &responseText
}
func V3StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *string) {
var responseText string
createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(splitOnDoubleNewline)
common.SetEventStreamHeaders(c)
var modelName string
for scanner.Scan() {
part := scanner.Text()
part = strings.TrimPrefix(part, "\n")
parts := strings.Split(part, "\n")
if len(parts) != 2 {
continue
}
if strings.HasPrefix(parts[0], "event:") && strings.HasPrefix(parts[1], "data:") {
continue
}
event, data := strings.TrimSpace(parts[0][6:]), strings.TrimSpace(parts[1][5:])
if event == "conversation.message.delta" || event == "conversation.chat.completed" {
data = strings.TrimSuffix(data, "\r")
var cozeResponse V3StreamResponse
err := json.Unmarshal([]byte(data), &cozeResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
response, _ := V3StreamResponseCoze2OpenAI(&cozeResponse)
if response == nil {
continue
}
for _, choice := range response.Choices {
responseText += conv.AsString(choice.Delta.Content)
}
response.Model = modelName
response.Created = createdTime
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &responseText
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *string) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -200,3 +298,42 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
}
return nil, &responseText
}
func V3Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *string) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var cozeResponse V3Response
err = json.Unmarshal(responseBody, &cozeResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if cozeResponse.Code != 0 {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: cozeResponse.Msg,
Code: cozeResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := V3ResponseCoze2OpenAI(&cozeResponse)
fullTextResponse.Model = modelName
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
var responseText string
if len(fullTextResponse.Choices) > 0 {
responseText = fullTextResponse.Choices[0].Message.StringContent()
}
return nil, &responseText
}

View File

@@ -36,3 +36,45 @@ type StreamResponse struct {
ConversationId string `json:"conversation_id,omitempty"`
ErrorInformation *ErrorInformation `json:"error_information,omitempty"`
}
type V3StreamResponse struct {
Id string `json:"id"`
ConversationId string `json:"conversation_id"`
BotId string `json:"bot_id"`
Role string `json:"role"`
Type string `json:"type"`
Content string `json:"content"`
ContentType string `json:"content_type"`
ChatId string `json:"chat_id"`
CreatedAt int `json:"created_at"`
CompletedAt int `json:"completed_at"`
LastError struct {
Code int `json:"code"`
Msg string `json:"msg"`
} `json:"last_error"`
Status string `json:"status"`
Usage struct {
TokenCount int `json:"token_count"`
OutputCount int `json:"output_count"`
InputCount int `json:"input_count"`
} `json:"usage"`
SectionId string `json:"section_id"`
}
type V3Response struct {
Data struct {
Id string `json:"id"`
ConversationId string `json:"conversation_id"`
BotId string `json:"bot_id"`
Content string `json:"content"`
ContentType string `json:"content_type"`
CreatedAt int `json:"created_at"`
LastError struct {
Code int `json:"code"`
Msg string `json:"msg"`
} `json:"last_error"`
Status string `json:"status"`
} `json:"data"`
Code int `json:"code"`
Msg string `json:"msg"`
}

View File

@@ -22,4 +22,5 @@ const (
Replicate
Dummy // this one is only for count, do not add any channel after this
CozeV3
)

View File

@@ -1,5 +1,7 @@
package model
import "encoding/json"
type Message struct {
Role string `json:"role,omitempty"`
Content any `json:"content,omitempty"`
@@ -37,6 +39,53 @@ func (m Message) StringContent() string {
return ""
}
func (m Message) CozeV3StringContent() string {
content, ok := m.Content.(string)
if ok {
return content
}
contentList, ok := m.Content.([]any)
if ok {
contents := make([]map[string]any, 0)
var contentStr string
for _, contentItem := range contentList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
switch contentMap["type"] {
case "text":
if subStr, ok := contentMap["text"].(string); ok {
contents = append(contents, map[string]any{
"type": "text",
"text": subStr,
})
}
case "image_url":
if subStr, ok := contentMap["image_url"].(string); ok {
contents = append(contents, map[string]any{
"type": "image",
"file_url": subStr,
})
}
case "file":
if subStr, ok := contentMap["image_url"].(string); ok {
contents = append(contents, map[string]any{
"type": "file",
"file_url": subStr,
})
}
}
}
if len(contents) > 0 {
b, _ := json.Marshal(contents)
return string(b)
}
return contentStr
}
return ""
}
func (m Message) ParseContent() []MessageContent {
var contentList []MessageContent
content, ok := m.Content.(string)