mirror of
https://github.com/linux-do/new-api.git
synced 2025-12-26 08:35:58 +08:00
feat: 初步重构
This commit is contained in:
57
relay/channel/adapter.go
Normal file
57
relay/channel/adapter.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/ali"
|
||||
"one-api/relay/channel/baidu"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/gemini"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/channel/palm"
|
||||
"one-api/relay/channel/tencent"
|
||||
"one-api/relay/channel/xunfei"
|
||||
"one-api/relay/channel/zhipu"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
type Adaptor interface {
|
||||
// Init IsStream bool
|
||||
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
|
||||
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
|
||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
}
|
||||
|
||||
func GetAdaptor(apiType int) Adaptor {
|
||||
switch apiType {
|
||||
//case constant.APITypeAIProxyLibrary:
|
||||
// return &aiproxy.Adaptor{}
|
||||
case constant.APITypeAli:
|
||||
return &ali.Adaptor{}
|
||||
case constant.APITypeAnthropic:
|
||||
return &claude.Adaptor{}
|
||||
case constant.APITypeBaidu:
|
||||
return &baidu.Adaptor{}
|
||||
case constant.APITypeGemini:
|
||||
return &gemini.Adaptor{}
|
||||
case constant.APITypeOpenAI:
|
||||
return &openai.Adaptor{}
|
||||
case constant.APITypePaLM:
|
||||
return &palm.Adaptor{}
|
||||
case constant.APITypeTencent:
|
||||
return &tencent.Adaptor{}
|
||||
case constant.APITypeXunfei:
|
||||
return &xunfei.Adaptor{}
|
||||
case constant.APITypeZhipu:
|
||||
return &zhipu.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
80
relay/channel/ali/adaptor.go
Normal file
80
relay/channel/ali/adaptor.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
relaychannel "one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", info.BaseUrl)
|
||||
if info.RelayMode == constant.RelayModeEmbeddings {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
|
||||
}
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
relaychannel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
if info.IsStream {
|
||||
req.Header.Set("X-DashScope-SSE", "enable")
|
||||
}
|
||||
if c.GetString("plugin") != "" {
|
||||
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
|
||||
return baiduEmbeddingRequest, nil
|
||||
default:
|
||||
baiduRequest := requestOpenAI2Ali(*request)
|
||||
return baiduRequest, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = aliStreamHandler(c, resp)
|
||||
} else {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = aliEmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = aliHandler(c, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
8
relay/channel/ali/constants.go
Normal file
8
relay/channel/ali/constants.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package ali
|
||||
|
||||
var ModelList = []string{
|
||||
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
|
||||
"text-embedding-v1",
|
||||
}
|
||||
|
||||
var ChannelName = "ali"
|
||||
70
relay/channel/ali/dto.go
Normal file
70
relay/channel/ali/dto.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package ali
|
||||
|
||||
type AliMessage struct {
|
||||
User string `json:"user"`
|
||||
Bot string `json:"bot"`
|
||||
}
|
||||
|
||||
type AliInput struct {
|
||||
Prompt string `json:"prompt"`
|
||||
History []AliMessage `json:"history"`
|
||||
}
|
||||
|
||||
type AliParameters struct {
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
EnableSearch bool `json:"enable_search,omitempty"`
|
||||
}
|
||||
|
||||
type AliChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input AliInput `json:"input"`
|
||||
Parameters AliParameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type AliEmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
} `json:"input"`
|
||||
Parameters *struct {
|
||||
TextType string `json:"text_type,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type AliEmbedding struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
TextIndex int `json:"text_index"`
|
||||
}
|
||||
|
||||
type AliEmbeddingResponse struct {
|
||||
Output struct {
|
||||
Embeddings []AliEmbedding `json:"embeddings"`
|
||||
} `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
AliError
|
||||
}
|
||||
|
||||
type AliError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
}
|
||||
|
||||
type AliUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type AliOutput struct {
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type AliChatResponse struct {
|
||||
Output AliOutput `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
AliError
|
||||
}
|
||||
263
relay/channel/ali/relay-ali.go
Normal file
263
relay/channel/ali/relay-ali.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
|
||||
|
||||
func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest {
|
||||
messages := make([]AliMessage, 0, len(request.Messages))
|
||||
prompt := ""
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, AliMessage{
|
||||
User: message.StringContent(),
|
||||
Bot: "Okay",
|
||||
})
|
||||
continue
|
||||
} else {
|
||||
if i == len(request.Messages)-1 {
|
||||
prompt = message.StringContent()
|
||||
break
|
||||
}
|
||||
messages = append(messages, AliMessage{
|
||||
User: message.StringContent(),
|
||||
Bot: string(request.Messages[i+1].Content),
|
||||
})
|
||||
i++
|
||||
}
|
||||
}
|
||||
return &AliChatRequest{
|
||||
Model: request.Model,
|
||||
Input: AliInput{
|
||||
Prompt: prompt,
|
||||
History: messages,
|
||||
},
|
||||
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
|
||||
// TopP: request.TopP,
|
||||
// TopK: 50,
|
||||
// //Seed: 0,
|
||||
// //EnableSearch: false,
|
||||
//},
|
||||
}
|
||||
}
|
||||
|
||||
func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
|
||||
return &AliEmbeddingRequest{
|
||||
Model: "text-embedding-v1",
|
||||
Input: struct {
|
||||
Texts []string `json:"texts"`
|
||||
}{
|
||||
Texts: request.ParseInput(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var aliResponse AliEmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
|
||||
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
||||
Model: "text-embedding-v1",
|
||||
Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
|
||||
}
|
||||
|
||||
for _, item := range response.Output.Embeddings {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
|
||||
Object: `embedding`,
|
||||
Index: item.TextIndex,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
|
||||
content, _ := json.Marshal(response.Output.Text)
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: response.Output.FinishReason,
|
||||
}
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: response.RequestId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []dto.OpenAITextResponseChoice{choice},
|
||||
Usage: dto.Usage{
|
||||
PromptTokens: response.Usage.InputTokens,
|
||||
CompletionTokens: response.Usage.OutputTokens,
|
||||
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
|
||||
},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = aliResponse.Output.Text
|
||||
if aliResponse.Output.FinishReason != "null" {
|
||||
finishReason := aliResponse.Output.FinishReason
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Id: aliResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "ernie-bot",
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var usage dto.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
lastResponseText := ""
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var aliResponse AliChatResponse
|
||||
err := json.Unmarshal([]byte(data), &aliResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if aliResponse.Usage.OutputTokens != 0 {
|
||||
usage.PromptTokens = aliResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||
}
|
||||
response := streamResponseAli2OpenAI(&aliResponse)
|
||||
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
||||
lastResponseText = aliResponse.Output.Text
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var aliResponse AliChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if aliResponse.Code != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseAli2OpenAI(&aliResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
52
relay/channel/api_request.go
Normal file
52
relay/channel/api_request.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
func SetupApiRequestHeader(info *relaycommon.RelayInfo, c *gin.Context, req *http.Request) {
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if info.IsStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
func DoApiRequest(a Adaptor, c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
fullRequestURL, err := a.GetRequestURL(info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get request url failed: %w", err)
|
||||
}
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
}
|
||||
err = a.SetupRequestHeader(c, req, info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
resp, err := doRequest(c, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do request failed: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
||||
resp, err := service.GetHttpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, errors.New("resp is nil")
|
||||
}
|
||||
_ = req.Body.Close()
|
||||
_ = c.Request.Body.Close()
|
||||
return resp, nil
|
||||
}
|
||||
92
relay/channel/baidu/adaptor.go
Normal file
92
relay/channel/baidu/adaptor.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
relaychannel "one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
var fullRequestURL string
|
||||
switch info.UpstreamModelName {
|
||||
case "ERNIE-Bot-4":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
||||
case "ERNIE-Bot-8K":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
|
||||
case "ERNIE-Bot":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
||||
case "ERNIE-Speed":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
|
||||
case "ERNIE-Bot-turbo":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||
case "BLOOMZ-7B":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
||||
case "Embedding-V1":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
||||
}
|
||||
var accessToken string
|
||||
var err error
|
||||
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
|
||||
return "", err
|
||||
}
|
||||
fullRequestURL += "?access_token=" + accessToken
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
relaychannel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
|
||||
return baiduEmbeddingRequest, nil
|
||||
default:
|
||||
baiduRequest := requestOpenAI2Baidu(*request)
|
||||
return baiduRequest, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = baiduStreamHandler(c, resp)
|
||||
} else {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = baiduEmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = baiduHandler(c, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
12
relay/channel/baidu/constants.go
Normal file
12
relay/channel/baidu/constants.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package baidu
|
||||
|
||||
var ModelList = []string{
|
||||
"ERNIE-Bot-4",
|
||||
"ERNIE-Bot-8K",
|
||||
"ERNIE-Bot",
|
||||
"ERNIE-Speed",
|
||||
"ERNIE-Bot-turbo",
|
||||
"Embedding-V1",
|
||||
}
|
||||
|
||||
var ChannelName = "baidu"
|
||||
71
relay/channel/baidu/dto.go
Normal file
71
relay/channel/baidu/dto.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"one-api/dto"
|
||||
"time"
|
||||
)
|
||||
|
||||
type BaiduMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type BaiduChatRequest struct {
|
||||
Messages []BaiduMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
ErrorCode int `json:"error_code"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
}
|
||||
|
||||
type BaiduChatResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Result string `json:"result"`
|
||||
IsTruncated bool `json:"is_truncated"`
|
||||
NeedClearHistory bool `json:"need_clear_history"`
|
||||
Usage dto.Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
|
||||
type BaiduChatStreamResponse struct {
|
||||
BaiduChatResponse
|
||||
SentenceId int `json:"sentence_id"`
|
||||
IsEnd bool `json:"is_end"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingRequest struct {
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingData struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Data []BaiduEmbeddingData `json:"data"`
|
||||
Usage dto.Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
|
||||
type BaiduAccessToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Error string `json:"error,omitempty"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
ExpiresAt time.Time `json:"-"`
|
||||
}
|
||||
|
||||
type BaiduTokenResponse struct {
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
298
relay/channel/baidu/relay-baidu.go
Normal file
298
relay/channel/baidu/relay-baidu.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
||||
|
||||
var baiduTokenStore sync.Map
|
||||
|
||||
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
messages := make([]BaiduMessage, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: "user",
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: "assistant",
|
||||
Content: "Okay",
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return &BaiduChatRequest{
|
||||
Messages: messages,
|
||||
Stream: request.Stream,
|
||||
}
|
||||
}
|
||||
|
||||
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
|
||||
content, _ := json.Marshal(response.Result)
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: response.Id,
|
||||
Object: "chat.completion",
|
||||
Created: response.Created,
|
||||
Choices: []dto.OpenAITextResponseChoice{choice},
|
||||
Usage: response.Usage,
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = baiduResponse.Result
|
||||
if baiduResponse.IsEnd {
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
}
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Id: baiduResponse.Id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: baiduResponse.Created,
|
||||
Model: "ernie-bot",
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func embeddingRequestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduEmbeddingRequest {
|
||||
return &BaiduEmbeddingRequest{
|
||||
Input: request.ParseInput(),
|
||||
}
|
||||
}
|
||||
|
||||
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
|
||||
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
|
||||
Model: "baidu-embedding",
|
||||
Usage: response.Usage,
|
||||
}
|
||||
for _, item := range response.Data {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
|
||||
Object: item.Object,
|
||||
Index: item.Index,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var usage dto.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
data = data[6:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var baiduResponse BaiduChatStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if baiduResponse.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||
}
|
||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var baiduResponse BaiduChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var baiduResponse BaiduEmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func getBaiduAccessToken(apiKey string) (string, error) {
|
||||
if val, ok := baiduTokenStore.Load(apiKey); ok {
|
||||
var accessToken BaiduAccessToken
|
||||
if accessToken, ok = val.(BaiduAccessToken); ok {
|
||||
// soon this will expire
|
||||
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
|
||||
go func() {
|
||||
_, _ = getBaiduAccessTokenHelper(apiKey)
|
||||
}()
|
||||
}
|
||||
return accessToken.AccessToken, nil
|
||||
}
|
||||
}
|
||||
accessToken, err := getBaiduAccessTokenHelper(apiKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if accessToken == nil {
|
||||
return "", errors.New("getBaiduAccessToken return a nil token")
|
||||
}
|
||||
return (*accessToken).AccessToken, nil
|
||||
}
|
||||
|
||||
func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
||||
parts := strings.Split(apiKey, "|")
|
||||
if len(parts) != 2 {
|
||||
return nil, errors.New("invalid baidu apikey")
|
||||
}
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
|
||||
parts[0], parts[1]), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Accept", "application/json")
|
||||
res, err := service.GetImpatientHttpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
var accessToken BaiduAccessToken
|
||||
err = json.NewDecoder(res.Body).Decode(&accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accessToken.Error != "" {
|
||||
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
|
||||
}
|
||||
if accessToken.AccessToken == "" {
|
||||
return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
|
||||
}
|
||||
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
|
||||
baiduTokenStore.Store(apiKey, accessToken)
|
||||
return &accessToken, nil
|
||||
}
|
||||
65
relay/channel/claude/adaptor.go
Normal file
65
relay/channel/claude/adaptor.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
relaychannel "one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
relaychannel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("x-api-key", info.ApiKey)
|
||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
anthropicVersion = "2023-06-01"
|
||||
}
|
||||
req.Header.Set("anthropic-version", anthropicVersion)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = claudeStreamHandler(c, resp)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = claudeHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
7
relay/channel/claude/constants.go
Normal file
7
relay/channel/claude/constants.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package claude
|
||||
|
||||
var ModelList = []string{
|
||||
"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1",
|
||||
}
|
||||
|
||||
var ChannelName = "claude"
|
||||
29
relay/channel/claude/dto.go
Normal file
29
relay/channel/claude/dto.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package claude
|
||||
|
||||
type ClaudeMetadata struct {
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokensToSample uint `json:"max_tokens_to_sample"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeError struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ClaudeResponse struct {
|
||||
Completion string `json:"completion"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
Model string `json:"model"`
|
||||
Error ClaudeError `json:"error"`
|
||||
}
|
||||
195
relay/channel/claude/relay-claude.go
Normal file
195
relay/channel/claude/relay-claude.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func stopReasonClaude2OpenAI(reason string) string {
|
||||
switch reason {
|
||||
case "stop_sequence":
|
||||
return "stop"
|
||||
case "max_tokens":
|
||||
return "length"
|
||||
default:
|
||||
return reason
|
||||
}
|
||||
}
|
||||
|
||||
func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
|
||||
claudeRequest := ClaudeRequest{
|
||||
Model: textRequest.Model,
|
||||
Prompt: "",
|
||||
MaxTokensToSample: textRequest.MaxTokens,
|
||||
StopSequences: nil,
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
Stream: textRequest.Stream,
|
||||
}
|
||||
if claudeRequest.MaxTokensToSample == 0 {
|
||||
claudeRequest.MaxTokensToSample = 1000000
|
||||
}
|
||||
prompt := ""
|
||||
for _, message := range textRequest.Messages {
|
||||
if message.Role == "user" {
|
||||
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
|
||||
} else if message.Role == "assistant" {
|
||||
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
|
||||
} else if message.Role == "system" {
|
||||
prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
|
||||
}
|
||||
}
|
||||
prompt += "\n\nAssistant:"
|
||||
claudeRequest.Prompt = prompt
|
||||
return &claudeRequest
|
||||
}
|
||||
|
||||
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = claudeResponse.Completion
|
||||
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = claudeResponse.Model
|
||||
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
|
||||
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
Name: nil,
|
||||
},
|
||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||
}
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []dto.OpenAITextResponseChoice{choice},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createdTime := common.GetTimestamp()
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
|
||||
return i + 4, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if !strings.HasPrefix(data, "event: completion") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
// some implementations may add \r at the end of data
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
var claudeResponse ClaudeResponse
|
||||
err := json.Unmarshal([]byte(data), &claudeResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
responseText += claudeResponse.Completion
|
||||
response := streamResponseClaude2OpenAI(&claudeResponse)
|
||||
response.Id = responseId
|
||||
response.Created = createdTime
|
||||
jsonStr, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var claudeResponse ClaudeResponse
|
||||
err = json.Unmarshal(responseBody, &claudeResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if claudeResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Message: claudeResponse.Error.Message,
|
||||
Type: claudeResponse.Error.Type,
|
||||
Param: "",
|
||||
Code: claudeResponse.Error.Type,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
|
||||
completionTokens := service.CountTokenText(claudeResponse.Completion, model)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
64
relay/channel/gemini/adaptor.go
Normal file
64
relay/channel/gemini/adaptor.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
relaychannel "one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
version := "v1"
|
||||
action := "generateContent"
|
||||
if info.IsStream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
relaychannel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("x-goog-api-key", info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return CovertGemini2OpenAI(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = geminiChatStreamHandler(c, resp)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
12
relay/channel/gemini/constant.go
Normal file
12
relay/channel/gemini/constant.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package gemini
|
||||
|
||||
const (
|
||||
GeminiVisionMaxImageNum = 16
|
||||
)
|
||||
|
||||
var ModelList = []string{
|
||||
"gemini-pro",
|
||||
"gemini-pro-vision",
|
||||
}
|
||||
|
||||
var ChannelName = "google gemini"
|
||||
62
relay/channel/gemini/dto.go
Normal file
62
relay/channel/gemini/dto.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package gemini
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
|
||||
Tools []GeminiChatTools `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GeminiChatSafetySettings struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
type GeminiChatTools struct {
|
||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
Content GeminiChatContent `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int64 `json:"index"`
|
||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
type GeminiChatSafetyRating struct {
|
||||
Category string `json:"category"`
|
||||
Probability string `json:"probability"`
|
||||
}
|
||||
|
||||
type GeminiChatPromptFeedback struct {
|
||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
type GeminiChatResponse struct {
|
||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||
}
|
||||
274
relay/channel/gemini/relay-gemini.go
Normal file
274
relay/channel/gemini/relay-gemini.go
Normal file
@@ -0,0 +1,274 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatRequest {
|
||||
geminiRequest := GeminiChatRequest{
|
||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
SafetySettings: []GeminiChatSafetySettings{
|
||||
{
|
||||
Category: "HARM_CATEGORY_HARASSMENT",
|
||||
Threshold: common.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_HATE_SPEECH",
|
||||
Threshold: common.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
Threshold: common.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
Threshold: common.GeminiSafetySetting,
|
||||
},
|
||||
},
|
||||
GenerationConfig: GeminiChatGenerationConfig{
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
MaxOutputTokens: textRequest.MaxTokens,
|
||||
},
|
||||
}
|
||||
if textRequest.Functions != nil {
|
||||
geminiRequest.Tools = []GeminiChatTools{
|
||||
{
|
||||
FunctionDeclarations: textRequest.Functions,
|
||||
},
|
||||
}
|
||||
}
|
||||
shouldAddDummyModelMessage := false
|
||||
for _, message := range textRequest.Messages {
|
||||
content := GeminiChatContent{
|
||||
Role: message.Role,
|
||||
Parts: []GeminiPart{
|
||||
{
|
||||
Text: message.StringContent(),
|
||||
},
|
||||
},
|
||||
}
|
||||
openaiContent := message.ParseContent()
|
||||
var parts []GeminiPart
|
||||
imageNum := 0
|
||||
for _, part := range openaiContent {
|
||||
|
||||
if part.Type == dto.ContentTypeText {
|
||||
parts = append(parts, GeminiPart{
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == dto.ContentTypeImageURL {
|
||||
imageNum += 1
|
||||
if imageNum > GeminiVisionMaxImageNum {
|
||||
continue
|
||||
}
|
||||
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
content.Parts = parts
|
||||
|
||||
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
||||
if content.Role == "assistant" {
|
||||
content.Role = "model"
|
||||
}
|
||||
// Converting system prompt to prompt from user for the same reason
|
||||
if content.Role == "system" {
|
||||
content.Role = "user"
|
||||
shouldAddDummyModelMessage = true
|
||||
}
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||
|
||||
// If a system message is the last message, we need to add a dummy model message to make gemini happy
|
||||
if shouldAddDummyModelMessage {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: []GeminiPart{
|
||||
{
|
||||
Text: "Okay",
|
||||
},
|
||||
},
|
||||
})
|
||||
shouldAddDummyModelMessage = false
|
||||
}
|
||||
}
|
||||
|
||||
return &geminiRequest
|
||||
}
|
||||
|
||||
func (g *GeminiChatResponse) GetResponseText() string {
|
||||
if g == nil {
|
||||
return ""
|
||||
}
|
||||
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
|
||||
return g.Candidates[0].Content.Parts[0].Text
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
content, _ := json.Marshal("")
|
||||
for i, candidate := range response.Candidates {
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: i,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: relaycommon.StopFinishReason,
|
||||
}
|
||||
content, _ = json.Marshal(candidate.Content.Parts[0].Text)
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
choice.Message.Content = content
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = geminiResponse.GetResponseText()
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = "gemini"
|
||||
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
data = strings.TrimSpace(data)
|
||||
if !strings.HasPrefix(data, "\"text\": \"") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "\"text\": \"")
|
||||
data = strings.TrimSuffix(data, "\"")
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
// this is used to prevent annoying \ related format bug
|
||||
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
||||
type dummyStruct struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
var dummy dummyStruct
|
||||
err := json.Unmarshal([]byte(data), &dummy)
|
||||
responseText += dummy.Content
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = dummy.Content
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "gemini-pro",
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var geminiResponse GeminiChatResponse
|
||||
err = json.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Message: "No candidates returned",
|
||||
Type: "server_error",
|
||||
Param: "",
|
||||
Code: 500,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||
completionTokens := service.CountTokenText(geminiResponse.GetResponseText(), model)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
7
relay/channel/moonshot/constants.go
Normal file
7
relay/channel/moonshot/constants.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package moonshot
|
||||
|
||||
var ModelList = []string{
|
||||
"moonshot-v1-8k",
|
||||
"moonshot-v1-32k",
|
||||
"moonshot-v1-128k",
|
||||
}
|
||||
84
relay/channel/openai/adaptor.go
Normal file
84
relay/channel/openai/adaptor.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaychannel "one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.ChannelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||
requestURL := strings.Split(info.RequestURLPath, "?")[0]
|
||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, info.ApiVersion)
|
||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||
model_ := info.UpstreamModelName
|
||||
model_ = strings.Replace(model_, ".", "", -1)
|
||||
// https://github.com/songquanpeng/one-api/issues/67
|
||||
model_ = strings.TrimSuffix(model_, "-0301")
|
||||
model_ = strings.TrimSuffix(model_, "-0314")
|
||||
model_ = strings.TrimSuffix(model_, "-0613")
|
||||
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||
}
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
relaychannel.SetupApiRequestHeader(info, c, req)
|
||||
if info.ChannelType == common.ChannelTypeAzure {
|
||||
req.Header.Set("api-key", info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
if info.ChannelType == common.ChannelTypeOpenRouter {
|
||||
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||||
req.Header.Set("X-Title", "One API")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = openaiStreamHandler(c, resp, info.RelayMode)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = openaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
21
relay/channel/openai/constant.go
Normal file
21
relay/channel/openai/constant.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package openai
|
||||
|
||||
var ModelList = []string{
|
||||
"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
|
||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
|
||||
"text-moderation-latest", "text-moderation-stable",
|
||||
"text-davinci-edit-001",
|
||||
"davinci-002", "babbage-002",
|
||||
"dall-e-2", "dall-e-3",
|
||||
"whisper-1",
|
||||
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
|
||||
}
|
||||
|
||||
var ChannelName = "openai"
|
||||
165
relay/channel/openai/relay-openai.go
Normal file
165
relay/channel/openai/relay-openai.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
var responseTextBuilder strings.Builder
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string, 5)
|
||||
stopChan := make(chan bool, 2)
|
||||
defer close(stopChan)
|
||||
defer close(dataChan)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
go func() {
|
||||
wg.Add(1)
|
||||
defer wg.Done()
|
||||
var streamItems []string
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
||||
continue
|
||||
}
|
||||
dataChan <- data
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
streamItems = append(streamItems, data)
|
||||
}
|
||||
}
|
||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
var streamResponses []dto.ChatCompletionsStreamResponseSimple
|
||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return // just ignore the error
|
||||
}
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.Content)
|
||||
}
|
||||
}
|
||||
case relayconstant.RelayModeCompletions:
|
||||
var streamResponses []dto.CompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return // just ignore the error
|
||||
}
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(dataChan) > 0 {
|
||||
// wait data out
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
common.SafeSend(stopChan, true)
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
if strings.HasPrefix(data, "data: [DONE]") {
|
||||
data = data[:12]
|
||||
}
|
||||
// some implementations may add \r at the end of data
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
c.Render(-1, common.CustomEvent{Data: data})
|
||||
return true
|
||||
case <-stopChan:
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
wg.Wait()
|
||||
return nil, responseTextBuilder.String()
|
||||
}
|
||||
|
||||
func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var textResponse dto.TextResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if textResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: textResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if textResponse.Usage.TotalTokens == 0 {
|
||||
completionTokens := 0
|
||||
for _, choice := range textResponse.Choices {
|
||||
completionTokens += service.CountTokenText(string(choice.Message.Content), model)
|
||||
}
|
||||
textResponse.Usage = dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
return nil, &textResponse.Usage
|
||||
}
|
||||
59
relay/channel/palm/adaptor.go
Normal file
59
relay/channel/palm/adaptor.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package palm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
relaychannel "one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
relaychannel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("x-goog-api-key", info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = palmStreamHandler(c, resp)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
7
relay/channel/palm/constants.go
Normal file
7
relay/channel/palm/constants.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package palm
|
||||
|
||||
var ModelList = []string{
|
||||
"PaLM-2",
|
||||
}
|
||||
|
||||
var ChannelName = "google palm"
|
||||
38
relay/channel/palm/dto.go
Normal file
38
relay/channel/palm/dto.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package palm
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type PaLMChatMessage struct {
|
||||
Author string `json:"author"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type PaLMFilter struct {
|
||||
Reason string `json:"reason"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type PaLMPrompt struct {
|
||||
Messages []PaLMChatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type PaLMChatRequest struct {
|
||||
Prompt PaLMPrompt `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK uint `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type PaLMError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type PaLMChatResponse struct {
|
||||
Candidates []PaLMChatMessage `json:"candidates"`
|
||||
Messages []dto.Message `json:"messages"`
|
||||
Filters []PaLMFilter `json:"filters"`
|
||||
Error PaLMError `json:"error"`
|
||||
}
|
||||
174
relay/channel/palm/relay-palm.go
Normal file
174
relay/channel/palm/relay-palm.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package palm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
||||
|
||||
func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
|
||||
palmRequest := PaLMChatRequest{
|
||||
Prompt: PaLMPrompt{
|
||||
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
|
||||
},
|
||||
Temperature: textRequest.Temperature,
|
||||
CandidateCount: textRequest.N,
|
||||
TopP: textRequest.TopP,
|
||||
TopK: textRequest.MaxTokens,
|
||||
}
|
||||
for _, message := range textRequest.Messages {
|
||||
palmMessage := PaLMChatMessage{
|
||||
Content: message.StringContent(),
|
||||
}
|
||||
if message.Role == "user" {
|
||||
palmMessage.Author = "0"
|
||||
} else {
|
||||
palmMessage.Author = "1"
|
||||
}
|
||||
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
|
||||
}
|
||||
return &palmRequest
|
||||
}
|
||||
|
||||
func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
for i, candidate := range response.Candidates {
|
||||
content, _ := json.Marshal(candidate.Content)
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: i,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
if len(palmResponse.Candidates) > 0 {
|
||||
choice.Delta.Content = palmResponse.Candidates[0].Content
|
||||
}
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = "palm2"
|
||||
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createdTime := common.GetTimestamp()
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.SysError("error reading stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
common.SysError("error closing stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
var palmResponse PaLMChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
|
||||
fullTextResponse.Id = responseId
|
||||
fullTextResponse.Created = createdTime
|
||||
if len(palmResponse.Candidates) > 0 {
|
||||
responseText = palmResponse.Candidates[0].Content
|
||||
}
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
dataChan <- string(jsonResponse)
|
||||
stopChan <- true
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + data})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var palmResponse PaLMChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Message: palmResponse.Error.Message,
|
||||
Type: palmResponse.Error.Status,
|
||||
Param: "",
|
||||
Code: palmResponse.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||
completionTokens := service.CountTokenText(palmResponse.Candidates[0].Content, model)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
73
relay/channel/tencent/adaptor.go
Normal file
73
relay/channel/tencent/adaptor.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package tencent
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
relaychannel "one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
Sign string
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/hyllm/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
relaychannel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", a.Sign)
|
||||
req.Header.Set("X-TC-Action", info.UpstreamModelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tencentRequest := requestOpenAI2Tencent(*request)
|
||||
tencentRequest.AppId = appId
|
||||
tencentRequest.SecretId = secretId
|
||||
// we have to calculate the sign here
|
||||
a.Sign = getTencentSign(*tencentRequest, secretKey)
|
||||
return tencentRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = tencentStreamHandler(c, resp)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = tencentHandler(c, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
9
relay/channel/tencent/constants.go
Normal file
9
relay/channel/tencent/constants.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package tencent
|
||||
|
||||
var ModelList = []string{
|
||||
"ChatPro",
|
||||
"ChatStd",
|
||||
"hunyuan",
|
||||
}
|
||||
|
||||
var ChannelName = "tencent"
|
||||
61
relay/channel/tencent/dto.go
Normal file
61
relay/channel/tencent/dto.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package tencent
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type TencentMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type TencentChatRequest struct {
|
||||
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
||||
SecretId string `json:"secret_id"` // 官网 SecretId
|
||||
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
||||
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
||||
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
||||
Expired int64 `json:"expired"`
|
||||
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
||||
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
||||
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
||||
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
||||
Temperature float64 `json:"temperature"`
|
||||
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
||||
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
||||
// 建议该参数和 temperature 只设置1个,不要同时更改
|
||||
TopP float64 `json:"top_p"`
|
||||
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
||||
// 同步请求超时:60s,如果内容较长建议使用流式
|
||||
Stream int `json:"stream"`
|
||||
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
||||
// 输入 content 总数最大支持 3000 token。
|
||||
Messages []TencentMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type TencentError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type TencentUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type TencentResponseChoices struct {
|
||||
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
}
|
||||
|
||||
type TencentChatResponse struct {
|
||||
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||
Id string `json:"id,omitempty"` // 会话 id
|
||||
Usage dto.Usage `json:"usage,omitempty"` // token 数量
|
||||
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"note,omitempty"` // 注释
|
||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
}
|
||||
233
relay/channel/tencent/relay-tencent.go
Normal file
233
relay/channel/tencent/relay-tencent.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package tencent
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// https://cloud.tencent.com/document/product/1729/97732
|
||||
|
||||
func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest {
|
||||
messages := make([]TencentMessage, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, TencentMessage{
|
||||
Role: "user",
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, TencentMessage{
|
||||
Role: "assistant",
|
||||
Content: "Okay",
|
||||
})
|
||||
continue
|
||||
}
|
||||
messages = append(messages, TencentMessage{
|
||||
Content: message.StringContent(),
|
||||
Role: message.Role,
|
||||
})
|
||||
}
|
||||
stream := 0
|
||||
if request.Stream {
|
||||
stream = 1
|
||||
}
|
||||
return &TencentChatRequest{
|
||||
Timestamp: common.GetTimestamp(),
|
||||
Expired: common.GetTimestamp() + 24*60*60,
|
||||
QueryID: common.GetUUID(),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
Stream: stream,
|
||||
Messages: messages,
|
||||
}
|
||||
}
|
||||
|
||||
func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Usage: response.Usage,
|
||||
}
|
||||
if len(response.Choices) > 0 {
|
||||
content, _ := json.Marshal(response.Choices[0].Messages.Content)
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: response.Choices[0].FinishReason,
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "tencent-hunyuan",
|
||||
}
|
||||
if len(TencentResponse.Choices) > 0 {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
||||
if TencentResponse.Choices[0].FinishReason == "stop" {
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
}
|
||||
response.Choices = append(response.Choices, choice)
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
var responseText string
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var TencentResponse TencentChatResponse
|
||||
err := json.Unmarshal([]byte(data), &TencentResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
response := streamResponseTencent2OpenAI(&TencentResponse)
|
||||
if len(response.Choices) != 0 {
|
||||
responseText += response.Choices[0].Delta.Content
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var TencentResponse TencentChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &TencentResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if TencentResponse.Error.Code != 0 {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Message: TencentResponse.Error.Message,
|
||||
Code: TencentResponse.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
|
||||
parts := strings.Split(config, "|")
|
||||
if len(parts) != 3 {
|
||||
err = errors.New("invalid tencent config")
|
||||
return
|
||||
}
|
||||
appId, err = strconv.ParseInt(parts[0], 10, 64)
|
||||
secretId = parts[1]
|
||||
secretKey = parts[2]
|
||||
return
|
||||
}
|
||||
|
||||
func getTencentSign(req TencentChatRequest, secretKey string) string {
|
||||
params := make([]string, 0)
|
||||
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
|
||||
params = append(params, "secret_id="+req.SecretId)
|
||||
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
|
||||
params = append(params, "query_id="+req.QueryID)
|
||||
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
|
||||
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
|
||||
params = append(params, "stream="+strconv.Itoa(req.Stream))
|
||||
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
|
||||
|
||||
var messageStr string
|
||||
for _, msg := range req.Messages {
|
||||
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
|
||||
}
|
||||
messageStr = strings.TrimSuffix(messageStr, ",")
|
||||
params = append(params, "messages=["+messageStr+"]")
|
||||
|
||||
sort.Sort(sort.StringSlice(params))
|
||||
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
|
||||
mac := hmac.New(sha1.New, []byte(secretKey))
|
||||
signURL := url
|
||||
mac.Write([]byte(signURL))
|
||||
sign := mac.Sum([]byte(nil))
|
||||
return base64.StdEncoding.EncodeToString(sign)
|
||||
}
|
||||
68
relay/channel/xunfei/adaptor.go
Normal file
68
relay/channel/xunfei/adaptor.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package xunfei
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
relaychannel "one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
request *dto.GeneralOpenAIRequest
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
relaychannel.SetupApiRequestHeader(info, c, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
a.request = request
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
// xunfei's request is not http request, so we don't need to do anything here
|
||||
dummyResp := &http.Response{}
|
||||
dummyResp.StatusCode = http.StatusOK
|
||||
return dummyResp, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
splits := strings.Split(info.ApiKey, "|")
|
||||
if len(splits) != 3 {
|
||||
return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||
}
|
||||
if a.request == nil {
|
||||
return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
|
||||
}
|
||||
if info.IsStream {
|
||||
err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
|
||||
} else {
|
||||
err, usage = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2])
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
11
relay/channel/xunfei/constants.go
Normal file
11
relay/channel/xunfei/constants.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package xunfei
|
||||
|
||||
var ModelList = []string{
|
||||
"SparkDesk",
|
||||
"SparkDesk-v1.1",
|
||||
"SparkDesk-v2.1",
|
||||
"SparkDesk-v3.1",
|
||||
"SparkDesk-v3.5",
|
||||
}
|
||||
|
||||
var ChannelName = "xunfei"
|
||||
59
relay/channel/xunfei/dto.go
Normal file
59
relay/channel/xunfei/dto.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package xunfei
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type XunfeiMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type XunfeiChatRequest struct {
|
||||
Header struct {
|
||||
AppId string `json:"app_id"`
|
||||
} `json:"header"`
|
||||
Parameter struct {
|
||||
Chat struct {
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Auditing bool `json:"auditing,omitempty"`
|
||||
} `json:"chat"`
|
||||
} `json:"parameter"`
|
||||
Payload struct {
|
||||
Message struct {
|
||||
Text []XunfeiMessage `json:"text"`
|
||||
} `json:"message"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
|
||||
type XunfeiChatResponseTextItem struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type XunfeiChatResponse struct {
|
||||
Header struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Sid string `json:"sid"`
|
||||
Status int `json:"status"`
|
||||
} `json:"header"`
|
||||
Payload struct {
|
||||
Choices struct {
|
||||
Status int `json:"status"`
|
||||
Seq int `json:"seq"`
|
||||
Text []XunfeiChatResponseTextItem `json:"text"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
//Text struct {
|
||||
// QuestionTokens string `json:"question_tokens"`
|
||||
// PromptTokens string `json:"prompt_tokens"`
|
||||
// CompletionTokens string `json:"completion_tokens"`
|
||||
// TotalTokens string `json:"total_tokens"`
|
||||
//} `json:"text"`
|
||||
Text dto.Usage `json:"text"`
|
||||
} `json:"usage"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
254
relay/channel/xunfei/relay-xunfei.go
Normal file
254
relay/channel/xunfei/relay-xunfei.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package xunfei
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// https://console.xfyun.cn/services/cbm
|
||||
// https://www.xfyun.cn/doc/spark/Web.html
|
||||
|
||||
func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
|
||||
messages := make([]XunfeiMessage, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: "user",
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: "assistant",
|
||||
Content: "Okay",
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
xunfeiRequest := XunfeiChatRequest{}
|
||||
xunfeiRequest.Header.AppId = xunfeiAppId
|
||||
xunfeiRequest.Parameter.Chat.Domain = domain
|
||||
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
|
||||
xunfeiRequest.Parameter.Chat.TopK = request.N
|
||||
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
|
||||
xunfeiRequest.Payload.Message.Text = messages
|
||||
return &xunfeiRequest
|
||||
}
|
||||
|
||||
func responseXunfei2OpenAI(response *XunfeiChatResponse) *dto.OpenAITextResponse {
|
||||
if len(response.Payload.Choices.Text) == 0 {
|
||||
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
||||
{
|
||||
Content: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
content, _ := json.Marshal(response.Payload.Choices.Text[0].Content)
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: relaycommon.StopFinishReason,
|
||||
}
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []dto.OpenAITextResponseChoice{choice},
|
||||
Usage: response.Payload.Usage.Text,
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
||||
{
|
||||
Content: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
if xunfeiResponse.Payload.Choices.Status == 2 {
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
}
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "SparkDesk",
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
||||
HmacWithShaToBase64 := func(algorithm, data, key string) string {
|
||||
mac := hmac.New(sha256.New, []byte(key))
|
||||
mac.Write([]byte(data))
|
||||
encodeData := mac.Sum(nil)
|
||||
return base64.StdEncoding.EncodeToString(encodeData)
|
||||
}
|
||||
ul, err := url.Parse(hostUrl)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
date := time.Now().UTC().Format(time.RFC1123)
|
||||
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
|
||||
sign := strings.Join(signString, "\n")
|
||||
sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
|
||||
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
|
||||
"hmac-sha256", "host date request-line", sha)
|
||||
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
|
||||
v := url.Values{}
|
||||
v.Add("host", ul.Host)
|
||||
v.Add("date", date)
|
||||
v.Add("authorization", authorization)
|
||||
callUrl := hostUrl + "?" + v.Encode()
|
||||
return callUrl
|
||||
}
|
||||
|
||||
func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||
}
|
||||
service.SetEventStreamHeaders(c)
|
||||
var usage dto.Usage
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case xunfeiResponse := <-dataChan:
|
||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||
}
|
||||
var usage dto.Usage
|
||||
var content string
|
||||
var xunfeiResponse XunfeiChatResponse
|
||||
stop := false
|
||||
for !stop {
|
||||
select {
|
||||
case xunfeiResponse = <-dataChan:
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
continue
|
||||
}
|
||||
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||
case stop = <-stopChan:
|
||||
}
|
||||
}
|
||||
|
||||
xunfeiResponse.Payload.Choices.Text[0].Content = content
|
||||
|
||||
response := responseXunfei2OpenAI(&xunfeiResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
|
||||
d := websocket.Dialer{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
}
|
||||
conn, resp, err := d.Dial(authUrl, nil)
|
||||
if err != nil || resp.StatusCode != 101 {
|
||||
return nil, nil, err
|
||||
}
|
||||
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
||||
err = conn.WriteJSON(data)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dataChan := make(chan XunfeiChatResponse)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
common.SysError("error reading stream response: " + err.Error())
|
||||
break
|
||||
}
|
||||
var response XunfeiChatResponse
|
||||
err = json.Unmarshal(msg, &response)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
break
|
||||
}
|
||||
dataChan <- response
|
||||
if response.Payload.Choices.Status == 2 {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
common.SysError("error closing websocket connection: " + err.Error())
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
|
||||
return dataChan, stopChan, nil
|
||||
}
|
||||
|
||||
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion == "" {
|
||||
apiVersion = c.GetString("api_version")
|
||||
}
|
||||
if apiVersion == "" {
|
||||
apiVersion = "v1.1"
|
||||
common.SysLog("api_version not found, use default: " + apiVersion)
|
||||
}
|
||||
domain := "general"
|
||||
if apiVersion != "v1.1" {
|
||||
domain += strings.Split(apiVersion, ".")[0]
|
||||
}
|
||||
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
||||
return domain, authUrl
|
||||
}
|
||||
61
relay/channel/zhipu/adaptor.go
Normal file
61
relay/channel/zhipu/adaptor.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
relaychannel "one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
method := "invoke"
|
||||
if info.IsStream {
|
||||
method = "sse-invoke"
|
||||
}
|
||||
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
relaychannel.SetupApiRequestHeader(info, c, req)
|
||||
token := getZhipuToken(info.ApiKey)
|
||||
req.Header.Set("Authorization", token)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return requestOpenAI2Zhipu(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = zhipuStreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = zhipuHandler(c, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
7
relay/channel/zhipu/constants.go
Normal file
7
relay/channel/zhipu/constants.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package zhipu
|
||||
|
||||
var ModelList = []string{
|
||||
"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
|
||||
}
|
||||
|
||||
var ChannelName = "zhipu"
|
||||
46
relay/channel/zhipu/dto.go
Normal file
46
relay/channel/zhipu/dto.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"one-api/dto"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ZhipuMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ZhipuRequest struct {
|
||||
Prompt []ZhipuMessage `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
RequestId string `json:"request_id,omitempty"`
|
||||
Incremental bool `json:"incremental,omitempty"`
|
||||
}
|
||||
|
||||
type ZhipuResponseData struct {
|
||||
TaskId string `json:"task_id"`
|
||||
RequestId string `json:"request_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
Choices []ZhipuMessage `json:"choices"`
|
||||
dto.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ZhipuResponse struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Success bool `json:"success"`
|
||||
Data ZhipuResponseData `json:"data"`
|
||||
}
|
||||
|
||||
type ZhipuStreamMetaResponse struct {
|
||||
RequestId string `json:"request_id"`
|
||||
TaskId string `json:"task_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
dto.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type zhipuTokenData struct {
|
||||
Token string
|
||||
ExpiryTime time.Time
|
||||
}
|
||||
265
relay/channel/zhipu/relay-zhipu.go
Normal file
265
relay/channel/zhipu/relay-zhipu.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// https://open.bigmodel.cn/doc/api#chatglm_std
|
||||
// chatglm_std, chatglm_lite
|
||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
|
||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
|
||||
|
||||
var zhipuTokens sync.Map
|
||||
var expSeconds int64 = 24 * 3600
|
||||
|
||||
func getZhipuToken(apikey string) string {
|
||||
data, ok := zhipuTokens.Load(apikey)
|
||||
if ok {
|
||||
tokenData := data.(zhipuTokenData)
|
||||
if time.Now().Before(tokenData.ExpiryTime) {
|
||||
return tokenData.Token
|
||||
}
|
||||
}
|
||||
|
||||
split := strings.Split(apikey, ".")
|
||||
if len(split) != 2 {
|
||||
common.SysError("invalid zhipu key: " + apikey)
|
||||
return ""
|
||||
}
|
||||
|
||||
id := split[0]
|
||||
secret := split[1]
|
||||
|
||||
expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
|
||||
expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
|
||||
|
||||
timestamp := time.Now().UnixNano() / 1e6
|
||||
|
||||
payload := jwt.MapClaims{
|
||||
"api_key": id,
|
||||
"exp": expMillis,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
||||
|
||||
token.Header["alg"] = "HS256"
|
||||
token.Header["sign_type"] = "SIGN"
|
||||
|
||||
tokenString, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
zhipuTokens.Store(apikey, zhipuTokenData{
|
||||
Token: tokenString,
|
||||
ExpiryTime: expiryTime,
|
||||
})
|
||||
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *ZhipuRequest {
|
||||
messages := make([]ZhipuMessage, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, ZhipuMessage{
|
||||
Role: "system",
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, ZhipuMessage{
|
||||
Role: "user",
|
||||
Content: "Okay",
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, ZhipuMessage{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return &ZhipuRequest{
|
||||
Prompt: messages,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
Incremental: false,
|
||||
}
|
||||
}
|
||||
|
||||
func responseZhipu2OpenAI(response *ZhipuResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: response.Data.TaskId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Data.Choices)),
|
||||
Usage: response.Data.Usage,
|
||||
}
|
||||
for i, choice := range response.Data.Choices {
|
||||
content, _ := json.Marshal(strings.Trim(choice.Content, "\""))
|
||||
openaiChoice := dto.OpenAITextResponseChoice{
|
||||
Index: i,
|
||||
Message: dto.Message{
|
||||
Role: choice.Role,
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: "",
|
||||
}
|
||||
if i == len(response.Data.Choices)-1 {
|
||||
openaiChoice.FinishReason = "stop"
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = zhipuResponse
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "chatglm",
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = ""
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Id: zhipuResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "chatglm",
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response, &zhipuResponse.Usage
|
||||
}
|
||||
|
||||
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var usage *dto.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
|
||||
return i + 2, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
metaChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
lines := strings.Split(data, "\n")
|
||||
for i, line := range lines {
|
||||
if len(line) < 5 {
|
||||
continue
|
||||
}
|
||||
if line[:5] == "data:" {
|
||||
dataChan <- line[5:]
|
||||
if i != len(lines)-1 {
|
||||
dataChan <- "\n"
|
||||
}
|
||||
} else if line[:5] == "meta:" {
|
||||
metaChan <- line[5:]
|
||||
}
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
response := streamResponseZhipu2OpenAI(data)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case data := <-metaChan:
|
||||
var zhipuResponse ZhipuStreamMetaResponse
|
||||
err := json.Unmarshal([]byte(data), &zhipuResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
usage = zhipuUsage
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var zhipuResponse ZhipuResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &zhipuResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if !zhipuResponse.Success {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Message: zhipuResponse.Msg,
|
||||
Type: "zhipu_error",
|
||||
Param: "",
|
||||
Code: zhipuResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
71
relay/common/relay_info.go
Normal file
71
relay/common/relay_info.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
"one-api/relay/constant"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RelayInfo struct {
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
UserId int
|
||||
Group string
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
ApiType int
|
||||
IsStream bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
RequestURLPath string
|
||||
ApiVersion string
|
||||
PromptTokens int
|
||||
ApiKey string
|
||||
BaseUrl string
|
||||
}
|
||||
|
||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
tokenUnlimited := c.GetBool("token_unlimited_quota")
|
||||
startTime := time.Now()
|
||||
|
||||
apiType := constant.ChannelType2APIType(channelType)
|
||||
|
||||
info := &RelayInfo{
|
||||
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: c.GetString("base_url"),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
ChannelType: channelType,
|
||||
ChannelId: channelId,
|
||||
TokenId: tokenId,
|
||||
UserId: userId,
|
||||
Group: group,
|
||||
TokenUnlimited: tokenUnlimited,
|
||||
StartTime: startTime,
|
||||
ApiType: apiType,
|
||||
ApiVersion: c.GetString("api_version"),
|
||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||
}
|
||||
if info.BaseUrl == "" {
|
||||
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
||||
}
|
||||
//if info.ChannelType == common.ChannelTypeAzure {
|
||||
// info.ApiVersion = GetAzureAPIVersion(c)
|
||||
//}
|
||||
return info
|
||||
}
|
||||
|
||||
func (info *RelayInfo) SetPromptTokens(promptTokens int) {
|
||||
info.PromptTokens = promptTokens
|
||||
}
|
||||
|
||||
func (info *RelayInfo) SetIsStream(isStream bool) {
|
||||
info.IsStream = isStream
|
||||
}
|
||||
68
relay/common/relay_utils.go
Normal file
68
relay/common/relay_utils.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var StopFinishReason = "stop"
|
||||
|
||||
func RelayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||
openAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
|
||||
StatusCode: resp.StatusCode,
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
|
||||
Type: "upstream_error",
|
||||
Code: "bad_response_status_code",
|
||||
Param: strconv.Itoa(resp.StatusCode),
|
||||
},
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var textResponse dto.TextResponse
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
|
||||
return
|
||||
}
|
||||
|
||||
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
|
||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||
switch channelType {
|
||||
case common.ChannelTypeOpenAI:
|
||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
||||
case common.ChannelTypeAzure:
|
||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
||||
}
|
||||
}
|
||||
return fullRequestURL
|
||||
}
|
||||
|
||||
func GetAPIVersion(c *gin.Context) string {
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion == "" {
|
||||
apiVersion = c.GetString("api_version")
|
||||
}
|
||||
return apiVersion
|
||||
}
|
||||
45
relay/constant/api_type.go
Normal file
45
relay/constant/api_type.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package constant
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
const (
|
||||
APITypeOpenAI = iota
|
||||
APITypeAnthropic
|
||||
APITypePaLM
|
||||
APITypeBaidu
|
||||
APITypeZhipu
|
||||
APITypeAli
|
||||
APITypeXunfei
|
||||
APITypeAIProxyLibrary
|
||||
APITypeTencent
|
||||
APITypeGemini
|
||||
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
|
||||
func ChannelType2APIType(channelType int) int {
|
||||
apiType := APITypeOpenAI
|
||||
switch channelType {
|
||||
case common.ChannelTypeAnthropic:
|
||||
apiType = APITypeAnthropic
|
||||
case common.ChannelTypeBaidu:
|
||||
apiType = APITypeBaidu
|
||||
case common.ChannelTypePaLM:
|
||||
apiType = APITypePaLM
|
||||
case common.ChannelTypeZhipu:
|
||||
apiType = APITypeZhipu
|
||||
case common.ChannelTypeAli:
|
||||
apiType = APITypeAli
|
||||
case common.ChannelTypeXunfei:
|
||||
apiType = APITypeXunfei
|
||||
case common.ChannelTypeAIProxyLibrary:
|
||||
apiType = APITypeAIProxyLibrary
|
||||
case common.ChannelTypeTencent:
|
||||
apiType = APITypeTencent
|
||||
case common.ChannelTypeGemini:
|
||||
apiType = APITypeGemini
|
||||
}
|
||||
return apiType
|
||||
}
|
||||
50
relay/constant/relay_mode.go
Normal file
50
relay/constant/relay_mode.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package constant
|
||||
|
||||
import "strings"
|
||||
|
||||
const (
|
||||
RelayModeUnknown = iota
|
||||
RelayModeChatCompletions
|
||||
RelayModeCompletions
|
||||
RelayModeEmbeddings
|
||||
RelayModeModerations
|
||||
RelayModeImagesGenerations
|
||||
RelayModeEdits
|
||||
RelayModeMidjourneyImagine
|
||||
RelayModeMidjourneyDescribe
|
||||
RelayModeMidjourneyBlend
|
||||
RelayModeMidjourneyChange
|
||||
RelayModeMidjourneySimpleChange
|
||||
RelayModeMidjourneyNotify
|
||||
RelayModeMidjourneyTaskFetch
|
||||
RelayModeMidjourneyTaskFetchByCondition
|
||||
RelayModeAudioSpeech
|
||||
RelayModeAudioTranscription
|
||||
RelayModeAudioTranslation
|
||||
)
|
||||
|
||||
func Path2RelayMode(path string) int {
|
||||
relayMode := RelayModeUnknown
|
||||
if strings.HasPrefix(path, "/v1/chat/completions") {
|
||||
relayMode = RelayModeChatCompletions
|
||||
} else if strings.HasPrefix(path, "/v1/completions") {
|
||||
relayMode = RelayModeCompletions
|
||||
} else if strings.HasPrefix(path, "/v1/embeddings") {
|
||||
relayMode = RelayModeEmbeddings
|
||||
} else if strings.HasSuffix(path, "embeddings") {
|
||||
relayMode = RelayModeEmbeddings
|
||||
} else if strings.HasPrefix(path, "/v1/moderations") {
|
||||
relayMode = RelayModeModerations
|
||||
} else if strings.HasPrefix(path, "/v1/images/generations") {
|
||||
relayMode = RelayModeImagesGenerations
|
||||
} else if strings.HasPrefix(path, "/v1/edits") {
|
||||
relayMode = RelayModeEdits
|
||||
} else if strings.HasPrefix(path, "/v1/audio/speech") {
|
||||
relayMode = RelayModeAudioSpeech
|
||||
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
|
||||
relayMode = RelayModeAudioTranscription
|
||||
} else if strings.HasPrefix(path, "/v1/audio/translations") {
|
||||
relayMode = RelayModeAudioTranslation
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
228
relay/relay-audio.go
Normal file
228
relay/relay-audio.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/controller"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var availableVoices = []string{
|
||||
"alloy",
|
||||
"echo",
|
||||
"fable",
|
||||
"onyx",
|
||||
"nova",
|
||||
"shimmer",
|
||||
}
|
||||
|
||||
func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWithStatusCode {
|
||||
tokenId := c.GetInt("token_id")
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
startTime := time.Now()
|
||||
|
||||
var audioRequest AudioRequest
|
||||
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||
err := common.UnmarshalBodyReusable(c, &audioRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
} else {
|
||||
audioRequest = AudioRequest{
|
||||
Model: "whisper-1",
|
||||
}
|
||||
}
|
||||
//err := common.UnmarshalBodyReusable(c, &audioRequest)
|
||||
|
||||
// request validation
|
||||
if audioRequest.Model == "" {
|
||||
return service.OpenAIErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
||||
if audioRequest.Voice == "" {
|
||||
return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
if !common.StringsContains(availableVoices, audioRequest.Voice) {
|
||||
return service.OpenAIErrorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
modelRatio := common.GetModelRatio(audioRequest.Model)
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
ratio := modelRatio * groupRatio
|
||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
preConsumedQuota = 0
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
if modelMapping != "" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[audioRequest.Model] != "" {
|
||||
audioRequest.Model = modelMap[audioRequest.Model]
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
|
||||
fullRequestURL := common.getFullRequestURL(baseURL, requestURL, channelType)
|
||||
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
||||
apiVersion := common.GetAPIVersion(c)
|
||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
|
||||
}
|
||||
|
||||
requestBody := c.Request.Body
|
||||
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
req.Header.Set("api-key", apiKey)
|
||||
req.ContentLength = c.Request.ContentLength
|
||||
} else {
|
||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
|
||||
resp, err := controller.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return common.relayErrorHandler(resp)
|
||||
}
|
||||
|
||||
var audioResponse dto.AudioResponse
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
go func() {
|
||||
useTimeSeconds := time.Now().Unix() - startTime.Unix()
|
||||
quota := 0
|
||||
var promptTokens = 0
|
||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
||||
quota = service.countAudioToken(audioRequest.Input, audioRequest.Model)
|
||||
promptTokens = quota
|
||||
} else {
|
||||
quota = service.countAudioToken(audioResponse.Text, audioRequest.Model)
|
||||
}
|
||||
quota = int(float64(quota) * ratio)
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, 0, audioRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
}()
|
||||
}(c.Request.Context())
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
||||
|
||||
} else {
|
||||
err = json.Unmarshal(responseBody, &audioResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
233
relay/relay-image.go
Normal file
233
relay/relay-image.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/controller"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/relay/common"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
tokenId := c.GetInt("token_id")
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
userId := c.GetInt("id")
|
||||
consumeQuota := c.GetBool("consume_quota")
|
||||
group := c.GetString("group")
|
||||
startTime := time.Now()
|
||||
|
||||
var imageRequest dto.ImageRequest
|
||||
if consumeQuota {
|
||||
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
if imageRequest.Model == "" {
|
||||
imageRequest.Model = "dall-e-2"
|
||||
}
|
||||
if imageRequest.Size == "" {
|
||||
imageRequest.Size = "1024x1024"
|
||||
}
|
||||
if imageRequest.N == 0 {
|
||||
imageRequest.N = 1
|
||||
}
|
||||
// Prompt validation
|
||||
if imageRequest.Prompt == "" {
|
||||
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if strings.Contains(imageRequest.Size, "×") {
|
||||
return errorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
|
||||
}
|
||||
// Not "256x256", "512x512", or "1024x1024"
|
||||
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
|
||||
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
|
||||
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
|
||||
}
|
||||
} else if imageRequest.Model == "dall-e-3" {
|
||||
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
|
||||
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
|
||||
}
|
||||
if imageRequest.N != 1 {
|
||||
return errorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
// N should between 1 and 10
|
||||
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
|
||||
return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
isModelMapped := false
|
||||
if modelMapping != "" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[imageRequest.Model] != "" {
|
||||
imageRequest.Model = modelMap[imageRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
}
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
||||
if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
||||
apiVersion := common.GetAPIVersion(c)
|
||||
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
|
||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion)
|
||||
}
|
||||
var requestBody io.Reader
|
||||
if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
|
||||
jsonStr, err := json.Marshal(imageRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
|
||||
modelRatio := common.GetModelRatio(imageRequest.Model)
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
ratio := modelRatio * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
|
||||
sizeRatio := 1.0
|
||||
// Size
|
||||
if imageRequest.Size == "256x256" {
|
||||
sizeRatio = 1
|
||||
} else if imageRequest.Size == "512x512" {
|
||||
sizeRatio = 1.125
|
||||
} else if imageRequest.Size == "1024x1024" {
|
||||
sizeRatio = 1.25
|
||||
} else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
|
||||
sizeRatio = 2.5
|
||||
}
|
||||
|
||||
qualityRatio := 1.0
|
||||
if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" {
|
||||
qualityRatio = 2.0
|
||||
if imageRequest.Size == "1024×1792" || imageRequest.Size == "1792×1024" {
|
||||
qualityRatio = 1.5
|
||||
}
|
||||
}
|
||||
|
||||
quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
|
||||
|
||||
if consumeQuota && userQuota-quota < 0 {
|
||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
token := c.Request.Header.Get("Authorization")
|
||||
if channelType == common.ChannelTypeAzure { // Azure authentication
|
||||
token = strings.TrimPrefix(token, "Bearer ")
|
||||
req.Header.Set("api-key", token)
|
||||
} else {
|
||||
req.Header.Set("Authorization", token)
|
||||
}
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
|
||||
resp, err := controller.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return relayErrorHandler(resp)
|
||||
}
|
||||
|
||||
var textResponse ImageResponse
|
||||
defer func(ctx context.Context) {
|
||||
useTimeSeconds := time.Now().Unix() - startTime.Unix()
|
||||
if consumeQuota {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
}
|
||||
}(c.Request.Context())
|
||||
|
||||
if consumeQuota {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
}
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
651
relay/relay-mj.go
Normal file
651
relay/relay-mj.go
Normal file
@@ -0,0 +1,651 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/controller"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Midjourney struct {
|
||||
MjId string `json:"id"`
|
||||
Action string `json:"action"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"promptEn"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submitTime"`
|
||||
StartTime int64 `json:"startTime"`
|
||||
FinishTime int64 `json:"finishTime"`
|
||||
ImageUrl string `json:"imageUrl"`
|
||||
Status string `json:"status"`
|
||||
Progress string `json:"progress"`
|
||||
FailReason string `json:"failReason"`
|
||||
}
|
||||
|
||||
type MidjourneyStatus struct {
|
||||
Status int `json:"status"`
|
||||
}
|
||||
type MidjourneyWithoutStatus struct {
|
||||
Id int `json:"id"`
|
||||
Code int `json:"code"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Action string `json:"action"`
|
||||
MjId string `json:"mj_id" gorm:"index"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"prompt_en"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submit_time"`
|
||||
StartTime int64 `json:"start_time"`
|
||||
FinishTime int64 `json:"finish_time"`
|
||||
ImageUrl string `json:"image_url"`
|
||||
Progress string `json:"progress"`
|
||||
FailReason string `json:"fail_reason"`
|
||||
ChannelId int `json:"channel_id"`
|
||||
}
|
||||
|
||||
var DefaultModelPrice = map[string]float64{
|
||||
"mj_imagine": 0.1,
|
||||
"mj_variation": 0.1,
|
||||
"mj_reroll": 0.1,
|
||||
"mj_blend": 0.1,
|
||||
"mj_describe": 0.05,
|
||||
"mj_upscale": 0.05,
|
||||
}
|
||||
|
||||
func RelayMidjourneyImage(c *gin.Context) {
|
||||
taskId := c.Param("id")
|
||||
midjourneyTask := model.GetByOnlyMJId(taskId)
|
||||
if midjourneyTask == nil {
|
||||
c.JSON(400, gin.H{
|
||||
"error": "midjourney_task_not_found",
|
||||
})
|
||||
return
|
||||
}
|
||||
resp, err := http.Get(midjourneyTask.ImageUrl)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "http_get_image_failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
responseBody, _ := io.ReadAll(resp.Body)
|
||||
c.JSON(resp.StatusCode, gin.H{
|
||||
"error": string(responseBody),
|
||||
})
|
||||
return
|
||||
}
|
||||
// 从Content-Type头获取MIME类型
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
// 如果无法确定内容类型,则默认为jpeg
|
||||
contentType = "image/jpeg"
|
||||
}
|
||||
// 设置响应的内容类型
|
||||
c.Writer.Header().Set("Content-Type", contentType)
|
||||
// 将图片流式传输到响应体
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
log.Println("Failed to stream image:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
|
||||
var midjRequest Midjourney
|
||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "bind_request_body_failed",
|
||||
Properties: nil,
|
||||
Result: "",
|
||||
}
|
||||
}
|
||||
midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
|
||||
if midjourneyTask == nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "midjourney_task_not_found",
|
||||
Properties: nil,
|
||||
Result: "",
|
||||
}
|
||||
}
|
||||
midjourneyTask.Progress = midjRequest.Progress
|
||||
midjourneyTask.PromptEn = midjRequest.PromptEn
|
||||
midjourneyTask.State = midjRequest.State
|
||||
midjourneyTask.SubmitTime = midjRequest.SubmitTime
|
||||
midjourneyTask.StartTime = midjRequest.StartTime
|
||||
midjourneyTask.FinishTime = midjRequest.FinishTime
|
||||
midjourneyTask.ImageUrl = midjRequest.ImageUrl
|
||||
midjourneyTask.Status = midjRequest.Status
|
||||
midjourneyTask.FailReason = midjRequest.FailReason
|
||||
err = midjourneyTask.Update()
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "update_midjourney_task_failed",
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjourneyTask Midjourney) {
|
||||
midjourneyTask.MjId = originTask.MjId
|
||||
midjourneyTask.Progress = originTask.Progress
|
||||
midjourneyTask.PromptEn = originTask.PromptEn
|
||||
midjourneyTask.State = originTask.State
|
||||
midjourneyTask.SubmitTime = originTask.SubmitTime
|
||||
midjourneyTask.StartTime = originTask.StartTime
|
||||
midjourneyTask.FinishTime = originTask.FinishTime
|
||||
midjourneyTask.ImageUrl = ""
|
||||
if originTask.ImageUrl != "" {
|
||||
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
|
||||
if originTask.Status != "SUCCESS" {
|
||||
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||
}
|
||||
}
|
||||
midjourneyTask.Status = originTask.Status
|
||||
midjourneyTask.FailReason = originTask.FailReason
|
||||
midjourneyTask.Action = originTask.Action
|
||||
midjourneyTask.Description = originTask.Description
|
||||
midjourneyTask.Prompt = originTask.Prompt
|
||||
return
|
||||
}
|
||||
|
||||
func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||
userId := c.GetInt("id")
|
||||
var err error
|
||||
var respBody []byte
|
||||
switch relayMode {
|
||||
case RelayModeMidjourneyTaskFetch:
|
||||
taskId := c.Param("id")
|
||||
originTask := model.GetByMJId(userId, taskId)
|
||||
if originTask == nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "task_no_found",
|
||||
}
|
||||
}
|
||||
midjourneyTask := getMidjourneyTaskModel(c, originTask)
|
||||
respBody, err = json.Marshal(midjourneyTask)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_response_body_failed",
|
||||
}
|
||||
}
|
||||
case RelayModeMidjourneyTaskFetchByCondition:
|
||||
var condition = struct {
|
||||
IDs []string `json:"ids"`
|
||||
}{}
|
||||
err = c.BindJSON(&condition)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "do_request_failed",
|
||||
}
|
||||
}
|
||||
var tasks []Midjourney
|
||||
if len(condition.IDs) != 0 {
|
||||
originTasks := model.GetByMJIds(userId, condition.IDs)
|
||||
for _, originTask := range originTasks {
|
||||
midjourneyTask := getMidjourneyTaskModel(c, originTask)
|
||||
tasks = append(tasks, midjourneyTask)
|
||||
}
|
||||
}
|
||||
if tasks == nil {
|
||||
tasks = make([]Midjourney, 0)
|
||||
}
|
||||
respBody, err = json.Marshal(tasks)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_response_body_failed",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
|
||||
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "copy_response_body_failed",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
// type 1 根据 mode 价格不同
|
||||
MJSubmitActionImagine = "IMAGINE"
|
||||
MJSubmitActionVariation = "VARIATION" //变换
|
||||
MJSubmitActionBlend = "BLEND" //混图
|
||||
|
||||
MJSubmitActionReroll = "REROLL" //重新生成
|
||||
// type 2 固定价格
|
||||
MJSubmitActionDescribe = "DESCRIBE"
|
||||
MJSubmitActionUpscale = "UPSCALE" // 放大
|
||||
)
|
||||
|
||||
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||
imageModel := "midjourney"
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
channelType := c.GetInt("channel")
|
||||
userId := c.GetInt("id")
|
||||
consumeQuota := c.GetBool("consume_quota")
|
||||
group := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
var midjRequest MidjourneyRequest
|
||||
if consumeQuota {
|
||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "bind_request_body_failed",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if relayMode == RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
|
||||
if midjRequest.Prompt == "" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "prompt_is_required",
|
||||
}
|
||||
}
|
||||
midjRequest.Action = "IMAGINE"
|
||||
} else if relayMode == RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
|
||||
midjRequest.Action = "DESCRIBE"
|
||||
} else if relayMode == RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
||||
midjRequest.Action = "BLEND"
|
||||
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
||||
mjId := ""
|
||||
if relayMode == RelayModeMidjourneyChange {
|
||||
if midjRequest.TaskId == "" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "taskId_is_required",
|
||||
}
|
||||
} else if midjRequest.Action == "" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "action_is_required",
|
||||
}
|
||||
} else if midjRequest.Index == 0 {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "index_can_only_be_1_2_3_4",
|
||||
}
|
||||
}
|
||||
//action = midjRequest.Action
|
||||
mjId = midjRequest.TaskId
|
||||
} else if relayMode == RelayModeMidjourneySimpleChange {
|
||||
if midjRequest.Content == "" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "content_is_required",
|
||||
}
|
||||
}
|
||||
params := convertSimpleChangeParams(midjRequest.Content)
|
||||
if params == nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "content_parse_failed",
|
||||
}
|
||||
}
|
||||
mjId = params.ID
|
||||
midjRequest.Action = params.Action
|
||||
}
|
||||
|
||||
originTask := model.GetByMJId(userId, mjId)
|
||||
if originTask == nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "task_no_found",
|
||||
}
|
||||
} else if originTask.Action == "UPSCALE" {
|
||||
//return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "upscale_task_can_not_be_change",
|
||||
}
|
||||
} else if originTask.Status != "SUCCESS" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "task_status_is_not_success",
|
||||
}
|
||||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, false)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "channel_not_found",
|
||||
}
|
||||
}
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
c.Set("channel_id", originTask.ChannelId)
|
||||
log.Printf("检测到此操作为放大、变换,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
|
||||
}
|
||||
midjRequest.Prompt = originTask.Prompt
|
||||
}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
isModelMapped := false
|
||||
if modelMapping != "" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_model_mapping_failed",
|
||||
}
|
||||
}
|
||||
if modelMap[imageModel] != "" {
|
||||
imageModel = modelMap[imageModel]
|
||||
isModelMapped = true
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
|
||||
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
|
||||
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
log.Printf("fullRequestURL: %s", fullRequestURL)
|
||||
|
||||
var requestBody io.Reader
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(midjRequest)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "marshal_text_request_failed",
|
||||
}
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
mjAction := "mj_" + strings.ToLower(midjRequest.Action)
|
||||
modelPrice := common.GetModelPrice(mjAction, true)
|
||||
// 如果没有配置价格,则使用默认价格
|
||||
if modelPrice == -1 {
|
||||
defaultPrice, ok := DefaultModelPrice[mjAction]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
modelPrice = defaultPrice
|
||||
}
|
||||
}
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
ratio := modelPrice * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: err.Error(),
|
||||
}
|
||||
}
|
||||
quota := int(ratio * common.QuotaPerUnit)
|
||||
|
||||
if consumeQuota && userQuota-quota < 0 {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "quota_not_enough",
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "create_request_failed",
|
||||
}
|
||||
}
|
||||
//req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey"))
|
||||
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
//mjToken := ""
|
||||
//if c.Request.Header.Get("ApiKey") != "" {
|
||||
// mjToken = strings.Split(c.Request.Header.Get("ApiKey"), " ")[1]
|
||||
//}
|
||||
//req.Header.Set("ApiKey", "Bearer midjourney-proxy")
|
||||
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
|
||||
// print request header
|
||||
log.Printf("request header: %s", req.Header)
|
||||
log.Printf("request body: %s", midjRequest.Prompt)
|
||||
|
||||
resp, err := controller.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "do_request_failed",
|
||||
}
|
||||
}
|
||||
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "close_request_body_failed",
|
||||
}
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "close_request_body_failed",
|
||||
}
|
||||
}
|
||||
var midjResponse MidjourneyResponse
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
if consumeQuota {
|
||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
}
|
||||
}(c.Request.Context())
|
||||
|
||||
//if consumeQuota {
|
||||
//
|
||||
//}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "read_response_body_failed",
|
||||
}
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "close_response_body_failed",
|
||||
}
|
||||
}
|
||||
|
||||
err = json.Unmarshal(responseBody, &midjResponse)
|
||||
log.Printf("responseBody: %s", string(responseBody))
|
||||
log.Printf("midjResponse: %v", midjResponse)
|
||||
if resp.StatusCode != 200 {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_response_body_failed",
|
||||
}
|
||||
}
|
||||
|
||||
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
|
||||
//1-提交成功
|
||||
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
|
||||
// 22-排队中 {"code":22,"description":"排队中,前面还有1个任务","result":"0741798445574458","properties":{"numberOfQueues":1,"discordInstanceId":"1118138338562560102"}}
|
||||
// 23-队列已满,请稍后再试 {"code":23,"description":"队列已满,请稍后尝试","result":"14001929738841620","properties":{"discordInstanceId":"1118138338562560102"}}
|
||||
// 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
|
||||
// other: 提交错误,description为错误描述
|
||||
midjourneyTask := &model.Midjourney{
|
||||
UserId: userId,
|
||||
Code: midjResponse.Code,
|
||||
Action: midjRequest.Action,
|
||||
MjId: midjResponse.Result,
|
||||
Prompt: midjRequest.Prompt,
|
||||
PromptEn: "",
|
||||
Description: midjResponse.Description,
|
||||
State: "",
|
||||
SubmitTime: time.Now().UnixNano() / int64(time.Millisecond),
|
||||
StartTime: 0,
|
||||
FinishTime: 0,
|
||||
ImageUrl: "",
|
||||
Status: "",
|
||||
Progress: "0%",
|
||||
FailReason: "",
|
||||
ChannelId: c.GetInt("channel_id"),
|
||||
Quota: quota,
|
||||
}
|
||||
|
||||
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
|
||||
//非1-提交成功,21-任务已存在和22-排队中,则记录错误原因
|
||||
midjourneyTask.FailReason = midjResponse.Description
|
||||
consumeQuota = false
|
||||
}
|
||||
|
||||
if midjResponse.Code == 21 { //21-任务已存在(处理中或者有结果了)
|
||||
// 将 properties 转换为一个 map
|
||||
properties, ok := midjResponse.Properties.(map[string]interface{})
|
||||
if ok {
|
||||
imageUrl, ok1 := properties["imageUrl"].(string)
|
||||
status, ok2 := properties["status"].(string)
|
||||
if ok1 && ok2 {
|
||||
midjourneyTask.ImageUrl = imageUrl
|
||||
midjourneyTask.Status = status
|
||||
if status == "SUCCESS" {
|
||||
midjourneyTask.Progress = "100%"
|
||||
midjourneyTask.StartTime = time.Now().UnixNano() / int64(time.Millisecond)
|
||||
midjourneyTask.FinishTime = time.Now().UnixNano() / int64(time.Millisecond)
|
||||
midjResponse.Code = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
//修改返回值
|
||||
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
|
||||
responseBody = []byte(newBody)
|
||||
}
|
||||
|
||||
err = midjourneyTask.Insert()
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "insert_midjourney_task_failed",
|
||||
}
|
||||
}
|
||||
|
||||
if midjResponse.Code == 22 { //22-排队中,说明任务已存在
|
||||
//修改返回值
|
||||
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
|
||||
responseBody = []byte(newBody)
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "copy_response_body_failed",
|
||||
}
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "close_response_body_failed",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type taskChangeParams struct {
|
||||
ID string
|
||||
Action string
|
||||
Index int
|
||||
}
|
||||
|
||||
func convertSimpleChangeParams(content string) *taskChangeParams {
|
||||
split := strings.Split(content, " ")
|
||||
if len(split) != 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
action := strings.ToLower(split[1])
|
||||
changeParams := &taskChangeParams{}
|
||||
changeParams.ID = split[0]
|
||||
|
||||
if action[0] == 'u' {
|
||||
changeParams.Action = "UPSCALE"
|
||||
} else if action[0] == 'v' {
|
||||
changeParams.Action = "VARIATION"
|
||||
} else if action == "r" {
|
||||
changeParams.Action = "REROLL"
|
||||
return changeParams
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
index, err := strconv.Atoi(action[1:2])
|
||||
if err != nil || index < 1 || index > 4 {
|
||||
return nil
|
||||
}
|
||||
changeParams.Index = index
|
||||
return changeParams
|
||||
}
|
||||
277
relay/relay-text.go
Normal file
277
relay/relay-text.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaychannel "one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
|
||||
textRequest := &dto.GeneralOpenAIRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, textRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeModerations && textRequest.Model == "" {
|
||||
textRequest.Model = "text-moderation-latest"
|
||||
}
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" {
|
||||
textRequest.Model = c.Param("model")
|
||||
}
|
||||
|
||||
if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
|
||||
return nil, errors.New("max_tokens is invalid")
|
||||
}
|
||||
if textRequest.Model == "" {
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
switch relayInfo.RelayMode {
|
||||
case relayconstant.RelayModeCompletions:
|
||||
if textRequest.Prompt == "" {
|
||||
return nil, errors.New("field prompt is required")
|
||||
}
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
|
||||
return nil, errors.New("field messages is required")
|
||||
}
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
case relayconstant.RelayModeModerations:
|
||||
if textRequest.Input == "" {
|
||||
return nil, errors.New("field input is required")
|
||||
}
|
||||
case relayconstant.RelayModeEdits:
|
||||
if textRequest.Instruction == "" {
|
||||
return nil, errors.New("field instruction is required")
|
||||
}
|
||||
}
|
||||
relayInfo.IsStream = textRequest.Stream
|
||||
return textRequest, nil
|
||||
}
|
||||
|
||||
func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
|
||||
// get & validate textRequest 获取并验证文本请求
|
||||
textRequest, err := getAndValidateTextRequest(c, relayInfo)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
isModelMapped := false
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[textRequest.Model] != "" {
|
||||
textRequest.Model = modelMap[textRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
}
|
||||
modelPrice := common.GetModelPrice(textRequest.Model, false)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
var modelRatio float64
|
||||
promptTokens, err := getPromptTokens(textRequest, relayInfo)
|
||||
|
||||
// count messages token error 计算promptTokens错误
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if modelPrice == -1 {
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if textRequest.MaxTokens != 0 {
|
||||
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
|
||||
}
|
||||
modelRatio = common.GetModelRatio(textRequest.Model)
|
||||
ratio = modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||
if err != nil {
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
adaptor := relaychannel.GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapper(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.Init(relayInfo, *textRequest)
|
||||
var requestBody io.Reader
|
||||
if relayInfo.ApiType == relayconstant.APITypeOpenAI {
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(textRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
|
||||
var promptTokens int
|
||||
var err error
|
||||
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model)
|
||||
case relayconstant.RelayModeCompletions:
|
||||
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model), nil
|
||||
case relayconstant.RelayModeModerations:
|
||||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil
|
||||
default:
|
||||
err = errors.New("unknown relay mode")
|
||||
promptTokens = 0
|
||||
}
|
||||
info.PromptTokens = promptTokens
|
||||
return promptTokens, err
|
||||
}
|
||||
|
||||
// 预扣费并返回用户剩余配额
|
||||
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *dto.OpenAIErrorWithStatusCode) {
|
||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
return 0, service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota < 0 || userQuota-preConsumedQuota < 0 {
|
||||
return 0, service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// 用户额度充足,判断令牌额度是否充足
|
||||
if !relayInfo.TokenUnlimited {
|
||||
// 非无限令牌,判断令牌额度是否充足
|
||||
tokenQuota := c.GetInt("token_quota")
|
||||
if tokenQuota > 100*preConsumedQuota {
|
||||
// 令牌额度充足,信任令牌
|
||||
preConsumedQuota = 0
|
||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
|
||||
}
|
||||
} else {
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
preConsumedQuota = 0
|
||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
|
||||
}
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
return userQuota, nil
|
||||
}
|
||||
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64) {
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
promptTokens := usage.PromptTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
|
||||
quota := 0
|
||||
if modelPrice == -1 {
|
||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||
quota = promptTokens + int(float64(completionTokens)*completionRatio)
|
||||
quota = int(float64(quota) * ratio)
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
} else {
|
||||
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
}
|
||||
totalTokens := promptTokens + completionTokens
|
||||
var logContent string
|
||||
if modelPrice == -1 {
|
||||
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
} else {
|
||||
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
|
||||
}
|
||||
|
||||
// record all the consume log even if quota is 0
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
logContent += fmt.Sprintf("(可能是上游超时)")
|
||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota))
|
||||
} else {
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
}
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
logModel := textRequest.Model
|
||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
|
||||
}
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream)
|
||||
|
||||
//if quota != 0 {
|
||||
//
|
||||
//}
|
||||
}
|
||||
Reference in New Issue
Block a user