mirror of
				https://github.com/linux-do/new-api.git
				synced 2025-11-04 13:23:42 +08:00 
			
		
		
		
	feat: realtime
(cherry picked from commit a5529df3e1a4c08a120e8c05203a7d885b0fe8d8)
This commit is contained in:
		
				
					committed by
					
						
						CalciumIon
					
				
			
			
				
	
			
			
			
						parent
						
							e3c85572d4
						
					
				
				
					commit
					33af069fae
				
			@@ -421,6 +421,20 @@ func GetCompletionRatio(name string) float64 {
 | 
			
		||||
	return 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAudioRatio(name string) float64 {
 | 
			
		||||
	if strings.HasPrefix(name, "gpt-4o-realtime") {
 | 
			
		||||
		return 20
 | 
			
		||||
	}
 | 
			
		||||
	return 20
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAudioCompletionRatio(name string) float64 {
 | 
			
		||||
	if strings.HasPrefix(name, "gpt-4o-realtime") {
 | 
			
		||||
		return 10
 | 
			
		||||
	}
 | 
			
		||||
	return 10
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetCompletionRatioMap() map[string]float64 {
 | 
			
		||||
	if CompletionRatio == nil {
 | 
			
		||||
		CompletionRatio = defaultCompletionRatio
 | 
			
		||||
 
 | 
			
		||||
@@ -102,17 +102,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err, nil
 | 
			
		||||
	}
 | 
			
		||||
	if resp != nil && resp.StatusCode != http.StatusOK {
 | 
			
		||||
		err := service.RelayErrorHandler(resp)
 | 
			
		||||
		return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err
 | 
			
		||||
	var httpResp *http.Response
 | 
			
		||||
	if resp != nil {
 | 
			
		||||
		httpResp = resp.(*http.Response)
 | 
			
		||||
		if httpResp.StatusCode != http.StatusOK {
 | 
			
		||||
			err := service.RelayErrorHandler(httpResp)
 | 
			
		||||
			return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	usage, respErr := adaptor.DoResponse(c, resp, meta)
 | 
			
		||||
	usageA, respErr := adaptor.DoResponse(c, httpResp, meta)
 | 
			
		||||
	if respErr != nil {
 | 
			
		||||
		return fmt.Errorf("%s", respErr.Error.Message), respErr
 | 
			
		||||
	}
 | 
			
		||||
	if usage == nil {
 | 
			
		||||
	if usageA == nil {
 | 
			
		||||
		return errors.New("usage is nil"), nil
 | 
			
		||||
	}
 | 
			
		||||
	usage := usageA.(dto.Usage)
 | 
			
		||||
	result := w.Result()
 | 
			
		||||
	respBody, err := io.ReadAll(result.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -39,6 +39,15 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
	var err *dto.OpenAIErrorWithStatusCode
 | 
			
		||||
	switch relayMode {
 | 
			
		||||
	default:
 | 
			
		||||
		err = relay.TextHelper(c)
 | 
			
		||||
	}
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Playground(c *gin.Context) {
 | 
			
		||||
	var openaiErr *dto.OpenAIErrorWithStatusCode
 | 
			
		||||
 | 
			
		||||
@@ -143,12 +152,16 @@ var upgrader = websocket.Upgrader{
 | 
			
		||||
 | 
			
		||||
func WssRelay(c *gin.Context) {
 | 
			
		||||
	// 将 HTTP 连接升级为 WebSocket 连接
 | 
			
		||||
 | 
			
		||||
	ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
 | 
			
		||||
	defer ws.Close()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
 | 
			
		||||
		service.WssError(c, ws, openaiErr.Error)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	relayMode := constant.Path2RelayMode(c.Request.URL.Path)
 | 
			
		||||
	requestId := c.GetString(common.RequestIdKey)
 | 
			
		||||
	group := c.GetString("group")
 | 
			
		||||
@@ -164,7 +177,7 @@ func WssRelay(c *gin.Context) {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		openaiErr = relayRequest(c, relayMode, channel)
 | 
			
		||||
		openaiErr = wssRequest(c, ws, relayMode, channel)
 | 
			
		||||
 | 
			
		||||
		if openaiErr == nil {
 | 
			
		||||
			return // 成功处理请求,直接返回
 | 
			
		||||
@@ -198,6 +211,13 @@ func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.Op
 | 
			
		||||
	return relayHandler(c, relayMode)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
	addUsedChannel(c, channel.Id)
 | 
			
		||||
	requestBody, _ := common.GetRequestBody(c)
 | 
			
		||||
	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
	return relay.WssHelper(c, ws)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func addUsedChannel(c *gin.Context, channelId int) {
 | 
			
		||||
	useChannel := c.GetStringSlice("use_channel")
 | 
			
		||||
	useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
 | 
			
		||||
 
 | 
			
		||||
@@ -7,13 +7,41 @@ const (
 | 
			
		||||
	RealtimeEventTypeResponseCreate     = "response.create"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RealtimeEventTypeResponseDone = "response.done"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type RealtimeEvent struct {
 | 
			
		||||
	EventId string `json:"event_id"`
 | 
			
		||||
	Type    string `json:"type"`
 | 
			
		||||
	//PreviousItemId string `json:"previous_item_id"`
 | 
			
		||||
	Session *RealtimeSession `json:"session,omitempty"`
 | 
			
		||||
	Item    *RealtimeItem    `json:"item,omitempty"`
 | 
			
		||||
	Error   *OpenAIError     `json:"error,omitempty"`
 | 
			
		||||
	Session  *RealtimeSession  `json:"session,omitempty"`
 | 
			
		||||
	Item     *RealtimeItem     `json:"item,omitempty"`
 | 
			
		||||
	Error    *OpenAIError      `json:"error,omitempty"`
 | 
			
		||||
	Response *RealtimeResponse `json:"response,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RealtimeResponse struct {
 | 
			
		||||
	Usage *RealtimeUsage `json:"usage"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RealtimeUsage struct {
 | 
			
		||||
	TotalTokens        int                `json:"total_tokens"`
 | 
			
		||||
	InputTokens        int                `json:"input_tokens"`
 | 
			
		||||
	OutputTokens       int                `json:"output_tokens"`
 | 
			
		||||
	InputTokenDetails  InputTokenDetails  `json:"input_token_details"`
 | 
			
		||||
	OutputTokenDetails OutputTokenDetails `json:"output_token_details"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type InputTokenDetails struct {
 | 
			
		||||
	CachedTokens int `json:"cached_tokens"`
 | 
			
		||||
	TextTokens   int `json:"text_tokens"`
 | 
			
		||||
	AudioTokens  int `json:"audio_tokens"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OutputTokenDetails struct {
 | 
			
		||||
	TextTokens  int `json:"text_tokens"`
 | 
			
		||||
	AudioTokens int `json:"audio_tokens"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RealtimeSession struct {
 | 
			
		||||
@@ -27,7 +55,7 @@ type RealtimeSession struct {
 | 
			
		||||
	Tools                   []RealTimeTool          `json:"tools"`
 | 
			
		||||
	ToolChoice              string                  `json:"tool_choice"`
 | 
			
		||||
	Temperature             float64                 `json:"temperature"`
 | 
			
		||||
	MaxResponseOutputTokens int                     `json:"max_response_output_tokens"`
 | 
			
		||||
	//MaxResponseOutputTokens int                     `json:"max_response_output_tokens"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type InputAudioTranscription struct {
 | 
			
		||||
@@ -42,14 +70,14 @@ type RealTimeTool struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RealtimeItem struct {
 | 
			
		||||
	Id        string          `json:"id"`
 | 
			
		||||
	Type      string          `json:"type"`
 | 
			
		||||
	Status    string          `json:"status"`
 | 
			
		||||
	Role      string          `json:"role"`
 | 
			
		||||
	Content   RealtimeContent `json:"content"`
 | 
			
		||||
	Name      *string         `json:"name,omitempty"`
 | 
			
		||||
	ToolCalls any             `json:"tool_calls,omitempty"`
 | 
			
		||||
	CallId    string          `json:"call_id,omitempty"`
 | 
			
		||||
	Id        string            `json:"id"`
 | 
			
		||||
	Type      string            `json:"type"`
 | 
			
		||||
	Status    string            `json:"status"`
 | 
			
		||||
	Role      string            `json:"role"`
 | 
			
		||||
	Content   []RealtimeContent `json:"content"`
 | 
			
		||||
	Name      *string           `json:"name,omitempty"`
 | 
			
		||||
	ToolCalls any               `json:"tool_calls,omitempty"`
 | 
			
		||||
	CallId    string            `json:"call_id,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
type RealtimeContent struct {
 | 
			
		||||
	Type       string `json:"type"`
 | 
			
		||||
 
 | 
			
		||||
@@ -155,8 +155,27 @@ func RootAuth() func(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func WssAuth(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		// 先检测是否为ws
 | 
			
		||||
		if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" {
 | 
			
		||||
			// Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1
 | 
			
		||||
			// read sk from Sec-WebSocket-Protocol
 | 
			
		||||
			key := c.Request.Header.Get("Sec-WebSocket-Protocol")
 | 
			
		||||
			parts := strings.Split(key, ",")
 | 
			
		||||
			for _, part := range parts {
 | 
			
		||||
				part = strings.TrimSpace(part)
 | 
			
		||||
				if strings.HasPrefix(part, "openai-insecure-api-key") {
 | 
			
		||||
					key = strings.TrimPrefix(part, "openai-insecure-api-key.")
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			c.Request.Header.Set("Authorization", "Bearer "+key)
 | 
			
		||||
		}
 | 
			
		||||
		key := c.Request.Header.Get("Authorization")
 | 
			
		||||
		parts := make([]string, 0)
 | 
			
		||||
		key = strings.TrimPrefix(key, "Bearer ")
 | 
			
		||||
 
 | 
			
		||||
@@ -12,13 +12,13 @@ type Adaptor interface {
 | 
			
		||||
	// Init IsStream bool
 | 
			
		||||
	Init(info *relaycommon.RelayInfo)
 | 
			
		||||
	GetRequestURL(info *relaycommon.RelayInfo) (string, error)
 | 
			
		||||
	SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
 | 
			
		||||
	SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
 | 
			
		||||
	ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
 | 
			
		||||
	ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
 | 
			
		||||
	ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
 | 
			
		||||
	ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
 | 
			
		||||
	DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
 | 
			
		||||
	DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
 | 
			
		||||
	DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
 | 
			
		||||
	DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
 | 
			
		||||
	GetModelList() []string
 | 
			
		||||
	GetChannelName() string
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -32,14 +32,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	return fullRequestURL, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
 | 
			
		||||
	req.Set("Authorization", "Bearer "+info.ApiKey)
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		req.Header.Set("X-DashScope-SSE", "enable")
 | 
			
		||||
		req.Set("X-DashScope-SSE", "enable")
 | 
			
		||||
	}
 | 
			
		||||
	if c.GetString("plugin") != "" {
 | 
			
		||||
		req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
 | 
			
		||||
		req.Set("X-DashScope-Plugin", c.GetString("plugin"))
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -72,11 +72,11 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
 | 
			
		||||
	return nil, errors.New("not implemented")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	switch info.RelayMode {
 | 
			
		||||
	case constant.RelayModeImagesGenerations:
 | 
			
		||||
		err, usage = aliImageHandler(c, resp, info)
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/relay/common"
 | 
			
		||||
@@ -11,14 +12,16 @@ import (
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
 | 
			
		||||
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
 | 
			
		||||
	if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
 | 
			
		||||
		// multipart/form-data
 | 
			
		||||
	} else if info.RelayMode == constant.RelayModeRealtime {
 | 
			
		||||
		// websocket
 | 
			
		||||
	} else {
 | 
			
		||||
		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
		req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
			
		||||
		req.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
		req.Set("Accept", c.Request.Header.Get("Accept"))
 | 
			
		||||
		if info.IsStream && c.Request.Header.Get("Accept") == "" {
 | 
			
		||||
			req.Header.Set("Accept", "text/event-stream")
 | 
			
		||||
			req.Set("Accept", "text/event-stream")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -32,7 +35,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("new request failed: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	err = a.SetupRequestHeader(c, req, info)
 | 
			
		||||
	err = a.SetupRequestHeader(c, &req.Header, info)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("setup request header failed: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -55,7 +58,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
 | 
			
		||||
	// set form data
 | 
			
		||||
	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
 | 
			
		||||
	err = a.SetupRequestHeader(c, req, info)
 | 
			
		||||
	err = a.SetupRequestHeader(c, &req.Header, info)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("setup request header failed: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -66,6 +69,27 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
 | 
			
		||||
	return resp, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*websocket.Conn, error) {
 | 
			
		||||
	fullRequestURL, err := a.GetRequestURL(info)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("get request url failed: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	targetHeader := http.Header{}
 | 
			
		||||
	err = a.SetupRequestHeader(c, &targetHeader, info)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("setup request header failed: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
	targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("dial failed to %s: %w", fullRequestURL, err)
 | 
			
		||||
	}
 | 
			
		||||
	// send request body
 | 
			
		||||
	//all, err := io.ReadAll(requestBody)
 | 
			
		||||
	//err = service.WssString(c, targetConn, string(all))
 | 
			
		||||
	return targetConn, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
 | 
			
		||||
	resp, err := service.GetHttpClient().Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -37,7 +37,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	return "", nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -59,11 +59,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -98,9 +98,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	return fullRequestURL, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
 | 
			
		||||
	req.Set("Authorization", "Bearer "+info.ApiKey)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -122,11 +122,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		err, usage = baiduStreamHandler(c, resp)
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -47,14 +47,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	req.Header.Set("x-api-key", info.ApiKey)
 | 
			
		||||
	req.Set("x-api-key", info.ApiKey)
 | 
			
		||||
	anthropicVersion := c.Request.Header.Get("anthropic-version")
 | 
			
		||||
	if anthropicVersion == "" {
 | 
			
		||||
		anthropicVersion = "2023-06-01"
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("anthropic-version", anthropicVersion)
 | 
			
		||||
	req.Set("anthropic-version", anthropicVersion)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -73,11 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -30,9 +30,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
 | 
			
		||||
	req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -48,7 +48,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -78,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 | 
			
		||||
	return nil, errors.New("not implemented")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	switch info.RelayMode {
 | 
			
		||||
	case constant.RelayModeEmbeddings:
 | 
			
		||||
		fallthrough
 | 
			
		||||
 
 | 
			
		||||
@@ -36,9 +36,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
 | 
			
		||||
	req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -46,7 +46,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 | 
			
		||||
	return requestOpenAI2Cohere(*request), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -54,7 +54,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 | 
			
		||||
	return requestConvertRerank2Cohere(request), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	if info.RelayMode == constant.RelayModeRerank {
 | 
			
		||||
		err, usage = cohereRerankHandler(c, resp, info)
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -31,9 +31,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
 | 
			
		||||
	req.Set("Authorization", "Bearer "+info.ApiKey)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -48,11 +48,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		err, usage = difyStreamHandler(c, resp, info)
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -47,9 +47,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	req.Header.Set("x-goog-api-key", info.ApiKey)
 | 
			
		||||
	req.Set("x-goog-api-key", info.ApiKey)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -64,11 +64,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		err, usage = GeminiChatStreamHandler(c, resp, info)
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -37,9 +37,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	return "", errors.New("invalid relay mode")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
 | 
			
		||||
	req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -47,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 | 
			
		||||
	return request, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -55,7 +55,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 | 
			
		||||
	return request, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	if info.RelayMode == constant.RelayModeRerank {
 | 
			
		||||
		err, usage = jinaRerankHandler(c, resp)
 | 
			
		||||
	} else if info.RelayMode == constant.RelayModeEmbeddings {
 | 
			
		||||
 
 | 
			
		||||
@@ -37,7 +37,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -58,11 +58,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		err, usage = openai.OaiStreamHandler(c, resp, info)
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -31,6 +31,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	if info.RelayMode == constant.RelayModeRealtime {
 | 
			
		||||
		// trim https
 | 
			
		||||
		baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
 | 
			
		||||
		baseUrl = strings.TrimPrefix(baseUrl, "http://")
 | 
			
		||||
		baseUrl = "wss://" + baseUrl
 | 
			
		||||
		info.BaseUrl = baseUrl
 | 
			
		||||
	}
 | 
			
		||||
	switch info.ChannelType {
 | 
			
		||||
	case common.ChannelTypeAzure:
 | 
			
		||||
		// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
 | 
			
		||||
@@ -54,16 +61,19 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	if info.ChannelType == common.ChannelTypeAzure {
 | 
			
		||||
		req.Header.Set("api-key", info.ApiKey)
 | 
			
		||||
		req.Set("api-key", info.ApiKey)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
 | 
			
		||||
		req.Header.Set("OpenAI-Organization", info.Organization)
 | 
			
		||||
		req.Set("OpenAI-Organization", info.Organization)
 | 
			
		||||
	}
 | 
			
		||||
	req.Set("Authorization", "Bearer "+info.ApiKey)
 | 
			
		||||
	if info.RelayMode == constant.RelayModeRealtime {
 | 
			
		||||
		req.Set("openai-beta", "realtime=v1")
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
 | 
			
		||||
	//if info.ChannelType == common.ChannelTypeOpenRouter {
 | 
			
		||||
	//	req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
 | 
			
		||||
	//	req.Header.Set("X-Title", "One API")
 | 
			
		||||
@@ -131,16 +141,20 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 | 
			
		||||
	return request, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
 | 
			
		||||
		return channel.DoFormRequest(a, c, info, requestBody)
 | 
			
		||||
	} else if info.RelayMode == constant.RelayModeRealtime {
 | 
			
		||||
		return channel.DoWssRequest(a, c, info, requestBody)
 | 
			
		||||
	} else {
 | 
			
		||||
		return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	switch info.RelayMode {
 | 
			
		||||
	case constant.RelayModeRealtime:
 | 
			
		||||
		err, usage = OpenaiRealtimeHandler(c, info)
 | 
			
		||||
	case constant.RelayModeAudioSpeech:
 | 
			
		||||
		err, usage = OpenaiTTSHandler(c, resp, info)
 | 
			
		||||
	case constant.RelayModeAudioTranslation:
 | 
			
		||||
 
 | 
			
		||||
@@ -7,6 +7,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/bytedance/gopkg/util/gopool"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
@@ -373,3 +374,106 @@ func getTextFromJSON(body []byte) (string, error) {
 | 
			
		||||
	}
 | 
			
		||||
	return whisperResponse.Text, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
 | 
			
		||||
	info.IsStream = true
 | 
			
		||||
	clientConn := info.ClientWs
 | 
			
		||||
	targetConn := info.TargetWs
 | 
			
		||||
 | 
			
		||||
	clientClosed := make(chan struct{})
 | 
			
		||||
	targetClosed := make(chan struct{})
 | 
			
		||||
	sendChan := make(chan []byte, 100)
 | 
			
		||||
	receiveChan := make(chan []byte, 100)
 | 
			
		||||
	errChan := make(chan error, 2)
 | 
			
		||||
 | 
			
		||||
	usage := &dto.RealtimeUsage{}
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			select {
 | 
			
		||||
			case <-c.Done():
 | 
			
		||||
				return
 | 
			
		||||
			default:
 | 
			
		||||
				_, message, err := clientConn.ReadMessage()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
 | 
			
		||||
						errChan <- fmt.Errorf("error reading from client: %v", err)
 | 
			
		||||
					}
 | 
			
		||||
					close(clientClosed)
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				err = service.WssString(c, targetConn, string(message))
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					errChan <- fmt.Errorf("error writing to target: %v", err)
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				select {
 | 
			
		||||
				case sendChan <- message:
 | 
			
		||||
				default:
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			select {
 | 
			
		||||
			case <-c.Done():
 | 
			
		||||
				return
 | 
			
		||||
			default:
 | 
			
		||||
				_, message, err := targetConn.ReadMessage()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
 | 
			
		||||
						errChan <- fmt.Errorf("error reading from target: %v", err)
 | 
			
		||||
					}
 | 
			
		||||
					close(targetClosed)
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
				info.SetFirstResponseTime()
 | 
			
		||||
				realtimeEvent := &dto.RealtimeEvent{}
 | 
			
		||||
				err = json.Unmarshal(message, realtimeEvent)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					errChan <- fmt.Errorf("error unmarshalling message: %v", err)
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
 | 
			
		||||
					realtimeUsage := realtimeEvent.Response.Usage
 | 
			
		||||
					if realtimeUsage != nil {
 | 
			
		||||
						usage.TotalTokens += realtimeUsage.TotalTokens
 | 
			
		||||
						usage.InputTokens += realtimeUsage.InputTokens
 | 
			
		||||
						usage.OutputTokens += realtimeUsage.OutputTokens
 | 
			
		||||
						usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
 | 
			
		||||
						usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
 | 
			
		||||
						usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
 | 
			
		||||
						usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
 | 
			
		||||
						usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				err = service.WssString(c, clientConn, string(message))
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					errChan <- fmt.Errorf("error writing to client: %v", err)
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				select {
 | 
			
		||||
				case receiveChan <- message:
 | 
			
		||||
				default:
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	select {
 | 
			
		||||
	case <-clientClosed:
 | 
			
		||||
	case <-targetClosed:
 | 
			
		||||
	case <-errChan:
 | 
			
		||||
		//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
 | 
			
		||||
	case <-c.Done():
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, usage
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -32,9 +32,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	req.Header.Set("x-goog-api-key", info.ApiKey)
 | 
			
		||||
	req.Set("x-goog-api-key", info.ApiKey)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -49,11 +49,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		var responseText string
 | 
			
		||||
		err, responseText = palmStreamHandler(c, resp)
 | 
			
		||||
 
 | 
			
		||||
@@ -32,9 +32,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
 | 
			
		||||
	req.Set("Authorization", "Bearer "+info.ApiKey)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -52,11 +52,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		err, usage = openai.OaiStreamHandler(c, resp, info)
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -40,9 +40,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	return "", errors.New("invalid relay mode")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
 | 
			
		||||
	req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 | 
			
		||||
	return request, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -58,7 +58,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 | 
			
		||||
	return request, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	switch info.RelayMode {
 | 
			
		||||
	case constant.RelayModeRerank:
 | 
			
		||||
		err, usage = siliconflowRerankHandler(c, resp)
 | 
			
		||||
 
 | 
			
		||||
@@ -43,12 +43,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	return fmt.Sprintf("%s/", info.BaseUrl), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	req.Header.Set("Authorization", a.Sign)
 | 
			
		||||
	req.Header.Set("X-TC-Action", a.Action)
 | 
			
		||||
	req.Header.Set("X-TC-Version", a.Version)
 | 
			
		||||
	req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
 | 
			
		||||
	req.Set("Authorization", a.Sign)
 | 
			
		||||
	req.Set("X-TC-Action", a.Action)
 | 
			
		||||
	req.Set("X-TC-Version", a.Version)
 | 
			
		||||
	req.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -73,11 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		var responseText string
 | 
			
		||||
		err, responseText = tencentStreamHandler(c, resp)
 | 
			
		||||
 
 | 
			
		||||
@@ -107,13 +107,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	return "", errors.New("unsupported request mode")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	accessToken, err := getAccessToken(a, info)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("Authorization", "Bearer "+accessToken)
 | 
			
		||||
	req.Set("Authorization", "Bearer "+accessToken)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -148,11 +148,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		switch a.RequestMode {
 | 
			
		||||
		case RequestModeClaude:
 | 
			
		||||
 
 | 
			
		||||
@@ -33,7 +33,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	return "", nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -50,14 +50,14 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	// xunfei's request is not http request, so we don't need to do anything here
 | 
			
		||||
	dummyResp := &http.Response{}
 | 
			
		||||
	dummyResp.StatusCode = http.StatusOK
 | 
			
		||||
	return dummyResp, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	splits := strings.Split(info.ApiKey, "|")
 | 
			
		||||
	if len(splits) != 3 {
 | 
			
		||||
		return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
 | 
			
		||||
 
 | 
			
		||||
@@ -35,10 +35,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	token := getZhipuToken(info.ApiKey)
 | 
			
		||||
	req.Header.Set("Authorization", token)
 | 
			
		||||
	req.Set("Authorization", token)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -56,11 +56,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		err, usage = zhipuStreamHandler(c, resp)
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -32,10 +32,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	token := getZhipuToken(info.ApiKey)
 | 
			
		||||
	req.Header.Set("Authorization", token)
 | 
			
		||||
	req.Set("Authorization", token)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -53,11 +53,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		err, usage = openai.OaiStreamHandler(c, resp, info)
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package common
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/relay/constant"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -32,6 +33,14 @@ type RelayInfo struct {
 | 
			
		||||
	BaseUrl              string
 | 
			
		||||
	SupportStreamOptions bool
 | 
			
		||||
	ShouldIncludeUsage   bool
 | 
			
		||||
	ClientWs             *websocket.Conn
 | 
			
		||||
	TargetWs             *websocket.Conn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
 | 
			
		||||
	info := GenRelayInfo(c)
 | 
			
		||||
	info.ClientWs = ws
 | 
			
		||||
	return info
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
 | 
			
		||||
 
 | 
			
		||||
@@ -122,19 +122,21 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	statusCodeMappingStr := c.GetString("status_code_mapping")
 | 
			
		||||
 | 
			
		||||
	var httpResp *http.Response
 | 
			
		||||
	if resp != nil {
 | 
			
		||||
		if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		httpResp = resp.(*http.Response)
 | 
			
		||||
		if httpResp.StatusCode != http.StatusOK {
 | 
			
		||||
			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
 | 
			
		||||
			openaiErr := service.RelayErrorHandler(resp)
 | 
			
		||||
			openaiErr := service.RelayErrorHandler(httpResp)
 | 
			
		||||
			// reset status code 重置状态码
 | 
			
		||||
			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 | 
			
		||||
			return openaiErr
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
 | 
			
		||||
	usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
 | 
			
		||||
	if openaiErr != nil {
 | 
			
		||||
		returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
 | 
			
		||||
		// reset status code 重置状态码
 | 
			
		||||
@@ -142,7 +144,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
		return openaiErr
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
 | 
			
		||||
	postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -149,22 +149,24 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
	requestBody = bytes.NewBuffer(jsonData)
 | 
			
		||||
 | 
			
		||||
	statusCodeMappingStr := c.GetString("status_code_mapping")
 | 
			
		||||
 | 
			
		||||
	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var httpResp *http.Response
 | 
			
		||||
	if resp != nil {
 | 
			
		||||
		relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 | 
			
		||||
		if resp.StatusCode != http.StatusOK {
 | 
			
		||||
			openaiErr := service.RelayErrorHandler(resp)
 | 
			
		||||
		httpResp = resp.(*http.Response)
 | 
			
		||||
		relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 | 
			
		||||
		if httpResp.StatusCode != http.StatusOK {
 | 
			
		||||
			openaiErr := service.RelayErrorHandler(httpResp)
 | 
			
		||||
			// reset status code 重置状态码
 | 
			
		||||
			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 | 
			
		||||
			return openaiErr
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
 | 
			
		||||
	_, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
 | 
			
		||||
	if openaiErr != nil {
 | 
			
		||||
		// reset status code 重置状态码
 | 
			
		||||
		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 | 
			
		||||
 
 | 
			
		||||
@@ -180,30 +180,32 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	statusCodeMappingStr := c.GetString("status_code_mapping")
 | 
			
		||||
	var httpResp *http.Response
 | 
			
		||||
	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if resp != nil {
 | 
			
		||||
		relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 | 
			
		||||
		if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		httpResp = resp.(*http.Response)
 | 
			
		||||
		relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 | 
			
		||||
		if httpResp.StatusCode != http.StatusOK {
 | 
			
		||||
			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
 | 
			
		||||
			openaiErr := service.RelayErrorHandler(resp)
 | 
			
		||||
			openaiErr := service.RelayErrorHandler(httpResp)
 | 
			
		||||
			// reset status code 重置状态码
 | 
			
		||||
			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 | 
			
		||||
			return openaiErr
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
 | 
			
		||||
	usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
 | 
			
		||||
	if openaiErr != nil {
 | 
			
		||||
		returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
 | 
			
		||||
		// reset status code 重置状态码
 | 
			
		||||
		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 | 
			
		||||
		return openaiErr
 | 
			
		||||
	}
 | 
			
		||||
	postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
 | 
			
		||||
	postConsumeQuota(c, relayInfo, textRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -99,23 +99,26 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var httpResp *http.Response
 | 
			
		||||
	if resp != nil {
 | 
			
		||||
		if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		httpResp = resp.(*http.Response)
 | 
			
		||||
		if httpResp.StatusCode != http.StatusOK {
 | 
			
		||||
			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
 | 
			
		||||
			openaiErr := service.RelayErrorHandler(resp)
 | 
			
		||||
			openaiErr := service.RelayErrorHandler(httpResp)
 | 
			
		||||
			// reset status code 重置状态码
 | 
			
		||||
			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 | 
			
		||||
			return openaiErr
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
 | 
			
		||||
	usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
 | 
			
		||||
	if openaiErr != nil {
 | 
			
		||||
		returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
 | 
			
		||||
		// reset status code 重置状态码
 | 
			
		||||
		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 | 
			
		||||
		return openaiErr
 | 
			
		||||
	}
 | 
			
		||||
	postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
 | 
			
		||||
	postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										242
									
								
								relay/websocket.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										242
									
								
								relay/websocket.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,242 @@
 | 
			
		||||
package relay
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"math"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
 | 
			
		||||
//	_, p, err := ws.ReadMessage()
 | 
			
		||||
//	if err != nil {
 | 
			
		||||
//		return nil, err
 | 
			
		||||
//	}
 | 
			
		||||
//	realtimeEvent := &dto.RealtimeEvent{}
 | 
			
		||||
//	err = json.Unmarshal(p, realtimeEvent)
 | 
			
		||||
//	if err != nil {
 | 
			
		||||
//		return nil, err
 | 
			
		||||
//	}
 | 
			
		||||
//	// save the original request
 | 
			
		||||
//	if realtimeEvent.Session == nil {
 | 
			
		||||
//		return nil, errors.New("session object is nil")
 | 
			
		||||
//	}
 | 
			
		||||
//	c.Set("first_wss_request", p)
 | 
			
		||||
//	return realtimeEvent, nil
 | 
			
		||||
//}
 | 
			
		||||
 | 
			
		||||
func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
	relayInfo := relaycommon.GenRelayInfoWs(c, ws)
 | 
			
		||||
 | 
			
		||||
	// get & validate textRequest 获取并验证文本请求
 | 
			
		||||
	//realtimeEvent, err := getAndValidateWssRequest(c, ws)
 | 
			
		||||
	//if err != nil {
 | 
			
		||||
	//	common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error()))
 | 
			
		||||
	//	return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
 | 
			
		||||
	//}
 | 
			
		||||
 | 
			
		||||
	// map model name
 | 
			
		||||
	modelMapping := c.GetString("model_mapping")
 | 
			
		||||
	//isModelMapped := false
 | 
			
		||||
	if modelMapping != "" && modelMapping != "{}" {
 | 
			
		||||
		modelMap := make(map[string]string)
 | 
			
		||||
		err := json.Unmarshal([]byte(modelMapping), &modelMap)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		if modelMap[relayInfo.OriginModelName] != "" {
 | 
			
		||||
			relayInfo.UpstreamModelName = modelMap[relayInfo.OriginModelName]
 | 
			
		||||
			// set upstream model name
 | 
			
		||||
			//isModelMapped = true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	//relayInfo.UpstreamModelName = textRequest.Model
 | 
			
		||||
	modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false)
 | 
			
		||||
	groupRatio := common.GetGroupRatio(relayInfo.Group)
 | 
			
		||||
 | 
			
		||||
	var preConsumedQuota int
 | 
			
		||||
	var ratio float64
 | 
			
		||||
	var modelRatio float64
 | 
			
		||||
	//err := service.SensitiveWordsCheck(textRequest)
 | 
			
		||||
 | 
			
		||||
	//if constant.ShouldCheckPromptSensitive() {
 | 
			
		||||
	//	err = checkRequestSensitive(textRequest, relayInfo)
 | 
			
		||||
	//	if err != nil {
 | 
			
		||||
	//		return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
 | 
			
		||||
	//	}
 | 
			
		||||
	//}
 | 
			
		||||
 | 
			
		||||
	//promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
 | 
			
		||||
	//// count messages token error 计算promptTokens错误
 | 
			
		||||
	//if err != nil {
 | 
			
		||||
	//	return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
 | 
			
		||||
	//}
 | 
			
		||||
	//
 | 
			
		||||
	if !getModelPriceSuccess {
 | 
			
		||||
		preConsumedTokens := common.PreConsumedQuota
 | 
			
		||||
		//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
 | 
			
		||||
		//	preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
 | 
			
		||||
		//}
 | 
			
		||||
		modelRatio = common.GetModelRatio(relayInfo.UpstreamModelName)
 | 
			
		||||
		ratio = modelRatio * groupRatio
 | 
			
		||||
		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
 | 
			
		||||
	} else {
 | 
			
		||||
		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// pre-consume quota 预消耗配额
 | 
			
		||||
	preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
 | 
			
		||||
	if openaiErr != nil {
 | 
			
		||||
		return openaiErr
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	adaptor := GetAdaptor(relayInfo.ApiType)
 | 
			
		||||
	if adaptor == nil {
 | 
			
		||||
		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
	adaptor.Init(relayInfo)
 | 
			
		||||
	//var requestBody io.Reader
 | 
			
		||||
	//firstWssRequest, _ := c.Get("first_wss_request")
 | 
			
		||||
	//requestBody = bytes.NewBuffer(firstWssRequest.([]byte))
 | 
			
		||||
 | 
			
		||||
	statusCodeMappingStr := c.GetString("status_code_mapping")
 | 
			
		||||
	resp, err := adaptor.DoRequest(c, relayInfo, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if resp != nil {
 | 
			
		||||
		relayInfo.TargetWs = resp.(*websocket.Conn)
 | 
			
		||||
		defer relayInfo.TargetWs.Close()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo)
 | 
			
		||||
	if openaiErr != nil {
 | 
			
		||||
		returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
 | 
			
		||||
		// reset status code 重置状态码
 | 
			
		||||
		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 | 
			
		||||
		return openaiErr
 | 
			
		||||
	}
 | 
			
		||||
	postWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func postWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
 | 
			
		||||
	usage *dto.RealtimeUsage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64,
 | 
			
		||||
	groupRatio float64,
 | 
			
		||||
	modelPrice float64, usePrice bool, extraContent string) {
 | 
			
		||||
 | 
			
		||||
	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 | 
			
		||||
	textInputTokens := usage.InputTokenDetails.TextTokens
 | 
			
		||||
	textOutTokens := usage.OutputTokenDetails.TextTokens
 | 
			
		||||
 | 
			
		||||
	audioInputTokens := usage.InputTokenDetails.AudioTokens
 | 
			
		||||
	audioOutTokens := usage.OutputTokenDetails.AudioTokens
 | 
			
		||||
 | 
			
		||||
	tokenName := ctx.GetString("token_name")
 | 
			
		||||
	completionRatio := common.GetCompletionRatio(modelName)
 | 
			
		||||
	audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
 | 
			
		||||
	audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
 | 
			
		||||
 | 
			
		||||
	quota := 0
 | 
			
		||||
	if !usePrice {
 | 
			
		||||
		quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio))
 | 
			
		||||
		quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*completionRatio*audioCompletionRatio))
 | 
			
		||||
 | 
			
		||||
		quota = int(math.Round(float64(quota) * ratio))
 | 
			
		||||
		if ratio != 0 && quota <= 0 {
 | 
			
		||||
			quota = 1
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
 | 
			
		||||
	}
 | 
			
		||||
	totalTokens := usage.TotalTokens
 | 
			
		||||
	var logContent string
 | 
			
		||||
	if !usePrice {
 | 
			
		||||
		logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
 | 
			
		||||
	} else {
 | 
			
		||||
		logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// record all the consume log even if quota is 0
 | 
			
		||||
	if totalTokens == 0 {
 | 
			
		||||
		// in this case, must be some error happened
 | 
			
		||||
		// we cannot just return, because we may have to return the pre-consumed quota
 | 
			
		||||
		quota = 0
 | 
			
		||||
		logContent += fmt.Sprintf("(可能是上游超时)")
 | 
			
		||||
		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
 | 
			
		||||
			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
 | 
			
		||||
	} else {
 | 
			
		||||
		//if sensitiveResp != nil {
 | 
			
		||||
		//	logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
 | 
			
		||||
		//}
 | 
			
		||||
		quotaDelta := quota - preConsumedQuota
 | 
			
		||||
		if quotaDelta != 0 {
 | 
			
		||||
			err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.LogError(ctx, "error consuming token remain quota: "+err.Error())
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		err := model.CacheUpdateUserQuota(relayInfo.UserId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.LogError(ctx, "error update user quota cache: "+err.Error())
 | 
			
		||||
		}
 | 
			
		||||
		model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
 | 
			
		||||
		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logModel := modelName
 | 
			
		||||
	if strings.HasPrefix(logModel, "gpt-4-gizmo") {
 | 
			
		||||
		logModel = "gpt-4-gizmo-*"
 | 
			
		||||
		logContent += fmt.Sprintf(",模型 %s", modelName)
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
 | 
			
		||||
		logModel = "gpt-4o-gizmo-*"
 | 
			
		||||
		logContent += fmt.Sprintf(",模型 %s", modelName)
 | 
			
		||||
	}
 | 
			
		||||
	if extraContent != "" {
 | 
			
		||||
		logContent += ", " + extraContent
 | 
			
		||||
	}
 | 
			
		||||
	other := service.GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, modelPrice)
 | 
			
		||||
	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
 | 
			
		||||
		tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
 | 
			
		||||
 | 
			
		||||
	//if quota != 0 {
 | 
			
		||||
	//
 | 
			
		||||
	//}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) {
 | 
			
		||||
	var promptTokens int
 | 
			
		||||
	var err error
 | 
			
		||||
	switch info.RelayMode {
 | 
			
		||||
	default:
 | 
			
		||||
		promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName)
 | 
			
		||||
	}
 | 
			
		||||
	info.PromptTokens = promptTokens
 | 
			
		||||
	return promptTokens, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
 | 
			
		||||
//	var err error
 | 
			
		||||
//	switch info.RelayMode {
 | 
			
		||||
//	case relayconstant.RelayModeChatCompletions:
 | 
			
		||||
//		err = service.CheckSensitiveMessages(textRequest.Messages)
 | 
			
		||||
//	case relayconstant.RelayModeCompletions:
 | 
			
		||||
//		err = service.CheckSensitiveInput(textRequest.Prompt)
 | 
			
		||||
//	case relayconstant.RelayModeModerations:
 | 
			
		||||
//		err = service.CheckSensitiveInput(textRequest.Input)
 | 
			
		||||
//	case relayconstant.RelayModeEmbeddings:
 | 
			
		||||
//		err = service.CheckSensitiveInput(textRequest.Input)
 | 
			
		||||
//	}
 | 
			
		||||
//	return err
 | 
			
		||||
//}
 | 
			
		||||
@@ -2,6 +2,7 @@ package service
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -17,3 +18,13 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
 | 
			
		||||
	other["admin_info"] = adminInfo
 | 
			
		||||
	return other
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, modelPrice float64) map[string]interface{} {
 | 
			
		||||
	info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
 | 
			
		||||
	info["ws"] = true
 | 
			
		||||
	info["audio_input"] = usage.InputTokenDetails.AudioTokens
 | 
			
		||||
	info["audio_output"] = usage.OutputTokenDetails.AudioTokens
 | 
			
		||||
	info["text_input"] = usage.InputTokenDetails.TextTokens
 | 
			
		||||
	info["text_output"] = usage.OutputTokenDetails.TextTokens
 | 
			
		||||
	return info
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -43,11 +43,25 @@ func Done(c *gin.Context) {
 | 
			
		||||
	_ = StringData(c, "[DONE]")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
 | 
			
		||||
	if ws == nil {
 | 
			
		||||
		common.LogError(c, "websocket connection is nil")
 | 
			
		||||
		return errors.New("websocket connection is nil")
 | 
			
		||||
	}
 | 
			
		||||
	common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
 | 
			
		||||
	return ws.WriteMessage(1, []byte(str))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
 | 
			
		||||
	jsonData, err := json.Marshal(object)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error marshalling object: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	if ws == nil {
 | 
			
		||||
		common.LogError(c, "websocket connection is nil")
 | 
			
		||||
		return errors.New("websocket connection is nil")
 | 
			
		||||
	}
 | 
			
		||||
	common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
 | 
			
		||||
	return ws.WriteMessage(1, jsonData)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -191,6 +191,45 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
 | 
			
		||||
	return tkm, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CountTokenRealtime(request dto.RealtimeEvent, model string) (int, error) {
 | 
			
		||||
	tkm := 0
 | 
			
		||||
	ratio := 1
 | 
			
		||||
	if request.Session != nil {
 | 
			
		||||
		msgTokens, err := CountTokenText(request.Session.Instructions, model)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return 0, err
 | 
			
		||||
		}
 | 
			
		||||
		ratio = len(request.Session.Modalities)
 | 
			
		||||
		tkm += msgTokens
 | 
			
		||||
		if request.Session.Tools != nil {
 | 
			
		||||
			toolsData, _ := json.Marshal(request.Session.Tools)
 | 
			
		||||
			var openaiTools []dto.OpenAITools
 | 
			
		||||
			err := json.Unmarshal(toolsData, &openaiTools)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
 | 
			
		||||
			}
 | 
			
		||||
			countStr := ""
 | 
			
		||||
			for _, tool := range openaiTools {
 | 
			
		||||
				countStr = tool.Function.Name
 | 
			
		||||
				if tool.Function.Description != "" {
 | 
			
		||||
					countStr += tool.Function.Description
 | 
			
		||||
				}
 | 
			
		||||
				if tool.Function.Parameters != nil {
 | 
			
		||||
					countStr += fmt.Sprintf("%v", tool.Function.Parameters)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			toolTokens, err := CountTokenInput(countStr, model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return 0, err
 | 
			
		||||
			}
 | 
			
		||||
			tkm += 8
 | 
			
		||||
			tkm += toolTokens
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	tkm *= ratio
 | 
			
		||||
	return tkm, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
 | 
			
		||||
	//recover when panic
 | 
			
		||||
	tokenEncoder := getTokenEncoder(model)
 | 
			
		||||
 
 | 
			
		||||
@@ -11,7 +11,7 @@ import {
 | 
			
		||||
 | 
			
		||||
import {
 | 
			
		||||
  Avatar,
 | 
			
		||||
  Button,
 | 
			
		||||
  Button, Descriptions,
 | 
			
		||||
  Form,
 | 
			
		||||
  Layout,
 | 
			
		||||
  Modal,
 | 
			
		||||
@@ -20,7 +20,7 @@ import {
 | 
			
		||||
  Spin,
 | 
			
		||||
  Table,
 | 
			
		||||
  Tag,
 | 
			
		||||
  Tooltip,
 | 
			
		||||
  Tooltip
 | 
			
		||||
} from '@douyinfe/semi-ui';
 | 
			
		||||
import { ITEMS_PER_PAGE } from '../constants';
 | 
			
		||||
import {
 | 
			
		||||
@@ -336,33 +336,33 @@ const LogsTable = () => {
 | 
			
		||||
        );
 | 
			
		||||
      },
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      title: '重试',
 | 
			
		||||
      dataIndex: 'retry',
 | 
			
		||||
      className: isAdmin() ? 'tableShow' : 'tableHiddle',
 | 
			
		||||
      render: (text, record, index) => {
 | 
			
		||||
        let content = '渠道:' + record.channel;
 | 
			
		||||
        if (record.other !== '') {
 | 
			
		||||
          let other = JSON.parse(record.other);
 | 
			
		||||
          if (other === null) {
 | 
			
		||||
            return <></>;
 | 
			
		||||
          }
 | 
			
		||||
          if (other.admin_info !== undefined) {
 | 
			
		||||
            if (
 | 
			
		||||
              other.admin_info.use_channel !== null &&
 | 
			
		||||
              other.admin_info.use_channel !== undefined &&
 | 
			
		||||
              other.admin_info.use_channel !== ''
 | 
			
		||||
            ) {
 | 
			
		||||
              // channel id array
 | 
			
		||||
              let useChannel = other.admin_info.use_channel;
 | 
			
		||||
              let useChannelStr = useChannel.join('->');
 | 
			
		||||
              content = `渠道:${useChannelStr}`;
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        return isAdminUser ? <div>{content}</div> : <></>;
 | 
			
		||||
      },
 | 
			
		||||
    },
 | 
			
		||||
    // {
 | 
			
		||||
    //   title: '重试',
 | 
			
		||||
    //   dataIndex: 'retry',
 | 
			
		||||
    //   className: isAdmin() ? 'tableShow' : 'tableHiddle',
 | 
			
		||||
    //   render: (text, record, index) => {
 | 
			
		||||
    //     let content = '渠道:' + record.channel;
 | 
			
		||||
    //     if (record.other !== '') {
 | 
			
		||||
    //       let other = JSON.parse(record.other);
 | 
			
		||||
    //       if (other === null) {
 | 
			
		||||
    //         return <></>;
 | 
			
		||||
    //       }
 | 
			
		||||
    //       if (other.admin_info !== undefined) {
 | 
			
		||||
    //         if (
 | 
			
		||||
    //           other.admin_info.use_channel !== null &&
 | 
			
		||||
    //           other.admin_info.use_channel !== undefined &&
 | 
			
		||||
    //           other.admin_info.use_channel !== ''
 | 
			
		||||
    //         ) {
 | 
			
		||||
    //           // channel id array
 | 
			
		||||
    //           let useChannel = other.admin_info.use_channel;
 | 
			
		||||
    //           let useChannelStr = useChannel.join('->');
 | 
			
		||||
    //           content = `渠道:${useChannelStr}`;
 | 
			
		||||
    //         }
 | 
			
		||||
    //       }
 | 
			
		||||
    //     }
 | 
			
		||||
    //     return isAdminUser ? <div>{content}</div> : <></>;
 | 
			
		||||
    //   },
 | 
			
		||||
    // },
 | 
			
		||||
    {
 | 
			
		||||
      title: '详情',
 | 
			
		||||
      dataIndex: 'content',
 | 
			
		||||
@@ -409,6 +409,7 @@ const LogsTable = () => {
 | 
			
		||||
  ];
 | 
			
		||||
 | 
			
		||||
  const [logs, setLogs] = useState([]);
 | 
			
		||||
  const [expandData, setExpandData] = useState({});
 | 
			
		||||
  const [showStat, setShowStat] = useState(false);
 | 
			
		||||
  const [loading, setLoading] = useState(false);
 | 
			
		||||
  const [loadingStat, setLoadingStat] = useState(false);
 | 
			
		||||
@@ -512,10 +513,54 @@ const LogsTable = () => {
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const setLogsFormat = (logs) => {
 | 
			
		||||
    let expandDatesLocal = {};
 | 
			
		||||
    for (let i = 0; i < logs.length; i++) {
 | 
			
		||||
      logs[i].timestamp2string = timestamp2string(logs[i].created_at);
 | 
			
		||||
      logs[i].key = '' + logs[i].id;
 | 
			
		||||
      let other = getLogOther(logs[i].other);
 | 
			
		||||
      let expandDataLocal = [];
 | 
			
		||||
      if (isAdmin()) {
 | 
			
		||||
        let content = '渠道:' + logs[i].channel;
 | 
			
		||||
        if (other.admin_info !== undefined) {
 | 
			
		||||
          if (
 | 
			
		||||
            other.admin_info.use_channel !== null &&
 | 
			
		||||
            other.admin_info.use_channel !== undefined &&
 | 
			
		||||
            other.admin_info.use_channel !== ''
 | 
			
		||||
          ) {
 | 
			
		||||
            // channel id array
 | 
			
		||||
            let useChannel = other.admin_info.use_channel;
 | 
			
		||||
            let useChannelStr = useChannel.join('->');
 | 
			
		||||
            content = `渠道:${useChannelStr}`;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        expandDataLocal.push({
 | 
			
		||||
          key: '重试',
 | 
			
		||||
          value: content,
 | 
			
		||||
        })
 | 
			
		||||
      }
 | 
			
		||||
      if (other.ws) {
 | 
			
		||||
        expandDataLocal.push({
 | 
			
		||||
          key: '语音输入',
 | 
			
		||||
          value: other.audio_input,
 | 
			
		||||
        });
 | 
			
		||||
        expandDataLocal.push({
 | 
			
		||||
          key: '语音输出',
 | 
			
		||||
          value: other.audio_output,
 | 
			
		||||
        });
 | 
			
		||||
        expandDataLocal.push({
 | 
			
		||||
          key: '文字输入',
 | 
			
		||||
          value: other.text_input,
 | 
			
		||||
        });
 | 
			
		||||
        expandDataLocal.push({
 | 
			
		||||
          key: '文字输出',
 | 
			
		||||
          value: other.text_output,
 | 
			
		||||
        });
 | 
			
		||||
      }
 | 
			
		||||
      expandDatesLocal[logs[i].key] = expandDataLocal;
 | 
			
		||||
    }
 | 
			
		||||
    console.log(expandDatesLocal);
 | 
			
		||||
    setExpandData(expandDatesLocal);
 | 
			
		||||
 | 
			
		||||
    setLogs(logs);
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
@@ -588,6 +633,10 @@ const LogsTable = () => {
 | 
			
		||||
    handleEyeClick();
 | 
			
		||||
  }, []);
 | 
			
		||||
 | 
			
		||||
  const expandRowRender = (record, index) => {
 | 
			
		||||
    return <Descriptions align="justify" data={expandData[record.key]} />;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
    <>
 | 
			
		||||
      <Layout>
 | 
			
		||||
@@ -686,7 +735,9 @@ const LogsTable = () => {
 | 
			
		||||
        <Table
 | 
			
		||||
          style={{ marginTop: 5 }}
 | 
			
		||||
          columns={columns}
 | 
			
		||||
          expandedRowRender={expandRowRender}
 | 
			
		||||
          dataSource={logs}
 | 
			
		||||
          rowKey="key"
 | 
			
		||||
          pagination={{
 | 
			
		||||
            currentPage: activePage,
 | 
			
		||||
            pageSize: pageSize,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user