mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-09 02:03:42 +08:00
refactor: refactor relay part (#957)
* refactor: refactor relay part * refactor: refactor config part
This commit is contained in:
22
relay/channel/aiproxy/adaptor.go
Normal file
22
relay/channel/aiproxy/adaptor.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package aiproxy
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/relay/channel/openai"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Auth(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/helper"
|
||||
"one-api/common/logger"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/constant"
|
||||
"strconv"
|
||||
@@ -50,9 +52,9 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: common.GetUUID(),
|
||||
Id: helper.GetUUID(),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
}
|
||||
return &fullTextResponse
|
||||
@@ -63,9 +65,9 @@ func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletion
|
||||
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
return &openai.ChatCompletionsStreamResponse{
|
||||
Id: common.GetUUID(),
|
||||
Id: helper.GetUUID(),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
@@ -75,9 +77,9 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = response.Content
|
||||
return &openai.ChatCompletionsStreamResponse{
|
||||
Id: common.GetUUID(),
|
||||
Id: helper.GetUUID(),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: response.Model,
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
@@ -122,7 +124,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
var AIProxyLibraryResponse LibraryStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if len(AIProxyLibraryResponse.Documents) != 0 {
|
||||
@@ -131,7 +133,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
@@ -140,7 +142,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
response := documentsAIProxyLibrary(documents)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
|
||||
22
relay/channel/ali/adaptor.go
Normal file
22
relay/channel/ali/adaptor.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/relay/channel/openai"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Auth(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/helper"
|
||||
"one-api/common/logger"
|
||||
"one-api/relay/channel/openai"
|
||||
"strings"
|
||||
)
|
||||
@@ -118,7 +120,7 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: response.RequestId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
Usage: openai.Usage{
|
||||
PromptTokens: response.Usage.InputTokens,
|
||||
@@ -139,7 +141,7 @@ func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletions
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: aliResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "qwen",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
@@ -185,7 +187,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
var aliResponse ChatResponse
|
||||
err := json.Unmarshal([]byte(data), &aliResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if aliResponse.Usage.OutputTokens != 0 {
|
||||
@@ -198,7 +200,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
//lastResponseText = aliResponse.Output.Text
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
|
||||
22
relay/channel/anthropic/adaptor.go
Normal file
22
relay/channel/anthropic/adaptor.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/relay/channel/openai"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Auth(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/helper"
|
||||
"one-api/common/logger"
|
||||
"one-api/relay/channel/openai"
|
||||
"strings"
|
||||
)
|
||||
@@ -78,9 +80,9 @@ func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
|
||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
}
|
||||
return &fullTextResponse
|
||||
@@ -88,8 +90,8 @@ func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createdTime := common.GetTimestamp()
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
|
||||
createdTime := helper.GetTimestamp()
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
@@ -125,7 +127,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
var claudeResponse Response
|
||||
err := json.Unmarshal([]byte(data), &claudeResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
responseText += claudeResponse.Completion
|
||||
@@ -134,7 +136,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
response.Created = createdTime
|
||||
jsonStr, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||
|
||||
22
relay/channel/baidu/adaptor.go
Normal file
22
relay/channel/baidu/adaptor.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/relay/channel/openai"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Auth(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/logger"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/constant"
|
||||
"one-api/relay/util"
|
||||
@@ -19,49 +20,49 @@ import (
|
||||
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
||||
|
||||
type BaiduTokenResponse struct {
|
||||
type TokenResponse struct {
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
|
||||
type BaiduMessage struct {
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type BaiduChatRequest struct {
|
||||
Messages []BaiduMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
type ChatRequest struct {
|
||||
Messages []Message `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
type BaiduError struct {
|
||||
type Error struct {
|
||||
ErrorCode int `json:"error_code"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
}
|
||||
|
||||
var baiduTokenStore sync.Map
|
||||
|
||||
func ConvertRequest(request openai.GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
messages := make([]BaiduMessage, 0, len(request.Messages))
|
||||
func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, BaiduMessage{
|
||||
messages = append(messages, Message{
|
||||
Role: "user",
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, BaiduMessage{
|
||||
messages = append(messages, Message{
|
||||
Role: "assistant",
|
||||
Content: "Okay",
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, BaiduMessage{
|
||||
messages = append(messages, Message{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return &BaiduChatRequest{
|
||||
return &ChatRequest{
|
||||
Messages: messages,
|
||||
Stream: request.Stream,
|
||||
}
|
||||
@@ -160,7 +161,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
var baiduResponse ChatStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if baiduResponse.Usage.TotalTokens != 0 {
|
||||
@@ -171,7 +172,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
|
||||
@@ -13,7 +13,7 @@ type ChatResponse struct {
|
||||
IsTruncated bool `json:"is_truncated"`
|
||||
NeedClearHistory bool `json:"need_clear_history"`
|
||||
Usage openai.Usage `json:"usage"`
|
||||
BaiduError
|
||||
Error
|
||||
}
|
||||
|
||||
type ChatStreamResponse struct {
|
||||
@@ -38,7 +38,7 @@ type EmbeddingResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []EmbeddingData `json:"data"`
|
||||
Usage openai.Usage `json:"usage"`
|
||||
BaiduError
|
||||
Error
|
||||
}
|
||||
|
||||
type AccessToken struct {
|
||||
|
||||
22
relay/channel/google/adaptor.go
Normal file
22
relay/channel/google/adaptor.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/relay/channel/openai"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Auth(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
@@ -7,7 +7,10 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/config"
|
||||
"one-api/common/helper"
|
||||
"one-api/common/image"
|
||||
"one-api/common/logger"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/constant"
|
||||
"strings"
|
||||
@@ -28,19 +31,19 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe
|
||||
SafetySettings: []GeminiChatSafetySettings{
|
||||
{
|
||||
Category: "HARM_CATEGORY_HARASSMENT",
|
||||
Threshold: common.GeminiSafetySetting,
|
||||
Threshold: config.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_HATE_SPEECH",
|
||||
Threshold: common.GeminiSafetySetting,
|
||||
Threshold: config.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
Threshold: common.GeminiSafetySetting,
|
||||
Threshold: config.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
Threshold: common.GeminiSafetySetting,
|
||||
Threshold: config.GeminiSafetySetting,
|
||||
},
|
||||
},
|
||||
GenerationConfig: GeminiChatGenerationConfig{
|
||||
@@ -151,9 +154,9 @@ type GeminiChatPromptFeedback struct {
|
||||
|
||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
for i, candidate := range response.Candidates {
|
||||
@@ -229,15 +232,15 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = dummy.Content
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "gemini-pro",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/helper"
|
||||
"one-api/common/logger"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
@@ -71,27 +73,27 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompl
|
||||
|
||||
func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createdTime := common.GetTimestamp()
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
|
||||
createdTime := helper.GetTimestamp()
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.SysError("error reading stream response: " + err.Error())
|
||||
logger.SysError("error reading stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
common.SysError("error closing stream response: " + err.Error())
|
||||
logger.SysError("error closing stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
var palmResponse PaLMChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
@@ -103,7 +105,7 @@ func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSt
|
||||
}
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
|
||||
15
relay/channel/interface.go
Normal file
15
relay/channel/interface.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/relay/channel/openai"
|
||||
)
|
||||
|
||||
type Adaptor interface {
|
||||
GetRequestURL() string
|
||||
Auth(c *gin.Context) error
|
||||
ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error)
|
||||
DoRequest(request *openai.GeneralOpenAIRequest) error
|
||||
DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error)
|
||||
}
|
||||
21
relay/channel/openai/adaptor.go
Normal file
21
relay/channel/openai/adaptor.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Auth(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(request *GeneralOpenAIRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*ErrorWithStatusCode, *Usage, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/logger"
|
||||
"one-api/relay/constant"
|
||||
"strings"
|
||||
)
|
||||
@@ -46,7 +47,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWi
|
||||
var streamResponse ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue // just ignore the error
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
@@ -56,7 +57,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWi
|
||||
var streamResponse CompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
|
||||
@@ -207,6 +207,11 @@ type Usage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type UsageOrResponseText struct {
|
||||
*Usage
|
||||
ResponseText string
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"math"
|
||||
"one-api/common"
|
||||
"one-api/common/config"
|
||||
"one-api/common/image"
|
||||
"one-api/common/logger"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -15,15 +17,15 @@ var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
||||
var defaultTokenEncoder *tiktoken.Tiktoken
|
||||
|
||||
func InitTokenEncoders() {
|
||||
common.SysLog("initializing token encoders")
|
||||
logger.SysLog("initializing token encoders")
|
||||
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
|
||||
logger.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
|
||||
}
|
||||
defaultTokenEncoder = gpt35TokenEncoder
|
||||
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
||||
logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
||||
}
|
||||
for model, _ := range common.ModelRatio {
|
||||
if strings.HasPrefix(model, "gpt-3.5") {
|
||||
@@ -34,7 +36,7 @@ func InitTokenEncoders() {
|
||||
tokenEncoderMap[model] = nil
|
||||
}
|
||||
}
|
||||
common.SysLog("token encoders initialized")
|
||||
logger.SysLog("token encoders initialized")
|
||||
}
|
||||
|
||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
@@ -45,7 +47,7 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
if ok {
|
||||
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
||||
logger.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
||||
tokenEncoder = defaultTokenEncoder
|
||||
}
|
||||
tokenEncoderMap[model] = tokenEncoder
|
||||
@@ -55,7 +57,7 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
}
|
||||
|
||||
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||
if common.ApproximateTokenEnabled {
|
||||
if config.ApproximateTokenEnabled {
|
||||
return int(float64(len(text)) * 0.38)
|
||||
}
|
||||
return len(tokenEncoder.Encode(text, nil, nil))
|
||||
@@ -99,7 +101,7 @@ func CountTokenMessages(messages []Message, model string) int {
|
||||
}
|
||||
imageTokens, err := countImageTokens(url, detail)
|
||||
if err != nil {
|
||||
common.SysError("error counting image tokens: " + err.Error())
|
||||
logger.SysError("error counting image tokens: " + err.Error())
|
||||
} else {
|
||||
tokenNum += imageTokens
|
||||
}
|
||||
|
||||
22
relay/channel/tencent/adaptor.go
Normal file
22
relay/channel/tencent/adaptor.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package tencent
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/relay/channel/openai"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Auth(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/helper"
|
||||
"one-api/common/logger"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/constant"
|
||||
"sort"
|
||||
@@ -46,9 +48,9 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
|
||||
stream = 1
|
||||
}
|
||||
return &ChatRequest{
|
||||
Timestamp: common.GetTimestamp(),
|
||||
Expired: common.GetTimestamp() + 24*60*60,
|
||||
QueryID: common.GetUUID(),
|
||||
Timestamp: helper.GetTimestamp(),
|
||||
Expired: helper.GetTimestamp() + 24*60*60,
|
||||
QueryID: helper.GetUUID(),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
Stream: stream,
|
||||
@@ -59,7 +61,7 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
|
||||
func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Created: helper.GetTimestamp(),
|
||||
Usage: response.Usage,
|
||||
}
|
||||
if len(response.Choices) > 0 {
|
||||
@@ -79,7 +81,7 @@ func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "tencent-hunyuan",
|
||||
}
|
||||
if len(TencentResponse.Choices) > 0 {
|
||||
@@ -131,7 +133,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
var TencentResponse ChatResponse
|
||||
err := json.Unmarshal([]byte(data), &TencentResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
response := streamResponseTencent2OpenAI(&TencentResponse)
|
||||
@@ -140,7 +142,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
|
||||
22
relay/channel/xunfei/adaptor.go
Normal file
22
relay/channel/xunfei/adaptor.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package xunfei
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/relay/channel/openai"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Auth(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/common/helper"
|
||||
"one-api/common/logger"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/constant"
|
||||
"strings"
|
||||
@@ -68,7 +70,7 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
Usage: response.Payload.Usage.Text,
|
||||
}
|
||||
@@ -90,7 +92,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl
|
||||
}
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "SparkDesk",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
@@ -140,7 +142,7 @@ func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appI
|
||||
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
@@ -215,20 +217,20 @@ func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl,
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
common.SysError("error reading stream response: " + err.Error())
|
||||
logger.SysError("error reading stream response: " + err.Error())
|
||||
break
|
||||
}
|
||||
var response ChatResponse
|
||||
err = json.Unmarshal(msg, &response)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
break
|
||||
}
|
||||
dataChan <- response
|
||||
if response.Payload.Choices.Status == 2 {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
common.SysError("error closing websocket connection: " + err.Error())
|
||||
logger.SysError("error closing websocket connection: " + err.Error())
|
||||
}
|
||||
break
|
||||
}
|
||||
@@ -247,7 +249,7 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string,
|
||||
}
|
||||
if apiVersion == "" {
|
||||
apiVersion = "v1.1"
|
||||
common.SysLog("api_version not found, use default: " + apiVersion)
|
||||
logger.SysLog("api_version not found, use default: " + apiVersion)
|
||||
}
|
||||
domain := "general"
|
||||
if apiVersion != "v1.1" {
|
||||
|
||||
22
relay/channel/zhipu/adaptor.go
Normal file
22
relay/channel/zhipu/adaptor.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/relay/channel/openai"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Auth(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/helper"
|
||||
"one-api/common/logger"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/constant"
|
||||
"strings"
|
||||
@@ -34,7 +36,7 @@ func GetToken(apikey string) string {
|
||||
|
||||
split := strings.Split(apikey, ".")
|
||||
if len(split) != 2 {
|
||||
common.SysError("invalid zhipu key: " + apikey)
|
||||
logger.SysError("invalid zhipu key: " + apikey)
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -101,7 +103,7 @@ func responseZhipu2OpenAI(response *Response) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: response.Data.TaskId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: make([]openai.TextResponseChoice, 0, len(response.Data.Choices)),
|
||||
Usage: response.Data.Usage,
|
||||
}
|
||||
@@ -127,7 +129,7 @@ func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStr
|
||||
choice.Delta.Content = zhipuResponse
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "chatglm",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
@@ -141,7 +143,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: zhipuResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "chatglm",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
@@ -193,7 +195,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
response := streamResponseZhipu2OpenAI(data)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
@@ -202,13 +204,13 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
var zhipuResponse StreamMetaResponse
|
||||
err := json.Unmarshal([]byte(data), &zhipuResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
usage = zhipuUsage
|
||||
|
||||
69
relay/constant/api_type.go
Normal file
69
relay/constant/api_type.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package constant
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
const (
|
||||
APITypeOpenAI = iota
|
||||
APITypeClaude
|
||||
APITypePaLM
|
||||
APITypeBaidu
|
||||
APITypeZhipu
|
||||
APITypeAli
|
||||
APITypeXunfei
|
||||
APITypeAIProxyLibrary
|
||||
APITypeTencent
|
||||
APITypeGemini
|
||||
)
|
||||
|
||||
func ChannelType2APIType(channelType int) int {
|
||||
apiType := APITypeOpenAI
|
||||
switch channelType {
|
||||
case common.ChannelTypeAnthropic:
|
||||
apiType = APITypeClaude
|
||||
case common.ChannelTypeBaidu:
|
||||
apiType = APITypeBaidu
|
||||
case common.ChannelTypePaLM:
|
||||
apiType = APITypePaLM
|
||||
case common.ChannelTypeZhipu:
|
||||
apiType = APITypeZhipu
|
||||
case common.ChannelTypeAli:
|
||||
apiType = APITypeAli
|
||||
case common.ChannelTypeXunfei:
|
||||
apiType = APITypeXunfei
|
||||
case common.ChannelTypeAIProxyLibrary:
|
||||
apiType = APITypeAIProxyLibrary
|
||||
case common.ChannelTypeTencent:
|
||||
apiType = APITypeTencent
|
||||
case common.ChannelTypeGemini:
|
||||
apiType = APITypeGemini
|
||||
}
|
||||
return apiType
|
||||
}
|
||||
|
||||
//func GetAdaptor(apiType int) channel.Adaptor {
|
||||
// switch apiType {
|
||||
// case APITypeOpenAI:
|
||||
// return &openai.Adaptor{}
|
||||
// case APITypeClaude:
|
||||
// return &anthropic.Adaptor{}
|
||||
// case APITypePaLM:
|
||||
// return &google.Adaptor{}
|
||||
// case APITypeZhipu:
|
||||
// return &baidu.Adaptor{}
|
||||
// case APITypeBaidu:
|
||||
// return &baidu.Adaptor{}
|
||||
// case APITypeAli:
|
||||
// return &ali.Adaptor{}
|
||||
// case APITypeXunfei:
|
||||
// return &xunfei.Adaptor{}
|
||||
// case APITypeAIProxyLibrary:
|
||||
// return &aiproxy.Adaptor{}
|
||||
// case APITypeTencent:
|
||||
// return &tencent.Adaptor{}
|
||||
// case APITypeGemini:
|
||||
// return &google.Adaptor{}
|
||||
// }
|
||||
// return nil
|
||||
//}
|
||||
3
relay/constant/common.go
Normal file
3
relay/constant/common.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package constant
|
||||
|
||||
var StopFinishReason = "stop"
|
||||
@@ -1,16 +0,0 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
RelayModeUnknown = iota
|
||||
RelayModeChatCompletions
|
||||
RelayModeCompletions
|
||||
RelayModeEmbeddings
|
||||
RelayModeModerations
|
||||
RelayModeImagesGenerations
|
||||
RelayModeEdits
|
||||
RelayModeAudioSpeech
|
||||
RelayModeAudioTranscription
|
||||
RelayModeAudioTranslation
|
||||
)
|
||||
|
||||
var StopFinishReason = "stop"
|
||||
42
relay/constant/relay_mode.go
Normal file
42
relay/constant/relay_mode.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package constant
|
||||
|
||||
import "strings"
|
||||
|
||||
const (
|
||||
RelayModeUnknown = iota
|
||||
RelayModeChatCompletions
|
||||
RelayModeCompletions
|
||||
RelayModeEmbeddings
|
||||
RelayModeModerations
|
||||
RelayModeImagesGenerations
|
||||
RelayModeEdits
|
||||
RelayModeAudioSpeech
|
||||
RelayModeAudioTranscription
|
||||
RelayModeAudioTranslation
|
||||
)
|
||||
|
||||
func Path2RelayMode(path string) int {
|
||||
relayMode := RelayModeUnknown
|
||||
if strings.HasPrefix(path, "/v1/chat/completions") {
|
||||
relayMode = RelayModeChatCompletions
|
||||
} else if strings.HasPrefix(path, "/v1/completions") {
|
||||
relayMode = RelayModeCompletions
|
||||
} else if strings.HasPrefix(path, "/v1/embeddings") {
|
||||
relayMode = RelayModeEmbeddings
|
||||
} else if strings.HasSuffix(path, "embeddings") {
|
||||
relayMode = RelayModeEmbeddings
|
||||
} else if strings.HasPrefix(path, "/v1/moderations") {
|
||||
relayMode = RelayModeModerations
|
||||
} else if strings.HasPrefix(path, "/v1/images/generations") {
|
||||
relayMode = RelayModeImagesGenerations
|
||||
} else if strings.HasPrefix(path, "/v1/edits") {
|
||||
relayMode = RelayModeEdits
|
||||
} else if strings.HasPrefix(path, "/v1/audio/speech") {
|
||||
relayMode = RelayModeAudioSpeech
|
||||
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
|
||||
relayMode = RelayModeAudioTranscription
|
||||
} else if strings.HasPrefix(path, "/v1/audio/translations") {
|
||||
relayMode = RelayModeAudioTranslation
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/config"
|
||||
"one-api/common/logger"
|
||||
"one-api/model"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/constant"
|
||||
@@ -53,7 +55,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
|
||||
preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
|
||||
quota = preConsumedQuota
|
||||
default:
|
||||
preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio)
|
||||
preConsumedQuota = int(float64(config.PreConsumedQuota) * ratio)
|
||||
}
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
if err != nil {
|
||||
@@ -102,7 +104,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
|
||||
fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
|
||||
if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
||||
apiVersion := util.GetAPIVersion(c)
|
||||
apiVersion := util.GetAzureAPIVersion(c)
|
||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
|
||||
}
|
||||
|
||||
@@ -191,7 +193,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
|
||||
// negative means add quota back for token & user
|
||||
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
|
||||
logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
|
||||
}
|
||||
}()
|
||||
}(c.Request.Context())
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/logger"
|
||||
"one-api/model"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/util"
|
||||
@@ -112,7 +113,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
|
||||
fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
|
||||
if channelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
||||
apiVersion := util.GetAPIVersion(c)
|
||||
apiVersion := util.GetAzureAPIVersion(c)
|
||||
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
|
||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion)
|
||||
}
|
||||
@@ -175,11 +176,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
|
||||
}
|
||||
err := model.PostConsumeTokenQuota(tokenId, quota)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
logger.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
logger.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
|
||||
@@ -1,206 +1,47 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/config"
|
||||
"one-api/common/logger"
|
||||
"one-api/model"
|
||||
"one-api/relay/channel/aiproxy"
|
||||
"one-api/relay/channel/ali"
|
||||
"one-api/relay/channel/anthropic"
|
||||
"one-api/relay/channel/baidu"
|
||||
"one-api/relay/channel/google"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/channel/tencent"
|
||||
"one-api/relay/channel/xunfei"
|
||||
"one-api/relay/channel/zhipu"
|
||||
"one-api/relay/constant"
|
||||
"one-api/relay/util"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
APITypeOpenAI = iota
|
||||
APITypeClaude
|
||||
APITypePaLM
|
||||
APITypeBaidu
|
||||
APITypeZhipu
|
||||
APITypeAli
|
||||
APITypeXunfei
|
||||
APITypeAIProxyLibrary
|
||||
APITypeTencent
|
||||
APITypeGemini
|
||||
)
|
||||
|
||||
func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
ctx := c.Request.Context()
|
||||
meta := util.GetRelayMeta(c)
|
||||
var textRequest openai.GeneralOpenAIRequest
|
||||
err := common.UnmarshalBodyReusable(c, &textRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
|
||||
return openai.ErrorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest)
|
||||
}
|
||||
if relayMode == constant.RelayModeModerations && textRequest.Model == "" {
|
||||
textRequest.Model = "text-moderation-latest"
|
||||
}
|
||||
if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" {
|
||||
textRequest.Model = c.Param("model")
|
||||
}
|
||||
// request validation
|
||||
if textRequest.Model == "" {
|
||||
return openai.ErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
|
||||
err = util.ValidateTextRequest(&textRequest, relayMode)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
|
||||
}
|
||||
switch relayMode {
|
||||
case constant.RelayModeCompletions:
|
||||
if textRequest.Prompt == "" {
|
||||
return openai.ErrorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
case constant.RelayModeChatCompletions:
|
||||
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
|
||||
return openai.ErrorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
case constant.RelayModeEmbeddings:
|
||||
case constant.RelayModeModerations:
|
||||
if textRequest.Input == "" {
|
||||
return openai.ErrorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
case constant.RelayModeEdits:
|
||||
if textRequest.Instruction == "" {
|
||||
return openai.ErrorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
isModelMapped := false
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[textRequest.Model] != "" {
|
||||
textRequest.Model = modelMap[textRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
}
|
||||
apiType := APITypeOpenAI
|
||||
switch channelType {
|
||||
case common.ChannelTypeAnthropic:
|
||||
apiType = APITypeClaude
|
||||
case common.ChannelTypeBaidu:
|
||||
apiType = APITypeBaidu
|
||||
case common.ChannelTypePaLM:
|
||||
apiType = APITypePaLM
|
||||
case common.ChannelTypeZhipu:
|
||||
apiType = APITypeZhipu
|
||||
case common.ChannelTypeAli:
|
||||
apiType = APITypeAli
|
||||
case common.ChannelTypeXunfei:
|
||||
apiType = APITypeXunfei
|
||||
case common.ChannelTypeAIProxyLibrary:
|
||||
apiType = APITypeAIProxyLibrary
|
||||
case common.ChannelTypeTencent:
|
||||
apiType = APITypeTencent
|
||||
case common.ChannelTypeGemini:
|
||||
apiType = APITypeGemini
|
||||
}
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
|
||||
switch apiType {
|
||||
case APITypeOpenAI:
|
||||
if channelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||
apiVersion := util.GetAPIVersion(c)
|
||||
requestURL := strings.Split(requestURL, "?")[0]
|
||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
||||
baseURL = c.GetString("base_url")
|
||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||
model_ := textRequest.Model
|
||||
model_ = strings.Replace(model_, ".", "", -1)
|
||||
// https://github.com/songquanpeng/one-api/issues/67
|
||||
model_ = strings.TrimSuffix(model_, "-0301")
|
||||
model_ = strings.TrimSuffix(model_, "-0314")
|
||||
model_ = strings.TrimSuffix(model_, "-0613")
|
||||
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
fullRequestURL = util.GetFullRequestURL(baseURL, requestURL, channelType)
|
||||
}
|
||||
case APITypeClaude:
|
||||
fullRequestURL = "https://api.anthropic.com/v1/complete"
|
||||
if baseURL != "" {
|
||||
fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
|
||||
}
|
||||
case APITypeBaidu:
|
||||
switch textRequest.Model {
|
||||
case "ERNIE-Bot":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
||||
case "ERNIE-Bot-turbo":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||
case "ERNIE-Bot-4":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
||||
case "BLOOMZ-7B":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
||||
case "Embedding-V1":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
var err error
|
||||
if apiKey, err = baidu.GetAccessToken(apiKey); err != nil {
|
||||
return openai.ErrorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
|
||||
}
|
||||
fullRequestURL += "?access_token=" + apiKey
|
||||
case APITypePaLM:
|
||||
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
|
||||
if baseURL != "" {
|
||||
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
|
||||
}
|
||||
case APITypeGemini:
|
||||
requestBaseURL := "https://generativelanguage.googleapis.com"
|
||||
if baseURL != "" {
|
||||
requestBaseURL = baseURL
|
||||
}
|
||||
version := "v1"
|
||||
if c.GetString("api_version") != "" {
|
||||
version = c.GetString("api_version")
|
||||
}
|
||||
action := "generateContent"
|
||||
if textRequest.Stream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
|
||||
case APITypeZhipu:
|
||||
method := "invoke"
|
||||
if textRequest.Stream {
|
||||
method = "sse-invoke"
|
||||
}
|
||||
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
||||
case APITypeAli:
|
||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||||
if relayMode == constant.RelayModeEmbeddings {
|
||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||
}
|
||||
case APITypeTencent:
|
||||
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
|
||||
case APITypeAIProxyLibrary:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
|
||||
var isModelMapped bool
|
||||
textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
|
||||
apiType := constant.ChannelType2APIType(meta.ChannelType)
|
||||
fullRequestURL, err := GetRequestURL(c.Request.URL.String(), apiType, relayMode, meta, &textRequest)
|
||||
if err != nil {
|
||||
logger.Error(ctx, fmt.Sprintf("util.GetRequestURL failed: %s", err.Error()))
|
||||
return openai.ErrorWrapper(fmt.Errorf("util.GetRequestURL failed"), "get_request_url_failed", http.StatusInternalServerError)
|
||||
}
|
||||
var promptTokens int
|
||||
var completionTokens int
|
||||
@@ -212,22 +53,22 @@ func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
|
||||
case constant.RelayModeModerations:
|
||||
promptTokens = openai.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||
}
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
preConsumedTokens := config.PreConsumedQuota
|
||||
if textRequest.MaxTokens != 0 {
|
||||
preConsumedTokens = promptTokens + textRequest.MaxTokens
|
||||
}
|
||||
modelRatio := common.GetModelRatio(textRequest.Model)
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
groupRatio := common.GetGroupRatio(meta.Group)
|
||||
ratio := modelRatio * groupRatio
|
||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
userQuota, err := model.CacheGetUserQuota(meta.UserId)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||
err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
@@ -235,165 +76,28 @@ func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
preConsumedQuota = 0
|
||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
|
||||
logger.Info(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota))
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||
err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
var requestBody io.Reader
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(textRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
requestBody, err := GetRequestBody(c, textRequest, isModelMapped, apiType, relayMode)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "get_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
switch apiType {
|
||||
case APITypeClaude:
|
||||
claudeRequest := anthropic.ConvertRequest(textRequest)
|
||||
jsonStr, err := json.Marshal(claudeRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeBaidu:
|
||||
var jsonData []byte
|
||||
var err error
|
||||
switch relayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
baiduEmbeddingRequest := baidu.ConvertEmbeddingRequest(textRequest)
|
||||
jsonData, err = json.Marshal(baiduEmbeddingRequest)
|
||||
default:
|
||||
baiduRequest := baidu.ConvertRequest(textRequest)
|
||||
jsonData, err = json.Marshal(baiduRequest)
|
||||
}
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
case APITypePaLM:
|
||||
palmRequest := google.ConvertPaLMRequest(textRequest)
|
||||
jsonStr, err := json.Marshal(palmRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeGemini:
|
||||
geminiChatRequest := google.ConvertGeminiRequest(textRequest)
|
||||
jsonStr, err := json.Marshal(geminiChatRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeZhipu:
|
||||
zhipuRequest := zhipu.ConvertRequest(textRequest)
|
||||
jsonStr, err := json.Marshal(zhipuRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeAli:
|
||||
var jsonStr []byte
|
||||
var err error
|
||||
switch relayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
aliEmbeddingRequest := ali.ConvertEmbeddingRequest(textRequest)
|
||||
jsonStr, err = json.Marshal(aliEmbeddingRequest)
|
||||
default:
|
||||
aliRequest := ali.ConvertRequest(textRequest)
|
||||
jsonStr, err = json.Marshal(aliRequest)
|
||||
}
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeTencent:
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
appId, secretId, secretKey, err := tencent.ParseConfig(apiKey)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
|
||||
}
|
||||
tencentRequest := tencent.ConvertRequest(textRequest)
|
||||
tencentRequest.AppId = appId
|
||||
tencentRequest.SecretId = secretId
|
||||
jsonStr, err := json.Marshal(tencentRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
sign := tencent.GetSign(*tencentRequest, secretKey)
|
||||
c.Request.Header.Set("Authorization", sign)
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeAIProxyLibrary:
|
||||
aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest)
|
||||
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
|
||||
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
var resp *http.Response
|
||||
isStream := textRequest.Stream
|
||||
|
||||
if apiType != APITypeXunfei { // cause xunfei use websocket
|
||||
if apiType != constant.APITypeXunfei { // cause xunfei use websocket
|
||||
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
switch apiType {
|
||||
case APITypeOpenAI:
|
||||
if channelType == common.ChannelTypeAzure {
|
||||
req.Header.Set("api-key", apiKey)
|
||||
} else {
|
||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||
if channelType == common.ChannelTypeOpenRouter {
|
||||
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||||
req.Header.Set("X-Title", "One API")
|
||||
}
|
||||
}
|
||||
case APITypeClaude:
|
||||
req.Header.Set("x-api-key", apiKey)
|
||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
anthropicVersion = "2023-06-01"
|
||||
}
|
||||
req.Header.Set("anthropic-version", anthropicVersion)
|
||||
case APITypeZhipu:
|
||||
token := zhipu.GetToken(apiKey)
|
||||
req.Header.Set("Authorization", token)
|
||||
case APITypeAli:
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
if textRequest.Stream {
|
||||
req.Header.Set("X-DashScope-SSE", "enable")
|
||||
}
|
||||
if c.GetString("plugin") != "" {
|
||||
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||
}
|
||||
case APITypeTencent:
|
||||
req.Header.Set("Authorization", apiKey)
|
||||
case APITypePaLM:
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
case APITypeGemini:
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
default:
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if isStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
|
||||
SetupRequestHeaders(c, req, apiType, meta, isStream)
|
||||
resp, err = util.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
@@ -409,29 +113,31 @@ func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
|
||||
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if preConsumedQuota != 0 {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||
return util.RelayErrorHandler(resp)
|
||||
}
|
||||
}
|
||||
|
||||
var textResponse openai.SlimTextResponse
|
||||
tokenName := c.GetString("token_name")
|
||||
var respErr *openai.ErrorWithStatusCode
|
||||
var usage *openai.Usage
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
// c.Writer.Flush()
|
||||
// Why we use defer here? Because if error happened, we will have to return the pre-consumed quota.
|
||||
if respErr != nil {
|
||||
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
|
||||
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||
return
|
||||
}
|
||||
if usage == nil {
|
||||
logger.Error(ctx, "usage is nil, which is unexpected")
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
quota := 0
|
||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||
promptTokens = textResponse.Usage.PromptTokens
|
||||
completionTokens = textResponse.Usage.CompletionTokens
|
||||
promptTokens = usage.PromptTokens
|
||||
completionTokens = usage.CompletionTokens
|
||||
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
@@ -443,239 +149,25 @@ func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
|
||||
quota = 0
|
||||
}
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||
err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
logger.Error(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
err = model.CacheUpdateUserQuota(meta.UserId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
logger.Error(ctx, "error update user quota cache: "+err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
|
||||
}
|
||||
|
||||
}()
|
||||
}(c.Request.Context())
|
||||
switch apiType {
|
||||
case APITypeOpenAI:
|
||||
if isStream {
|
||||
err, responseText := openai.StreamHandler(c, resp, relayMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := openai.Handler(c, resp, promptTokens, textRequest.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeClaude:
|
||||
if isStream {
|
||||
err, responseText := anthropic.StreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := anthropic.Handler(c, resp, promptTokens, textRequest.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeBaidu:
|
||||
if isStream {
|
||||
err, usage := baidu.StreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
var err *openai.ErrorWithStatusCode
|
||||
var usage *openai.Usage
|
||||
switch relayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = baidu.EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = baidu.Handler(c, resp)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypePaLM:
|
||||
if textRequest.Stream { // PaLM2 API does not support stream
|
||||
err, responseText := google.PaLMStreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := google.PaLMHandler(c, resp, promptTokens, textRequest.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeGemini:
|
||||
if textRequest.Stream {
|
||||
err, responseText := google.StreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := google.GeminiHandler(c, resp, promptTokens, textRequest.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeZhipu:
|
||||
if isStream {
|
||||
err, usage := zhipu.StreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
// zhipu's API does not return prompt tokens & completion tokens
|
||||
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
|
||||
return nil
|
||||
} else {
|
||||
err, usage := zhipu.Handler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
// zhipu's API does not return prompt tokens & completion tokens
|
||||
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
|
||||
return nil
|
||||
}
|
||||
case APITypeAli:
|
||||
if isStream {
|
||||
err, usage := ali.StreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
var err *openai.ErrorWithStatusCode
|
||||
var usage *openai.Usage
|
||||
switch relayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = ali.EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = ali.Handler(c, resp)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeXunfei:
|
||||
auth := c.Request.Header.Get("Authorization")
|
||||
auth = strings.TrimPrefix(auth, "Bearer ")
|
||||
splits := strings.Split(auth, "|")
|
||||
if len(splits) != 3 {
|
||||
return openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||
}
|
||||
var err *openai.ErrorWithStatusCode
|
||||
var usage *openai.Usage
|
||||
if isStream {
|
||||
err, usage = xunfei.StreamHandler(c, textRequest, splits[0], splits[1], splits[2])
|
||||
} else {
|
||||
err, usage = xunfei.Handler(c, textRequest, splits[0], splits[1], splits[2])
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
case APITypeAIProxyLibrary:
|
||||
if isStream {
|
||||
err, usage := aiproxy.StreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
err, usage := aiproxy.Handler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeTencent:
|
||||
if isStream {
|
||||
err, responseText := tencent.StreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := tencent.Handler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
return openai.ErrorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
||||
}(ctx)
|
||||
usage, respErr = DoResponse(c, &textRequest, resp, relayMode, apiType, isStream, promptTokens)
|
||||
if respErr != nil {
|
||||
return respErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
337
relay/controller/util.go
Normal file
337
relay/controller/util.go
Normal file
@@ -0,0 +1,337 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/helper"
|
||||
"one-api/relay/channel/aiproxy"
|
||||
"one-api/relay/channel/ali"
|
||||
"one-api/relay/channel/anthropic"
|
||||
"one-api/relay/channel/baidu"
|
||||
"one-api/relay/channel/google"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/channel/tencent"
|
||||
"one-api/relay/channel/xunfei"
|
||||
"one-api/relay/channel/zhipu"
|
||||
"one-api/relay/constant"
|
||||
"one-api/relay/util"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetRequestURL(requestURL string, apiType int, relayMode int, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest) (string, error) {
|
||||
fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
|
||||
switch apiType {
|
||||
case constant.APITypeOpenAI:
|
||||
if meta.ChannelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||
requestURL := strings.Split(requestURL, "?")[0]
|
||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion)
|
||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||
model_ := textRequest.Model
|
||||
model_ = strings.Replace(model_, ".", "", -1)
|
||||
// https://github.com/songquanpeng/one-api/issues/67
|
||||
model_ = strings.TrimSuffix(model_, "-0301")
|
||||
model_ = strings.TrimSuffix(model_, "-0314")
|
||||
model_ = strings.TrimSuffix(model_, "-0613")
|
||||
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
fullRequestURL = util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
|
||||
}
|
||||
case constant.APITypeClaude:
|
||||
fullRequestURL = fmt.Sprintf("%s/v1/complete", meta.BaseURL)
|
||||
case constant.APITypeBaidu:
|
||||
switch textRequest.Model {
|
||||
case "ERNIE-Bot":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
||||
case "ERNIE-Bot-turbo":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||
case "ERNIE-Bot-4":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
||||
case "BLOOMZ-7B":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
||||
case "Embedding-V1":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
||||
}
|
||||
var accessToken string
|
||||
var err error
|
||||
if accessToken, err = baidu.GetAccessToken(meta.APIKey); err != nil {
|
||||
return "", fmt.Errorf("failed to get baidu access token: %w", err)
|
||||
}
|
||||
fullRequestURL += "?access_token=" + accessToken
|
||||
case constant.APITypePaLM:
|
||||
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL)
|
||||
case constant.APITypeGemini:
|
||||
version := helper.AssignOrDefault(meta.APIVersion, "v1")
|
||||
action := "generateContent"
|
||||
if textRequest.Stream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, textRequest.Model, action)
|
||||
case constant.APITypeZhipu:
|
||||
method := "invoke"
|
||||
if textRequest.Stream {
|
||||
method = "sse-invoke"
|
||||
}
|
||||
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
||||
case constant.APITypeAli:
|
||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||||
if relayMode == constant.RelayModeEmbeddings {
|
||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||
}
|
||||
case constant.APITypeTencent:
|
||||
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
|
||||
case constant.APITypeAIProxyLibrary:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/library/ask", meta.BaseURL)
|
||||
}
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isModelMapped bool, apiType int, relayMode int) (io.Reader, error) {
|
||||
var requestBody io.Reader
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(textRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
switch apiType {
|
||||
case constant.APITypeClaude:
|
||||
claudeRequest := anthropic.ConvertRequest(textRequest)
|
||||
jsonStr, err := json.Marshal(claudeRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case constant.APITypeBaidu:
|
||||
var jsonData []byte
|
||||
var err error
|
||||
switch relayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
baiduEmbeddingRequest := baidu.ConvertEmbeddingRequest(textRequest)
|
||||
jsonData, err = json.Marshal(baiduEmbeddingRequest)
|
||||
default:
|
||||
baiduRequest := baidu.ConvertRequest(textRequest)
|
||||
jsonData, err = json.Marshal(baiduRequest)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
case constant.APITypePaLM:
|
||||
palmRequest := google.ConvertPaLMRequest(textRequest)
|
||||
jsonStr, err := json.Marshal(palmRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case constant.APITypeGemini:
|
||||
geminiChatRequest := google.ConvertGeminiRequest(textRequest)
|
||||
jsonStr, err := json.Marshal(geminiChatRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case constant.APITypeZhipu:
|
||||
zhipuRequest := zhipu.ConvertRequest(textRequest)
|
||||
jsonStr, err := json.Marshal(zhipuRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case constant.APITypeAli:
|
||||
var jsonStr []byte
|
||||
var err error
|
||||
switch relayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
aliEmbeddingRequest := ali.ConvertEmbeddingRequest(textRequest)
|
||||
jsonStr, err = json.Marshal(aliEmbeddingRequest)
|
||||
default:
|
||||
aliRequest := ali.ConvertRequest(textRequest)
|
||||
jsonStr, err = json.Marshal(aliRequest)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case constant.APITypeTencent:
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
appId, secretId, secretKey, err := tencent.ParseConfig(apiKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tencentRequest := tencent.ConvertRequest(textRequest)
|
||||
tencentRequest.AppId = appId
|
||||
tencentRequest.SecretId = secretId
|
||||
jsonStr, err := json.Marshal(tencentRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sign := tencent.GetSign(*tencentRequest, secretKey)
|
||||
c.Request.Header.Set("Authorization", sign)
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case constant.APITypeAIProxyLibrary:
|
||||
aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest)
|
||||
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
|
||||
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
}
|
||||
return requestBody, nil
|
||||
}
|
||||
|
||||
func SetupRequestHeaders(c *gin.Context, req *http.Request, apiType int, meta *util.RelayMeta, isStream bool) {
|
||||
SetupAuthHeaders(c, req, apiType, meta, isStream)
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if isStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
func SetupAuthHeaders(c *gin.Context, req *http.Request, apiType int, meta *util.RelayMeta, isStream bool) {
|
||||
apiKey := meta.APIKey
|
||||
switch apiType {
|
||||
case constant.APITypeOpenAI:
|
||||
if meta.ChannelType == common.ChannelTypeAzure {
|
||||
req.Header.Set("api-key", apiKey)
|
||||
} else {
|
||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||
if meta.ChannelType == common.ChannelTypeOpenRouter {
|
||||
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||||
req.Header.Set("X-Title", "One API")
|
||||
}
|
||||
}
|
||||
case constant.APITypeClaude:
|
||||
req.Header.Set("x-api-key", apiKey)
|
||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
anthropicVersion = "2023-06-01"
|
||||
}
|
||||
req.Header.Set("anthropic-version", anthropicVersion)
|
||||
case constant.APITypeZhipu:
|
||||
token := zhipu.GetToken(apiKey)
|
||||
req.Header.Set("Authorization", token)
|
||||
case constant.APITypeAli:
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
if isStream {
|
||||
req.Header.Set("X-DashScope-SSE", "enable")
|
||||
}
|
||||
if c.GetString("plugin") != "" {
|
||||
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||
}
|
||||
case constant.APITypeTencent:
|
||||
req.Header.Set("Authorization", apiKey)
|
||||
case constant.APITypePaLM:
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
case constant.APITypeGemini:
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
default:
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
}
|
||||
|
||||
func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *http.Response, relayMode int, apiType int, isStream bool, promptTokens int) (usage *openai.Usage, err *openai.ErrorWithStatusCode) {
|
||||
var responseText string
|
||||
switch apiType {
|
||||
case constant.APITypeOpenAI:
|
||||
if isStream {
|
||||
err, responseText = openai.StreamHandler(c, resp, relayMode)
|
||||
} else {
|
||||
err, usage = openai.Handler(c, resp, promptTokens, textRequest.Model)
|
||||
}
|
||||
case constant.APITypeClaude:
|
||||
if isStream {
|
||||
err, responseText = anthropic.StreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = anthropic.Handler(c, resp, promptTokens, textRequest.Model)
|
||||
}
|
||||
case constant.APITypeBaidu:
|
||||
if isStream {
|
||||
err, usage = baidu.StreamHandler(c, resp)
|
||||
} else {
|
||||
switch relayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = baidu.EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = baidu.Handler(c, resp)
|
||||
}
|
||||
}
|
||||
case constant.APITypePaLM:
|
||||
if isStream { // PaLM2 API does not support stream
|
||||
err, responseText = google.PaLMStreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = google.PaLMHandler(c, resp, promptTokens, textRequest.Model)
|
||||
}
|
||||
case constant.APITypeGemini:
|
||||
if isStream {
|
||||
err, responseText = google.StreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = google.GeminiHandler(c, resp, promptTokens, textRequest.Model)
|
||||
}
|
||||
case constant.APITypeZhipu:
|
||||
if isStream {
|
||||
err, usage = zhipu.StreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = zhipu.Handler(c, resp)
|
||||
}
|
||||
case constant.APITypeAli:
|
||||
if isStream {
|
||||
err, usage = ali.StreamHandler(c, resp)
|
||||
} else {
|
||||
switch relayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = ali.EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = ali.Handler(c, resp)
|
||||
}
|
||||
}
|
||||
case constant.APITypeXunfei:
|
||||
auth := c.Request.Header.Get("Authorization")
|
||||
auth = strings.TrimPrefix(auth, "Bearer ")
|
||||
splits := strings.Split(auth, "|")
|
||||
if len(splits) != 3 {
|
||||
return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||
}
|
||||
if isStream {
|
||||
err, usage = xunfei.StreamHandler(c, *textRequest, splits[0], splits[1], splits[2])
|
||||
} else {
|
||||
err, usage = xunfei.Handler(c, *textRequest, splits[0], splits[1], splits[2])
|
||||
}
|
||||
case constant.APITypeAIProxyLibrary:
|
||||
if isStream {
|
||||
err, usage = aiproxy.StreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = aiproxy.Handler(c, resp)
|
||||
}
|
||||
case constant.APITypeTencent:
|
||||
if isStream {
|
||||
err, responseText = tencent.StreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = tencent.Handler(c, resp)
|
||||
}
|
||||
default:
|
||||
return nil, openai.ErrorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if usage == nil && responseText != "" {
|
||||
usage = &openai.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
19
relay/util/billing.go
Normal file
19
relay/util/billing.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"one-api/common/logger"
|
||||
"one-api/model"
|
||||
)
|
||||
|
||||
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int, tokenId int) {
|
||||
if preConsumedQuota != 0 {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(ctx)
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/config"
|
||||
"one-api/common/logger"
|
||||
"one-api/model"
|
||||
"one-api/relay/channel/openai"
|
||||
"strconv"
|
||||
@@ -16,7 +18,7 @@ import (
|
||||
)
|
||||
|
||||
func ShouldDisableChannel(err *openai.Error, statusCode int) bool {
|
||||
if !common.AutomaticDisableChannelEnabled {
|
||||
if !config.AutomaticDisableChannelEnabled {
|
||||
return false
|
||||
}
|
||||
if err == nil {
|
||||
@@ -32,7 +34,7 @@ func ShouldDisableChannel(err *openai.Error, statusCode int) bool {
|
||||
}
|
||||
|
||||
func ShouldEnableChannel(err error, openAIErr *openai.Error) bool {
|
||||
if !common.AutomaticEnableChannelEnabled {
|
||||
if !config.AutomaticEnableChannelEnabled {
|
||||
return false
|
||||
}
|
||||
if err != nil {
|
||||
@@ -138,11 +140,11 @@ func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuo
|
||||
// quotaDelta is remaining quota to be consumed
|
||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
logger.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
logger.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
// totalQuota is total quota consumed
|
||||
if totalQuota != 0 {
|
||||
@@ -152,11 +154,11 @@ func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuo
|
||||
model.UpdateChannelUsedQuota(channelId, totalQuota)
|
||||
}
|
||||
if totalQuota <= 0 {
|
||||
common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
|
||||
logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
|
||||
}
|
||||
}
|
||||
|
||||
func GetAPIVersion(c *gin.Context) string {
|
||||
func GetAzureAPIVersion(c *gin.Context) string {
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion == "" {
|
||||
|
||||
@@ -2,7 +2,7 @@ package util
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/config"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -10,11 +10,11 @@ var HTTPClient *http.Client
|
||||
var ImpatientHTTPClient *http.Client
|
||||
|
||||
func init() {
|
||||
if common.RelayTimeout == 0 {
|
||||
if config.RelayTimeout == 0 {
|
||||
HTTPClient = &http.Client{}
|
||||
} else {
|
||||
HTTPClient = &http.Client{
|
||||
Timeout: time.Duration(common.RelayTimeout) * time.Second,
|
||||
Timeout: time.Duration(config.RelayTimeout) * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
12
relay/util/model_mapping.go
Normal file
12
relay/util/model_mapping.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package util
|
||||
|
||||
func GetMappedModelName(modelName string, mapping map[string]string) (string, bool) {
|
||||
if mapping == nil {
|
||||
return modelName, false
|
||||
}
|
||||
mappedModelName := mapping[modelName]
|
||||
if mappedModelName != "" {
|
||||
return mappedModelName, true
|
||||
}
|
||||
return modelName, false
|
||||
}
|
||||
44
relay/util/relay_meta.go
Normal file
44
relay/util/relay_meta.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type RelayMeta struct {
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
TokenName string
|
||||
UserId int
|
||||
Group string
|
||||
ModelMapping map[string]string
|
||||
BaseURL string
|
||||
APIVersion string
|
||||
APIKey string
|
||||
Config map[string]string
|
||||
}
|
||||
|
||||
func GetRelayMeta(c *gin.Context) *RelayMeta {
|
||||
meta := RelayMeta{
|
||||
ChannelType: c.GetInt("channel"),
|
||||
ChannelId: c.GetInt("channel_id"),
|
||||
TokenId: c.GetInt("token_id"),
|
||||
TokenName: c.GetString("token_name"),
|
||||
UserId: c.GetInt("id"),
|
||||
Group: c.GetString("group"),
|
||||
ModelMapping: c.GetStringMapString("model_mapping"),
|
||||
BaseURL: c.GetString("base_url"),
|
||||
APIVersion: c.GetString("api_version"),
|
||||
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||
Config: nil,
|
||||
}
|
||||
if meta.ChannelType == common.ChannelTypeAzure {
|
||||
meta.APIVersion = GetAzureAPIVersion(c)
|
||||
}
|
||||
if meta.BaseURL == "" {
|
||||
meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType]
|
||||
}
|
||||
return &meta
|
||||
}
|
||||
37
relay/util/validation.go
Normal file
37
relay/util/validation.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
func ValidateTextRequest(textRequest *openai.GeneralOpenAIRequest, relayMode int) error {
|
||||
if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
|
||||
return errors.New("max_tokens is invalid")
|
||||
}
|
||||
if textRequest.Model == "" {
|
||||
return errors.New("model is required")
|
||||
}
|
||||
switch relayMode {
|
||||
case constant.RelayModeCompletions:
|
||||
if textRequest.Prompt == "" {
|
||||
return errors.New("field prompt is required")
|
||||
}
|
||||
case constant.RelayModeChatCompletions:
|
||||
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
|
||||
return errors.New("field messages is required")
|
||||
}
|
||||
case constant.RelayModeEmbeddings:
|
||||
case constant.RelayModeModerations:
|
||||
if textRequest.Input == "" {
|
||||
return errors.New("field input is required")
|
||||
}
|
||||
case constant.RelayModeEdits:
|
||||
if textRequest.Instruction == "" {
|
||||
return errors.New("field instruction is required")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user