merge upstream

Signed-off-by: wozulong <>
This commit is contained in:
wozulong
2024-07-11 14:10:10 +08:00
72 changed files with 1467 additions and 398 deletions

View File

@@ -11,9 +11,11 @@ import (
type Adaptor interface {
// Init IsStream bool
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest)
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
GetModelList() []string

View File

@@ -15,6 +15,9 @@ import (
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
@@ -53,6 +56,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -20,6 +20,11 @@ type Adaptor struct {
RequestMode int
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage
@@ -53,13 +58,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return claudeReq, err
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = awsStreamHandler(c, info, a.RequestMode)
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
} else {
err, usage = awsHandler(c, info, a.RequestMode)
}

View File

@@ -13,6 +13,7 @@ import (
relaymodel "one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
@@ -112,7 +113,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
return nil, &usage
}
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
@@ -162,7 +163,6 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
@@ -214,6 +214,17 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
return false
}
})
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
err := service.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}

View File

@@ -2,6 +2,7 @@ package baidu
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
@@ -15,44 +16,74 @@ import (
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
var fullRequestURL string
switch info.UpstreamModelName {
case "ERNIE-Bot-4":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "ERNIE-Bot-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
case "ERNIE-Bot":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Speed":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "ERNIE-4.0-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "ERNIE-3.5-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Speed-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
case "ERNIE-Character-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k"
case "ERNIE-Functions-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-func-8k"
case "ERNIE-Lite-8K-0922":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "Yi-34B-Chat":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat"
case "Embedding-V1":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
default:
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + strings.ToLower(info.UpstreamModelName)
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix := "chat/"
if strings.HasPrefix(info.UpstreamModelName, "Embedding") {
suffix = "embeddings/"
}
if strings.HasPrefix(info.UpstreamModelName, "bge-large") {
suffix = "embeddings/"
}
if strings.HasPrefix(info.UpstreamModelName, "tao-8k") {
suffix = "embeddings/"
}
switch info.UpstreamModelName {
case "ERNIE-4.0":
suffix += "completions_pro"
case "ERNIE-Bot-4":
suffix += "completions_pro"
case "ERNIE-Bot":
suffix += "completions"
case "ERNIE-Bot-turbo":
suffix += "eb-instant"
case "ERNIE-Speed":
suffix += "ernie_speed"
case "ERNIE-4.0-8K":
suffix += "completions_pro"
case "ERNIE-3.5-8K":
suffix += "completions"
case "ERNIE-3.5-8K-0205":
suffix += "ernie-3.5-8k-0205"
case "ERNIE-3.5-8K-1222":
suffix += "ernie-3.5-8k-1222"
case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k"
case "ERNIE-3.5-4K-0205":
suffix += "ernie-3.5-4k-0205"
case "ERNIE-Speed-8K":
suffix += "ernie_speed"
case "ERNIE-Speed-128K":
suffix += "ernie-speed-128k"
case "ERNIE-Lite-8K-0922":
suffix += "eb-instant"
case "ERNIE-Lite-8K-0308":
suffix += "ernie-lite-8k"
case "ERNIE-Tiny-8K":
suffix += "ernie-tiny-8k"
case "BLOOMZ-7B":
suffix += "bloomz_7b1"
case "Embedding-V1":
suffix += "embedding-v1"
case "bge-large-zh":
suffix += "bge_large_zh"
case "bge-large-en":
suffix += "bge_large_en"
case "tao-8k":
suffix += "tao_8k"
default:
suffix += strings.ToLower(info.UpstreamModelName)
}
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix)
var accessToken string
var err error
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
@@ -82,6 +113,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -1,20 +1,22 @@
package baidu
var ModelList = []string{
"ERNIE-3.5-8K",
"ERNIE-4.0-8K",
"ERNIE-3.5-8K",
"ERNIE-3.5-8K-0205",
"ERNIE-3.5-8K-1222",
"ERNIE-Bot-8K",
"ERNIE-3.5-4K-0205",
"ERNIE-Speed-8K",
"ERNIE-Speed-128K",
"ERNIE-Lite-8K",
"ERNIE-Lite-8K-0922",
"ERNIE-Lite-8K-0308",
"ERNIE-Tiny-8K",
"ERNIE-Character-8K",
"ERNIE-Functions-8K",
//"ERNIE-Bot-4",
//"ERNIE-Bot-8K",
//"ERNIE-Bot",
//"ERNIE-Speed",
//"ERNIE-Bot-turbo",
"BLOOMZ-7B",
"Embedding-V1",
"bge-large-zh",
"bge-large-en",
"tao-8k",
}
var ChannelName = "baidu"

View File

@@ -11,9 +11,16 @@ type BaiduMessage struct {
}
type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"`
Stream bool `json:"stream"`
UserId string `json:"user_id,omitempty"`
Messages []BaiduMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
PenaltyScore float64 `json:"penalty_score,omitempty"`
Stream bool `json:"stream,omitempty"`
System string `json:"system,omitempty"`
DisableSearch bool `json:"disable_search,omitempty"`
EnableCitation bool `json:"enable_citation,omitempty"`
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
UserId string `json:"user_id,omitempty"`
}
type Error struct {

View File

@@ -22,17 +22,33 @@ import (
var baiduTokenStore sync.Map
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
messages := make([]BaiduMessage, 0, len(request.Messages))
baiduRequest := BaiduChatRequest{
Temperature: request.Temperature,
TopP: request.TopP,
PenaltyScore: request.FrequencyPenalty,
Stream: request.Stream,
DisableSearch: false,
EnableCitation: false,
UserId: request.User,
}
if request.MaxTokens != 0 {
maxTokens := int(request.MaxTokens)
if request.MaxTokens == 1 {
maxTokens = 2
}
baiduRequest.MaxOutputTokens = &maxTokens
}
for _, message := range request.Messages {
messages = append(messages, BaiduMessage{
Role: message.Role,
Content: message.StringContent(),
})
}
return &BaiduChatRequest{
Messages: messages,
Stream: request.Stream,
if message.Role == "system" {
baiduRequest.System = message.StringContent()
} else {
baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{
Role: message.Role,
Content: message.StringContent(),
})
}
}
return &baiduRequest
}
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {

View File

@@ -21,6 +21,11 @@ type Adaptor struct {
RequestMode int
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage
@@ -59,6 +64,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -330,22 +330,15 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
response.Created = createdTime
response.Model = info.UpstreamModelName
jsonStr, err := json.Marshal(response)
err = service.ObjectData(c, response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
common.SysError(err.Error())
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
@@ -356,6 +349,18 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
}
}
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, usage
}

View File

@@ -8,16 +8,24 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
} else {
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
@@ -34,11 +42,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return requestConvertRerank2Cohere(request), nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = cohereStreamHandler(c, resp, info)
if info.RelayMode == constant.RelayModeRerank {
err, usage = cohereRerankHandler(c, resp, info)
} else {
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
if info.IsStream {
err, usage = cohereStreamHandler(c, resp, info)
} else {
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
}
}
return
}

View File

@@ -2,6 +2,7 @@ package cohere
var ModelList = []string{
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
"rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0",
}
var ChannelName = "cohere"

View File

@@ -1,5 +1,7 @@
package cohere
import "one-api/dto"
type CohereRequest struct {
Model string `json:"model"`
ChatHistory []ChatHistory `json:"chat_history"`
@@ -28,6 +30,19 @@ type CohereResponseResult struct {
Meta CohereMeta `json:"meta"`
}
type CohereRerankRequest struct {
Documents []any `json:"documents"`
Query string `json:"query"`
Model string `json:"model"`
TopN int `json:"top_n"`
ReturnDocuments bool `json:"return_documents"`
}
type CohereRerankResponseResult struct {
Results []dto.RerankResponseDocument `json:"results"`
Meta CohereMeta `json:"meta"`
}
type CohereMeta struct {
//Tokens CohereTokens `json:"tokens"`
BilledUnits CohereBilledUnits `json:"billed_units"`

View File

@@ -47,6 +47,20 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
return &cohereReq
}
func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
if rerankRequest.TopN == 0 {
rerankRequest.TopN = 1
}
cohereReq := CohereRerankRequest{
Query: rerankRequest.Query,
Documents: rerankRequest.Documents,
Model: rerankRequest.Model,
TopN: rerankRequest.TopN,
ReturnDocuments: true,
}
return &cohereReq
}
func stopReasonCohere2OpenAI(reason string) string {
switch reason {
case "COMPLETE":
@@ -194,3 +208,42 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var cohereResp CohereRerankResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
usage := dto.Usage{}
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens = 0
usage.TotalTokens = info.PromptTokens
} else {
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
}
var rerankResp dto.RerankResponse
rerankResp.Results = cohereResp.Results
rerankResp.Usage = usage
jsonResponse, err := json.Marshal(rerankResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -0,0 +1,65 @@
package dify
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
)
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return requestOpenAI2Dify(*request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = difyStreamHandler(c, resp, info)
} else {
err, usage = difyHandler(c, resp, info)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,5 @@
package dify
var ModelList []string
var ChannelName = "dify"

35
relay/channel/dify/dto.go Normal file
View File

@@ -0,0 +1,35 @@
package dify
import "one-api/dto"
type DifyChatRequest struct {
Inputs map[string]interface{} `json:"inputs"`
Query string `json:"query"`
ResponseMode string `json:"response_mode"`
User string `json:"user"`
AutoGenerateName bool `json:"auto_generate_name"`
}
type DifyMetaData struct {
Usage dto.Usage `json:"usage"`
}
type DifyData struct {
WorkflowId string `json:"workflow_id"`
NodeId string `json:"node_id"`
}
type DifyChatCompletionResponse struct {
ConversationId string `json:"conversation_id"`
Answer string `json:"answer"`
CreateAt int64 `json:"create_at"`
MetaData DifyMetaData `json:"metadata"`
}
type DifyChunkChatCompletionResponse struct {
Event string `json:"event"`
ConversationId string `json:"conversation_id"`
Answer string `json:"answer"`
Data DifyData `json:"data"`
MetaData DifyMetaData `json:"metadata"`
}

View File

@@ -0,0 +1,156 @@
package dify
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
)
func requestOpenAI2Dify(request dto.GeneralOpenAIRequest) *DifyChatRequest {
content := ""
for _, message := range request.Messages {
if message.Role == "system" {
content += "SYSTEM: \n" + message.StringContent() + "\n"
} else if message.Role == "assistant" {
content += "ASSISTANT: \n" + message.StringContent() + "\n"
} else {
content += "USER: \n" + message.StringContent() + "\n"
}
}
mode := "blocking"
if request.Stream {
mode = "streaming"
}
user := request.User
if user == "" {
user = "api-user"
}
return &DifyChatRequest{
Inputs: make(map[string]interface{}),
Query: content,
ResponseMode: mode,
User: user,
AutoGenerateName: false,
}
}
func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse {
response := dto.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "dify",
}
var choice dto.ChatCompletionsStreamResponseChoice
if constant.DifyDebug && difyResponse.Event == "workflow_started" {
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
} else if constant.DifyDebug && difyResponse.Event == "node_started" {
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
} else if difyResponse.Event == "message" {
choice.Delta.SetContentString(difyResponse.Answer)
}
response.Choices = append(response.Choices, choice)
return &response
}
func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var responseText string
usage := &dto.Usage{}
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
continue
}
data = strings.TrimPrefix(data, "data:")
var difyResponse DifyChunkChatCompletionResponse
err := json.Unmarshal([]byte(data), &difyResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue
}
var openaiResponse dto.ChatCompletionsStreamResponse
if difyResponse.Event == "message_end" {
usage = &difyResponse.MetaData.Usage
break
} else if difyResponse.Event == "error" {
break
} else {
openaiResponse = *streamResponseDify2OpenAI(difyResponse)
if len(openaiResponse.Choices) != 0 {
responseText += openaiResponse.Choices[0].Delta.GetContentString()
}
}
err = service.ObjectData(c, openaiResponse)
if err != nil {
common.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
common.SysError("error reading stream: " + err.Error())
}
service.Done(c)
err := resp.Body.Close()
if err != nil {
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
common.SysError("close_response_body_failed: " + err.Error())
}
if usage.TotalTokens == 0 {
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
return nil, usage
}
func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var difyResponse DifyChatCompletionResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &difyResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
fullTextResponse := dto.OpenAITextResponse{
Id: difyResponse.ConversationId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Usage: difyResponse.MetaData.Usage,
}
content, _ := json.Marshal(difyResponse.Answer)
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
Content: content,
},
FinishReason: "stop",
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &difyResponse.MetaData.Usage
}

View File

@@ -9,12 +9,14 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
)
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
@@ -56,15 +58,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return CovertGemini2OpenAI(*request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = geminiChatStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
err, usage = geminiChatStreamHandler(c, resp, info)
} else {
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -59,4 +59,11 @@ type GeminiChatPromptFeedback struct {
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
}
type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"one-api/common"
"one-api/constant"
@@ -162,8 +163,12 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
return &response
}
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) {
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseText := ""
responseJson := ""
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp()
var usage = &dto.Usage{}
dataChan := make(chan string, 5)
stopChan := make(chan bool, 2)
scanner := bufio.NewScanner(resp.Body)
@@ -182,6 +187,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
go func() {
for scanner.Scan() {
data := scanner.Text()
responseJson += data
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "\"text\": \"") {
continue
@@ -216,10 +222,10 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(dummy.Content)
response := dto.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Id: id,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "gemini-pro",
Created: createAt,
Model: info.UpstreamModelName,
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
@@ -230,15 +236,34 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
var geminiChatResponses []GeminiChatResponse
err := json.Unmarshal([]byte(responseJson), &geminiChatResponses)
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
log.Printf("cannot get gemini usage: %s", err.Error())
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
for _, response := range geminiChatResponses {
usage.PromptTokens = response.UsageMetadata.PromptTokenCount
usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount
}
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
return nil, responseText
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage
}
return nil, usage
}
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -267,11 +292,10 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model)
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)

View File

@@ -0,0 +1,64 @@
package jina
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil
}
return "", errors.New("invalid relay mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = jinaRerankHandler(c, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,8 @@
package jina
var ModelList = []string{
"jina-clip-v1",
"jina-reranker-v2-base-multilingual",
}
var ChannelName = "jina"

View File

@@ -0,0 +1,35 @@
package jina
import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/service"
)
func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var jinaResp dto.RerankResponse
err = json.Unmarshal(responseBody, &jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
jsonResponse, err := json.Marshal(jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &jinaResp.Usage
}

View File

@@ -16,6 +16,9 @@ import (
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
@@ -45,6 +48,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
@@ -52,8 +59,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
} else {
if info.RelayMode == relayconstant.RelayModeEmbeddings {
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)

View File

@@ -22,6 +22,13 @@ type Adaptor struct {
ChannelType int
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
a.ChannelType = info.ChannelType
}
@@ -82,9 +89,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
var toolCount int
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -18,9 +18,10 @@ import (
"time"
)
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string, int) {
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) {
//checkSensitive := constant.ShouldCheckCompletionSensitive()
var responseTextBuilder strings.Builder
var usage dto.Usage
toolCount := 0
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -62,17 +63,24 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
streamItems = append(streamItems, data)
}
}
// 计算token
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
if streamResponse.Usage != nil {
if streamResponse.Usage.TotalTokens != 0 {
usage = *streamResponse.Usage
}
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
@@ -89,6 +97,11 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
} else {
for _, streamResponse := range streamResponses {
if streamResponse.Usage != nil {
if streamResponse.Usage.TotalTokens != 0 {
usage = *streamResponse.Usage
}
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
@@ -107,6 +120,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
var streamResponses []dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
@@ -133,13 +147,19 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}()
service.SetEventStreamHeaders(c)
isFirst := true
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
defer ticker.Stop()
c.Stream(func(w io.Writer) bool {
select {
case <-ticker.C:
common.LogError(c, "reading data from upstream timeout")
return false
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
@@ -153,10 +173,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
}
wg.Wait()
return nil, responseTextBuilder.String(), toolCount
return nil, &usage, responseTextBuilder.String(), toolCount
}
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {

View File

@@ -15,6 +15,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
@@ -35,6 +40,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -16,6 +16,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
@@ -39,6 +44,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return requestOpenAI2Perplexity(*request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
@@ -46,8 +55,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -6,28 +6,44 @@ import (
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"strconv"
"strings"
)
type Adaptor struct {
Sign string
Sign string
AppID int64
Action string
Version string
Timestamp int64
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
a.Action = "ChatCompletions"
a.Version = "2023-09-01"
a.Timestamp = common.GetTimestamp()
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/hyllm/v1/chat/completions", info.BaseUrl), nil
return fmt.Sprintf("%s/", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", a.Sign)
req.Header.Set("X-TC-Action", info.UpstreamModelName)
req.Header.Set("X-TC-Action", a.Action)
req.Header.Set("X-TC-Version", a.Version)
req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
return nil
}
@@ -38,17 +54,20 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
a.AppID = appId
if err != nil {
return nil, err
}
tencentRequest := requestOpenAI2Tencent(*request)
tencentRequest.AppId = appId
tencentRequest.SecretId = secretId
tencentRequest := requestOpenAI2Tencent(a, *request)
// we have to calculate the sign here
a.Sign = getTencentSign(*tencentRequest, secretKey)
a.Sign = getTencentSign(*tencentRequest, a, secretId, secretKey)
return tencentRequest, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -1,9 +1,10 @@
package tencent
var ModelList = []string{
"ChatPro",
"ChatStd",
"hunyuan",
"hunyuan-lite",
"hunyuan-standard",
"hunyuan-standard-256K",
"hunyuan-pro",
}
var ChannelName = "tencent"

View File

@@ -1,62 +1,75 @@
package tencent
import "one-api/dto"
type TencentMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Role string `json:"Role"`
Content string `json:"Content"`
}
type TencentChatRequest struct {
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
SecretId string `json:"secret_id"` // 官网 SecretId
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
// 例如1529223702如果与当前时间相差过大会引起签名过期错误
Timestamp int64 `json:"timestamp"`
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
// 单位为秒Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
Expired int64 `json:"expired"`
QueryID string `json:"query_id"` //请求 Id用于问题排查
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
// 建议该参数和 top_p 只设置1个不要同时更改 top_p
Temperature float64 `json:"temperature"`
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
// 建议该参数和 temperature 只设置1个不要同时更改
TopP float64 `json:"top_p"`
// Stream 0同步1流式 默认协议SSE)
// 同步请求超时60s如果内容较长建议使用流式
Stream int `json:"stream"`
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
// 输入 content 总数最大支持 3000 token。
Messages []TencentMessage `json:"messages"`
Model string `json:"model"` // 模型名称
// 模型名称,可选值包括 hunyuan-lite、hunyuan-standard、hunyuan-standard-256K、hunyuan-pro。
// 各模型介绍请阅读 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 中的说明。
//
// 注意:
// 不同的模型计费不同,请根据 [购买指南](https://cloud.tencent.com/document/product/1729/97731) 按需调用。
Model *string `json:"Model"`
// 聊天上下文信息。
// 说明:
// 1. 长度最多为 40按对话时间从旧到新在数组中排列。
// 2. Message.Role 可选值system、user、assistant。
// 其中system 角色可选如存在则必须位于列表的最开始。user 和 assistant 需交替出现(一问一答),以 user 提问开始和结束,且 Content 不能为空。Role 的顺序示例:[system可选 user assistant user assistant user ...]。
// 3. Messages 中 Content 总长度不能超过模型输入长度上限(可参考 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 文档),超过则会截断最前面的内容,只保留尾部内容。
Messages []*TencentMessage `json:"Messages"`
// 流式调用开关。
// 说明:
// 1. 未传值时默认为非流式调用false
// 2. 流式调用时以 SSE 协议增量返回结果(返回值取 Choices[n].Delta 中的值,需要拼接增量数据才能获得完整结果)。
// 3. 非流式调用时:
// 调用方式与普通 HTTP 请求无异。
// 接口响应耗时较长,**如需更低时延建议设置为 true**。
// 只返回一次最终结果(返回值取 Choices[n].Message 中的值)。
//
// 注意:
// 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。
Stream *bool `json:"Stream,omitempty"`
// 说明:
// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
// 3. 非必要不建议使用,不合理的取值会影响效果。
TopP *float64 `json:"TopP,omitempty"`
// 说明:
// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
// 3. 非必要不建议使用,不合理的取值会影响效果。
Temperature *float64 `json:"Temperature,omitempty"`
}
type TencentError struct {
Code int `json:"code"`
Message string `json:"message"`
Code int `json:"Code"`
Message string `json:"Message"`
}
type TencentUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokens int `json:"PromptTokens"`
CompletionTokens int `json:"CompletionTokens"`
TotalTokens int `json:"TotalTokens"`
}
type TencentResponseChoices struct {
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
FinishReason string `json:"FinishReason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages TencentMessage `json:"Message,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta TencentMessage `json:"Delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
}
type TencentChatResponse struct {
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id
Usage dto.Usage `json:"usage,omitempty"` // token 数量
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
Choices []TencentResponseChoices `json:"Choices,omitempty"` // 结果
Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串
Id string `json:"Id,omitempty"` // 会话 id
Usage TencentUsage `json:"Usage,omitempty"` // token 数量
Error TencentError `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"Note,omitempty"` // 注释
ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}
type TencentChatResponseSB struct {
Response TencentChatResponse `json:"Response,omitempty"`
}

View File

@@ -3,8 +3,8 @@ package tencent
import (
"bufio"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
@@ -15,54 +15,46 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"sort"
"strconv"
"strings"
"time"
)
// https://cloud.tencent.com/document/product/1729/97732
func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest {
messages := make([]TencentMessage, 0, len(request.Messages))
func requestOpenAI2Tencent(a *Adaptor, request dto.GeneralOpenAIRequest) *TencentChatRequest {
messages := make([]*TencentMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, TencentMessage{
Role: "user",
Content: message.StringContent(),
})
messages = append(messages, TencentMessage{
Role: "assistant",
Content: "Okay",
})
continue
}
messages = append(messages, TencentMessage{
messages = append(messages, &TencentMessage{
Content: message.StringContent(),
Role: message.Role,
})
}
stream := 0
if request.Stream {
stream = 1
var req = TencentChatRequest{
Stream: &request.Stream,
Messages: messages,
Model: &request.Model,
}
return &TencentChatRequest{
Timestamp: common.GetTimestamp(),
Expired: common.GetTimestamp() + 24*60*60,
QueryID: common.GetUUID(),
Temperature: request.Temperature,
TopP: request.TopP,
Stream: stream,
Messages: messages,
Model: request.Model,
if request.TopP != 0 {
req.TopP = &request.TopP
}
if request.Temperature != 0 {
req.Temperature = &request.Temperature
}
return &req
}
func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: response.Id,
Object: "chat.completion",
Created: common.GetTimestamp(),
Usage: response.Usage,
Usage: dto.Usage{
PromptTokens: response.Usage.PromptTokens,
CompletionTokens: response.Usage.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
},
}
if len(response.Choices) > 0 {
content, _ := json.Marshal(response.Choices[0].Messages.Content)
@@ -99,69 +91,51 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var TencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &TencentResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.GetContentString()
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
continue
}
})
data = strings.TrimPrefix(data, "data:")
var tencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &tencentResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue
}
response := streamResponseTencent2OpenAI(&tencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.GetContentString()
}
err = service.ObjectData(c, response)
if err != nil {
common.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
common.SysError("error reading stream: " + err.Error())
}
service.Done(c)
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var TencentResponse TencentChatResponse
var tencentSb TencentChatResponseSB
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -170,20 +144,20 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &TencentResponse)
err = json.Unmarshal(responseBody, &tencentSb)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if TencentResponse.Error.Code != 0 {
if tencentSb.Response.Error.Code != 0 {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code,
Message: tencentSb.Response.Error.Message,
Code: tencentSb.Response.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
@@ -206,29 +180,62 @@ func parseTencentConfig(config string) (appId int64, secretId string, secretKey
return
}
func getTencentSign(req TencentChatRequest, secretKey string) string {
params := make([]string, 0)
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
params = append(params, "secret_id="+req.SecretId)
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
params = append(params, "query_id="+req.QueryID)
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
params = append(params, "stream="+strconv.Itoa(req.Stream))
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
var messageStr string
for _, msg := range req.Messages {
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
}
messageStr = strings.TrimSuffix(messageStr, ",")
params = append(params, "messages=["+messageStr+"]")
sort.Sort(sort.StringSlice(params))
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
mac := hmac.New(sha1.New, []byte(secretKey))
signURL := url
mac.Write([]byte(signURL))
sign := mac.Sum([]byte(nil))
return base64.StdEncoding.EncodeToString(sign)
func sha256hex(s string) string {
b := sha256.Sum256([]byte(s))
return hex.EncodeToString(b[:])
}
func hmacSha256(s, key string) string {
hashed := hmac.New(sha256.New, []byte(key))
hashed.Write([]byte(s))
return string(hashed.Sum(nil))
}
func getTencentSign(req TencentChatRequest, adaptor *Adaptor, secId, secKey string) string {
// build canonical request string
host := "hunyuan.tencentcloudapi.com"
httpRequestMethod := "POST"
canonicalURI := "/"
canonicalQueryString := ""
canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n",
"application/json", host, strings.ToLower(adaptor.Action))
signedHeaders := "content-type;host;x-tc-action"
payload, _ := json.Marshal(req)
hashedRequestPayload := sha256hex(string(payload))
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
httpRequestMethod,
canonicalURI,
canonicalQueryString,
canonicalHeaders,
signedHeaders,
hashedRequestPayload)
// build string to sign
algorithm := "TC3-HMAC-SHA256"
requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10)
timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64)
t := time.Unix(timestamp, 0).UTC()
// must be the format 2006-01-02, ref to package time for more info
date := t.Format("2006-01-02")
credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan")
hashedCanonicalRequest := sha256hex(canonicalRequest)
string2sign := fmt.Sprintf("%s\n%s\n%s\n%s",
algorithm,
requestTimestamp,
credentialScope,
hashedCanonicalRequest)
// sign string
secretDate := hmacSha256(date, "TC3"+secKey)
secretService := hmacSha256("hunyuan", secretDate)
secretKey := hmacSha256("tc3_request", secretService)
signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey)))
// build authorization
authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
algorithm,
secId,
credentialScope,
signedHeaders,
signature)
return authorization
}

View File

@@ -16,6 +16,11 @@ type Adaptor struct {
request *dto.GeneralOpenAIRequest
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
@@ -36,6 +41,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
// xunfei's request is not http request, so we don't need to do anything here
dummyResp := &http.Response{}

View File

@@ -14,6 +14,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
@@ -42,6 +47,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return requestOpenAI2Zhipu(*request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -16,6 +16,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
@@ -40,6 +45,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return requestOpenAI2Zhipu(*request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
@@ -48,9 +57,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
var toolCount int
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
err, usage, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}