Merge pull request #551 from Calcium-Ion/realtime

feat: support openai realtime api
This commit is contained in:
Calcium-Ion 2024-11-05 19:45:43 +08:00 committed by GitHub
commit a859ff5985
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
48 changed files with 1266 additions and 198 deletions

View File

@ -421,6 +421,34 @@ func GetCompletionRatio(name string) float64 {
return 1 return 1
} }
func GetAudioRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-4o-realtime") {
return 20
}
return 20
}
func GetAudioCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-4o-realtime") {
return 10
}
return 2
}
//func GetAudioPricePerMinute(name string) float64 {
// if strings.HasPrefix(name, "gpt-4o-realtime") {
// return 0.06
// }
// return 0.06
//}
//
//func GetAudioCompletionPricePerMinute(name string) float64 {
// if strings.HasPrefix(name, "gpt-4o-realtime") {
// return 0.24
// }
// return 0.24
//}
func GetCompletionRatioMap() map[string]float64 { func GetCompletionRatioMap() map[string]float64 {
if CompletionRatio == nil { if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio CompletionRatio = defaultCompletionRatio

View File

@ -102,17 +102,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if err != nil { if err != nil {
return err, nil return err, nil
} }
if resp != nil && resp.StatusCode != http.StatusOK { var httpResp *http.Response
err := service.RelayErrorHandler(resp) if resp != nil {
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
err := service.RelayErrorHandler(httpResp)
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
} }
usage, respErr := adaptor.DoResponse(c, resp, meta) }
usageA, respErr := adaptor.DoResponse(c, httpResp, meta)
if respErr != nil { if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), respErr return fmt.Errorf("%s", respErr.Error.Message), respErr
} }
if usage == nil { if usageA == nil {
return errors.New("usage is nil"), nil return errors.New("usage is nil"), nil
} }
usage := usageA.(*dto.Usage)
result := w.Result() result := w.Result()
respBody, err := io.ReadAll(result.Body) respBody, err := io.ReadAll(result.Body)
if err != nil { if err != nil {

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io" "io"
"log" "log"
"net/http" "net/http"
@ -38,6 +39,15 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
return err return err
} }
func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
var err *dto.OpenAIErrorWithStatusCode
switch relayMode {
default:
err = relay.TextHelper(c)
}
return err
}
func Playground(c *gin.Context) { func Playground(c *gin.Context) {
var openaiErr *dto.OpenAIErrorWithStatusCode var openaiErr *dto.OpenAIErrorWithStatusCode
@ -134,6 +144,67 @@ func Relay(c *gin.Context) {
} }
} }
var upgrader = websocket.Upgrader{
Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol则必须在此声明对应的 Protocol TODO add other protocol
CheckOrigin: func(r *http.Request) bool {
return true // 允许跨域
},
}
func WssRelay(c *gin.Context) {
// 将 HTTP 连接升级为 WebSocket 连接
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
defer ws.Close()
if err != nil {
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
service.WssError(c, ws, openaiErr.Error)
return
}
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
originalModel := c.GetString("original_model")
var openaiErr *dto.OpenAIErrorWithStatusCode
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, err.Error())
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break
}
openaiErr = wssRequest(c, ws, relayMode, channel)
if openaiErr == nil {
return // 成功处理请求,直接返回
}
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr)
}
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
service.WssError(c, ws, openaiErr.Error)
}
}
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id) addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c) requestBody, _ := common.GetRequestBody(c)
@ -141,6 +212,13 @@ func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.Op
return relayHandler(c, relayMode) return relayHandler(c, relayMode)
} }
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relay.WssHelper(c, ws)
}
func addUsedChannel(c *gin.Context, channelId int) { func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel") useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) useChannel = append(useChannel, fmt.Sprintf("%d", channelId))

97
dto/realtime.go Normal file
View File

@ -0,0 +1,97 @@
package dto
const (
RealtimeEventTypeError = "error"
RealtimeEventTypeSessionUpdate = "session.update"
RealtimeEventTypeConversationCreate = "conversation.item.create"
RealtimeEventTypeResponseCreate = "response.create"
RealtimeEventInputAudioBufferAppend = "input_audio_buffer.append"
)
const (
RealtimeEventTypeResponseDone = "response.done"
RealtimeEventTypeSessionUpdated = "session.updated"
RealtimeEventTypeSessionCreated = "session.created"
RealtimeEventResponseAudioDelta = "response.audio.delta"
RealtimeEventResponseAudioTranscriptionDelta = "response.audio_transcript.delta"
RealtimeEventResponseFunctionCallArgumentsDelta = "response.function_call_arguments.delta"
RealtimeEventResponseFunctionCallArgumentsDone = "response.function_call_arguments.done"
RealtimeEventConversationItemCreated = "conversation.item.created"
)
type RealtimeEvent struct {
EventId string `json:"event_id"`
Type string `json:"type"`
//PreviousItemId string `json:"previous_item_id"`
Session *RealtimeSession `json:"session,omitempty"`
Item *RealtimeItem `json:"item,omitempty"`
Error *OpenAIError `json:"error,omitempty"`
Response *RealtimeResponse `json:"response,omitempty"`
Delta string `json:"delta,omitempty"`
Audio string `json:"audio,omitempty"`
}
type RealtimeResponse struct {
Usage *RealtimeUsage `json:"usage"`
}
type RealtimeUsage struct {
TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails InputTokenDetails `json:"input_token_details"`
OutputTokenDetails OutputTokenDetails `json:"output_token_details"`
}
type InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
}
type OutputTokenDetails struct {
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
}
type RealtimeSession struct {
Modalities []string `json:"modalities"`
Instructions string `json:"instructions"`
Voice string `json:"voice"`
InputAudioFormat string `json:"input_audio_format"`
OutputAudioFormat string `json:"output_audio_format"`
InputAudioTranscription InputAudioTranscription `json:"input_audio_transcription"`
TurnDetection interface{} `json:"turn_detection"`
Tools []RealTimeTool `json:"tools"`
ToolChoice string `json:"tool_choice"`
Temperature float64 `json:"temperature"`
//MaxResponseOutputTokens int `json:"max_response_output_tokens"`
}
type InputAudioTranscription struct {
Model string `json:"model"`
}
type RealTimeTool struct {
Type string `json:"type"`
Name string `json:"name"`
Description string `json:"description"`
Parameters any `json:"parameters"`
}
type RealtimeItem struct {
Id string `json:"id"`
Type string `json:"type"`
Status string `json:"status"`
Role string `json:"role"`
Content []RealtimeContent `json:"content"`
Name *string `json:"name,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"`
CallId string `json:"call_id,omitempty"`
}
type RealtimeContent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Audio string `json:"audio,omitempty"` // Base64-encoded audio bytes.
Transcript string `json:"transcript,omitempty"`
}

View File

@ -155,8 +155,27 @@ func RootAuth() func(c *gin.Context) {
} }
} }
func WssAuth(c *gin.Context) {
}
func TokenAuth() func(c *gin.Context) { func TokenAuth() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
// 先检测是否为ws
if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" {
// Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1
// read sk from Sec-WebSocket-Protocol
key := c.Request.Header.Get("Sec-WebSocket-Protocol")
parts := strings.Split(key, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "openai-insecure-api-key") {
key = strings.TrimPrefix(part, "openai-insecure-api-key.")
break
}
}
c.Request.Header.Set("Authorization", "Bearer "+key)
}
key := c.Request.Header.Get("Authorization") key := c.Request.Header.Get("Authorization")
parts := make([]string, 0) parts := make([]string, 0)
key = strings.TrimPrefix(key, "Bearer ") key = strings.TrimPrefix(key, "Bearer ")

View File

@ -170,6 +170,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return nil, false, errors.New("无效的请求, " + err.Error()) return nil, false, errors.New("无效的请求, " + err.Error())
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
modelRequest.Model = c.Query("model")
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" { if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable" modelRequest.Model = "text-moderation-stable"

View File

@ -12,13 +12,13 @@ type Adaptor interface {
// Init IsStream bool // Init IsStream bool
Init(info *relaycommon.RelayInfo) Init(info *relaycommon.RelayInfo)
GetRequestURL(info *relaycommon.RelayInfo) (string, error) GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
GetModelList() []string GetModelList() []string
GetChannelName() string GetChannelName() string
} }

View File

@ -32,14 +32,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fullRequestURL, nil return fullRequestURL, nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey) req.Set("Authorization", "Bearer "+info.ApiKey)
if info.IsStream { if info.IsStream {
req.Header.Set("X-DashScope-SSE", "enable") req.Set("X-DashScope-SSE", "enable")
} }
if c.GetString("plugin") != "" { if c.GetString("plugin") != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) req.Set("X-DashScope-Plugin", c.GetString("plugin"))
} }
return nil return nil
} }
@ -72,11 +72,11 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeImagesGenerations: case constant.RelayModeImagesGenerations:
err, usage = aliImageHandler(c, resp, info) err, usage = aliImageHandler(c, resp, info)

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io" "io"
"net/http" "net/http"
"one-api/relay/common" "one-api/relay/common"
@ -11,14 +12,16 @@ import (
"one-api/service" "one-api/service"
) )
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) { func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
// multipart/form-data // multipart/form-data
} else if info.RelayMode == constant.RelayModeRealtime {
// websocket
} else { } else {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Set("Accept", c.Request.Header.Get("Accept"))
if info.IsStream && c.Request.Header.Get("Accept") == "" { if info.IsStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream") req.Set("Accept", "text/event-stream")
} }
} }
} }
@ -32,7 +35,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
if err != nil { if err != nil {
return nil, fmt.Errorf("new request failed: %w", err) return nil, fmt.Errorf("new request failed: %w", err)
} }
err = a.SetupRequestHeader(c, req, info) err = a.SetupRequestHeader(c, &req.Header, info)
if err != nil { if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err) return nil, fmt.Errorf("setup request header failed: %w", err)
} }
@ -55,7 +58,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
// set form data // set form data
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
err = a.SetupRequestHeader(c, req, info) err = a.SetupRequestHeader(c, &req.Header, info)
if err != nil { if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err) return nil, fmt.Errorf("setup request header failed: %w", err)
} }
@ -66,6 +69,27 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
return resp, nil return resp, nil
} }
func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*websocket.Conn, error) {
fullRequestURL, err := a.GetRequestURL(info)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
targetHeader := http.Header{}
err = a.SetupRequestHeader(c, &targetHeader, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type"))
targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader)
if err != nil {
return nil, fmt.Errorf("dial failed to %s: %w", fullRequestURL, err)
}
// send request body
//all, err := io.ReadAll(requestBody)
//err = service.WssString(c, targetConn, string(all))
return targetConn, nil
}
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) { func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := service.GetHttpClient().Do(req) resp, err := service.GetHttpClient().Do(req)
if err != nil { if err != nil {

View File

@ -37,7 +37,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return "", nil return "", nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
return nil return nil
} }
@ -59,11 +59,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return nil, nil return nil, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = awsStreamHandler(c, resp, info, a.RequestMode) err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
} else { } else {

View File

@ -98,9 +98,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fullRequestURL, nil return fullRequestURL, nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey) req.Set("Authorization", "Bearer "+info.ApiKey)
return nil return nil
} }
@ -122,11 +122,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = baiduStreamHandler(c, resp) err, usage = baiduStreamHandler(c, resp)
} else { } else {

View File

@ -47,14 +47,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-api-key", info.ApiKey) req.Set("x-api-key", info.ApiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version") anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" { if anthropicVersion == "" {
anthropicVersion = "2023-06-01" anthropicVersion = "2023-06-01"
} }
req.Header.Set("anthropic-version", anthropicVersion) req.Set("anthropic-version", anthropicVersion)
return nil return nil
} }
@ -73,11 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode) err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
} else { } else {

View File

@ -509,7 +509,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
}, nil }, nil
} }
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
completionTokens, err := service.CountTokenText(claudeResponse.Completion, info.OriginModelName) completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
} }

View File

@ -30,9 +30,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil return nil
} }
@ -48,7 +48,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
} }
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
@ -78,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeEmbeddings: case constant.RelayModeEmbeddings:
fallthrough fallthrough

View File

@ -149,7 +149,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
usage := &dto.Usage{} usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName) usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage return nil, usage

View File

@ -36,9 +36,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil return nil
} }
@ -46,7 +46,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return requestOpenAI2Cohere(*request), nil return requestOpenAI2Cohere(*request), nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
@ -54,7 +54,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return requestConvertRerank2Cohere(request), nil return requestConvertRerank2Cohere(request), nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank { if info.RelayMode == constant.RelayModeRerank {
err, usage = cohereRerankHandler(c, resp, info) err, usage = cohereRerankHandler(c, resp, info)
} else { } else {

View File

@ -31,9 +31,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey) req.Set("Authorization", "Bearer "+info.ApiKey)
return nil return nil
} }
@ -48,11 +48,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = difyStreamHandler(c, resp, info) err, usage = difyStreamHandler(c, resp, info)
} else { } else {

View File

@ -108,7 +108,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
} }
if usage.TotalTokens == 0 { if usage.TotalTokens == 0 {
usage.PromptTokens = info.PromptTokens usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText) usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
} }
return nil, usage return nil, usage

View File

@ -47,9 +47,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-goog-api-key", info.ApiKey) req.Set("x-goog-api-key", info.ApiKey)
return nil return nil
} }
@ -64,11 +64,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = GeminiChatStreamHandler(c, resp, info) err, usage = GeminiChatStreamHandler(c, resp, info)
} else { } else {

View File

@ -37,9 +37,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return "", errors.New("invalid relay mode") return "", errors.New("invalid relay mode")
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil return nil
} }
@ -47,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return request, nil return request, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
@ -55,7 +55,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return request, nil return request, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank { if info.RelayMode == constant.RelayModeRerank {
err, usage = jinaRerankHandler(c, resp) err, usage = jinaRerankHandler(c, resp)
} else if info.RelayMode == constant.RelayModeEmbeddings { } else if info.RelayMode == constant.RelayModeEmbeddings {

View File

@ -31,9 +31,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey) req.Set("Authorization", "Bearer "+info.ApiKey)
return nil return nil
} }
@ -50,11 +50,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info) err, usage = openai.OaiStreamHandler(c, resp, info)
} else { } else {

View File

@ -37,7 +37,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
return nil return nil
} }
@ -58,11 +58,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info) err, usage = openai.OaiStreamHandler(c, resp, info)
} else { } else {

View File

@ -31,6 +31,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRealtime {
// trim https
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
baseUrl = strings.TrimPrefix(baseUrl, "http://")
baseUrl = "wss://" + baseUrl
info.BaseUrl = baseUrl
}
switch info.ChannelType { switch info.ChannelType {
case common.ChannelTypeAzure: case common.ChannelTypeAzure:
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
@ -40,8 +47,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
model_ := info.UpstreamModelName model_ := info.UpstreamModelName
model_ = strings.Replace(model_, ".", "", -1) model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67 // https://github.com/songquanpeng/one-api/issues/67
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
if info.RelayMode == constant.RelayModeRealtime {
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, info.ApiVersion)
}
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
case common.ChannelTypeMiniMax: case common.ChannelTypeMiniMax:
return minimax.GetRequestURL(info) return minimax.GetRequestURL(info)
@ -54,16 +63,34 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, header)
if info.ChannelType == common.ChannelTypeAzure { if info.ChannelType == common.ChannelTypeAzure {
req.Header.Set("api-key", info.ApiKey) header.Set("api-key", info.ApiKey)
return nil return nil
} }
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization { if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
req.Header.Set("OpenAI-Organization", info.Organization) header.Set("OpenAI-Organization", info.Organization)
}
if info.RelayMode == constant.RelayModeRealtime {
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
if swp != "" {
items := []string{
"realtime",
"openai-insecure-api-key." + info.ApiKey,
"openai-beta.realtime-v1",
}
header.Set("Sec-WebSocket-Protocol", strings.Join(items, ","))
//req.Header.Set("Sec-WebSocket-Key", c.Request.Header.Get("Sec-WebSocket-Key"))
//req.Header.Set("Sec-Websocket-Extensions", c.Request.Header.Get("Sec-Websocket-Extensions"))
//req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version"))
} else {
header.Set("openai-beta", "realtime=v1")
header.Set("Authorization", "Bearer "+info.ApiKey)
}
} else {
header.Set("Authorization", "Bearer "+info.ApiKey)
} }
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
//if info.ChannelType == common.ChannelTypeOpenRouter { //if info.ChannelType == common.ChannelTypeOpenRouter {
// req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") // req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
// req.Header.Set("X-Title", "One API") // req.Header.Set("X-Title", "One API")
@ -131,16 +158,20 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
return request, nil return request, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
return channel.DoFormRequest(a, c, info, requestBody) return channel.DoFormRequest(a, c, info, requestBody)
} else if info.RelayMode == constant.RelayModeRealtime {
return channel.DoWssRequest(a, c, info, requestBody)
} else { } else {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeRealtime:
err, usage = OpenaiRealtimeHandler(c, info)
case constant.RelayModeAudioSpeech: case constant.RelayModeAudioSpeech:
err, usage = OpenaiTTSHandler(c, resp, info) err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeAudioTranslation: case constant.RelayModeAudioTranslation:

View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"github.com/bytedance/gopkg/util/gopool" "github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
@ -231,7 +232,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0 completionTokens := 0
for _, choice := range simpleResponse.Choices { for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model) ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
completionTokens += ctkm completionTokens += ctkm
} }
simpleResponse.Usage = dto.Usage{ simpleResponse.Usage = dto.Usage{
@ -324,7 +325,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
usage := &dto.Usage{} usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName) usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage return nil, usage
} }
@ -373,3 +374,210 @@ func getTextFromJSON(body []byte) (string, error) {
} }
return whisperResponse.Text, nil return whisperResponse.Text, nil
} }
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
info.IsStream = true
clientConn := info.ClientWs
targetConn := info.TargetWs
clientClosed := make(chan struct{})
targetClosed := make(chan struct{})
sendChan := make(chan []byte, 100)
receiveChan := make(chan []byte, 100)
errChan := make(chan error, 2)
usage := &dto.RealtimeUsage{}
localUsage := &dto.RealtimeUsage{}
sumUsage := &dto.RealtimeUsage{}
gopool.Go(func() {
for {
select {
case <-c.Done():
return
default:
_, message, err := clientConn.ReadMessage()
if err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
errChan <- fmt.Errorf("error reading from client: %v", err)
}
close(clientClosed)
return
}
realtimeEvent := &dto.RealtimeEvent{}
err = json.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
}
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
if realtimeEvent.Session != nil {
if realtimeEvent.Session.Tools != nil {
info.RealtimeTools = realtimeEvent.Session.Tools
}
}
}
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
err = service.WssString(c, targetConn, string(message))
if err != nil {
errChan <- fmt.Errorf("error writing to target: %v", err)
return
}
select {
case sendChan <- message:
default:
}
}
}
})
gopool.Go(func() {
for {
select {
case <-c.Done():
return
default:
_, message, err := targetConn.ReadMessage()
if err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
errChan <- fmt.Errorf("error reading from target: %v", err)
}
close(targetClosed)
return
}
info.SetFirstResponseTime()
realtimeEvent := &dto.RealtimeEvent{}
err = json.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
}
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
realtimeUsage := realtimeEvent.Response.Usage
if realtimeUsage != nil {
usage.TotalTokens += realtimeUsage.TotalTokens
usage.InputTokens += realtimeUsage.InputTokens
usage.OutputTokens += realtimeUsage.OutputTokens
usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
err := preConsumeUsage(c, info, usage, sumUsage)
if err != nil {
errChan <- fmt.Errorf("error consume usage: %v", err)
return
}
// 本次计费完成,清除
usage = &dto.RealtimeUsage{}
localUsage = &dto.RealtimeUsage{}
} else {
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
info.IsFirstRequest = false
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
err = preConsumeUsage(c, info, localUsage, sumUsage)
if err != nil {
errChan <- fmt.Errorf("error consume usage: %v", err)
return
}
// 本次计费完成,清除
localUsage = &dto.RealtimeUsage{}
// print now usage
}
//common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
realtimeSession := realtimeEvent.Session
if realtimeSession != nil {
// update audio format
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
}
} else {
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.OutputTokens += textToken + audioToken
localUsage.OutputTokenDetails.TextTokens += textToken
localUsage.OutputTokenDetails.AudioTokens += audioToken
}
err = service.WssString(c, clientConn, string(message))
if err != nil {
errChan <- fmt.Errorf("error writing to client: %v", err)
return
}
select {
case receiveChan <- message:
default:
}
}
}
})
select {
case <-clientClosed:
case <-targetClosed:
case err := <-errChan:
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
common.LogError(c, "realtime error: "+err.Error())
case <-c.Done():
}
if usage.TotalTokens != 0 {
_ = preConsumeUsage(c, info, usage, sumUsage)
}
if localUsage.TotalTokens != 0 {
_ = preConsumeUsage(c, info, localUsage, sumUsage)
}
// check usage total tokens, if 0, use local usage
return nil, sumUsage
}
func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
totalUsage.TotalTokens += usage.TotalTokens
totalUsage.InputTokens += usage.InputTokens
totalUsage.OutputTokens += usage.OutputTokens
totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
// clear usage
err := service.PreWssConsumeQuota(ctx, info, usage)
return err
}

View File

@ -32,9 +32,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-goog-api-key", info.ApiKey) req.Set("x-goog-api-key", info.ApiKey)
return nil return nil
} }
@ -49,11 +49,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = palmStreamHandler(c, resp) err, responseText = palmStreamHandler(c, resp)

View File

@ -156,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil }, nil
} }
fullTextResponse := responsePaLM2OpenAI(&palmResponse) fullTextResponse := responsePaLM2OpenAI(&palmResponse)
completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model) completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
usage := dto.Usage{ usage := dto.Usage{
PromptTokens: promptTokens, PromptTokens: promptTokens,
CompletionTokens: completionTokens, CompletionTokens: completionTokens,

View File

@ -32,9 +32,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey) req.Set("Authorization", "Bearer "+info.ApiKey)
return nil return nil
} }
@ -52,11 +52,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info) err, usage = openai.OaiStreamHandler(c, resp, info)
} else { } else {

View File

@ -40,9 +40,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return "", errors.New("invalid relay mode") return "", errors.New("invalid relay mode")
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil return nil
} }
@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return request, nil return request, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
@ -58,7 +58,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return request, nil return request, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeRerank: case constant.RelayModeRerank:
err, usage = siliconflowRerankHandler(c, resp) err, usage = siliconflowRerankHandler(c, resp)

View File

@ -43,12 +43,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/", info.BaseUrl), nil return fmt.Sprintf("%s/", info.BaseUrl), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", a.Sign) req.Set("Authorization", a.Sign)
req.Header.Set("X-TC-Action", a.Action) req.Set("X-TC-Action", a.Action)
req.Header.Set("X-TC-Version", a.Version) req.Set("X-TC-Version", a.Version)
req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10)) req.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
return nil return nil
} }
@ -73,11 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = tencentStreamHandler(c, resp) err, responseText = tencentStreamHandler(c, resp)

View File

@ -107,13 +107,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return "", errors.New("unsupported request mode") return "", errors.New("unsupported request mode")
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
accessToken, err := getAccessToken(a, info) accessToken, err := getAccessToken(a, info)
if err != nil { if err != nil {
return err return err
} }
req.Header.Set("Authorization", "Bearer "+accessToken) req.Set("Authorization", "Bearer "+accessToken)
return nil return nil
} }
@ -148,11 +148,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
switch a.RequestMode { switch a.RequestMode {
case RequestModeClaude: case RequestModeClaude:

View File

@ -33,7 +33,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return "", nil return "", nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
return nil return nil
} }
@ -50,14 +50,14 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
// xunfei's request is not http request, so we don't need to do anything here // xunfei's request is not http request, so we don't need to do anything here
dummyResp := &http.Response{} dummyResp := &http.Response{}
dummyResp.StatusCode = http.StatusOK dummyResp.StatusCode = http.StatusOK
return dummyResp, nil return dummyResp, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
splits := strings.Split(info.ApiKey, "|") splits := strings.Split(info.ApiKey, "|")
if len(splits) != 3 { if len(splits) != 3 {
return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)

View File

@ -35,10 +35,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
token := getZhipuToken(info.ApiKey) token := getZhipuToken(info.ApiKey)
req.Header.Set("Authorization", token) req.Set("Authorization", token)
return nil return nil
} }
@ -56,11 +56,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = zhipuStreamHandler(c, resp) err, usage = zhipuStreamHandler(c, resp)
} else { } else {

View File

@ -32,10 +32,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
token := getZhipuToken(info.ApiKey) token := getZhipuToken(info.ApiKey)
req.Header.Set("Authorization", token) req.Set("Authorization", token)
return nil return nil
} }
@ -53,11 +53,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info) err, usage = openai.OaiStreamHandler(c, resp, info)
} else { } else {

View File

@ -2,7 +2,9 @@ package common
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"one-api/common" "one-api/common"
"one-api/dto"
"one-api/relay/constant" "one-api/relay/constant"
"strings" "strings"
"time" "time"
@ -21,6 +23,7 @@ type RelayInfo struct {
ApiType int ApiType int
IsStream bool IsStream bool
IsPlayground bool IsPlayground bool
UsePrice bool
RelayMode int RelayMode int
UpstreamModelName string UpstreamModelName string
OriginModelName string OriginModelName string
@ -32,6 +35,21 @@ type RelayInfo struct {
BaseUrl string BaseUrl string
SupportStreamOptions bool SupportStreamOptions bool
ShouldIncludeUsage bool ShouldIncludeUsage bool
ClientWs *websocket.Conn
TargetWs *websocket.Conn
InputAudioFormat string
OutputAudioFormat string
RealtimeTools []dto.RealTimeTool
IsFirstRequest bool
}
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
info := GenRelayInfo(c)
info.ClientWs = ws
info.InputAudioFormat = "pcm16"
info.OutputAudioFormat = "pcm16"
info.IsFirstRequest = true
return info
} }
func GenRelayInfo(c *gin.Context) *RelayInfo { func GenRelayInfo(c *gin.Context) *RelayInfo {

View File

@ -38,6 +38,8 @@ const (
RelayModeSunoSubmit RelayModeSunoSubmit
RelayModeRerank RelayModeRerank
RelayModeRealtime
) )
func Path2RelayMode(path string) int { func Path2RelayMode(path string) int {
@ -64,6 +66,8 @@ func Path2RelayMode(path string) int {
relayMode = RelayModeAudioTranslation relayMode = RelayModeAudioTranslation
} else if strings.HasPrefix(path, "/v1/rerank") { } else if strings.HasPrefix(path, "/v1/rerank") {
relayMode = RelayModeRerank relayMode = RelayModeRerank
} else if strings.HasPrefix(path, "/v1/realtime") {
relayMode = RelayModeRealtime
} }
return relayMode return relayMode
} }

View File

@ -46,7 +46,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
return audioRequest, nil return audioRequest, nil
} }
func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c) relayInfo := relaycommon.GenRelayInfo(c)
audioRequest, err := getAndValidAudioRequest(c, relayInfo) audioRequest, err := getAndValidAudioRequest(c, relayInfo)
@ -58,7 +58,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
promptTokens := 0 promptTokens := 0
preConsumedTokens := common.PreConsumedQuota preConsumedTokens := common.PreConsumedQuota
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model) promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
} }
@ -92,6 +92,11 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
} }
} }
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
// map model name // map model name
modelMapping := c.GetString("model_mapping") modelMapping := c.GetString("model_mapping")
@ -122,19 +127,20 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }
statusCodeMappingStr := c.GetString("status_code_mapping") statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
if resp != nil { if resp != nil {
if resp.StatusCode != http.StatusOK { httpResp = resp.(*http.Response)
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) if httpResp.StatusCode != http.StatusOK {
openaiErr := service.RelayErrorHandler(resp) openaiErr = service.RelayErrorHandler(httpResp)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
} }
} }
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil { if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
// reset status code 重置状态码 // reset status code 重置状态码
@ -142,7 +148,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return openaiErr return openaiErr
} }
postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "") postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
return nil return nil
} }

View File

@ -149,22 +149,24 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
requestBody = bytes.NewBuffer(jsonData) requestBody = bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping") statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, requestBody) resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }
var httpResp *http.Response
if resp != nil { if resp != nil {
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") httpResp = resp.(*http.Response)
if resp.StatusCode != http.StatusOK { relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
openaiErr := service.RelayErrorHandler(resp) if httpResp.StatusCode != http.StatusOK {
openaiErr := service.RelayErrorHandler(httpResp)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
} }
} }
_, openaiErr := adaptor.DoResponse(c, resp, relayInfo) _, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil { if openaiErr != nil {
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)

View File

@ -64,7 +64,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
return textRequest, nil return textRequest, nil
} }
func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c) relayInfo := relaycommon.GenRelayInfo(c)
@ -131,7 +131,11 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if openaiErr != nil { if openaiErr != nil {
return openaiErr return openaiErr
} }
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
includeUsage := false includeUsage := false
// 判断用户是否需要返回使用情况 // 判断用户是否需要返回使用情况
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage { if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
@ -180,30 +184,30 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
} }
statusCodeMappingStr := c.GetString("status_code_mapping") statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody) resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }
if resp != nil { if resp != nil {
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") httpResp = resp.(*http.Response)
if resp.StatusCode != http.StatusOK { relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) if httpResp.StatusCode != http.StatusOK {
openaiErr := service.RelayErrorHandler(resp) openaiErr = service.RelayErrorHandler(httpResp)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
} }
} }
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil { if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
} }
postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") postConsumeQuota(c, relayInfo, textRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
return nil return nil
} }

View File

@ -23,7 +23,7 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
return token return token
} }
func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c) relayInfo := relaycommon.GenRelayInfo(c)
var rerankRequest *dto.RerankRequest var rerankRequest *dto.RerankRequest
@ -79,6 +79,12 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
if openaiErr != nil { if openaiErr != nil {
return openaiErr return openaiErr
} }
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType) adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil { if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
@ -99,23 +105,24 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }
var httpResp *http.Response
if resp != nil { if resp != nil {
if resp.StatusCode != http.StatusOK { httpResp = resp.(*http.Response)
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) if httpResp.StatusCode != http.StatusOK {
openaiErr := service.RelayErrorHandler(resp) openaiErr = service.RelayErrorHandler(httpResp)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
} }
} }
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil { if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
} }
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "") postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
return nil return nil
} }

159
relay/websocket.go Normal file
View File

@ -0,0 +1,159 @@
package relay
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
)
//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
// _, p, err := ws.ReadMessage()
// if err != nil {
// return nil, err
// }
// realtimeEvent := &dto.RealtimeEvent{}
// err = json.Unmarshal(p, realtimeEvent)
// if err != nil {
// return nil, err
// }
// // save the original request
// if realtimeEvent.Session == nil {
// return nil, errors.New("session object is nil")
// }
// c.Set("first_wss_request", p)
// return realtimeEvent, nil
//}
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfoWs(c, ws)
// get & validate textRequest 获取并验证文本请求
//realtimeEvent, err := getAndValidateWssRequest(c, ws)
//if err != nil {
// common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error()))
// return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
//}
// map model name
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[relayInfo.OriginModelName] != "" {
relayInfo.UpstreamModelName = modelMap[relayInfo.OriginModelName]
// set upstream model name
//isModelMapped = true
}
}
//relayInfo.UpstreamModelName = textRequest.Model
modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false)
groupRatio := common.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
//err := service.SensitiveWordsCheck(textRequest)
//if constant.ShouldCheckPromptSensitive() {
// err = checkRequestSensitive(textRequest, relayInfo)
// if err != nil {
// return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
// }
//}
//promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
//// count messages token error 计算promptTokens错误
//if err != nil {
// return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
//}
//
if !getModelPriceSuccess {
preConsumedTokens := common.PreConsumedQuota
//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
// preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
//}
modelRatio = common.GetModelRatio(relayInfo.UpstreamModelName)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
relayInfo.UsePrice = true
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(relayInfo)
//var requestBody io.Reader
//firstWssRequest, _ := c.Get("first_wss_request")
//requestBody = bytes.NewBuffer(firstWssRequest.([]byte))
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, nil)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
if resp != nil {
relayInfo.TargetWs = resp.(*websocket.Conn)
defer relayInfo.TargetWs.Close()
}
usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo)
if openaiErr != nil {
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
return nil
}
//func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) {
// var promptTokens int
// var err error
// switch info.RelayMode {
// default:
// promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName)
// }
// info.PromptTokens = promptTokens
// return promptTokens, err
//}
//func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
// var err error
// switch info.RelayMode {
// case relayconstant.RelayModeChatCompletions:
// err = service.CheckSensitiveMessages(textRequest.Messages)
// case relayconstant.RelayModeCompletions:
// err = service.CheckSensitiveInput(textRequest.Prompt)
// case relayconstant.RelayModeModerations:
// err = service.CheckSensitiveInput(textRequest.Input)
// case relayconstant.RelayModeEmbeddings:
// err = service.CheckSensitiveInput(textRequest.Input)
// }
// return err
//}

View File

@ -22,32 +22,41 @@ func SetRelayRouter(router *gin.Engine) {
playgroundRouter.POST("/chat/completions", controller.Playground) playgroundRouter.POST("/chat/completions", controller.Playground)
} }
relayV1Router := router.Group("/v1") relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) relayV1Router.Use(middleware.TokenAuth())
{ {
relayV1Router.POST("/completions", controller.Relay) // WebSocket 路由
relayV1Router.POST("/chat/completions", controller.Relay) wsRouter := relayV1Router.Group("")
relayV1Router.POST("/edits", controller.Relay) wsRouter.Use(middleware.Distribute())
relayV1Router.POST("/images/generations", controller.Relay) wsRouter.GET("/realtime", controller.WssRelay)
relayV1Router.POST("/images/edits", controller.RelayNotImplemented) }
relayV1Router.POST("/images/variations", controller.RelayNotImplemented) {
relayV1Router.POST("/embeddings", controller.Relay) //http router
relayV1Router.POST("/engines/:model/embeddings", controller.Relay) httpRouter := relayV1Router.Group("")
relayV1Router.POST("/audio/transcriptions", controller.Relay) httpRouter.Use(middleware.Distribute())
relayV1Router.POST("/audio/translations", controller.Relay) httpRouter.POST("/completions", controller.Relay)
relayV1Router.POST("/audio/speech", controller.Relay) httpRouter.POST("/chat/completions", controller.Relay)
relayV1Router.GET("/files", controller.RelayNotImplemented) httpRouter.POST("/edits", controller.Relay)
relayV1Router.POST("/files", controller.RelayNotImplemented) httpRouter.POST("/images/generations", controller.Relay)
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) httpRouter.POST("/images/edits", controller.RelayNotImplemented)
relayV1Router.GET("/files/:id", controller.RelayNotImplemented) httpRouter.POST("/images/variations", controller.RelayNotImplemented)
relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented) httpRouter.POST("/embeddings", controller.Relay)
relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented) httpRouter.POST("/engines/:model/embeddings", controller.Relay)
relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented) httpRouter.POST("/audio/transcriptions", controller.Relay)
relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented) httpRouter.POST("/audio/translations", controller.Relay)
relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) httpRouter.POST("/audio/speech", controller.Relay)
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) httpRouter.GET("/files", controller.RelayNotImplemented)
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) httpRouter.POST("/files", controller.RelayNotImplemented)
relayV1Router.POST("/moderations", controller.Relay) httpRouter.DELETE("/files/:id", controller.RelayNotImplemented)
relayV1Router.POST("/rerank", controller.Relay) httpRouter.GET("/files/:id", controller.RelayNotImplemented)
httpRouter.GET("/files/:id/content", controller.RelayNotImplemented)
httpRouter.POST("/fine-tunes", controller.RelayNotImplemented)
httpRouter.GET("/fine-tunes", controller.RelayNotImplemented)
httpRouter.GET("/fine-tunes/:id", controller.RelayNotImplemented)
httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
httpRouter.DELETE("/models/:model", controller.RelayNotImplemented)
httpRouter.POST("/moderations", controller.Relay)
httpRouter.POST("/rerank", controller.Relay)
} }
relayMjRouter := router.Group("/mj") relayMjRouter := router.Group("/mj")

31
service/audio.go Normal file
View File

@ -0,0 +1,31 @@
package service
import (
"encoding/base64"
"fmt"
)
func parseAudio(audioBase64 string, format string) (duration float64, err error) {
audioData, err := base64.StdEncoding.DecodeString(audioBase64)
if err != nil {
return 0, fmt.Errorf("base64 decode error: %v", err)
}
var samplesCount int
var sampleRate int
switch format {
case "pcm16":
samplesCount = len(audioData) / 2 // 16位 = 2字节每样本
sampleRate = 24000 // 24kHz
case "g711_ulaw", "g711_alaw":
samplesCount = len(audioData) // 8位 = 1字节每样本
sampleRate = 8000 // 8kHz
default:
samplesCount = len(audioData) // 8位 = 1字节每样本
sampleRate = 8000 // 8kHz
}
duration = float64(samplesCount) / float64(sampleRate)
return duration, nil
}

View File

@ -2,6 +2,7 @@ package service
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
) )
@ -17,3 +18,13 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
other["admin_info"] = adminInfo other["admin_info"] = adminInfo
return other return other
} }
func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, modelPrice float64) map[string]interface{} {
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
info["ws"] = true
info["audio_input"] = usage.InputTokenDetails.AudioTokens
info["audio_output"] = usage.OutputTokenDetails.AudioTokens
info["text_input"] = usage.InputTokenDetails.TextTokens
info["text_output"] = usage.OutputTokenDetails.TextTokens
return info
}

140
service/quota.go Normal file
View File

@ -0,0 +1,140 @@
package service
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"math"
"one-api/common"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
"strings"
"time"
)
func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
if relayInfo.UsePrice {
return nil
}
userQuota, err := model.GetUserQuota(relayInfo.UserId)
if err != nil {
return err
}
modelName := relayInfo.UpstreamModelName
textInputTokens := usage.InputTokenDetails.TextTokens
textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens
audioOutTokens := usage.OutputTokenDetails.AudioTokens
completionRatio := common.GetCompletionRatio(modelName)
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
groupRatio := common.GetGroupRatio(relayInfo.Group)
modelRatio := common.GetModelRatio(modelName)
ratio := groupRatio * modelRatio
quota := textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio))
quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio))
quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
}
if userQuota < quota {
return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
}
err = model.PostConsumeTokenQuota(relayInfo, 0, quota, 0, false)
if err != nil {
return err
}
common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
err = model.CacheUpdateUserQuota(relayInfo.UserId)
if err != nil {
return err
}
return nil
}
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
usage *dto.RealtimeUsage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64,
groupRatio float64,
modelPrice float64, usePrice bool, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.InputTokenDetails.TextTokens
textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens
audioOutTokens := usage.OutputTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(modelName)
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
quota := 0
if !usePrice {
quota = int(math.Round(float64(textInputTokens)*ratio + float64(textOutTokens)*ratio*completionRatio))
quota += int(math.Round(float64(audioInputTokens)*ratio*audioRatio + float64(audioOutTokens)*ratio*audioRatio*audioCompletionRatio))
if ratio != 0 && quota <= 0 {
quota = 1
}
} else {
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
totalTokens := usage.TotalTokens
var logContent string
if !usePrice {
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
}
// record all the consume log even if quota is 0
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
} else {
//if sensitiveResp != nil {
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
//}
//quotaDelta := quota - preConsumedQuota
//if quotaDelta != 0 {
// err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
// if err != nil {
// common.LogError(ctx, "error consuming token remain quota: "+err.Error())
// }
//}
//err := model.CacheUpdateUserQuota(relayInfo.UserId)
//if err != nil {
// common.LogError(ctx, "error update user quota cache: "+err.Error())
//}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
logModel := modelName
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
logModel = "gpt-4-gizmo-*"
logContent += fmt.Sprintf(",模型 %s", modelName)
}
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
logModel = "gpt-4o-gizmo-*"
logContent += fmt.Sprintf(",模型 %s", modelName)
}
if extraContent != "" {
logContent += ", " + extraContent
}
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, modelPrice)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
}

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
@ -42,11 +43,47 @@ func Done(c *gin.Context) {
_ = StringData(c, "[DONE]") _ = StringData(c, "[DONE]")
} }
func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
if ws == nil {
common.LogError(c, "websocket connection is nil")
return errors.New("websocket connection is nil")
}
//common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
return ws.WriteMessage(1, []byte(str))
}
func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
jsonData, err := json.Marshal(object)
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)
}
if ws == nil {
common.LogError(c, "websocket connection is nil")
return errors.New("websocket connection is nil")
}
//common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
return ws.WriteMessage(1, jsonData)
}
func WssError(c *gin.Context, ws *websocket.Conn, openaiError dto.OpenAIError) {
errorObj := &dto.RealtimeEvent{
Type: "error",
EventId: GetLocalRealtimeID(c),
Error: &openaiError,
}
_ = WssObject(c, ws, errorObj)
}
func GetResponseID(c *gin.Context) string { func GetResponseID(c *gin.Context) string {
logID := c.GetString("X-Oneapi-Request-Id") logID := c.GetString(common.RequestIdKey)
return fmt.Sprintf("chatcmpl-%s", logID) return fmt.Sprintf("chatcmpl-%s", logID)
} }
func GetLocalRealtimeID(c *gin.Context) string {
logID := c.GetString(common.RequestIdKey)
return fmt.Sprintf("evt_%s", logID)
}
func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse { func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{ return &dto.ChatCompletionsStreamResponse{
Id: id, Id: id,

View File

@ -11,6 +11,7 @@ import (
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common"
"strings" "strings"
"unicode/utf8" "unicode/utf8"
) )
@ -191,6 +192,72 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
return tkm, nil return tkm, nil
} }
func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
audioToken := 0
textToken := 0
switch request.Type {
case dto.RealtimeEventTypeSessionUpdate:
if request.Session != nil {
msgTokens, err := CountTextToken(request.Session.Instructions, model)
if err != nil {
return 0, 0, err
}
textToken += msgTokens
}
case dto.RealtimeEventResponseAudioDelta:
// count audio token
atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat)
if err != nil {
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
}
audioToken += atk
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
// count text token
tkm, err := CountTextToken(request.Delta, model)
if err != nil {
return 0, 0, fmt.Errorf("error counting text token: %v", err)
}
textToken += tkm
case dto.RealtimeEventInputAudioBufferAppend:
// count audio token
atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat)
if err != nil {
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
}
audioToken += atk
case dto.RealtimeEventConversationItemCreated:
if request.Item != nil {
switch request.Item.Type {
case "message":
for _, content := range request.Item.Content {
if content.Type == "input_text" {
tokens, err := CountTextToken(content.Text, model)
if err != nil {
return 0, 0, err
}
textToken += tokens
}
}
}
}
case dto.RealtimeEventTypeResponseDone:
// count tools token
if !info.IsFirstRequest {
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
for _, tool := range info.RealtimeTools {
toolTokens, err := CountTokenInput(tool, model)
if err != nil {
return 0, 0, err
}
textToken += 8
textToken += toolTokens
}
}
}
}
return textToken, audioToken, nil
}
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) { func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
//recover when panic //recover when panic
tokenEncoder := getTokenEncoder(model) tokenEncoder := getTokenEncoder(model)
@ -248,13 +315,13 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
func CountTokenInput(input any, model string) (int, error) { func CountTokenInput(input any, model string) (int, error) {
switch v := input.(type) { switch v := input.(type) {
case string: case string:
return CountTokenText(v, model) return CountTextToken(v, model)
case []string: case []string:
text := "" text := ""
for _, s := range v { for _, s := range v {
text += s text += s
} }
return CountTokenText(text, model) return CountTextToken(text, model)
} }
return CountTokenInput(fmt.Sprintf("%v", input), model) return CountTokenInput(fmt.Sprintf("%v", input), model)
} }
@ -276,16 +343,44 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
return tokens return tokens
} }
func CountAudioToken(text string, model string) (int, error) { func CountTTSToken(text string, model string) (int, error) {
if strings.HasPrefix(model, "tts") { if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text), nil return utf8.RuneCountInString(text), nil
} else { } else {
return CountTokenText(text, model) return CountTextToken(text, model)
} }
} }
// CountTokenText 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量 func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
func CountTokenText(text string, model string) (int, error) { if audioBase64 == "" {
return 0, nil
}
duration, err := parseAudio(audioBase64, audioFormat)
if err != nil {
return 0, err
}
return int(duration / 60 * 100 / 0.06), nil
}
func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) {
if audioBase64 == "" {
return 0, nil
}
duration, err := parseAudio(audioBase64, audioFormat)
if err != nil {
return 0, err
}
return int(duration / 60 * 200 / 0.24), nil
}
//func CountAudioToken(sec float64, audioType string) {
// if audioType == "input" {
//
// }
//}
// CountTextToken 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量
func CountTextToken(text string, model string) (int, error) {
var err error var err error
tokenEncoder := getTokenEncoder(model) tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text), err return getTokenNum(tokenEncoder, text), err

View File

@ -19,7 +19,7 @@ import (
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) { func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
usage := &dto.Usage{} usage := &dto.Usage{}
usage.PromptTokens = promptTokens usage.PromptTokens = promptTokens
ctkm, err := CountTokenText(responseText, modeName) ctkm, err := CountTextToken(responseText, modeName)
usage.CompletionTokens = ctkm usage.CompletionTokens = ctkm
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage, err return usage, err

View File

@ -11,7 +11,7 @@ import {
import { import {
Avatar, Avatar,
Button, Button, Descriptions,
Form, Form,
Layout, Layout,
Modal, Modal,
@ -20,7 +20,7 @@ import {
Spin, Spin,
Table, Table,
Tag, Tag,
Tooltip, Tooltip
} from '@douyinfe/semi-ui'; } from '@douyinfe/semi-ui';
import { ITEMS_PER_PAGE } from '../constants'; import { ITEMS_PER_PAGE } from '../constants';
import { import {
@ -336,33 +336,33 @@ const LogsTable = () => {
); );
}, },
}, },
{ // {
title: '重试', // title: '重试',
dataIndex: 'retry', // dataIndex: 'retry',
className: isAdmin() ? 'tableShow' : 'tableHiddle', // className: isAdmin() ? 'tableShow' : 'tableHiddle',
render: (text, record, index) => { // render: (text, record, index) => {
let content = '渠道:' + record.channel; // let content = '渠道:' + record.channel;
if (record.other !== '') { // if (record.other !== '') {
let other = JSON.parse(record.other); // let other = JSON.parse(record.other);
if (other === null) { // if (other === null) {
return <></>; // return <></>;
} // }
if (other.admin_info !== undefined) { // if (other.admin_info !== undefined) {
if ( // if (
other.admin_info.use_channel !== null && // other.admin_info.use_channel !== null &&
other.admin_info.use_channel !== undefined && // other.admin_info.use_channel !== undefined &&
other.admin_info.use_channel !== '' // other.admin_info.use_channel !== ''
) { // ) {
// channel id array // // channel id array
let useChannel = other.admin_info.use_channel; // let useChannel = other.admin_info.use_channel;
let useChannelStr = useChannel.join('->'); // let useChannelStr = useChannel.join('->');
content = `渠道:${useChannelStr}`; // content = `渠道:${useChannelStr}`;
} // }
} // }
} // }
return isAdminUser ? <div>{content}</div> : <></>; // return isAdminUser ? <div>{content}</div> : <></>;
}, // },
}, // },
{ {
title: '详情', title: '详情',
dataIndex: 'content', dataIndex: 'content',
@ -409,6 +409,7 @@ const LogsTable = () => {
]; ];
const [logs, setLogs] = useState([]); const [logs, setLogs] = useState([]);
const [expandData, setExpandData] = useState({});
const [showStat, setShowStat] = useState(false); const [showStat, setShowStat] = useState(false);
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
const [loadingStat, setLoadingStat] = useState(false); const [loadingStat, setLoadingStat] = useState(false);
@ -512,10 +513,54 @@ const LogsTable = () => {
}; };
const setLogsFormat = (logs) => { const setLogsFormat = (logs) => {
let expandDatesLocal = {};
for (let i = 0; i < logs.length; i++) { for (let i = 0; i < logs.length; i++) {
logs[i].timestamp2string = timestamp2string(logs[i].created_at); logs[i].timestamp2string = timestamp2string(logs[i].created_at);
logs[i].key = '' + logs[i].id; logs[i].key = '' + logs[i].id;
let other = getLogOther(logs[i].other);
let expandDataLocal = [];
if (isAdmin()) {
let content = '渠道:' + logs[i].channel;
if (other.admin_info !== undefined) {
if (
other.admin_info.use_channel !== null &&
other.admin_info.use_channel !== undefined &&
other.admin_info.use_channel !== ''
) {
// channel id array
let useChannel = other.admin_info.use_channel;
let useChannelStr = useChannel.join('->');
content = `渠道:${useChannelStr}`;
} }
}
expandDataLocal.push({
key: '重试',
value: content,
})
}
if (other.ws) {
expandDataLocal.push({
key: '语音输入',
value: other.audio_input,
});
expandDataLocal.push({
key: '语音输出',
value: other.audio_output,
});
expandDataLocal.push({
key: '文字输入',
value: other.text_input,
});
expandDataLocal.push({
key: '文字输出',
value: other.text_output,
});
}
expandDatesLocal[logs[i].key] = expandDataLocal;
}
console.log(expandDatesLocal);
setExpandData(expandDatesLocal);
setLogs(logs); setLogs(logs);
}; };
@ -588,6 +633,10 @@ const LogsTable = () => {
handleEyeClick(); handleEyeClick();
}, []); }, []);
const expandRowRender = (record, index) => {
return <Descriptions align="justify" data={expandData[record.key]} />;
};
return ( return (
<> <>
<Layout> <Layout>
@ -686,7 +735,9 @@ const LogsTable = () => {
<Table <Table
style={{ marginTop: 5 }} style={{ marginTop: 5 }}
columns={columns} columns={columns}
expandedRowRender={expandRowRender}
dataSource={logs} dataSource={logs}
rowKey="key"
pagination={{ pagination={{
currentPage: activePage, currentPage: activePage,
pageSize: pageSize, pageSize: pageSize,