mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 16:06:38 +08:00
feat: realtime
(cherry picked from commit a5529df3e1a4c08a120e8c05203a7d885b0fe8d8)
This commit is contained in:
parent
e3c85572d4
commit
33af069fae
@ -421,6 +421,20 @@ func GetCompletionRatio(name string) float64 {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetAudioRatio(name string) float64 {
|
||||||
|
if strings.HasPrefix(name, "gpt-4o-realtime") {
|
||||||
|
return 20
|
||||||
|
}
|
||||||
|
return 20
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAudioCompletionRatio(name string) float64 {
|
||||||
|
if strings.HasPrefix(name, "gpt-4o-realtime") {
|
||||||
|
return 10
|
||||||
|
}
|
||||||
|
return 10
|
||||||
|
}
|
||||||
|
|
||||||
func GetCompletionRatioMap() map[string]float64 {
|
func GetCompletionRatioMap() map[string]float64 {
|
||||||
if CompletionRatio == nil {
|
if CompletionRatio == nil {
|
||||||
CompletionRatio = defaultCompletionRatio
|
CompletionRatio = defaultCompletionRatio
|
||||||
|
@ -102,17 +102,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
var httpResp *http.Response
|
||||||
err := service.RelayErrorHandler(resp)
|
if resp != nil {
|
||||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err
|
httpResp = resp.(*http.Response)
|
||||||
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
|
err := service.RelayErrorHandler(httpResp)
|
||||||
|
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
|
||||||
}
|
}
|
||||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
}
|
||||||
|
usageA, respErr := adaptor.DoResponse(c, httpResp, meta)
|
||||||
if respErr != nil {
|
if respErr != nil {
|
||||||
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
||||||
}
|
}
|
||||||
if usage == nil {
|
if usageA == nil {
|
||||||
return errors.New("usage is nil"), nil
|
return errors.New("usage is nil"), nil
|
||||||
}
|
}
|
||||||
|
usage := usageA.(dto.Usage)
|
||||||
result := w.Result()
|
result := w.Result()
|
||||||
respBody, err := io.ReadAll(result.Body)
|
respBody, err := io.ReadAll(result.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -39,6 +39,15 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||||
|
var err *dto.OpenAIErrorWithStatusCode
|
||||||
|
switch relayMode {
|
||||||
|
default:
|
||||||
|
err = relay.TextHelper(c)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func Playground(c *gin.Context) {
|
func Playground(c *gin.Context) {
|
||||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||||
|
|
||||||
@ -143,12 +152,16 @@ var upgrader = websocket.Upgrader{
|
|||||||
|
|
||||||
func WssRelay(c *gin.Context) {
|
func WssRelay(c *gin.Context) {
|
||||||
// 将 HTTP 连接升级为 WebSocket 连接
|
// 将 HTTP 连接升级为 WebSocket 连接
|
||||||
|
|
||||||
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||||
|
defer ws.Close()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
|
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
|
||||||
service.WssError(c, ws, openaiErr.Error)
|
service.WssError(c, ws, openaiErr.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
@ -164,7 +177,7 @@ func WssRelay(c *gin.Context) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
openaiErr = relayRequest(c, relayMode, channel)
|
openaiErr = wssRequest(c, ws, relayMode, channel)
|
||||||
|
|
||||||
if openaiErr == nil {
|
if openaiErr == nil {
|
||||||
return // 成功处理请求,直接返回
|
return // 成功处理请求,直接返回
|
||||||
@ -198,6 +211,13 @@ func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.Op
|
|||||||
return relayHandler(c, relayMode)
|
return relayHandler(c, relayMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||||
|
addUsedChannel(c, channel.Id)
|
||||||
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
return relay.WssHelper(c, ws)
|
||||||
|
}
|
||||||
|
|
||||||
func addUsedChannel(c *gin.Context, channelId int) {
|
func addUsedChannel(c *gin.Context, channelId int) {
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||||
|
@ -7,6 +7,10 @@ const (
|
|||||||
RealtimeEventTypeResponseCreate = "response.create"
|
RealtimeEventTypeResponseCreate = "response.create"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
RealtimeEventTypeResponseDone = "response.done"
|
||||||
|
)
|
||||||
|
|
||||||
type RealtimeEvent struct {
|
type RealtimeEvent struct {
|
||||||
EventId string `json:"event_id"`
|
EventId string `json:"event_id"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
@ -14,6 +18,30 @@ type RealtimeEvent struct {
|
|||||||
Session *RealtimeSession `json:"session,omitempty"`
|
Session *RealtimeSession `json:"session,omitempty"`
|
||||||
Item *RealtimeItem `json:"item,omitempty"`
|
Item *RealtimeItem `json:"item,omitempty"`
|
||||||
Error *OpenAIError `json:"error,omitempty"`
|
Error *OpenAIError `json:"error,omitempty"`
|
||||||
|
Response *RealtimeResponse `json:"response,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RealtimeResponse struct {
|
||||||
|
Usage *RealtimeUsage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RealtimeUsage struct {
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
InputTokenDetails InputTokenDetails `json:"input_token_details"`
|
||||||
|
OutputTokenDetails OutputTokenDetails `json:"output_token_details"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type InputTokenDetails struct {
|
||||||
|
CachedTokens int `json:"cached_tokens"`
|
||||||
|
TextTokens int `json:"text_tokens"`
|
||||||
|
AudioTokens int `json:"audio_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OutputTokenDetails struct {
|
||||||
|
TextTokens int `json:"text_tokens"`
|
||||||
|
AudioTokens int `json:"audio_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type RealtimeSession struct {
|
type RealtimeSession struct {
|
||||||
@ -27,7 +55,7 @@ type RealtimeSession struct {
|
|||||||
Tools []RealTimeTool `json:"tools"`
|
Tools []RealTimeTool `json:"tools"`
|
||||||
ToolChoice string `json:"tool_choice"`
|
ToolChoice string `json:"tool_choice"`
|
||||||
Temperature float64 `json:"temperature"`
|
Temperature float64 `json:"temperature"`
|
||||||
MaxResponseOutputTokens int `json:"max_response_output_tokens"`
|
//MaxResponseOutputTokens int `json:"max_response_output_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type InputAudioTranscription struct {
|
type InputAudioTranscription struct {
|
||||||
@ -46,7 +74,7 @@ type RealtimeItem struct {
|
|||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content RealtimeContent `json:"content"`
|
Content []RealtimeContent `json:"content"`
|
||||||
Name *string `json:"name,omitempty"`
|
Name *string `json:"name,omitempty"`
|
||||||
ToolCalls any `json:"tool_calls,omitempty"`
|
ToolCalls any `json:"tool_calls,omitempty"`
|
||||||
CallId string `json:"call_id,omitempty"`
|
CallId string `json:"call_id,omitempty"`
|
||||||
|
@ -155,8 +155,27 @@ func RootAuth() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WssAuth(c *gin.Context) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func TokenAuth() func(c *gin.Context) {
|
func TokenAuth() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
|
// 先检测是否为ws
|
||||||
|
if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" {
|
||||||
|
// Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1
|
||||||
|
// read sk from Sec-WebSocket-Protocol
|
||||||
|
key := c.Request.Header.Get("Sec-WebSocket-Protocol")
|
||||||
|
parts := strings.Split(key, ",")
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if strings.HasPrefix(part, "openai-insecure-api-key") {
|
||||||
|
key = strings.TrimPrefix(part, "openai-insecure-api-key.")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Request.Header.Set("Authorization", "Bearer "+key)
|
||||||
|
}
|
||||||
key := c.Request.Header.Get("Authorization")
|
key := c.Request.Header.Get("Authorization")
|
||||||
parts := make([]string, 0)
|
parts := make([]string, 0)
|
||||||
key = strings.TrimPrefix(key, "Bearer ")
|
key = strings.TrimPrefix(key, "Bearer ")
|
||||||
|
@ -12,13 +12,13 @@ type Adaptor interface {
|
|||||||
// Init IsStream bool
|
// Init IsStream bool
|
||||||
Init(info *relaycommon.RelayInfo)
|
Init(info *relaycommon.RelayInfo)
|
||||||
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||||
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
|
||||||
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
||||||
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
||||||
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
|
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
|
||||||
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
|
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
|
||||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
|
||||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
|
||||||
GetModelList() []string
|
GetModelList() []string
|
||||||
GetChannelName() string
|
GetChannelName() string
|
||||||
}
|
}
|
||||||
|
@ -32,14 +32,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
return fullRequestURL, nil
|
return fullRequestURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
req.Header.Set("X-DashScope-SSE", "enable")
|
req.Set("X-DashScope-SSE", "enable")
|
||||||
}
|
}
|
||||||
if c.GetString("plugin") != "" {
|
if c.GetString("plugin") != "" {
|
||||||
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
req.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -72,11 +72,11 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeImagesGenerations:
|
case constant.RelayModeImagesGenerations:
|
||||||
err, usage = aliImageHandler(c, resp, info)
|
err, usage = aliImageHandler(c, resp, info)
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/relay/common"
|
"one-api/relay/common"
|
||||||
@ -11,14 +12,16 @@ import (
|
|||||||
"one-api/service"
|
"one-api/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
|
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
|
||||||
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
||||||
// multipart/form-data
|
// multipart/form-data
|
||||||
|
} else if info.RelayMode == constant.RelayModeRealtime {
|
||||||
|
// websocket
|
||||||
} else {
|
} else {
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
req.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
if info.IsStream && c.Request.Header.Get("Accept") == "" {
|
if info.IsStream && c.Request.Header.Get("Accept") == "" {
|
||||||
req.Header.Set("Accept", "text/event-stream")
|
req.Set("Accept", "text/event-stream")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -32,7 +35,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("new request failed: %w", err)
|
return nil, fmt.Errorf("new request failed: %w", err)
|
||||||
}
|
}
|
||||||
err = a.SetupRequestHeader(c, req, info)
|
err = a.SetupRequestHeader(c, &req.Header, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||||
}
|
}
|
||||||
@ -55,7 +58,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
|
|||||||
// set form data
|
// set form data
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
|
|
||||||
err = a.SetupRequestHeader(c, req, info)
|
err = a.SetupRequestHeader(c, &req.Header, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||||
}
|
}
|
||||||
@ -66,6 +69,27 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*websocket.Conn, error) {
|
||||||
|
fullRequestURL, err := a.GetRequestURL(info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get request url failed: %w", err)
|
||||||
|
}
|
||||||
|
targetHeader := http.Header{}
|
||||||
|
err = a.SetupRequestHeader(c, &targetHeader, info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||||
|
}
|
||||||
|
targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
|
targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dial failed to %s: %w", fullRequestURL, err)
|
||||||
|
}
|
||||||
|
// send request body
|
||||||
|
//all, err := io.ReadAll(requestBody)
|
||||||
|
//err = service.WssString(c, targetConn, string(all))
|
||||||
|
return targetConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
||||||
resp, err := service.GetHttpClient().Do(req)
|
resp, err := service.GetHttpClient().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -37,7 +37,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,11 +59,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
||||||
} else {
|
} else {
|
||||||
|
@ -98,9 +98,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
return fullRequestURL, nil
|
return fullRequestURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,11 +122,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = baiduStreamHandler(c, resp)
|
err, usage = baiduStreamHandler(c, resp)
|
||||||
} else {
|
} else {
|
||||||
|
@ -47,14 +47,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("x-api-key", info.ApiKey)
|
req.Set("x-api-key", info.ApiKey)
|
||||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||||
if anthropicVersion == "" {
|
if anthropicVersion == "" {
|
||||||
anthropicVersion = "2023-06-01"
|
anthropicVersion = "2023-06-01"
|
||||||
}
|
}
|
||||||
req.Header.Set("anthropic-version", anthropicVersion)
|
req.Set("anthropic-version", anthropicVersion)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -73,11 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
||||||
} else {
|
} else {
|
||||||
|
@ -30,9 +30,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
fallthrough
|
fallthrough
|
||||||
|
@ -36,9 +36,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
|||||||
return requestOpenAI2Cohere(*request), nil
|
return requestOpenAI2Cohere(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return requestConvertRerank2Cohere(request), nil
|
return requestConvertRerank2Cohere(request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
err, usage = cohereRerankHandler(c, resp, info)
|
err, usage = cohereRerankHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
|
@ -31,9 +31,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,11 +48,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = difyStreamHandler(c, resp, info)
|
err, usage = difyStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
|
@ -47,9 +47,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("x-goog-api-key", info.ApiKey)
|
req.Set("x-goog-api-key", info.ApiKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,11 +64,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = GeminiChatStreamHandler(c, resp, info)
|
err, usage = GeminiChatStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
|
@ -37,9 +37,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
return "", errors.New("invalid relay mode")
|
return "", errors.New("invalid relay mode")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,7 +55,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
err, usage = jinaRerankHandler(c, resp)
|
err, usage = jinaRerankHandler(c, resp)
|
||||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||||
|
@ -37,7 +37,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -58,11 +58,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
|
@ -31,6 +31,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
if info.RelayMode == constant.RelayModeRealtime {
|
||||||
|
// trim https
|
||||||
|
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
|
||||||
|
baseUrl = strings.TrimPrefix(baseUrl, "http://")
|
||||||
|
baseUrl = "wss://" + baseUrl
|
||||||
|
info.BaseUrl = baseUrl
|
||||||
|
}
|
||||||
switch info.ChannelType {
|
switch info.ChannelType {
|
||||||
case common.ChannelTypeAzure:
|
case common.ChannelTypeAzure:
|
||||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||||
@ -54,16 +61,19 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
if info.ChannelType == common.ChannelTypeAzure {
|
if info.ChannelType == common.ChannelTypeAzure {
|
||||||
req.Header.Set("api-key", info.ApiKey)
|
req.Set("api-key", info.ApiKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
|
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
|
||||||
req.Header.Set("OpenAI-Organization", info.Organization)
|
req.Set("OpenAI-Organization", info.Organization)
|
||||||
|
}
|
||||||
|
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
|
if info.RelayMode == constant.RelayModeRealtime {
|
||||||
|
req.Set("openai-beta", "realtime=v1")
|
||||||
}
|
}
|
||||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
|
||||||
//if info.ChannelType == common.ChannelTypeOpenRouter {
|
//if info.ChannelType == common.ChannelTypeOpenRouter {
|
||||||
// req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
// req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||||||
// req.Header.Set("X-Title", "One API")
|
// req.Header.Set("X-Title", "One API")
|
||||||
@ -131,16 +141,20 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
||||||
return channel.DoFormRequest(a, c, info, requestBody)
|
return channel.DoFormRequest(a, c, info, requestBody)
|
||||||
|
} else if info.RelayMode == constant.RelayModeRealtime {
|
||||||
|
return channel.DoWssRequest(a, c, info, requestBody)
|
||||||
} else {
|
} else {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
|
case constant.RelayModeRealtime:
|
||||||
|
err, usage = OpenaiRealtimeHandler(c, info)
|
||||||
case constant.RelayModeAudioSpeech:
|
case constant.RelayModeAudioSpeech:
|
||||||
err, usage = OpenaiTTSHandler(c, resp, info)
|
err, usage = OpenaiTTSHandler(c, resp, info)
|
||||||
case constant.RelayModeAudioTranslation:
|
case constant.RelayModeAudioTranslation:
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/bytedance/gopkg/util/gopool"
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@ -373,3 +374,106 @@ func getTextFromJSON(body []byte) (string, error) {
|
|||||||
}
|
}
|
||||||
return whisperResponse.Text, nil
|
return whisperResponse.Text, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
|
||||||
|
info.IsStream = true
|
||||||
|
clientConn := info.ClientWs
|
||||||
|
targetConn := info.TargetWs
|
||||||
|
|
||||||
|
clientClosed := make(chan struct{})
|
||||||
|
targetClosed := make(chan struct{})
|
||||||
|
sendChan := make(chan []byte, 100)
|
||||||
|
receiveChan := make(chan []byte, 100)
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
|
||||||
|
usage := &dto.RealtimeUsage{}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
_, message, err := clientConn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||||
|
errChan <- fmt.Errorf("error reading from client: %v", err)
|
||||||
|
}
|
||||||
|
close(clientClosed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = service.WssString(c, targetConn, string(message))
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("error writing to target: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case sendChan <- message:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
_, message, err := targetConn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||||
|
errChan <- fmt.Errorf("error reading from target: %v", err)
|
||||||
|
}
|
||||||
|
close(targetClosed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
info.SetFirstResponseTime()
|
||||||
|
realtimeEvent := &dto.RealtimeEvent{}
|
||||||
|
err = json.Unmarshal(message, realtimeEvent)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
|
||||||
|
realtimeUsage := realtimeEvent.Response.Usage
|
||||||
|
if realtimeUsage != nil {
|
||||||
|
usage.TotalTokens += realtimeUsage.TotalTokens
|
||||||
|
usage.InputTokens += realtimeUsage.InputTokens
|
||||||
|
usage.OutputTokens += realtimeUsage.OutputTokens
|
||||||
|
usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
|
||||||
|
usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
|
||||||
|
usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
|
||||||
|
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
|
||||||
|
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = service.WssString(c, clientConn, string(message))
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("error writing to client: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case receiveChan <- message:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-clientClosed:
|
||||||
|
case <-targetClosed:
|
||||||
|
case <-errChan:
|
||||||
|
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
|
||||||
|
case <-c.Done():
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
|
@ -32,9 +32,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("x-goog-api-key", info.ApiKey)
|
req.Set("x-goog-api-key", info.ApiKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -49,11 +49,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = palmStreamHandler(c, resp)
|
err, responseText = palmStreamHandler(c, resp)
|
||||||
|
@ -32,9 +32,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,11 +52,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
|
@ -40,9 +40,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
return "", errors.New("invalid relay mode")
|
return "", errors.New("invalid relay mode")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeRerank:
|
case constant.RelayModeRerank:
|
||||||
err, usage = siliconflowRerankHandler(c, resp)
|
err, usage = siliconflowRerankHandler(c, resp)
|
||||||
|
@ -43,12 +43,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
return fmt.Sprintf("%s/", info.BaseUrl), nil
|
return fmt.Sprintf("%s/", info.BaseUrl), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("Authorization", a.Sign)
|
req.Set("Authorization", a.Sign)
|
||||||
req.Header.Set("X-TC-Action", a.Action)
|
req.Set("X-TC-Action", a.Action)
|
||||||
req.Header.Set("X-TC-Version", a.Version)
|
req.Set("X-TC-Version", a.Version)
|
||||||
req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
|
req.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -73,11 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = tencentStreamHandler(c, resp)
|
err, responseText = tencentStreamHandler(c, resp)
|
||||||
|
@ -107,13 +107,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
return "", errors.New("unsupported request mode")
|
return "", errors.New("unsupported request mode")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
accessToken, err := getAccessToken(a, info)
|
accessToken, err := getAccessToken(a, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
req.Set("Authorization", "Bearer "+accessToken)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,11 +148,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
switch a.RequestMode {
|
switch a.RequestMode {
|
||||||
case RequestModeClaude:
|
case RequestModeClaude:
|
||||||
|
@ -33,7 +33,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -50,14 +50,14 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
// xunfei's request is not http request, so we don't need to do anything here
|
// xunfei's request is not http request, so we don't need to do anything here
|
||||||
dummyResp := &http.Response{}
|
dummyResp := &http.Response{}
|
||||||
dummyResp.StatusCode = http.StatusOK
|
dummyResp.StatusCode = http.StatusOK
|
||||||
return dummyResp, nil
|
return dummyResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
splits := strings.Split(info.ApiKey, "|")
|
splits := strings.Split(info.ApiKey, "|")
|
||||||
if len(splits) != 3 {
|
if len(splits) != 3 {
|
||||||
return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||||
|
@ -35,10 +35,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
|
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
token := getZhipuToken(info.ApiKey)
|
token := getZhipuToken(info.ApiKey)
|
||||||
req.Header.Set("Authorization", token)
|
req.Set("Authorization", token)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,11 +56,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = zhipuStreamHandler(c, resp)
|
err, usage = zhipuStreamHandler(c, resp)
|
||||||
} else {
|
} else {
|
||||||
|
@ -32,10 +32,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
token := getZhipuToken(info.ApiKey)
|
token := getZhipuToken(info.ApiKey)
|
||||||
req.Header.Set("Authorization", token)
|
req.Set("Authorization", token)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -53,11 +53,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
|
@ -2,6 +2,7 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
"strings"
|
"strings"
|
||||||
@ -32,6 +33,14 @@ type RelayInfo struct {
|
|||||||
BaseUrl string
|
BaseUrl string
|
||||||
SupportStreamOptions bool
|
SupportStreamOptions bool
|
||||||
ShouldIncludeUsage bool
|
ShouldIncludeUsage bool
|
||||||
|
ClientWs *websocket.Conn
|
||||||
|
TargetWs *websocket.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
||||||
|
info := GenRelayInfo(c)
|
||||||
|
info.ClientWs = ws
|
||||||
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||||
|
@ -122,19 +122,21 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
|
|
||||||
|
var httpResp *http.Response
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
if resp.StatusCode != http.StatusOK {
|
httpResp = resp.(*http.Response)
|
||||||
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
openaiErr := service.RelayErrorHandler(resp)
|
openaiErr := service.RelayErrorHandler(httpResp)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
return openaiErr
|
return openaiErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||||
if openaiErr != nil {
|
if openaiErr != nil {
|
||||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
@ -142,7 +144,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
return openaiErr
|
return openaiErr
|
||||||
}
|
}
|
||||||
|
|
||||||
postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
|
postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -149,22 +149,24 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|||||||
requestBody = bytes.NewBuffer(jsonData)
|
requestBody = bytes.NewBuffer(jsonData)
|
||||||
|
|
||||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
|
|
||||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
var httpResp *http.Response
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
httpResp = resp.(*http.Response)
|
||||||
if resp.StatusCode != http.StatusOK {
|
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
openaiErr := service.RelayErrorHandler(resp)
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
|
openaiErr := service.RelayErrorHandler(httpResp)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
return openaiErr
|
return openaiErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
_, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||||
if openaiErr != nil {
|
if openaiErr != nil {
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
|
@ -180,30 +180,32 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
}
|
}
|
||||||
|
|
||||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
|
var httpResp *http.Response
|
||||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
httpResp = resp.(*http.Response)
|
||||||
if resp.StatusCode != http.StatusOK {
|
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
openaiErr := service.RelayErrorHandler(resp)
|
openaiErr := service.RelayErrorHandler(httpResp)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
return openaiErr
|
return openaiErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||||
if openaiErr != nil {
|
if openaiErr != nil {
|
||||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
return openaiErr
|
return openaiErr
|
||||||
}
|
}
|
||||||
postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
postConsumeQuota(c, relayInfo, textRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,23 +99,26 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var httpResp *http.Response
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
if resp.StatusCode != http.StatusOK {
|
httpResp = resp.(*http.Response)
|
||||||
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
openaiErr := service.RelayErrorHandler(resp)
|
openaiErr := service.RelayErrorHandler(httpResp)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
return openaiErr
|
return openaiErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||||
if openaiErr != nil {
|
if openaiErr != nil {
|
||||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
return openaiErr
|
return openaiErr
|
||||||
}
|
}
|
||||||
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
|
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
242
relay/websocket.go
Normal file
242
relay/websocket.go
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/model"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
|
||||||
|
// _, p, err := ws.ReadMessage()
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
// realtimeEvent := &dto.RealtimeEvent{}
|
||||||
|
// err = json.Unmarshal(p, realtimeEvent)
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
// // save the original request
|
||||||
|
// if realtimeEvent.Session == nil {
|
||||||
|
// return nil, errors.New("session object is nil")
|
||||||
|
// }
|
||||||
|
// c.Set("first_wss_request", p)
|
||||||
|
// return realtimeEvent, nil
|
||||||
|
//}
|
||||||
|
|
||||||
|
func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCode {
|
||||||
|
relayInfo := relaycommon.GenRelayInfoWs(c, ws)
|
||||||
|
|
||||||
|
// get & validate textRequest 获取并验证文本请求
|
||||||
|
//realtimeEvent, err := getAndValidateWssRequest(c, ws)
|
||||||
|
//if err != nil {
|
||||||
|
// common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error()))
|
||||||
|
// return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||||
|
//}
|
||||||
|
|
||||||
|
// map model name
|
||||||
|
modelMapping := c.GetString("model_mapping")
|
||||||
|
//isModelMapped := false
|
||||||
|
if modelMapping != "" && modelMapping != "{}" {
|
||||||
|
modelMap := make(map[string]string)
|
||||||
|
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if modelMap[relayInfo.OriginModelName] != "" {
|
||||||
|
relayInfo.UpstreamModelName = modelMap[relayInfo.OriginModelName]
|
||||||
|
// set upstream model name
|
||||||
|
//isModelMapped = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//relayInfo.UpstreamModelName = textRequest.Model
|
||||||
|
modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false)
|
||||||
|
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||||
|
|
||||||
|
var preConsumedQuota int
|
||||||
|
var ratio float64
|
||||||
|
var modelRatio float64
|
||||||
|
//err := service.SensitiveWordsCheck(textRequest)
|
||||||
|
|
||||||
|
//if constant.ShouldCheckPromptSensitive() {
|
||||||
|
// err = checkRequestSensitive(textRequest, relayInfo)
|
||||||
|
// if err != nil {
|
||||||
|
// return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
|
||||||
|
//promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
|
||||||
|
//// count messages token error 计算promptTokens错误
|
||||||
|
//if err != nil {
|
||||||
|
// return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
if !getModelPriceSuccess {
|
||||||
|
preConsumedTokens := common.PreConsumedQuota
|
||||||
|
//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
|
||||||
|
// preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
|
||||||
|
//}
|
||||||
|
modelRatio = common.GetModelRatio(relayInfo.UpstreamModelName)
|
||||||
|
ratio = modelRatio * groupRatio
|
||||||
|
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||||
|
} else {
|
||||||
|
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||||
|
}
|
||||||
|
|
||||||
|
// pre-consume quota 预消耗配额
|
||||||
|
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||||
|
if openaiErr != nil {
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
|
||||||
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
|
if adaptor == nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
adaptor.Init(relayInfo)
|
||||||
|
//var requestBody io.Reader
|
||||||
|
//firstWssRequest, _ := c.Get("first_wss_request")
|
||||||
|
//requestBody = bytes.NewBuffer(firstWssRequest.([]byte))
|
||||||
|
|
||||||
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
|
resp, err := adaptor.DoRequest(c, relayInfo, nil)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != nil {
|
||||||
|
relayInfo.TargetWs = resp.(*websocket.Conn)
|
||||||
|
defer relayInfo.TargetWs.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo)
|
||||||
|
if openaiErr != nil {
|
||||||
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
|
// reset status code 重置状态码
|
||||||
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
postWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func postWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
||||||
|
usage *dto.RealtimeUsage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64,
|
||||||
|
groupRatio float64,
|
||||||
|
modelPrice float64, usePrice bool, extraContent string) {
|
||||||
|
|
||||||
|
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||||
|
textInputTokens := usage.InputTokenDetails.TextTokens
|
||||||
|
textOutTokens := usage.OutputTokenDetails.TextTokens
|
||||||
|
|
||||||
|
audioInputTokens := usage.InputTokenDetails.AudioTokens
|
||||||
|
audioOutTokens := usage.OutputTokenDetails.AudioTokens
|
||||||
|
|
||||||
|
tokenName := ctx.GetString("token_name")
|
||||||
|
completionRatio := common.GetCompletionRatio(modelName)
|
||||||
|
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
|
||||||
|
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
|
||||||
|
|
||||||
|
quota := 0
|
||||||
|
if !usePrice {
|
||||||
|
quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio))
|
||||||
|
quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*completionRatio*audioCompletionRatio))
|
||||||
|
|
||||||
|
quota = int(math.Round(float64(quota) * ratio))
|
||||||
|
if ratio != 0 && quota <= 0 {
|
||||||
|
quota = 1
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||||
|
}
|
||||||
|
totalTokens := usage.TotalTokens
|
||||||
|
var logContent string
|
||||||
|
if !usePrice {
|
||||||
|
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
|
||||||
|
} else {
|
||||||
|
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
|
||||||
|
}
|
||||||
|
|
||||||
|
// record all the consume log even if quota is 0
|
||||||
|
if totalTokens == 0 {
|
||||||
|
// in this case, must be some error happened
|
||||||
|
// we cannot just return, because we may have to return the pre-consumed quota
|
||||||
|
quota = 0
|
||||||
|
logContent += fmt.Sprintf("(可能是上游超时)")
|
||||||
|
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||||
|
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
|
||||||
|
} else {
|
||||||
|
//if sensitiveResp != nil {
|
||||||
|
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
|
||||||
|
//}
|
||||||
|
quotaDelta := quota - preConsumedQuota
|
||||||
|
if quotaDelta != 0 {
|
||||||
|
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err := model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||||
|
}
|
||||||
|
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||||
|
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
logModel := modelName
|
||||||
|
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||||
|
logModel = "gpt-4-gizmo-*"
|
||||||
|
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
|
||||||
|
logModel = "gpt-4o-gizmo-*"
|
||||||
|
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||||
|
}
|
||||||
|
if extraContent != "" {
|
||||||
|
logContent += ", " + extraContent
|
||||||
|
}
|
||||||
|
other := service.GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, modelPrice)
|
||||||
|
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
|
||||||
|
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||||
|
|
||||||
|
//if quota != 0 {
|
||||||
|
//
|
||||||
|
//}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) {
|
||||||
|
var promptTokens int
|
||||||
|
var err error
|
||||||
|
switch info.RelayMode {
|
||||||
|
default:
|
||||||
|
promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName)
|
||||||
|
}
|
||||||
|
info.PromptTokens = promptTokens
|
||||||
|
return promptTokens, err
|
||||||
|
}
|
||||||
|
|
||||||
|
//func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
|
||||||
|
// var err error
|
||||||
|
// switch info.RelayMode {
|
||||||
|
// case relayconstant.RelayModeChatCompletions:
|
||||||
|
// err = service.CheckSensitiveMessages(textRequest.Messages)
|
||||||
|
// case relayconstant.RelayModeCompletions:
|
||||||
|
// err = service.CheckSensitiveInput(textRequest.Prompt)
|
||||||
|
// case relayconstant.RelayModeModerations:
|
||||||
|
// err = service.CheckSensitiveInput(textRequest.Input)
|
||||||
|
// case relayconstant.RelayModeEmbeddings:
|
||||||
|
// err = service.CheckSensitiveInput(textRequest.Input)
|
||||||
|
// }
|
||||||
|
// return err
|
||||||
|
//}
|
@ -2,6 +2,7 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -17,3 +18,13 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
|
|||||||
other["admin_info"] = adminInfo
|
other["admin_info"] = adminInfo
|
||||||
return other
|
return other
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, modelPrice float64) map[string]interface{} {
|
||||||
|
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
|
||||||
|
info["ws"] = true
|
||||||
|
info["audio_input"] = usage.InputTokenDetails.AudioTokens
|
||||||
|
info["audio_output"] = usage.OutputTokenDetails.AudioTokens
|
||||||
|
info["text_input"] = usage.InputTokenDetails.TextTokens
|
||||||
|
info["text_output"] = usage.OutputTokenDetails.TextTokens
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
@ -43,11 +43,25 @@ func Done(c *gin.Context) {
|
|||||||
_ = StringData(c, "[DONE]")
|
_ = StringData(c, "[DONE]")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
|
||||||
|
if ws == nil {
|
||||||
|
common.LogError(c, "websocket connection is nil")
|
||||||
|
return errors.New("websocket connection is nil")
|
||||||
|
}
|
||||||
|
common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
|
||||||
|
return ws.WriteMessage(1, []byte(str))
|
||||||
|
}
|
||||||
|
|
||||||
func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
|
func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
|
||||||
jsonData, err := json.Marshal(object)
|
jsonData, err := json.Marshal(object)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error marshalling object: %w", err)
|
return fmt.Errorf("error marshalling object: %w", err)
|
||||||
}
|
}
|
||||||
|
if ws == nil {
|
||||||
|
common.LogError(c, "websocket connection is nil")
|
||||||
|
return errors.New("websocket connection is nil")
|
||||||
|
}
|
||||||
|
common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
|
||||||
return ws.WriteMessage(1, jsonData)
|
return ws.WriteMessage(1, jsonData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -191,6 +191,45 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
|
|||||||
return tkm, nil
|
return tkm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CountTokenRealtime(request dto.RealtimeEvent, model string) (int, error) {
|
||||||
|
tkm := 0
|
||||||
|
ratio := 1
|
||||||
|
if request.Session != nil {
|
||||||
|
msgTokens, err := CountTokenText(request.Session.Instructions, model)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
ratio = len(request.Session.Modalities)
|
||||||
|
tkm += msgTokens
|
||||||
|
if request.Session.Tools != nil {
|
||||||
|
toolsData, _ := json.Marshal(request.Session.Tools)
|
||||||
|
var openaiTools []dto.OpenAITools
|
||||||
|
err := json.Unmarshal(toolsData, &openaiTools)
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
|
||||||
|
}
|
||||||
|
countStr := ""
|
||||||
|
for _, tool := range openaiTools {
|
||||||
|
countStr = tool.Function.Name
|
||||||
|
if tool.Function.Description != "" {
|
||||||
|
countStr += tool.Function.Description
|
||||||
|
}
|
||||||
|
if tool.Function.Parameters != nil {
|
||||||
|
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
toolTokens, err := CountTokenInput(countStr, model)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
tkm += 8
|
||||||
|
tkm += toolTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tkm *= ratio
|
||||||
|
return tkm, nil
|
||||||
|
}
|
||||||
|
|
||||||
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
|
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
|
||||||
//recover when panic
|
//recover when panic
|
||||||
tokenEncoder := getTokenEncoder(model)
|
tokenEncoder := getTokenEncoder(model)
|
||||||
|
@ -11,7 +11,7 @@ import {
|
|||||||
|
|
||||||
import {
|
import {
|
||||||
Avatar,
|
Avatar,
|
||||||
Button,
|
Button, Descriptions,
|
||||||
Form,
|
Form,
|
||||||
Layout,
|
Layout,
|
||||||
Modal,
|
Modal,
|
||||||
@ -20,7 +20,7 @@ import {
|
|||||||
Spin,
|
Spin,
|
||||||
Table,
|
Table,
|
||||||
Tag,
|
Tag,
|
||||||
Tooltip,
|
Tooltip
|
||||||
} from '@douyinfe/semi-ui';
|
} from '@douyinfe/semi-ui';
|
||||||
import { ITEMS_PER_PAGE } from '../constants';
|
import { ITEMS_PER_PAGE } from '../constants';
|
||||||
import {
|
import {
|
||||||
@ -336,33 +336,33 @@ const LogsTable = () => {
|
|||||||
);
|
);
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
// {
|
||||||
title: '重试',
|
// title: '重试',
|
||||||
dataIndex: 'retry',
|
// dataIndex: 'retry',
|
||||||
className: isAdmin() ? 'tableShow' : 'tableHiddle',
|
// className: isAdmin() ? 'tableShow' : 'tableHiddle',
|
||||||
render: (text, record, index) => {
|
// render: (text, record, index) => {
|
||||||
let content = '渠道:' + record.channel;
|
// let content = '渠道:' + record.channel;
|
||||||
if (record.other !== '') {
|
// if (record.other !== '') {
|
||||||
let other = JSON.parse(record.other);
|
// let other = JSON.parse(record.other);
|
||||||
if (other === null) {
|
// if (other === null) {
|
||||||
return <></>;
|
// return <></>;
|
||||||
}
|
// }
|
||||||
if (other.admin_info !== undefined) {
|
// if (other.admin_info !== undefined) {
|
||||||
if (
|
// if (
|
||||||
other.admin_info.use_channel !== null &&
|
// other.admin_info.use_channel !== null &&
|
||||||
other.admin_info.use_channel !== undefined &&
|
// other.admin_info.use_channel !== undefined &&
|
||||||
other.admin_info.use_channel !== ''
|
// other.admin_info.use_channel !== ''
|
||||||
) {
|
// ) {
|
||||||
// channel id array
|
// // channel id array
|
||||||
let useChannel = other.admin_info.use_channel;
|
// let useChannel = other.admin_info.use_channel;
|
||||||
let useChannelStr = useChannel.join('->');
|
// let useChannelStr = useChannel.join('->');
|
||||||
content = `渠道:${useChannelStr}`;
|
// content = `渠道:${useChannelStr}`;
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
return isAdminUser ? <div>{content}</div> : <></>;
|
// return isAdminUser ? <div>{content}</div> : <></>;
|
||||||
},
|
// },
|
||||||
},
|
// },
|
||||||
{
|
{
|
||||||
title: '详情',
|
title: '详情',
|
||||||
dataIndex: 'content',
|
dataIndex: 'content',
|
||||||
@ -409,6 +409,7 @@ const LogsTable = () => {
|
|||||||
];
|
];
|
||||||
|
|
||||||
const [logs, setLogs] = useState([]);
|
const [logs, setLogs] = useState([]);
|
||||||
|
const [expandData, setExpandData] = useState({});
|
||||||
const [showStat, setShowStat] = useState(false);
|
const [showStat, setShowStat] = useState(false);
|
||||||
const [loading, setLoading] = useState(false);
|
const [loading, setLoading] = useState(false);
|
||||||
const [loadingStat, setLoadingStat] = useState(false);
|
const [loadingStat, setLoadingStat] = useState(false);
|
||||||
@ -512,10 +513,54 @@ const LogsTable = () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const setLogsFormat = (logs) => {
|
const setLogsFormat = (logs) => {
|
||||||
|
let expandDatesLocal = {};
|
||||||
for (let i = 0; i < logs.length; i++) {
|
for (let i = 0; i < logs.length; i++) {
|
||||||
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
|
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
|
||||||
logs[i].key = '' + logs[i].id;
|
logs[i].key = '' + logs[i].id;
|
||||||
|
let other = getLogOther(logs[i].other);
|
||||||
|
let expandDataLocal = [];
|
||||||
|
if (isAdmin()) {
|
||||||
|
let content = '渠道:' + logs[i].channel;
|
||||||
|
if (other.admin_info !== undefined) {
|
||||||
|
if (
|
||||||
|
other.admin_info.use_channel !== null &&
|
||||||
|
other.admin_info.use_channel !== undefined &&
|
||||||
|
other.admin_info.use_channel !== ''
|
||||||
|
) {
|
||||||
|
// channel id array
|
||||||
|
let useChannel = other.admin_info.use_channel;
|
||||||
|
let useChannelStr = useChannel.join('->');
|
||||||
|
content = `渠道:${useChannelStr}`;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
expandDataLocal.push({
|
||||||
|
key: '重试',
|
||||||
|
value: content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if (other.ws) {
|
||||||
|
expandDataLocal.push({
|
||||||
|
key: '语音输入',
|
||||||
|
value: other.audio_input,
|
||||||
|
});
|
||||||
|
expandDataLocal.push({
|
||||||
|
key: '语音输出',
|
||||||
|
value: other.audio_output,
|
||||||
|
});
|
||||||
|
expandDataLocal.push({
|
||||||
|
key: '文字输入',
|
||||||
|
value: other.text_input,
|
||||||
|
});
|
||||||
|
expandDataLocal.push({
|
||||||
|
key: '文字输出',
|
||||||
|
value: other.text_output,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
expandDatesLocal[logs[i].key] = expandDataLocal;
|
||||||
|
}
|
||||||
|
console.log(expandDatesLocal);
|
||||||
|
setExpandData(expandDatesLocal);
|
||||||
|
|
||||||
setLogs(logs);
|
setLogs(logs);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -588,6 +633,10 @@ const LogsTable = () => {
|
|||||||
handleEyeClick();
|
handleEyeClick();
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
const expandRowRender = (record, index) => {
|
||||||
|
return <Descriptions align="justify" data={expandData[record.key]} />;
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Layout>
|
<Layout>
|
||||||
@ -686,7 +735,9 @@ const LogsTable = () => {
|
|||||||
<Table
|
<Table
|
||||||
style={{ marginTop: 5 }}
|
style={{ marginTop: 5 }}
|
||||||
columns={columns}
|
columns={columns}
|
||||||
|
expandedRowRender={expandRowRender}
|
||||||
dataSource={logs}
|
dataSource={logs}
|
||||||
|
rowKey="key"
|
||||||
pagination={{
|
pagination={{
|
||||||
currentPage: activePage,
|
currentPage: activePage,
|
||||||
pageSize: pageSize,
|
pageSize: pageSize,
|
||||||
|
Loading…
Reference in New Issue
Block a user