mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-31 22:03:41 +08:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			v0.6.11-al
			...
			refactor
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | e12b0c7aa8 | 
| @@ -31,3 +31,11 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { | |||||||
| 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func SetEventStreamHeaders(c *gin.Context) { | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||||
|  | 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||||
|  | 	c.Writer.Header().Set("Connection", "keep-alive") | ||||||
|  | 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||||
|  | 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||||
|  | } | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import ( | |||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/model" | 	"one-api/model" | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func GetSubscription(c *gin.Context) { | func GetSubscription(c *gin.Context) { | ||||||
| @@ -27,12 +28,12 @@ func GetSubscription(c *gin.Context) { | |||||||
| 		expiredTime = 0 | 		expiredTime = 0 | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		openAIError := OpenAIError{ | 		Error := openai.Error{ | ||||||
| 			Message: err.Error(), | 			Message: err.Error(), | ||||||
| 			Type:    "upstream_error", | 			Type:    "upstream_error", | ||||||
| 		} | 		} | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"error": openAIError, | 			"error": Error, | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -69,12 +70,12 @@ func GetUsage(c *gin.Context) { | |||||||
| 		quota, err = model.GetUserUsedQuota(userId) | 		quota, err = model.GetUserUsedQuota(userId) | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		openAIError := OpenAIError{ | 		Error := openai.Error{ | ||||||
| 			Message: err.Error(), | 			Message: err.Error(), | ||||||
| 			Type:    "one_api_error", | 			Type:    "one_api_error", | ||||||
| 		} | 		} | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"error": openAIError, | 			"error": Error, | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ import ( | |||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/model" | 	"one-api/model" | ||||||
|  | 	"one-api/relay/util" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| @@ -92,7 +93,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He | |||||||
| 	for k := range headers { | 	for k := range headers { | ||||||
| 		req.Header.Add(k, headers.Get(k)) | 		req.Header.Add(k, headers.Get(k)) | ||||||
| 	} | 	} | ||||||
| 	res, err := httpClient.Do(req) | 	res, err := util.HTTPClient.Do(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -9,6 +9,8 @@ import ( | |||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/model" | 	"one-api/model" | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
|  | 	"one-api/relay/util" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -16,7 +18,7 @@ import ( | |||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { | func testChannel(channel *model.Channel, request openai.ChatRequest) (err error, openaiErr *openai.Error) { | ||||||
| 	switch channel.Type { | 	switch channel.Type { | ||||||
| 	case common.ChannelTypePaLM: | 	case common.ChannelTypePaLM: | ||||||
| 		fallthrough | 		fallthrough | ||||||
| @@ -46,13 +48,13 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai | |||||||
| 	} | 	} | ||||||
| 	requestURL := common.ChannelBaseURLs[channel.Type] | 	requestURL := common.ChannelBaseURLs[channel.Type] | ||||||
| 	if channel.Type == common.ChannelTypeAzure { | 	if channel.Type == common.ChannelTypeAzure { | ||||||
| 		requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) | 		requestURL = util.GetFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) | ||||||
| 	} else { | 	} else { | ||||||
| 		if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { | 		if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { | ||||||
| 			requestURL = baseURL | 			requestURL = baseURL | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) | 		requestURL = util.GetFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) | ||||||
| 	} | 	} | ||||||
| 	jsonData, err := json.Marshal(request) | 	jsonData, err := json.Marshal(request) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -68,12 +70,12 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai | |||||||
| 		req.Header.Set("Authorization", "Bearer "+channel.Key) | 		req.Header.Set("Authorization", "Bearer "+channel.Key) | ||||||
| 	} | 	} | ||||||
| 	req.Header.Set("Content-Type", "application/json") | 	req.Header.Set("Content-Type", "application/json") | ||||||
| 	resp, err := httpClient.Do(req) | 	resp, err := util.HTTPClient.Do(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err, nil | 		return err, nil | ||||||
| 	} | 	} | ||||||
| 	defer resp.Body.Close() | 	defer resp.Body.Close() | ||||||
| 	var response TextResponse | 	var response openai.SlimTextResponse | ||||||
| 	body, err := io.ReadAll(resp.Body) | 	body, err := io.ReadAll(resp.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err, nil | 		return err, nil | ||||||
| @@ -91,12 +93,12 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai | |||||||
| 	return nil, nil | 	return nil, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func buildTestRequest() *ChatRequest { | func buildTestRequest() *openai.ChatRequest { | ||||||
| 	testRequest := &ChatRequest{ | 	testRequest := &openai.ChatRequest{ | ||||||
| 		Model:     "", // this will be set later | 		Model:     "", // this will be set later | ||||||
| 		MaxTokens: 1, | 		MaxTokens: 1, | ||||||
| 	} | 	} | ||||||
| 	testMessage := Message{ | 	testMessage := openai.Message{ | ||||||
| 		Role:    "user", | 		Role:    "user", | ||||||
| 		Content: "hi", | 		Content: "hi", | ||||||
| 	} | 	} | ||||||
| @@ -204,10 +206,10 @@ func testAllChannels(notify bool) error { | |||||||
| 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||||
| 			} | 			} | ||||||
| 			if isChannelEnabled && shouldDisableChannel(openaiErr, -1) { | 			if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { | ||||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||||
| 			} | 			} | ||||||
| 			if !isChannelEnabled && shouldEnableChannel(err, openaiErr) { | 			if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { | ||||||
| 				enableChannel(channel.Id, channel.Name) | 				enableChannel(channel.Id, channel.Name) | ||||||
| 			} | 			} | ||||||
| 			channel.UpdateResponseTime(milliseconds) | 			channel.UpdateResponseTime(milliseconds) | ||||||
|   | |||||||
| @@ -2,8 +2,8 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // https://platform.openai.com/docs/api-reference/models/list | // https://platform.openai.com/docs/api-reference/models/list | ||||||
| @@ -613,14 +613,14 @@ func RetrieveModel(c *gin.Context) { | |||||||
| 	if model, ok := openAIModelsMap[modelId]; ok { | 	if model, ok := openAIModelsMap[modelId]; ok { | ||||||
| 		c.JSON(200, model) | 		c.JSON(200, model) | ||||||
| 	} else { | 	} else { | ||||||
| 		openAIError := OpenAIError{ | 		Error := openai.Error{ | ||||||
| 			Message: fmt.Sprintf("The model '%s' does not exist", modelId), | 			Message: fmt.Sprintf("The model '%s' does not exist", modelId), | ||||||
| 			Type:    "invalid_request_error", | 			Type:    "invalid_request_error", | ||||||
| 			Param:   "model", | 			Param:   "model", | ||||||
| 			Code:    "model_not_found", | 			Code:    "model_not_found", | ||||||
| 		} | 		} | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"error": openAIError, | 			"error": Error, | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -4,349 +4,53 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
|  | 	"one-api/relay/constant" | ||||||
|  | 	"one-api/relay/controller" | ||||||
|  | 	"one-api/relay/util" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type Message struct { |  | ||||||
| 	Role    string  `json:"role"` |  | ||||||
| 	Content any     `json:"content"` |  | ||||||
| 	Name    *string `json:"name,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ImageURL struct { |  | ||||||
| 	Url    string `json:"url,omitempty"` |  | ||||||
| 	Detail string `json:"detail,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type TextContent struct { |  | ||||||
| 	Type string `json:"type,omitempty"` |  | ||||||
| 	Text string `json:"text,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ImageContent struct { |  | ||||||
| 	Type     string    `json:"type,omitempty"` |  | ||||||
| 	ImageURL *ImageURL `json:"image_url,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	ContentTypeText     = "text" |  | ||||||
| 	ContentTypeImageURL = "image_url" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type OpenAIMessageContent struct { |  | ||||||
| 	Type     string    `json:"type,omitempty"` |  | ||||||
| 	Text     string    `json:"text"` |  | ||||||
| 	ImageURL *ImageURL `json:"image_url,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (m Message) IsStringContent() bool { |  | ||||||
| 	_, ok := m.Content.(string) |  | ||||||
| 	return ok |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (m Message) StringContent() string { |  | ||||||
| 	content, ok := m.Content.(string) |  | ||||||
| 	if ok { |  | ||||||
| 		return content |  | ||||||
| 	} |  | ||||||
| 	contentList, ok := m.Content.([]any) |  | ||||||
| 	if ok { |  | ||||||
| 		var contentStr string |  | ||||||
| 		for _, contentItem := range contentList { |  | ||||||
| 			contentMap, ok := contentItem.(map[string]any) |  | ||||||
| 			if !ok { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			if contentMap["type"] == ContentTypeText { |  | ||||||
| 				if subStr, ok := contentMap["text"].(string); ok { |  | ||||||
| 					contentStr += subStr |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		return contentStr |  | ||||||
| 	} |  | ||||||
| 	return "" |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (m Message) ParseContent() []OpenAIMessageContent { |  | ||||||
| 	var contentList []OpenAIMessageContent |  | ||||||
| 	content, ok := m.Content.(string) |  | ||||||
| 	if ok { |  | ||||||
| 		contentList = append(contentList, OpenAIMessageContent{ |  | ||||||
| 			Type: ContentTypeText, |  | ||||||
| 			Text: content, |  | ||||||
| 		}) |  | ||||||
| 		return contentList |  | ||||||
| 	} |  | ||||||
| 	anyList, ok := m.Content.([]any) |  | ||||||
| 	if ok { |  | ||||||
| 		for _, contentItem := range anyList { |  | ||||||
| 			contentMap, ok := contentItem.(map[string]any) |  | ||||||
| 			if !ok { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			switch contentMap["type"] { |  | ||||||
| 			case ContentTypeText: |  | ||||||
| 				if subStr, ok := contentMap["text"].(string); ok { |  | ||||||
| 					contentList = append(contentList, OpenAIMessageContent{ |  | ||||||
| 						Type: ContentTypeText, |  | ||||||
| 						Text: subStr, |  | ||||||
| 					}) |  | ||||||
| 				} |  | ||||||
| 			case ContentTypeImageURL: |  | ||||||
| 				if subObj, ok := contentMap["image_url"].(map[string]any); ok { |  | ||||||
| 					contentList = append(contentList, OpenAIMessageContent{ |  | ||||||
| 						Type: ContentTypeImageURL, |  | ||||||
| 						ImageURL: &ImageURL{ |  | ||||||
| 							Url: subObj["url"].(string), |  | ||||||
| 						}, |  | ||||||
| 					}) |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		return contentList |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	RelayModeUnknown = iota |  | ||||||
| 	RelayModeChatCompletions |  | ||||||
| 	RelayModeCompletions |  | ||||||
| 	RelayModeEmbeddings |  | ||||||
| 	RelayModeModerations |  | ||||||
| 	RelayModeImagesGenerations |  | ||||||
| 	RelayModeEdits |  | ||||||
| 	RelayModeAudioSpeech |  | ||||||
| 	RelayModeAudioTranscription |  | ||||||
| 	RelayModeAudioTranslation |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // https://platform.openai.com/docs/api-reference/chat | // https://platform.openai.com/docs/api-reference/chat | ||||||
|  |  | ||||||
| type ResponseFormat struct { |  | ||||||
| 	Type string `json:"type,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type GeneralOpenAIRequest struct { |  | ||||||
| 	Model            string          `json:"model,omitempty"` |  | ||||||
| 	Messages         []Message       `json:"messages,omitempty"` |  | ||||||
| 	Prompt           any             `json:"prompt,omitempty"` |  | ||||||
| 	Stream           bool            `json:"stream,omitempty"` |  | ||||||
| 	MaxTokens        int             `json:"max_tokens,omitempty"` |  | ||||||
| 	Temperature      float64         `json:"temperature,omitempty"` |  | ||||||
| 	TopP             float64         `json:"top_p,omitempty"` |  | ||||||
| 	N                int             `json:"n,omitempty"` |  | ||||||
| 	Input            any             `json:"input,omitempty"` |  | ||||||
| 	Instruction      string          `json:"instruction,omitempty"` |  | ||||||
| 	Size             string          `json:"size,omitempty"` |  | ||||||
| 	Functions        any             `json:"functions,omitempty"` |  | ||||||
| 	FrequencyPenalty float64         `json:"frequency_penalty,omitempty"` |  | ||||||
| 	PresencePenalty  float64         `json:"presence_penalty,omitempty"` |  | ||||||
| 	ResponseFormat   *ResponseFormat `json:"response_format,omitempty"` |  | ||||||
| 	Seed             float64         `json:"seed,omitempty"` |  | ||||||
| 	Tools            any             `json:"tools,omitempty"` |  | ||||||
| 	ToolChoice       any             `json:"tool_choice,omitempty"` |  | ||||||
| 	User             string          `json:"user,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (r GeneralOpenAIRequest) ParseInput() []string { |  | ||||||
| 	if r.Input == nil { |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
| 	var input []string |  | ||||||
| 	switch r.Input.(type) { |  | ||||||
| 	case string: |  | ||||||
| 		input = []string{r.Input.(string)} |  | ||||||
| 	case []any: |  | ||||||
| 		input = make([]string, 0, len(r.Input.([]any))) |  | ||||||
| 		for _, item := range r.Input.([]any) { |  | ||||||
| 			if str, ok := item.(string); ok { |  | ||||||
| 				input = append(input, str) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return input |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ChatRequest struct { |  | ||||||
| 	Model     string    `json:"model"` |  | ||||||
| 	Messages  []Message `json:"messages"` |  | ||||||
| 	MaxTokens int       `json:"max_tokens"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type TextRequest struct { |  | ||||||
| 	Model     string    `json:"model"` |  | ||||||
| 	Messages  []Message `json:"messages"` |  | ||||||
| 	Prompt    string    `json:"prompt"` |  | ||||||
| 	MaxTokens int       `json:"max_tokens"` |  | ||||||
| 	//Stream   bool      `json:"stream"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create |  | ||||||
| type ImageRequest struct { |  | ||||||
| 	Model          string `json:"model"` |  | ||||||
| 	Prompt         string `json:"prompt" binding:"required"` |  | ||||||
| 	N              int    `json:"n,omitempty"` |  | ||||||
| 	Size           string `json:"size,omitempty"` |  | ||||||
| 	Quality        string `json:"quality,omitempty"` |  | ||||||
| 	ResponseFormat string `json:"response_format,omitempty"` |  | ||||||
| 	Style          string `json:"style,omitempty"` |  | ||||||
| 	User           string `json:"user,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type WhisperJSONResponse struct { |  | ||||||
| 	Text string `json:"text,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type WhisperVerboseJSONResponse struct { |  | ||||||
| 	Task     string    `json:"task,omitempty"` |  | ||||||
| 	Language string    `json:"language,omitempty"` |  | ||||||
| 	Duration float64   `json:"duration,omitempty"` |  | ||||||
| 	Text     string    `json:"text,omitempty"` |  | ||||||
| 	Segments []Segment `json:"segments,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type Segment struct { |  | ||||||
| 	Id               int     `json:"id"` |  | ||||||
| 	Seek             int     `json:"seek"` |  | ||||||
| 	Start            float64 `json:"start"` |  | ||||||
| 	End              float64 `json:"end"` |  | ||||||
| 	Text             string  `json:"text"` |  | ||||||
| 	Tokens           []int   `json:"tokens"` |  | ||||||
| 	Temperature      float64 `json:"temperature"` |  | ||||||
| 	AvgLogprob       float64 `json:"avg_logprob"` |  | ||||||
| 	CompressionRatio float64 `json:"compression_ratio"` |  | ||||||
| 	NoSpeechProb     float64 `json:"no_speech_prob"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type TextToSpeechRequest struct { |  | ||||||
| 	Model          string  `json:"model" binding:"required"` |  | ||||||
| 	Input          string  `json:"input" binding:"required"` |  | ||||||
| 	Voice          string  `json:"voice" binding:"required"` |  | ||||||
| 	Speed          float64 `json:"speed"` |  | ||||||
| 	ResponseFormat string  `json:"response_format"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type Usage struct { |  | ||||||
| 	PromptTokens     int `json:"prompt_tokens"` |  | ||||||
| 	CompletionTokens int `json:"completion_tokens"` |  | ||||||
| 	TotalTokens      int `json:"total_tokens"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type OpenAIError struct { |  | ||||||
| 	Message string `json:"message"` |  | ||||||
| 	Type    string `json:"type"` |  | ||||||
| 	Param   string `json:"param"` |  | ||||||
| 	Code    any    `json:"code"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type OpenAIErrorWithStatusCode struct { |  | ||||||
| 	OpenAIError |  | ||||||
| 	StatusCode int `json:"status_code"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type TextResponse struct { |  | ||||||
| 	Choices []OpenAITextResponseChoice `json:"choices"` |  | ||||||
| 	Usage   `json:"usage"` |  | ||||||
| 	Error   OpenAIError `json:"error"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type OpenAITextResponseChoice struct { |  | ||||||
| 	Index        int `json:"index"` |  | ||||||
| 	Message      `json:"message"` |  | ||||||
| 	FinishReason string `json:"finish_reason"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type OpenAITextResponse struct { |  | ||||||
| 	Id      string                     `json:"id"` |  | ||||||
| 	Model   string                     `json:"model,omitempty"` |  | ||||||
| 	Object  string                     `json:"object"` |  | ||||||
| 	Created int64                      `json:"created"` |  | ||||||
| 	Choices []OpenAITextResponseChoice `json:"choices"` |  | ||||||
| 	Usage   `json:"usage"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type OpenAIEmbeddingResponseItem struct { |  | ||||||
| 	Object    string    `json:"object"` |  | ||||||
| 	Index     int       `json:"index"` |  | ||||||
| 	Embedding []float64 `json:"embedding"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type OpenAIEmbeddingResponse struct { |  | ||||||
| 	Object string                        `json:"object"` |  | ||||||
| 	Data   []OpenAIEmbeddingResponseItem `json:"data"` |  | ||||||
| 	Model  string                        `json:"model"` |  | ||||||
| 	Usage  `json:"usage"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ImageResponse struct { |  | ||||||
| 	Created int `json:"created"` |  | ||||||
| 	Data    []struct { |  | ||||||
| 		Url string `json:"url"` |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ChatCompletionsStreamResponseChoice struct { |  | ||||||
| 	Delta struct { |  | ||||||
| 		Content string `json:"content"` |  | ||||||
| 	} `json:"delta"` |  | ||||||
| 	FinishReason *string `json:"finish_reason,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ChatCompletionsStreamResponse struct { |  | ||||||
| 	Id      string                                `json:"id"` |  | ||||||
| 	Object  string                                `json:"object"` |  | ||||||
| 	Created int64                                 `json:"created"` |  | ||||||
| 	Model   string                                `json:"model"` |  | ||||||
| 	Choices []ChatCompletionsStreamResponseChoice `json:"choices"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type CompletionsStreamResponse struct { |  | ||||||
| 	Choices []struct { |  | ||||||
| 		Text         string `json:"text"` |  | ||||||
| 		FinishReason string `json:"finish_reason"` |  | ||||||
| 	} `json:"choices"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func Relay(c *gin.Context) { | func Relay(c *gin.Context) { | ||||||
| 	relayMode := RelayModeUnknown | 	relayMode := constant.RelayModeUnknown | ||||||
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { | 	if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { | ||||||
| 		relayMode = RelayModeChatCompletions | 		relayMode = constant.RelayModeChatCompletions | ||||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { | 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { | ||||||
| 		relayMode = RelayModeCompletions | 		relayMode = constant.RelayModeCompletions | ||||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { | 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { | ||||||
| 		relayMode = RelayModeEmbeddings | 		relayMode = constant.RelayModeEmbeddings | ||||||
| 	} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | 	} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | ||||||
| 		relayMode = RelayModeEmbeddings | 		relayMode = constant.RelayModeEmbeddings | ||||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | ||||||
| 		relayMode = RelayModeModerations | 		relayMode = constant.RelayModeModerations | ||||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | ||||||
| 		relayMode = RelayModeImagesGenerations | 		relayMode = constant.RelayModeImagesGenerations | ||||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { | 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { | ||||||
| 		relayMode = RelayModeEdits | 		relayMode = constant.RelayModeEdits | ||||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { | 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { | ||||||
| 		relayMode = RelayModeAudioSpeech | 		relayMode = constant.RelayModeAudioSpeech | ||||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { | 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { | ||||||
| 		relayMode = RelayModeAudioTranscription | 		relayMode = constant.RelayModeAudioTranscription | ||||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { | 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { | ||||||
| 		relayMode = RelayModeAudioTranslation | 		relayMode = constant.RelayModeAudioTranslation | ||||||
| 	} | 	} | ||||||
| 	var err *OpenAIErrorWithStatusCode | 	var err *openai.ErrorWithStatusCode | ||||||
| 	switch relayMode { | 	switch relayMode { | ||||||
| 	case RelayModeImagesGenerations: | 	case constant.RelayModeImagesGenerations: | ||||||
| 		err = relayImageHelper(c, relayMode) | 		err = controller.RelayImageHelper(c, relayMode) | ||||||
| 	case RelayModeAudioSpeech: | 	case constant.RelayModeAudioSpeech: | ||||||
| 		fallthrough | 		fallthrough | ||||||
| 	case RelayModeAudioTranslation: | 	case constant.RelayModeAudioTranslation: | ||||||
| 		fallthrough | 		fallthrough | ||||||
| 	case RelayModeAudioTranscription: | 	case constant.RelayModeAudioTranscription: | ||||||
| 		err = relayAudioHelper(c, relayMode) | 		err = controller.RelayAudioHelper(c, relayMode) | ||||||
| 	default: | 	default: | ||||||
| 		err = relayTextHelper(c, relayMode) | 		err = controller.RelayTextHelper(c, relayMode) | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		requestId := c.GetString(common.RequestIdKey) | 		requestId := c.GetString(common.RequestIdKey) | ||||||
| @@ -359,17 +63,17 @@ func Relay(c *gin.Context) { | |||||||
| 			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) | 			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) | ||||||
| 		} else { | 		} else { | ||||||
| 			if err.StatusCode == http.StatusTooManyRequests { | 			if err.StatusCode == http.StatusTooManyRequests { | ||||||
| 				err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" | 				err.Error.Message = "当前分组上游负载已饱和,请稍后再试" | ||||||
| 			} | 			} | ||||||
| 			err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) | 			err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId) | ||||||
| 			c.JSON(err.StatusCode, gin.H{ | 			c.JSON(err.StatusCode, gin.H{ | ||||||
| 				"error": err.OpenAIError, | 				"error": err.Error, | ||||||
| 			}) | 			}) | ||||||
| 		} | 		} | ||||||
| 		channelId := c.GetInt("channel_id") | 		channelId := c.GetInt("channel_id") | ||||||
| 		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) | 		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) | ||||||
| 		// https://platform.openai.com/docs/guides/error-codes/api-errors | 		// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||||
| 		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { | 		if util.ShouldDisableChannel(&err.Error, err.StatusCode) { | ||||||
| 			channelId := c.GetInt("channel_id") | 			channelId := c.GetInt("channel_id") | ||||||
| 			channelName := c.GetString("channel_name") | 			channelName := c.GetString("channel_name") | ||||||
| 			disableChannel(channelId, channelName, err.Message) | 			disableChannel(channelId, channelName, err.Message) | ||||||
| @@ -378,7 +82,7 @@ func Relay(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func RelayNotImplemented(c *gin.Context) { | func RelayNotImplemented(c *gin.Context) { | ||||||
| 	err := OpenAIError{ | 	err := openai.Error{ | ||||||
| 		Message: "API not implemented", | 		Message: "API not implemented", | ||||||
| 		Type:    "one_api_error", | 		Type:    "one_api_error", | ||||||
| 		Param:   "", | 		Param:   "", | ||||||
| @@ -390,7 +94,7 @@ func RelayNotImplemented(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func RelayNotFound(c *gin.Context) { | func RelayNotFound(c *gin.Context) { | ||||||
| 	err := OpenAIError{ | 	err := openai.Error{ | ||||||
| 		Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), | 		Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), | ||||||
| 		Type:    "invalid_request_error", | 		Type:    "invalid_request_error", | ||||||
| 		Param:   "", | 		Param:   "", | ||||||
|   | |||||||
							
								
								
									
										3
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								main.go
									
									
									
									
									
								
							| @@ -10,6 +10,7 @@ import ( | |||||||
| 	"one-api/controller" | 	"one-api/controller" | ||||||
| 	"one-api/middleware" | 	"one-api/middleware" | ||||||
| 	"one-api/model" | 	"one-api/model" | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
| 	"one-api/router" | 	"one-api/router" | ||||||
| 	"os" | 	"os" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| @@ -80,7 +81,7 @@ func main() { | |||||||
| 		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") | 		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") | ||||||
| 		model.InitBatchUpdater() | 		model.InitBatchUpdater() | ||||||
| 	} | 	} | ||||||
| 	controller.InitTokenEncoders() | 	openai.InitTokenEncoders() | ||||||
|  |  | ||||||
| 	// Initialize HTTP server | 	// Initialize HTTP server | ||||||
| 	server := gin.New() | 	server := gin.New() | ||||||
|   | |||||||
| @@ -16,7 +16,7 @@ var DB *gorm.DB | |||||||
|  |  | ||||||
| func createRootAccountIfNeed() error { | func createRootAccountIfNeed() error { | ||||||
| 	var user User | 	var user User | ||||||
| 	//if user.Status != common.UserStatusEnabled { | 	//if user.Status != util.UserStatusEnabled { | ||||||
| 	if err := DB.First(&user).Error; err != nil { | 	if err := DB.First(&user).Error; err != nil { | ||||||
| 		common.SysLog("no user exists, create a root user for you: username is root, password is 123456") | 		common.SysLog("no user exists, create a root user for you: username is root, password is 123456") | ||||||
| 		hashedPassword, err := common.Password2Hash("123456") | 		hashedPassword, err := common.Password2Hash("123456") | ||||||
|   | |||||||
| @@ -15,7 +15,7 @@ type User struct { | |||||||
| 	Username         string `json:"username" gorm:"unique;index" validate:"max=12"` | 	Username         string `json:"username" gorm:"unique;index" validate:"max=12"` | ||||||
| 	Password         string `json:"password" gorm:"not null;" validate:"min=8,max=20"` | 	Password         string `json:"password" gorm:"not null;" validate:"min=8,max=20"` | ||||||
| 	DisplayName      string `json:"display_name" gorm:"index" validate:"max=20"` | 	DisplayName      string `json:"display_name" gorm:"index" validate:"max=20"` | ||||||
| 	Role             int    `json:"role" gorm:"type:int;default:1"`   // admin, common | 	Role             int    `json:"role" gorm:"type:int;default:1"`   // admin, util | ||||||
| 	Status           int    `json:"status" gorm:"type:int;default:1"` // enabled, disabled | 	Status           int    `json:"status" gorm:"type:int;default:1"` // enabled, disabled | ||||||
| 	Email            string `json:"email" gorm:"index" validate:"max=50"` | 	Email            string `json:"email" gorm:"index" validate:"max=50"` | ||||||
| 	GitHubId         string `json:"github_id" gorm:"column:github_id;index"` | 	GitHubId         string `json:"github_id" gorm:"column:github_id;index"` | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| package controller | package aiproxy | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"bufio" | 	"bufio" | ||||||
| @@ -8,56 +8,27 @@ import ( | |||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
|  | 	"one-api/relay/constant" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 | // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 | ||||||
| 
 | 
 | ||||||
| type AIProxyLibraryRequest struct { | func ConvertRequest(request openai.GeneralOpenAIRequest) *LibraryRequest { | ||||||
| 	Model     string `json:"model"` |  | ||||||
| 	Query     string `json:"query"` |  | ||||||
| 	LibraryId string `json:"libraryId"` |  | ||||||
| 	Stream    bool   `json:"stream"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type AIProxyLibraryError struct { |  | ||||||
| 	ErrCode int    `json:"errCode"` |  | ||||||
| 	Message string `json:"message"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type AIProxyLibraryDocument struct { |  | ||||||
| 	Title string `json:"title"` |  | ||||||
| 	URL   string `json:"url"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type AIProxyLibraryResponse struct { |  | ||||||
| 	Success   bool                     `json:"success"` |  | ||||||
| 	Answer    string                   `json:"answer"` |  | ||||||
| 	Documents []AIProxyLibraryDocument `json:"documents"` |  | ||||||
| 	AIProxyLibraryError |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type AIProxyLibraryStreamResponse struct { |  | ||||||
| 	Content   string                   `json:"content"` |  | ||||||
| 	Finish    bool                     `json:"finish"` |  | ||||||
| 	Model     string                   `json:"model"` |  | ||||||
| 	Documents []AIProxyLibraryDocument `json:"documents"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { |  | ||||||
| 	query := "" | 	query := "" | ||||||
| 	if len(request.Messages) != 0 { | 	if len(request.Messages) != 0 { | ||||||
| 		query = request.Messages[len(request.Messages)-1].StringContent() | 		query = request.Messages[len(request.Messages)-1].StringContent() | ||||||
| 	} | 	} | ||||||
| 	return &AIProxyLibraryRequest{ | 	return &LibraryRequest{ | ||||||
| 		Model:  request.Model, | 		Model:  request.Model, | ||||||
| 		Stream: request.Stream, | 		Stream: request.Stream, | ||||||
| 		Query:  query, | 		Query:  query, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { | func aiProxyDocuments2Markdown(documents []LibraryDocument) string { | ||||||
| 	if len(documents) == 0 { | 	if len(documents) == 0 { | ||||||
| 		return "" | 		return "" | ||||||
| 	} | 	} | ||||||
| @@ -68,52 +39,52 @@ func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { | |||||||
| 	return content | 	return content | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { | func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextResponse { | ||||||
| 	content := response.Answer + aiProxyDocuments2Markdown(response.Documents) | 	content := response.Answer + aiProxyDocuments2Markdown(response.Documents) | ||||||
| 	choice := OpenAITextResponseChoice{ | 	choice := openai.TextResponseChoice{ | ||||||
| 		Index: 0, | 		Index: 0, | ||||||
| 		Message: Message{ | 		Message: openai.Message{ | ||||||
| 			Role:    "assistant", | 			Role:    "assistant", | ||||||
| 			Content: content, | 			Content: content, | ||||||
| 		}, | 		}, | ||||||
| 		FinishReason: "stop", | 		FinishReason: "stop", | ||||||
| 	} | 	} | ||||||
| 	fullTextResponse := OpenAITextResponse{ | 	fullTextResponse := openai.TextResponse{ | ||||||
| 		Id:      common.GetUUID(), | 		Id:      common.GetUUID(), | ||||||
| 		Object:  "chat.completion", | 		Object:  "chat.completion", | ||||||
| 		Created: common.GetTimestamp(), | 		Created: common.GetTimestamp(), | ||||||
| 		Choices: []OpenAITextResponseChoice{choice}, | 		Choices: []openai.TextResponseChoice{choice}, | ||||||
| 	} | 	} | ||||||
| 	return &fullTextResponse | 	return &fullTextResponse | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse { | func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletionsStreamResponse { | ||||||
| 	var choice ChatCompletionsStreamResponseChoice | 	var choice openai.ChatCompletionsStreamResponseChoice | ||||||
| 	choice.Delta.Content = aiProxyDocuments2Markdown(documents) | 	choice.Delta.Content = aiProxyDocuments2Markdown(documents) | ||||||
| 	choice.FinishReason = &stopFinishReason | 	choice.FinishReason = &constant.StopFinishReason | ||||||
| 	return &ChatCompletionsStreamResponse{ | 	return &openai.ChatCompletionsStreamResponse{ | ||||||
| 		Id:      common.GetUUID(), | 		Id:      common.GetUUID(), | ||||||
| 		Object:  "chat.completion.chunk", | 		Object:  "chat.completion.chunk", | ||||||
| 		Created: common.GetTimestamp(), | 		Created: common.GetTimestamp(), | ||||||
| 		Model:   "", | 		Model:   "", | ||||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse { | func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *openai.ChatCompletionsStreamResponse { | ||||||
| 	var choice ChatCompletionsStreamResponseChoice | 	var choice openai.ChatCompletionsStreamResponseChoice | ||||||
| 	choice.Delta.Content = response.Content | 	choice.Delta.Content = response.Content | ||||||
| 	return &ChatCompletionsStreamResponse{ | 	return &openai.ChatCompletionsStreamResponse{ | ||||||
| 		Id:      common.GetUUID(), | 		Id:      common.GetUUID(), | ||||||
| 		Object:  "chat.completion.chunk", | 		Object:  "chat.completion.chunk", | ||||||
| 		Created: common.GetTimestamp(), | 		Created: common.GetTimestamp(), | ||||||
| 		Model:   response.Model, | 		Model:   response.Model, | ||||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||||
| 	var usage Usage | 	var usage openai.Usage | ||||||
| 	scanner := bufio.NewScanner(resp.Body) | 	scanner := bufio.NewScanner(resp.Body) | ||||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
| 		if atEOF && len(data) == 0 { | 		if atEOF && len(data) == 0 { | ||||||
| @@ -143,12 +114,12 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr | |||||||
| 		} | 		} | ||||||
| 		stopChan <- true | 		stopChan <- true | ||||||
| 	}() | 	}() | ||||||
| 	setEventStreamHeaders(c) | 	common.SetEventStreamHeaders(c) | ||||||
| 	var documents []AIProxyLibraryDocument | 	var documents []LibraryDocument | ||||||
| 	c.Stream(func(w io.Writer) bool { | 	c.Stream(func(w io.Writer) bool { | ||||||
| 		select { | 		select { | ||||||
| 		case data := <-dataChan: | 		case data := <-dataChan: | ||||||
| 			var AIProxyLibraryResponse AIProxyLibraryStreamResponse | 			var AIProxyLibraryResponse LibraryStreamResponse | ||||||
| 			err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) | 			err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
| @@ -179,28 +150,28 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr | |||||||
| 	}) | 	}) | ||||||
| 	err := resp.Body.Close() | 	err := resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	return nil, &usage | 	return nil, &usage | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||||
| 	var AIProxyLibraryResponse AIProxyLibraryResponse | 	var AIProxyLibraryResponse LibraryResponse | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	err = resp.Body.Close() | 	err = resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(responseBody, &AIProxyLibraryResponse) | 	err = json.Unmarshal(responseBody, &AIProxyLibraryResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	if AIProxyLibraryResponse.ErrCode != 0 { | 	if AIProxyLibraryResponse.ErrCode != 0 { | ||||||
| 		return &OpenAIErrorWithStatusCode{ | 		return &openai.ErrorWithStatusCode{ | ||||||
| 			OpenAIError: OpenAIError{ | 			Error: openai.Error{ | ||||||
| 				Message: AIProxyLibraryResponse.Message, | 				Message: AIProxyLibraryResponse.Message, | ||||||
| 				Type:    strconv.Itoa(AIProxyLibraryResponse.ErrCode), | 				Type:    strconv.Itoa(AIProxyLibraryResponse.ErrCode), | ||||||
| 				Code:    AIProxyLibraryResponse.ErrCode, | 				Code:    AIProxyLibraryResponse.ErrCode, | ||||||
| @@ -211,7 +182,7 @@ func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit | |||||||
| 	fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) | 	fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
							
								
								
									
										32
									
								
								relay/channel/aiproxy/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								relay/channel/aiproxy/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | |||||||
|  | package aiproxy | ||||||
|  |  | ||||||
|  | type LibraryRequest struct { | ||||||
|  | 	Model     string `json:"model"` | ||||||
|  | 	Query     string `json:"query"` | ||||||
|  | 	LibraryId string `json:"libraryId"` | ||||||
|  | 	Stream    bool   `json:"stream"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type LibraryError struct { | ||||||
|  | 	ErrCode int    `json:"errCode"` | ||||||
|  | 	Message string `json:"message"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type LibraryDocument struct { | ||||||
|  | 	Title string `json:"title"` | ||||||
|  | 	URL   string `json:"url"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type LibraryResponse struct { | ||||||
|  | 	Success   bool              `json:"success"` | ||||||
|  | 	Answer    string            `json:"answer"` | ||||||
|  | 	Documents []LibraryDocument `json:"documents"` | ||||||
|  | 	LibraryError | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type LibraryStreamResponse struct { | ||||||
|  | 	Content   string            `json:"content"` | ||||||
|  | 	Finish    bool              `json:"finish"` | ||||||
|  | 	Model     string            `json:"model"` | ||||||
|  | 	Documents []LibraryDocument `json:"documents"` | ||||||
|  | } | ||||||
| @@ -1,4 +1,4 @@ | |||||||
| package controller | package ali | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"bufio" | 	"bufio" | ||||||
| @@ -7,112 +7,43 @@ import ( | |||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r | // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r | ||||||
| 
 | 
 | ||||||
| type AliMessage struct { | const EnableSearchModelSuffix = "-internet" | ||||||
| 	Content string `json:"content"` |  | ||||||
| 	Role    string `json:"role"` |  | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| type AliInput struct { | func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest { | ||||||
| 	//Prompt   string       `json:"prompt"` | 	messages := make([]Message, 0, len(request.Messages)) | ||||||
| 	Messages []AliMessage `json:"messages"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type AliParameters struct { |  | ||||||
| 	TopP              float64 `json:"top_p,omitempty"` |  | ||||||
| 	TopK              int     `json:"top_k,omitempty"` |  | ||||||
| 	Seed              uint64  `json:"seed,omitempty"` |  | ||||||
| 	EnableSearch      bool    `json:"enable_search,omitempty"` |  | ||||||
| 	IncrementalOutput bool    `json:"incremental_output,omitempty"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type AliChatRequest struct { |  | ||||||
| 	Model      string        `json:"model"` |  | ||||||
| 	Input      AliInput      `json:"input"` |  | ||||||
| 	Parameters AliParameters `json:"parameters,omitempty"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type AliEmbeddingRequest struct { |  | ||||||
| 	Model string `json:"model"` |  | ||||||
| 	Input struct { |  | ||||||
| 		Texts []string `json:"texts"` |  | ||||||
| 	} `json:"input"` |  | ||||||
| 	Parameters *struct { |  | ||||||
| 		TextType string `json:"text_type,omitempty"` |  | ||||||
| 	} `json:"parameters,omitempty"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type AliEmbedding struct { |  | ||||||
| 	Embedding []float64 `json:"embedding"` |  | ||||||
| 	TextIndex int       `json:"text_index"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type AliEmbeddingResponse struct { |  | ||||||
| 	Output struct { |  | ||||||
| 		Embeddings []AliEmbedding `json:"embeddings"` |  | ||||||
| 	} `json:"output"` |  | ||||||
| 	Usage AliUsage `json:"usage"` |  | ||||||
| 	AliError |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type AliError struct { |  | ||||||
| 	Code      string `json:"code"` |  | ||||||
| 	Message   string `json:"message"` |  | ||||||
| 	RequestId string `json:"request_id"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type AliUsage struct { |  | ||||||
| 	InputTokens  int `json:"input_tokens"` |  | ||||||
| 	OutputTokens int `json:"output_tokens"` |  | ||||||
| 	TotalTokens  int `json:"total_tokens"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type AliOutput struct { |  | ||||||
| 	Text         string `json:"text"` |  | ||||||
| 	FinishReason string `json:"finish_reason"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type AliChatResponse struct { |  | ||||||
| 	Output AliOutput `json:"output"` |  | ||||||
| 	Usage  AliUsage  `json:"usage"` |  | ||||||
| 	AliError |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| const AliEnableSearchModelSuffix = "-internet" |  | ||||||
| 
 |  | ||||||
| func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { |  | ||||||
| 	messages := make([]AliMessage, 0, len(request.Messages)) |  | ||||||
| 	for i := 0; i < len(request.Messages); i++ { | 	for i := 0; i < len(request.Messages); i++ { | ||||||
| 		message := request.Messages[i] | 		message := request.Messages[i] | ||||||
| 		messages = append(messages, AliMessage{ | 		messages = append(messages, Message{ | ||||||
| 			Content: message.StringContent(), | 			Content: message.StringContent(), | ||||||
| 			Role:    strings.ToLower(message.Role), | 			Role:    strings.ToLower(message.Role), | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| 	enableSearch := false | 	enableSearch := false | ||||||
| 	aliModel := request.Model | 	aliModel := request.Model | ||||||
| 	if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) { | 	if strings.HasSuffix(aliModel, EnableSearchModelSuffix) { | ||||||
| 		enableSearch = true | 		enableSearch = true | ||||||
| 		aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix) | 		aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) | ||||||
| 	} | 	} | ||||||
| 	return &AliChatRequest{ | 	return &ChatRequest{ | ||||||
| 		Model: aliModel, | 		Model: aliModel, | ||||||
| 		Input: AliInput{ | 		Input: Input{ | ||||||
| 			Messages: messages, | 			Messages: messages, | ||||||
| 		}, | 		}, | ||||||
| 		Parameters: AliParameters{ | 		Parameters: Parameters{ | ||||||
| 			EnableSearch:      enableSearch, | 			EnableSearch:      enableSearch, | ||||||
| 			IncrementalOutput: request.Stream, | 			IncrementalOutput: request.Stream, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest { | func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest { | ||||||
| 	return &AliEmbeddingRequest{ | 	return &EmbeddingRequest{ | ||||||
| 		Model: "text-embedding-v1", | 		Model: "text-embedding-v1", | ||||||
| 		Input: struct { | 		Input: struct { | ||||||
| 			Texts []string `json:"texts"` | 			Texts []string `json:"texts"` | ||||||
| @@ -122,21 +53,21 @@ func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingReque | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||||
| 	var aliResponse AliEmbeddingResponse | 	var aliResponse EmbeddingResponse | ||||||
| 	err := json.NewDecoder(resp.Body).Decode(&aliResponse) | 	err := json.NewDecoder(resp.Body).Decode(&aliResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = resp.Body.Close() | 	err = resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if aliResponse.Code != "" { | 	if aliResponse.Code != "" { | ||||||
| 		return &OpenAIErrorWithStatusCode{ | 		return &openai.ErrorWithStatusCode{ | ||||||
| 			OpenAIError: OpenAIError{ | 			Error: openai.Error{ | ||||||
| 				Message: aliResponse.Message, | 				Message: aliResponse.Message, | ||||||
| 				Type:    aliResponse.Code, | 				Type:    aliResponse.Code, | ||||||
| 				Param:   aliResponse.RequestId, | 				Param:   aliResponse.RequestId, | ||||||
| @@ -149,7 +80,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS | |||||||
| 	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) | 	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
| @@ -157,16 +88,16 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS | |||||||
| 	return nil, &fullTextResponse.Usage | 	return nil, &fullTextResponse.Usage | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse { | func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { | ||||||
| 	openAIEmbeddingResponse := OpenAIEmbeddingResponse{ | 	openAIEmbeddingResponse := openai.EmbeddingResponse{ | ||||||
| 		Object: "list", | 		Object: "list", | ||||||
| 		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)), | 		Data:   make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)), | ||||||
| 		Model:  "text-embedding-v1", | 		Model:  "text-embedding-v1", | ||||||
| 		Usage:  Usage{TotalTokens: response.Usage.TotalTokens}, | 		Usage:  openai.Usage{TotalTokens: response.Usage.TotalTokens}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, item := range response.Output.Embeddings { | 	for _, item := range response.Output.Embeddings { | ||||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ | 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ | ||||||
| 			Object:    `embedding`, | 			Object:    `embedding`, | ||||||
| 			Index:     item.TextIndex, | 			Index:     item.TextIndex, | ||||||
| 			Embedding: item.Embedding, | 			Embedding: item.Embedding, | ||||||
| @@ -175,21 +106,21 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddin | |||||||
| 	return &openAIEmbeddingResponse | 	return &openAIEmbeddingResponse | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { | func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||||
| 	choice := OpenAITextResponseChoice{ | 	choice := openai.TextResponseChoice{ | ||||||
| 		Index: 0, | 		Index: 0, | ||||||
| 		Message: Message{ | 		Message: openai.Message{ | ||||||
| 			Role:    "assistant", | 			Role:    "assistant", | ||||||
| 			Content: response.Output.Text, | 			Content: response.Output.Text, | ||||||
| 		}, | 		}, | ||||||
| 		FinishReason: response.Output.FinishReason, | 		FinishReason: response.Output.FinishReason, | ||||||
| 	} | 	} | ||||||
| 	fullTextResponse := OpenAITextResponse{ | 	fullTextResponse := openai.TextResponse{ | ||||||
| 		Id:      response.RequestId, | 		Id:      response.RequestId, | ||||||
| 		Object:  "chat.completion", | 		Object:  "chat.completion", | ||||||
| 		Created: common.GetTimestamp(), | 		Created: common.GetTimestamp(), | ||||||
| 		Choices: []OpenAITextResponseChoice{choice}, | 		Choices: []openai.TextResponseChoice{choice}, | ||||||
| 		Usage: Usage{ | 		Usage: openai.Usage{ | ||||||
| 			PromptTokens:     response.Usage.InputTokens, | 			PromptTokens:     response.Usage.InputTokens, | ||||||
| 			CompletionTokens: response.Usage.OutputTokens, | 			CompletionTokens: response.Usage.OutputTokens, | ||||||
| 			TotalTokens:      response.Usage.InputTokens + response.Usage.OutputTokens, | 			TotalTokens:      response.Usage.InputTokens + response.Usage.OutputTokens, | ||||||
| @@ -198,25 +129,25 @@ func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { | |||||||
| 	return &fullTextResponse | 	return &fullTextResponse | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse { | func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { | ||||||
| 	var choice ChatCompletionsStreamResponseChoice | 	var choice openai.ChatCompletionsStreamResponseChoice | ||||||
| 	choice.Delta.Content = aliResponse.Output.Text | 	choice.Delta.Content = aliResponse.Output.Text | ||||||
| 	if aliResponse.Output.FinishReason != "null" { | 	if aliResponse.Output.FinishReason != "null" { | ||||||
| 		finishReason := aliResponse.Output.FinishReason | 		finishReason := aliResponse.Output.FinishReason | ||||||
| 		choice.FinishReason = &finishReason | 		choice.FinishReason = &finishReason | ||||||
| 	} | 	} | ||||||
| 	response := ChatCompletionsStreamResponse{ | 	response := openai.ChatCompletionsStreamResponse{ | ||||||
| 		Id:      aliResponse.RequestId, | 		Id:      aliResponse.RequestId, | ||||||
| 		Object:  "chat.completion.chunk", | 		Object:  "chat.completion.chunk", | ||||||
| 		Created: common.GetTimestamp(), | 		Created: common.GetTimestamp(), | ||||||
| 		Model:   "qwen", | 		Model:   "qwen", | ||||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||||
| 	} | 	} | ||||||
| 	return &response | 	return &response | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||||
| 	var usage Usage | 	var usage openai.Usage | ||||||
| 	scanner := bufio.NewScanner(resp.Body) | 	scanner := bufio.NewScanner(resp.Body) | ||||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
| 		if atEOF && len(data) == 0 { | 		if atEOF && len(data) == 0 { | ||||||
| @@ -246,12 +177,12 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat | |||||||
| 		} | 		} | ||||||
| 		stopChan <- true | 		stopChan <- true | ||||||
| 	}() | 	}() | ||||||
| 	setEventStreamHeaders(c) | 	common.SetEventStreamHeaders(c) | ||||||
| 	//lastResponseText := "" | 	//lastResponseText := "" | ||||||
| 	c.Stream(func(w io.Writer) bool { | 	c.Stream(func(w io.Writer) bool { | ||||||
| 		select { | 		select { | ||||||
| 		case data := <-dataChan: | 		case data := <-dataChan: | ||||||
| 			var aliResponse AliChatResponse | 			var aliResponse ChatResponse | ||||||
| 			err := json.Unmarshal([]byte(data), &aliResponse) | 			err := json.Unmarshal([]byte(data), &aliResponse) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
| @@ -279,28 +210,28 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat | |||||||
| 	}) | 	}) | ||||||
| 	err := resp.Body.Close() | 	err := resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	return nil, &usage | 	return nil, &usage | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||||
| 	var aliResponse AliChatResponse | 	var aliResponse ChatResponse | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	err = resp.Body.Close() | 	err = resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(responseBody, &aliResponse) | 	err = json.Unmarshal(responseBody, &aliResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	if aliResponse.Code != "" { | 	if aliResponse.Code != "" { | ||||||
| 		return &OpenAIErrorWithStatusCode{ | 		return &openai.ErrorWithStatusCode{ | ||||||
| 			OpenAIError: OpenAIError{ | 			Error: openai.Error{ | ||||||
| 				Message: aliResponse.Message, | 				Message: aliResponse.Message, | ||||||
| 				Type:    aliResponse.Code, | 				Type:    aliResponse.Code, | ||||||
| 				Param:   aliResponse.RequestId, | 				Param:   aliResponse.RequestId, | ||||||
| @@ -313,7 +244,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode | |||||||
| 	fullTextResponse.Model = "qwen" | 	fullTextResponse.Model = "qwen" | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
							
								
								
									
										71
									
								
								relay/channel/ali/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								relay/channel/ali/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,71 @@ | |||||||
|  | package ali | ||||||
|  |  | ||||||
|  | type Message struct { | ||||||
|  | 	Content string `json:"content"` | ||||||
|  | 	Role    string `json:"role"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Input struct { | ||||||
|  | 	//Prompt   string       `json:"prompt"` | ||||||
|  | 	Messages []Message `json:"messages"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Parameters struct { | ||||||
|  | 	TopP              float64 `json:"top_p,omitempty"` | ||||||
|  | 	TopK              int     `json:"top_k,omitempty"` | ||||||
|  | 	Seed              uint64  `json:"seed,omitempty"` | ||||||
|  | 	EnableSearch      bool    `json:"enable_search,omitempty"` | ||||||
|  | 	IncrementalOutput bool    `json:"incremental_output,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ChatRequest struct { | ||||||
|  | 	Model      string     `json:"model"` | ||||||
|  | 	Input      Input      `json:"input"` | ||||||
|  | 	Parameters Parameters `json:"parameters,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type EmbeddingRequest struct { | ||||||
|  | 	Model string `json:"model"` | ||||||
|  | 	Input struct { | ||||||
|  | 		Texts []string `json:"texts"` | ||||||
|  | 	} `json:"input"` | ||||||
|  | 	Parameters *struct { | ||||||
|  | 		TextType string `json:"text_type,omitempty"` | ||||||
|  | 	} `json:"parameters,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Embedding struct { | ||||||
|  | 	Embedding []float64 `json:"embedding"` | ||||||
|  | 	TextIndex int       `json:"text_index"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type EmbeddingResponse struct { | ||||||
|  | 	Output struct { | ||||||
|  | 		Embeddings []Embedding `json:"embeddings"` | ||||||
|  | 	} `json:"output"` | ||||||
|  | 	Usage Usage `json:"usage"` | ||||||
|  | 	Error | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Error struct { | ||||||
|  | 	Code      string `json:"code"` | ||||||
|  | 	Message   string `json:"message"` | ||||||
|  | 	RequestId string `json:"request_id"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Usage struct { | ||||||
|  | 	InputTokens  int `json:"input_tokens"` | ||||||
|  | 	OutputTokens int `json:"output_tokens"` | ||||||
|  | 	TotalTokens  int `json:"total_tokens"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Output struct { | ||||||
|  | 	Text         string `json:"text"` | ||||||
|  | 	FinishReason string `json:"finish_reason"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ChatResponse struct { | ||||||
|  | 	Output Output `json:"output"` | ||||||
|  | 	Usage  Usage  `json:"usage"` | ||||||
|  | 	Error | ||||||
|  | } | ||||||
| @@ -1,4 +1,4 @@ | |||||||
| package controller | package anthropic | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"bufio" | 	"bufio" | ||||||
| @@ -8,37 +8,10 @@ import ( | |||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type ClaudeMetadata struct { |  | ||||||
| 	UserId string `json:"user_id"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type ClaudeRequest struct { |  | ||||||
| 	Model             string   `json:"model"` |  | ||||||
| 	Prompt            string   `json:"prompt"` |  | ||||||
| 	MaxTokensToSample int      `json:"max_tokens_to_sample"` |  | ||||||
| 	StopSequences     []string `json:"stop_sequences,omitempty"` |  | ||||||
| 	Temperature       float64  `json:"temperature,omitempty"` |  | ||||||
| 	TopP              float64  `json:"top_p,omitempty"` |  | ||||||
| 	TopK              int      `json:"top_k,omitempty"` |  | ||||||
| 	//ClaudeMetadata    `json:"metadata,omitempty"` |  | ||||||
| 	Stream bool `json:"stream,omitempty"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type ClaudeError struct { |  | ||||||
| 	Type    string `json:"type"` |  | ||||||
| 	Message string `json:"message"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type ClaudeResponse struct { |  | ||||||
| 	Completion string      `json:"completion"` |  | ||||||
| 	StopReason string      `json:"stop_reason"` |  | ||||||
| 	Model      string      `json:"model"` |  | ||||||
| 	Error      ClaudeError `json:"error"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func stopReasonClaude2OpenAI(reason string) string { | func stopReasonClaude2OpenAI(reason string) string { | ||||||
| 	switch reason { | 	switch reason { | ||||||
| 	case "stop_sequence": | 	case "stop_sequence": | ||||||
| @@ -50,8 +23,8 @@ func stopReasonClaude2OpenAI(reason string) string { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { | func ConvertRequest(textRequest openai.GeneralOpenAIRequest) *Request { | ||||||
| 	claudeRequest := ClaudeRequest{ | 	claudeRequest := Request{ | ||||||
| 		Model:             textRequest.Model, | 		Model:             textRequest.Model, | ||||||
| 		Prompt:            "", | 		Prompt:            "", | ||||||
| 		MaxTokensToSample: textRequest.MaxTokens, | 		MaxTokensToSample: textRequest.MaxTokens, | ||||||
| @@ -80,40 +53,40 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { | |||||||
| 	return &claudeRequest | 	return &claudeRequest | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse { | func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletionsStreamResponse { | ||||||
| 	var choice ChatCompletionsStreamResponseChoice | 	var choice openai.ChatCompletionsStreamResponseChoice | ||||||
| 	choice.Delta.Content = claudeResponse.Completion | 	choice.Delta.Content = claudeResponse.Completion | ||||||
| 	finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) | 	finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) | ||||||
| 	if finishReason != "null" { | 	if finishReason != "null" { | ||||||
| 		choice.FinishReason = &finishReason | 		choice.FinishReason = &finishReason | ||||||
| 	} | 	} | ||||||
| 	var response ChatCompletionsStreamResponse | 	var response openai.ChatCompletionsStreamResponse | ||||||
| 	response.Object = "chat.completion.chunk" | 	response.Object = "chat.completion.chunk" | ||||||
| 	response.Model = claudeResponse.Model | 	response.Model = claudeResponse.Model | ||||||
| 	response.Choices = []ChatCompletionsStreamResponseChoice{choice} | 	response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} | ||||||
| 	return &response | 	return &response | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { | func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { | ||||||
| 	choice := OpenAITextResponseChoice{ | 	choice := openai.TextResponseChoice{ | ||||||
| 		Index: 0, | 		Index: 0, | ||||||
| 		Message: Message{ | 		Message: openai.Message{ | ||||||
| 			Role:    "assistant", | 			Role:    "assistant", | ||||||
| 			Content: strings.TrimPrefix(claudeResponse.Completion, " "), | 			Content: strings.TrimPrefix(claudeResponse.Completion, " "), | ||||||
| 			Name:    nil, | 			Name:    nil, | ||||||
| 		}, | 		}, | ||||||
| 		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), | 		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), | ||||||
| 	} | 	} | ||||||
| 	fullTextResponse := OpenAITextResponse{ | 	fullTextResponse := openai.TextResponse{ | ||||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()), | 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()), | ||||||
| 		Object:  "chat.completion", | 		Object:  "chat.completion", | ||||||
| 		Created: common.GetTimestamp(), | 		Created: common.GetTimestamp(), | ||||||
| 		Choices: []OpenAITextResponseChoice{choice}, | 		Choices: []openai.TextResponseChoice{choice}, | ||||||
| 	} | 	} | ||||||
| 	return &fullTextResponse | 	return &fullTextResponse | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { | ||||||
| 	responseText := "" | 	responseText := "" | ||||||
| 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) | 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) | ||||||
| 	createdTime := common.GetTimestamp() | 	createdTime := common.GetTimestamp() | ||||||
| @@ -143,13 +116,13 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS | |||||||
| 		} | 		} | ||||||
| 		stopChan <- true | 		stopChan <- true | ||||||
| 	}() | 	}() | ||||||
| 	setEventStreamHeaders(c) | 	common.SetEventStreamHeaders(c) | ||||||
| 	c.Stream(func(w io.Writer) bool { | 	c.Stream(func(w io.Writer) bool { | ||||||
| 		select { | 		select { | ||||||
| 		case data := <-dataChan: | 		case data := <-dataChan: | ||||||
| 			// some implementations may add \r at the end of data | 			// some implementations may add \r at the end of data | ||||||
| 			data = strings.TrimSuffix(data, "\r") | 			data = strings.TrimSuffix(data, "\r") | ||||||
| 			var claudeResponse ClaudeResponse | 			var claudeResponse Response | ||||||
| 			err := json.Unmarshal([]byte(data), &claudeResponse) | 			err := json.Unmarshal([]byte(data), &claudeResponse) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
| @@ -173,28 +146,28 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS | |||||||
| 	}) | 	}) | ||||||
| 	err := resp.Body.Close() | 	err := resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||||
| 	} | 	} | ||||||
| 	return nil, responseText | 	return nil, responseText | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	err = resp.Body.Close() | 	err = resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	var claudeResponse ClaudeResponse | 	var claudeResponse Response | ||||||
| 	err = json.Unmarshal(responseBody, &claudeResponse) | 	err = json.Unmarshal(responseBody, &claudeResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	if claudeResponse.Error.Type != "" { | 	if claudeResponse.Error.Type != "" { | ||||||
| 		return &OpenAIErrorWithStatusCode{ | 		return &openai.ErrorWithStatusCode{ | ||||||
| 			OpenAIError: OpenAIError{ | 			Error: openai.Error{ | ||||||
| 				Message: claudeResponse.Error.Message, | 				Message: claudeResponse.Error.Message, | ||||||
| 				Type:    claudeResponse.Error.Type, | 				Type:    claudeResponse.Error.Type, | ||||||
| 				Param:   "", | 				Param:   "", | ||||||
| @@ -205,8 +178,8 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model | |||||||
| 	} | 	} | ||||||
| 	fullTextResponse := responseClaude2OpenAI(&claudeResponse) | 	fullTextResponse := responseClaude2OpenAI(&claudeResponse) | ||||||
| 	fullTextResponse.Model = model | 	fullTextResponse.Model = model | ||||||
| 	completionTokens := countTokenText(claudeResponse.Completion, model) | 	completionTokens := openai.CountTokenText(claudeResponse.Completion, model) | ||||||
| 	usage := Usage{ | 	usage := openai.Usage{ | ||||||
| 		PromptTokens:     promptTokens, | 		PromptTokens:     promptTokens, | ||||||
| 		CompletionTokens: completionTokens, | 		CompletionTokens: completionTokens, | ||||||
| 		TotalTokens:      promptTokens + completionTokens, | 		TotalTokens:      promptTokens + completionTokens, | ||||||
| @@ -214,7 +187,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model | |||||||
| 	fullTextResponse.Usage = usage | 	fullTextResponse.Usage = usage | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
							
								
								
									
										29
									
								
								relay/channel/anthropic/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								relay/channel/anthropic/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | |||||||
|  | package anthropic | ||||||
|  |  | ||||||
|  | type Metadata struct { | ||||||
|  | 	UserId string `json:"user_id"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Request struct { | ||||||
|  | 	Model             string   `json:"model"` | ||||||
|  | 	Prompt            string   `json:"prompt"` | ||||||
|  | 	MaxTokensToSample int      `json:"max_tokens_to_sample"` | ||||||
|  | 	StopSequences     []string `json:"stop_sequences,omitempty"` | ||||||
|  | 	Temperature       float64  `json:"temperature,omitempty"` | ||||||
|  | 	TopP              float64  `json:"top_p,omitempty"` | ||||||
|  | 	TopK              int      `json:"top_k,omitempty"` | ||||||
|  | 	//Metadata    `json:"metadata,omitempty"` | ||||||
|  | 	Stream bool `json:"stream,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Error struct { | ||||||
|  | 	Type    string `json:"type"` | ||||||
|  | 	Message string `json:"message"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Response struct { | ||||||
|  | 	Completion string `json:"completion"` | ||||||
|  | 	StopReason string `json:"stop_reason"` | ||||||
|  | 	Model      string `json:"model"` | ||||||
|  | 	Error      Error  `json:"error"` | ||||||
|  | } | ||||||
| @@ -1,4 +1,4 @@ | |||||||
| package controller | package baidu | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"bufio" | 	"bufio" | ||||||
| @@ -9,6 +9,9 @@ import ( | |||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
|  | 	"one-api/relay/constant" | ||||||
|  | 	"one-api/relay/util" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -37,53 +40,9 @@ type BaiduError struct { | |||||||
| 	ErrorMsg  string `json:"error_msg"` | 	ErrorMsg  string `json:"error_msg"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type BaiduChatResponse struct { |  | ||||||
| 	Id               string `json:"id"` |  | ||||||
| 	Object           string `json:"object"` |  | ||||||
| 	Created          int64  `json:"created"` |  | ||||||
| 	Result           string `json:"result"` |  | ||||||
| 	IsTruncated      bool   `json:"is_truncated"` |  | ||||||
| 	NeedClearHistory bool   `json:"need_clear_history"` |  | ||||||
| 	Usage            Usage  `json:"usage"` |  | ||||||
| 	BaiduError |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type BaiduChatStreamResponse struct { |  | ||||||
| 	BaiduChatResponse |  | ||||||
| 	SentenceId int  `json:"sentence_id"` |  | ||||||
| 	IsEnd      bool `json:"is_end"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type BaiduEmbeddingRequest struct { |  | ||||||
| 	Input []string `json:"input"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type BaiduEmbeddingData struct { |  | ||||||
| 	Object    string    `json:"object"` |  | ||||||
| 	Embedding []float64 `json:"embedding"` |  | ||||||
| 	Index     int       `json:"index"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type BaiduEmbeddingResponse struct { |  | ||||||
| 	Id      string               `json:"id"` |  | ||||||
| 	Object  string               `json:"object"` |  | ||||||
| 	Created int64                `json:"created"` |  | ||||||
| 	Data    []BaiduEmbeddingData `json:"data"` |  | ||||||
| 	Usage   Usage                `json:"usage"` |  | ||||||
| 	BaiduError |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type BaiduAccessToken struct { |  | ||||||
| 	AccessToken      string    `json:"access_token"` |  | ||||||
| 	Error            string    `json:"error,omitempty"` |  | ||||||
| 	ErrorDescription string    `json:"error_description,omitempty"` |  | ||||||
| 	ExpiresIn        int64     `json:"expires_in,omitempty"` |  | ||||||
| 	ExpiresAt        time.Time `json:"-"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| var baiduTokenStore sync.Map | var baiduTokenStore sync.Map | ||||||
| 
 | 
 | ||||||
| func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { | func ConvertRequest(request openai.GeneralOpenAIRequest) *BaiduChatRequest { | ||||||
| 	messages := make([]BaiduMessage, 0, len(request.Messages)) | 	messages := make([]BaiduMessage, 0, len(request.Messages)) | ||||||
| 	for _, message := range request.Messages { | 	for _, message := range request.Messages { | ||||||
| 		if message.Role == "system" { | 		if message.Role == "system" { | ||||||
| @@ -108,56 +67,56 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { | func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||||
| 	choice := OpenAITextResponseChoice{ | 	choice := openai.TextResponseChoice{ | ||||||
| 		Index: 0, | 		Index: 0, | ||||||
| 		Message: Message{ | 		Message: openai.Message{ | ||||||
| 			Role:    "assistant", | 			Role:    "assistant", | ||||||
| 			Content: response.Result, | 			Content: response.Result, | ||||||
| 		}, | 		}, | ||||||
| 		FinishReason: "stop", | 		FinishReason: "stop", | ||||||
| 	} | 	} | ||||||
| 	fullTextResponse := OpenAITextResponse{ | 	fullTextResponse := openai.TextResponse{ | ||||||
| 		Id:      response.Id, | 		Id:      response.Id, | ||||||
| 		Object:  "chat.completion", | 		Object:  "chat.completion", | ||||||
| 		Created: response.Created, | 		Created: response.Created, | ||||||
| 		Choices: []OpenAITextResponseChoice{choice}, | 		Choices: []openai.TextResponseChoice{choice}, | ||||||
| 		Usage:   response.Usage, | 		Usage:   response.Usage, | ||||||
| 	} | 	} | ||||||
| 	return &fullTextResponse | 	return &fullTextResponse | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse { | func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatCompletionsStreamResponse { | ||||||
| 	var choice ChatCompletionsStreamResponseChoice | 	var choice openai.ChatCompletionsStreamResponseChoice | ||||||
| 	choice.Delta.Content = baiduResponse.Result | 	choice.Delta.Content = baiduResponse.Result | ||||||
| 	if baiduResponse.IsEnd { | 	if baiduResponse.IsEnd { | ||||||
| 		choice.FinishReason = &stopFinishReason | 		choice.FinishReason = &constant.StopFinishReason | ||||||
| 	} | 	} | ||||||
| 	response := ChatCompletionsStreamResponse{ | 	response := openai.ChatCompletionsStreamResponse{ | ||||||
| 		Id:      baiduResponse.Id, | 		Id:      baiduResponse.Id, | ||||||
| 		Object:  "chat.completion.chunk", | 		Object:  "chat.completion.chunk", | ||||||
| 		Created: baiduResponse.Created, | 		Created: baiduResponse.Created, | ||||||
| 		Model:   "ernie-bot", | 		Model:   "ernie-bot", | ||||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||||
| 	} | 	} | ||||||
| 	return &response | 	return &response | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { | func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest { | ||||||
| 	return &BaiduEmbeddingRequest{ | 	return &EmbeddingRequest{ | ||||||
| 		Input: request.ParseInput(), | 		Input: request.ParseInput(), | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { | func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { | ||||||
| 	openAIEmbeddingResponse := OpenAIEmbeddingResponse{ | 	openAIEmbeddingResponse := openai.EmbeddingResponse{ | ||||||
| 		Object: "list", | 		Object: "list", | ||||||
| 		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)), | 		Data:   make([]openai.EmbeddingResponseItem, 0, len(response.Data)), | ||||||
| 		Model:  "baidu-embedding", | 		Model:  "baidu-embedding", | ||||||
| 		Usage:  response.Usage, | 		Usage:  response.Usage, | ||||||
| 	} | 	} | ||||||
| 	for _, item := range response.Data { | 	for _, item := range response.Data { | ||||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ | 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ | ||||||
| 			Object:    item.Object, | 			Object:    item.Object, | ||||||
| 			Index:     item.Index, | 			Index:     item.Index, | ||||||
| 			Embedding: item.Embedding, | 			Embedding: item.Embedding, | ||||||
| @@ -166,8 +125,8 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbe | |||||||
| 	return &openAIEmbeddingResponse | 	return &openAIEmbeddingResponse | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||||
| 	var usage Usage | 	var usage openai.Usage | ||||||
| 	scanner := bufio.NewScanner(resp.Body) | 	scanner := bufio.NewScanner(resp.Body) | ||||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
| 		if atEOF && len(data) == 0 { | 		if atEOF && len(data) == 0 { | ||||||
| @@ -194,11 +153,11 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt | |||||||
| 		} | 		} | ||||||
| 		stopChan <- true | 		stopChan <- true | ||||||
| 	}() | 	}() | ||||||
| 	setEventStreamHeaders(c) | 	common.SetEventStreamHeaders(c) | ||||||
| 	c.Stream(func(w io.Writer) bool { | 	c.Stream(func(w io.Writer) bool { | ||||||
| 		select { | 		select { | ||||||
| 		case data := <-dataChan: | 		case data := <-dataChan: | ||||||
| 			var baiduResponse BaiduChatStreamResponse | 			var baiduResponse ChatStreamResponse | ||||||
| 			err := json.Unmarshal([]byte(data), &baiduResponse) | 			err := json.Unmarshal([]byte(data), &baiduResponse) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
| @@ -224,28 +183,28 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt | |||||||
| 	}) | 	}) | ||||||
| 	err := resp.Body.Close() | 	err := resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	return nil, &usage | 	return nil, &usage | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||||
| 	var baiduResponse BaiduChatResponse | 	var baiduResponse ChatResponse | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	err = resp.Body.Close() | 	err = resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(responseBody, &baiduResponse) | 	err = json.Unmarshal(responseBody, &baiduResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	if baiduResponse.ErrorMsg != "" { | 	if baiduResponse.ErrorMsg != "" { | ||||||
| 		return &OpenAIErrorWithStatusCode{ | 		return &openai.ErrorWithStatusCode{ | ||||||
| 			OpenAIError: OpenAIError{ | 			Error: openai.Error{ | ||||||
| 				Message: baiduResponse.ErrorMsg, | 				Message: baiduResponse.ErrorMsg, | ||||||
| 				Type:    "baidu_error", | 				Type:    "baidu_error", | ||||||
| 				Param:   "", | 				Param:   "", | ||||||
| @@ -258,7 +217,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo | |||||||
| 	fullTextResponse.Model = "ernie-bot" | 	fullTextResponse.Model = "ernie-bot" | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
| @@ -266,23 +225,23 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo | |||||||
| 	return nil, &fullTextResponse.Usage | 	return nil, &fullTextResponse.Usage | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||||
| 	var baiduResponse BaiduEmbeddingResponse | 	var baiduResponse EmbeddingResponse | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	err = resp.Body.Close() | 	err = resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(responseBody, &baiduResponse) | 	err = json.Unmarshal(responseBody, &baiduResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	if baiduResponse.ErrorMsg != "" { | 	if baiduResponse.ErrorMsg != "" { | ||||||
| 		return &OpenAIErrorWithStatusCode{ | 		return &openai.ErrorWithStatusCode{ | ||||||
| 			OpenAIError: OpenAIError{ | 			Error: openai.Error{ | ||||||
| 				Message: baiduResponse.ErrorMsg, | 				Message: baiduResponse.ErrorMsg, | ||||||
| 				Type:    "baidu_error", | 				Type:    "baidu_error", | ||||||
| 				Param:   "", | 				Param:   "", | ||||||
| @@ -294,7 +253,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit | |||||||
| 	fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) | 	fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
| @@ -302,10 +261,10 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit | |||||||
| 	return nil, &fullTextResponse.Usage | 	return nil, &fullTextResponse.Usage | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getBaiduAccessToken(apiKey string) (string, error) { | func GetAccessToken(apiKey string) (string, error) { | ||||||
| 	if val, ok := baiduTokenStore.Load(apiKey); ok { | 	if val, ok := baiduTokenStore.Load(apiKey); ok { | ||||||
| 		var accessToken BaiduAccessToken | 		var accessToken AccessToken | ||||||
| 		if accessToken, ok = val.(BaiduAccessToken); ok { | 		if accessToken, ok = val.(AccessToken); ok { | ||||||
| 			// soon this will expire | 			// soon this will expire | ||||||
| 			if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { | 			if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { | ||||||
| 				go func() { | 				go func() { | ||||||
| @@ -320,12 +279,12 @@ func getBaiduAccessToken(apiKey string) (string, error) { | |||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
| 	if accessToken == nil { | 	if accessToken == nil { | ||||||
| 		return "", errors.New("getBaiduAccessToken return a nil token") | 		return "", errors.New("GetAccessToken return a nil token") | ||||||
| 	} | 	} | ||||||
| 	return (*accessToken).AccessToken, nil | 	return (*accessToken).AccessToken, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { | func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) { | ||||||
| 	parts := strings.Split(apiKey, "|") | 	parts := strings.Split(apiKey, "|") | ||||||
| 	if len(parts) != 2 { | 	if len(parts) != 2 { | ||||||
| 		return nil, errors.New("invalid baidu apikey") | 		return nil, errors.New("invalid baidu apikey") | ||||||
| @@ -337,13 +296,13 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { | |||||||
| 	} | 	} | ||||||
| 	req.Header.Add("Content-Type", "application/json") | 	req.Header.Add("Content-Type", "application/json") | ||||||
| 	req.Header.Add("Accept", "application/json") | 	req.Header.Add("Accept", "application/json") | ||||||
| 	res, err := impatientHTTPClient.Do(req) | 	res, err := util.ImpatientHTTPClient.Do(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	defer res.Body.Close() | 	defer res.Body.Close() | ||||||
| 
 | 
 | ||||||
| 	var accessToken BaiduAccessToken | 	var accessToken AccessToken | ||||||
| 	err = json.NewDecoder(res.Body).Decode(&accessToken) | 	err = json.NewDecoder(res.Body).Decode(&accessToken) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
							
								
								
									
										50
									
								
								relay/channel/baidu/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								relay/channel/baidu/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | |||||||
|  | package baidu | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type ChatResponse struct { | ||||||
|  | 	Id               string       `json:"id"` | ||||||
|  | 	Object           string       `json:"object"` | ||||||
|  | 	Created          int64        `json:"created"` | ||||||
|  | 	Result           string       `json:"result"` | ||||||
|  | 	IsTruncated      bool         `json:"is_truncated"` | ||||||
|  | 	NeedClearHistory bool         `json:"need_clear_history"` | ||||||
|  | 	Usage            openai.Usage `json:"usage"` | ||||||
|  | 	BaiduError | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ChatStreamResponse struct { | ||||||
|  | 	ChatResponse | ||||||
|  | 	SentenceId int  `json:"sentence_id"` | ||||||
|  | 	IsEnd      bool `json:"is_end"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type EmbeddingRequest struct { | ||||||
|  | 	Input []string `json:"input"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type EmbeddingData struct { | ||||||
|  | 	Object    string    `json:"object"` | ||||||
|  | 	Embedding []float64 `json:"embedding"` | ||||||
|  | 	Index     int       `json:"index"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type EmbeddingResponse struct { | ||||||
|  | 	Id      string          `json:"id"` | ||||||
|  | 	Object  string          `json:"object"` | ||||||
|  | 	Created int64           `json:"created"` | ||||||
|  | 	Data    []EmbeddingData `json:"data"` | ||||||
|  | 	Usage   openai.Usage    `json:"usage"` | ||||||
|  | 	BaiduError | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AccessToken struct { | ||||||
|  | 	AccessToken      string    `json:"access_token"` | ||||||
|  | 	Error            string    `json:"error,omitempty"` | ||||||
|  | 	ErrorDescription string    `json:"error_description,omitempty"` | ||||||
|  | 	ExpiresIn        int64     `json:"expires_in,omitempty"` | ||||||
|  | 	ExpiresAt        time.Time `json:"-"` | ||||||
|  | } | ||||||
| @@ -1,4 +1,4 @@ | |||||||
| package controller | package google | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"bufio" | 	"bufio" | ||||||
| @@ -8,6 +8,8 @@ import ( | |||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/common/image" | 	"one-api/common/image" | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
|  | 	"one-api/relay/constant" | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| @@ -19,48 +21,8 @@ const ( | |||||||
| 	GeminiVisionMaxImageNum = 16 | 	GeminiVisionMaxImageNum = 16 | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type GeminiChatRequest struct { |  | ||||||
| 	Contents         []GeminiChatContent        `json:"contents"` |  | ||||||
| 	SafetySettings   []GeminiChatSafetySettings `json:"safety_settings,omitempty"` |  | ||||||
| 	GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` |  | ||||||
| 	Tools            []GeminiChatTools          `json:"tools,omitempty"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type GeminiInlineData struct { |  | ||||||
| 	MimeType string `json:"mimeType"` |  | ||||||
| 	Data     string `json:"data"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type GeminiPart struct { |  | ||||||
| 	Text       string            `json:"text,omitempty"` |  | ||||||
| 	InlineData *GeminiInlineData `json:"inlineData,omitempty"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type GeminiChatContent struct { |  | ||||||
| 	Role  string       `json:"role,omitempty"` |  | ||||||
| 	Parts []GeminiPart `json:"parts"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type GeminiChatSafetySettings struct { |  | ||||||
| 	Category  string `json:"category"` |  | ||||||
| 	Threshold string `json:"threshold"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type GeminiChatTools struct { |  | ||||||
| 	FunctionDeclarations any `json:"functionDeclarations,omitempty"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type GeminiChatGenerationConfig struct { |  | ||||||
| 	Temperature     float64  `json:"temperature,omitempty"` |  | ||||||
| 	TopP            float64  `json:"topP,omitempty"` |  | ||||||
| 	TopK            float64  `json:"topK,omitempty"` |  | ||||||
| 	MaxOutputTokens int      `json:"maxOutputTokens,omitempty"` |  | ||||||
| 	CandidateCount  int      `json:"candidateCount,omitempty"` |  | ||||||
| 	StopSequences   []string `json:"stopSequences,omitempty"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Setting safety to the lowest possible values since Gemini is already powerless enough | // Setting safety to the lowest possible values since Gemini is already powerless enough | ||||||
| func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { | func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRequest { | ||||||
| 	geminiRequest := GeminiChatRequest{ | 	geminiRequest := GeminiChatRequest{ | ||||||
| 		Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), | 		Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), | ||||||
| 		SafetySettings: []GeminiChatSafetySettings{ | 		SafetySettings: []GeminiChatSafetySettings{ | ||||||
| @@ -108,11 +70,11 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { | |||||||
| 		var parts []GeminiPart | 		var parts []GeminiPart | ||||||
| 		imageNum := 0 | 		imageNum := 0 | ||||||
| 		for _, part := range openaiContent { | 		for _, part := range openaiContent { | ||||||
| 			if part.Type == ContentTypeText { | 			if part.Type == openai.ContentTypeText { | ||||||
| 				parts = append(parts, GeminiPart{ | 				parts = append(parts, GeminiPart{ | ||||||
| 					Text: part.Text, | 					Text: part.Text, | ||||||
| 				}) | 				}) | ||||||
| 			} else if part.Type == ContentTypeImageURL { | 			} else if part.Type == openai.ContentTypeImageURL { | ||||||
| 				imageNum += 1 | 				imageNum += 1 | ||||||
| 				if imageNum > GeminiVisionMaxImageNum { | 				if imageNum > GeminiVisionMaxImageNum { | ||||||
| 					continue | 					continue | ||||||
| @@ -187,21 +149,21 @@ type GeminiChatPromptFeedback struct { | |||||||
| 	SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` | 	SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse { | func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextResponse { | ||||||
| 	fullTextResponse := OpenAITextResponse{ | 	fullTextResponse := openai.TextResponse{ | ||||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()), | 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()), | ||||||
| 		Object:  "chat.completion", | 		Object:  "chat.completion", | ||||||
| 		Created: common.GetTimestamp(), | 		Created: common.GetTimestamp(), | ||||||
| 		Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), | 		Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), | ||||||
| 	} | 	} | ||||||
| 	for i, candidate := range response.Candidates { | 	for i, candidate := range response.Candidates { | ||||||
| 		choice := OpenAITextResponseChoice{ | 		choice := openai.TextResponseChoice{ | ||||||
| 			Index: i, | 			Index: i, | ||||||
| 			Message: Message{ | 			Message: openai.Message{ | ||||||
| 				Role:    "assistant", | 				Role:    "assistant", | ||||||
| 				Content: "", | 				Content: "", | ||||||
| 			}, | 			}, | ||||||
| 			FinishReason: stopFinishReason, | 			FinishReason: constant.StopFinishReason, | ||||||
| 		} | 		} | ||||||
| 		if len(candidate.Content.Parts) > 0 { | 		if len(candidate.Content.Parts) > 0 { | ||||||
| 			choice.Message.Content = candidate.Content.Parts[0].Text | 			choice.Message.Content = candidate.Content.Parts[0].Text | ||||||
| @@ -211,18 +173,18 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse | |||||||
| 	return &fullTextResponse | 	return &fullTextResponse | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { | func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai.ChatCompletionsStreamResponse { | ||||||
| 	var choice ChatCompletionsStreamResponseChoice | 	var choice openai.ChatCompletionsStreamResponseChoice | ||||||
| 	choice.Delta.Content = geminiResponse.GetResponseText() | 	choice.Delta.Content = geminiResponse.GetResponseText() | ||||||
| 	choice.FinishReason = &stopFinishReason | 	choice.FinishReason = &constant.StopFinishReason | ||||||
| 	var response ChatCompletionsStreamResponse | 	var response openai.ChatCompletionsStreamResponse | ||||||
| 	response.Object = "chat.completion.chunk" | 	response.Object = "chat.completion.chunk" | ||||||
| 	response.Model = "gemini" | 	response.Model = "gemini" | ||||||
| 	response.Choices = []ChatCompletionsStreamResponseChoice{choice} | 	response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} | ||||||
| 	return &response | 	return &response | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { | ||||||
| 	responseText := "" | 	responseText := "" | ||||||
| 	dataChan := make(chan string) | 	dataChan := make(chan string) | ||||||
| 	stopChan := make(chan bool) | 	stopChan := make(chan bool) | ||||||
| @@ -252,7 +214,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW | |||||||
| 		} | 		} | ||||||
| 		stopChan <- true | 		stopChan <- true | ||||||
| 	}() | 	}() | ||||||
| 	setEventStreamHeaders(c) | 	common.SetEventStreamHeaders(c) | ||||||
| 	c.Stream(func(w io.Writer) bool { | 	c.Stream(func(w io.Writer) bool { | ||||||
| 		select { | 		select { | ||||||
| 		case data := <-dataChan: | 		case data := <-dataChan: | ||||||
| @@ -264,14 +226,14 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW | |||||||
| 			var dummy dummyStruct | 			var dummy dummyStruct | ||||||
| 			err := json.Unmarshal([]byte(data), &dummy) | 			err := json.Unmarshal([]byte(data), &dummy) | ||||||
| 			responseText += dummy.Content | 			responseText += dummy.Content | ||||||
| 			var choice ChatCompletionsStreamResponseChoice | 			var choice openai.ChatCompletionsStreamResponseChoice | ||||||
| 			choice.Delta.Content = dummy.Content | 			choice.Delta.Content = dummy.Content | ||||||
| 			response := ChatCompletionsStreamResponse{ | 			response := openai.ChatCompletionsStreamResponse{ | ||||||
| 				Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()), | 				Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()), | ||||||
| 				Object:  "chat.completion.chunk", | 				Object:  "chat.completion.chunk", | ||||||
| 				Created: common.GetTimestamp(), | 				Created: common.GetTimestamp(), | ||||||
| 				Model:   "gemini-pro", | 				Model:   "gemini-pro", | ||||||
| 				Choices: []ChatCompletionsStreamResponseChoice{choice}, | 				Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||||
| 			} | 			} | ||||||
| 			jsonResponse, err := json.Marshal(response) | 			jsonResponse, err := json.Marshal(response) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -287,28 +249,28 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW | |||||||
| 	}) | 	}) | ||||||
| 	err := resp.Body.Close() | 	err := resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||||
| 	} | 	} | ||||||
| 	return nil, responseText | 	return nil, responseText | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	err = resp.Body.Close() | 	err = resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	var geminiResponse GeminiChatResponse | 	var geminiResponse GeminiChatResponse | ||||||
| 	err = json.Unmarshal(responseBody, &geminiResponse) | 	err = json.Unmarshal(responseBody, &geminiResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	if len(geminiResponse.Candidates) == 0 { | 	if len(geminiResponse.Candidates) == 0 { | ||||||
| 		return &OpenAIErrorWithStatusCode{ | 		return &openai.ErrorWithStatusCode{ | ||||||
| 			OpenAIError: OpenAIError{ | 			Error: openai.Error{ | ||||||
| 				Message: "No candidates returned", | 				Message: "No candidates returned", | ||||||
| 				Type:    "server_error", | 				Type:    "server_error", | ||||||
| 				Param:   "", | 				Param:   "", | ||||||
| @@ -319,8 +281,8 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo | |||||||
| 	} | 	} | ||||||
| 	fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) | 	fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) | ||||||
| 	fullTextResponse.Model = model | 	fullTextResponse.Model = model | ||||||
| 	completionTokens := countTokenText(geminiResponse.GetResponseText(), model) | 	completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), model) | ||||||
| 	usage := Usage{ | 	usage := openai.Usage{ | ||||||
| 		PromptTokens:     promptTokens, | 		PromptTokens:     promptTokens, | ||||||
| 		CompletionTokens: completionTokens, | 		CompletionTokens: completionTokens, | ||||||
| 		TotalTokens:      promptTokens + completionTokens, | 		TotalTokens:      promptTokens + completionTokens, | ||||||
| @@ -328,7 +290,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo | |||||||
| 	fullTextResponse.Usage = usage | 	fullTextResponse.Usage = usage | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
							
								
								
									
										80
									
								
								relay/channel/google/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								relay/channel/google/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,80 @@ | |||||||
|  | package google | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type GeminiChatRequest struct { | ||||||
|  | 	Contents         []GeminiChatContent        `json:"contents"` | ||||||
|  | 	SafetySettings   []GeminiChatSafetySettings `json:"safety_settings,omitempty"` | ||||||
|  | 	GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` | ||||||
|  | 	Tools            []GeminiChatTools          `json:"tools,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type GeminiInlineData struct { | ||||||
|  | 	MimeType string `json:"mimeType"` | ||||||
|  | 	Data     string `json:"data"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type GeminiPart struct { | ||||||
|  | 	Text       string            `json:"text,omitempty"` | ||||||
|  | 	InlineData *GeminiInlineData `json:"inlineData,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type GeminiChatContent struct { | ||||||
|  | 	Role  string       `json:"role,omitempty"` | ||||||
|  | 	Parts []GeminiPart `json:"parts"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type GeminiChatSafetySettings struct { | ||||||
|  | 	Category  string `json:"category"` | ||||||
|  | 	Threshold string `json:"threshold"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type GeminiChatTools struct { | ||||||
|  | 	FunctionDeclarations any `json:"functionDeclarations,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type GeminiChatGenerationConfig struct { | ||||||
|  | 	Temperature     float64  `json:"temperature,omitempty"` | ||||||
|  | 	TopP            float64  `json:"topP,omitempty"` | ||||||
|  | 	TopK            float64  `json:"topK,omitempty"` | ||||||
|  | 	MaxOutputTokens int      `json:"maxOutputTokens,omitempty"` | ||||||
|  | 	CandidateCount  int      `json:"candidateCount,omitempty"` | ||||||
|  | 	StopSequences   []string `json:"stopSequences,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type PaLMChatMessage struct { | ||||||
|  | 	Author  string `json:"author"` | ||||||
|  | 	Content string `json:"content"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type PaLMFilter struct { | ||||||
|  | 	Reason  string `json:"reason"` | ||||||
|  | 	Message string `json:"message"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type PaLMPrompt struct { | ||||||
|  | 	Messages []PaLMChatMessage `json:"messages"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type PaLMChatRequest struct { | ||||||
|  | 	Prompt         PaLMPrompt `json:"prompt"` | ||||||
|  | 	Temperature    float64    `json:"temperature,omitempty"` | ||||||
|  | 	CandidateCount int        `json:"candidateCount,omitempty"` | ||||||
|  | 	TopP           float64    `json:"topP,omitempty"` | ||||||
|  | 	TopK           int        `json:"topK,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type PaLMError struct { | ||||||
|  | 	Code    int    `json:"code"` | ||||||
|  | 	Message string `json:"message"` | ||||||
|  | 	Status  string `json:"status"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type PaLMChatResponse struct { | ||||||
|  | 	Candidates []PaLMChatMessage `json:"candidates"` | ||||||
|  | 	Messages   []openai.Message  `json:"messages"` | ||||||
|  | 	Filters    []PaLMFilter      `json:"filters"` | ||||||
|  | 	Error      PaLMError         `json:"error"` | ||||||
|  | } | ||||||
| @@ -1,4 +1,4 @@ | |||||||
| package controller | package google | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| @@ -7,47 +7,14 @@ import ( | |||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
|  | 	"one-api/relay/constant" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body | // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body | ||||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body | // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body | ||||||
| 
 | 
 | ||||||
| type PaLMChatMessage struct { | func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatRequest { | ||||||
| 	Author  string `json:"author"` |  | ||||||
| 	Content string `json:"content"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type PaLMFilter struct { |  | ||||||
| 	Reason  string `json:"reason"` |  | ||||||
| 	Message string `json:"message"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type PaLMPrompt struct { |  | ||||||
| 	Messages []PaLMChatMessage `json:"messages"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type PaLMChatRequest struct { |  | ||||||
| 	Prompt         PaLMPrompt `json:"prompt"` |  | ||||||
| 	Temperature    float64    `json:"temperature,omitempty"` |  | ||||||
| 	CandidateCount int        `json:"candidateCount,omitempty"` |  | ||||||
| 	TopP           float64    `json:"topP,omitempty"` |  | ||||||
| 	TopK           int        `json:"topK,omitempty"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type PaLMError struct { |  | ||||||
| 	Code    int    `json:"code"` |  | ||||||
| 	Message string `json:"message"` |  | ||||||
| 	Status  string `json:"status"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type PaLMChatResponse struct { |  | ||||||
| 	Candidates []PaLMChatMessage `json:"candidates"` |  | ||||||
| 	Messages   []Message         `json:"messages"` |  | ||||||
| 	Filters    []PaLMFilter      `json:"filters"` |  | ||||||
| 	Error      PaLMError         `json:"error"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { |  | ||||||
| 	palmRequest := PaLMChatRequest{ | 	palmRequest := PaLMChatRequest{ | ||||||
| 		Prompt: PaLMPrompt{ | 		Prompt: PaLMPrompt{ | ||||||
| 			Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), | 			Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), | ||||||
| @@ -71,14 +38,14 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { | |||||||
| 	return &palmRequest | 	return &palmRequest | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { | func responsePaLM2OpenAI(response *PaLMChatResponse) *openai.TextResponse { | ||||||
| 	fullTextResponse := OpenAITextResponse{ | 	fullTextResponse := openai.TextResponse{ | ||||||
| 		Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), | 		Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), | ||||||
| 	} | 	} | ||||||
| 	for i, candidate := range response.Candidates { | 	for i, candidate := range response.Candidates { | ||||||
| 		choice := OpenAITextResponseChoice{ | 		choice := openai.TextResponseChoice{ | ||||||
| 			Index: i, | 			Index: i, | ||||||
| 			Message: Message{ | 			Message: openai.Message{ | ||||||
| 				Role:    "assistant", | 				Role:    "assistant", | ||||||
| 				Content: candidate.Content, | 				Content: candidate.Content, | ||||||
| 			}, | 			}, | ||||||
| @@ -89,20 +56,20 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { | |||||||
| 	return &fullTextResponse | 	return &fullTextResponse | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse { | func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompletionsStreamResponse { | ||||||
| 	var choice ChatCompletionsStreamResponseChoice | 	var choice openai.ChatCompletionsStreamResponseChoice | ||||||
| 	if len(palmResponse.Candidates) > 0 { | 	if len(palmResponse.Candidates) > 0 { | ||||||
| 		choice.Delta.Content = palmResponse.Candidates[0].Content | 		choice.Delta.Content = palmResponse.Candidates[0].Content | ||||||
| 	} | 	} | ||||||
| 	choice.FinishReason = &stopFinishReason | 	choice.FinishReason = &constant.StopFinishReason | ||||||
| 	var response ChatCompletionsStreamResponse | 	var response openai.ChatCompletionsStreamResponse | ||||||
| 	response.Object = "chat.completion.chunk" | 	response.Object = "chat.completion.chunk" | ||||||
| 	response.Model = "palm2" | 	response.Model = "palm2" | ||||||
| 	response.Choices = []ChatCompletionsStreamResponseChoice{choice} | 	response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} | ||||||
| 	return &response | 	return &response | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { | ||||||
| 	responseText := "" | 	responseText := "" | ||||||
| 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) | 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) | ||||||
| 	createdTime := common.GetTimestamp() | 	createdTime := common.GetTimestamp() | ||||||
| @@ -143,7 +110,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta | |||||||
| 		dataChan <- string(jsonResponse) | 		dataChan <- string(jsonResponse) | ||||||
| 		stopChan <- true | 		stopChan <- true | ||||||
| 	}() | 	}() | ||||||
| 	setEventStreamHeaders(c) | 	common.SetEventStreamHeaders(c) | ||||||
| 	c.Stream(func(w io.Writer) bool { | 	c.Stream(func(w io.Writer) bool { | ||||||
| 		select { | 		select { | ||||||
| 		case data := <-dataChan: | 		case data := <-dataChan: | ||||||
| @@ -156,28 +123,28 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta | |||||||
| 	}) | 	}) | ||||||
| 	err := resp.Body.Close() | 	err := resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||||
| 	} | 	} | ||||||
| 	return nil, responseText | 	return nil, responseText | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	err = resp.Body.Close() | 	err = resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	var palmResponse PaLMChatResponse | 	var palmResponse PaLMChatResponse | ||||||
| 	err = json.Unmarshal(responseBody, &palmResponse) | 	err = json.Unmarshal(responseBody, &palmResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { | 	if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { | ||||||
| 		return &OpenAIErrorWithStatusCode{ | 		return &openai.ErrorWithStatusCode{ | ||||||
| 			OpenAIError: OpenAIError{ | 			Error: openai.Error{ | ||||||
| 				Message: palmResponse.Error.Message, | 				Message: palmResponse.Error.Message, | ||||||
| 				Type:    palmResponse.Error.Status, | 				Type:    palmResponse.Error.Status, | ||||||
| 				Param:   "", | 				Param:   "", | ||||||
| @@ -188,8 +155,8 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st | |||||||
| 	} | 	} | ||||||
| 	fullTextResponse := responsePaLM2OpenAI(&palmResponse) | 	fullTextResponse := responsePaLM2OpenAI(&palmResponse) | ||||||
| 	fullTextResponse.Model = model | 	fullTextResponse.Model = model | ||||||
| 	completionTokens := countTokenText(palmResponse.Candidates[0].Content, model) | 	completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, model) | ||||||
| 	usage := Usage{ | 	usage := openai.Usage{ | ||||||
| 		PromptTokens:     promptTokens, | 		PromptTokens:     promptTokens, | ||||||
| 		CompletionTokens: completionTokens, | 		CompletionTokens: completionTokens, | ||||||
| 		TotalTokens:      promptTokens + completionTokens, | 		TotalTokens:      promptTokens + completionTokens, | ||||||
| @@ -197,7 +164,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st | |||||||
| 	fullTextResponse.Usage = usage | 	fullTextResponse.Usage = usage | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
							
								
								
									
										6
									
								
								relay/channel/openai/constant.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								relay/channel/openai/constant.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | |||||||
|  | package openai | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	ContentTypeText     = "text" | ||||||
|  | 	ContentTypeImageURL = "image_url" | ||||||
|  | ) | ||||||
| @@ -1,4 +1,4 @@ | |||||||
| package controller | package openai | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"bufio" | 	"bufio" | ||||||
| @@ -8,10 +8,11 @@ import ( | |||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
|  | 	"one-api/relay/constant" | ||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { | func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWithStatusCode, string) { | ||||||
| 	responseText := "" | 	responseText := "" | ||||||
| 	scanner := bufio.NewScanner(resp.Body) | 	scanner := bufio.NewScanner(resp.Body) | ||||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
| @@ -41,7 +42,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O | |||||||
| 			data = data[6:] | 			data = data[6:] | ||||||
| 			if !strings.HasPrefix(data, "[DONE]") { | 			if !strings.HasPrefix(data, "[DONE]") { | ||||||
| 				switch relayMode { | 				switch relayMode { | ||||||
| 				case RelayModeChatCompletions: | 				case constant.RelayModeChatCompletions: | ||||||
| 					var streamResponse ChatCompletionsStreamResponse | 					var streamResponse ChatCompletionsStreamResponse | ||||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) | 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||||
| 					if err != nil { | 					if err != nil { | ||||||
| @@ -51,7 +52,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O | |||||||
| 					for _, choice := range streamResponse.Choices { | 					for _, choice := range streamResponse.Choices { | ||||||
| 						responseText += choice.Delta.Content | 						responseText += choice.Delta.Content | ||||||
| 					} | 					} | ||||||
| 				case RelayModeCompletions: | 				case constant.RelayModeCompletions: | ||||||
| 					var streamResponse CompletionsStreamResponse | 					var streamResponse CompletionsStreamResponse | ||||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) | 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||||
| 					if err != nil { | 					if err != nil { | ||||||
| @@ -66,7 +67,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O | |||||||
| 		} | 		} | ||||||
| 		stopChan <- true | 		stopChan <- true | ||||||
| 	}() | 	}() | ||||||
| 	setEventStreamHeaders(c) | 	common.SetEventStreamHeaders(c) | ||||||
| 	c.Stream(func(w io.Writer) bool { | 	c.Stream(func(w io.Writer) bool { | ||||||
| 		select { | 		select { | ||||||
| 		case data := <-dataChan: | 		case data := <-dataChan: | ||||||
| @@ -83,29 +84,29 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O | |||||||
| 	}) | 	}) | ||||||
| 	err := resp.Body.Close() | 	err := resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | 		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||||
| 	} | 	} | ||||||
| 	return nil, responseText | 	return nil, responseText | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*ErrorWithStatusCode, *Usage) { | ||||||
| 	var textResponse TextResponse | 	var textResponse SlimTextResponse | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | 		return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	err = resp.Body.Close() | 	err = resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | 		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(responseBody, &textResponse) | 	err = json.Unmarshal(responseBody, &textResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | 		return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	if textResponse.Error.Type != "" { | 	if textResponse.Error.Type != "" { | ||||||
| 		return &OpenAIErrorWithStatusCode{ | 		return &ErrorWithStatusCode{ | ||||||
| 			OpenAIError: textResponse.Error, | 			Error:      textResponse.Error, | ||||||
| 			StatusCode:  resp.StatusCode, | 			StatusCode: resp.StatusCode, | ||||||
| 		}, nil | 		}, nil | ||||||
| 	} | 	} | ||||||
| 	// Reset response body | 	// Reset response body | ||||||
| @@ -113,7 +114,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model | |||||||
| 
 | 
 | ||||||
| 	// We shouldn't set the header before we parse the response body, because the parse part may fail. | 	// We shouldn't set the header before we parse the response body, because the parse part may fail. | ||||||
| 	// And then we will have to send an error response, but in this case, the header has already been set. | 	// And then we will have to send an error response, but in this case, the header has already been set. | ||||||
| 	// So the httpClient will be confused by the response. | 	// So the HTTPClient will be confused by the response. | ||||||
| 	// For example, Postman will report error, and we cannot check the response at all. | 	// For example, Postman will report error, and we cannot check the response at all. | ||||||
| 	for k, v := range resp.Header { | 	for k, v := range resp.Header { | ||||||
| 		c.Writer.Header().Set(k, v[0]) | 		c.Writer.Header().Set(k, v[0]) | ||||||
| @@ -121,17 +122,17 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model | |||||||
| 	c.Writer.WriteHeader(resp.StatusCode) | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
| 	_, err = io.Copy(c.Writer, resp.Body) | 	_, err = io.Copy(c.Writer, resp.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil | 		return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	err = resp.Body.Close() | 	err = resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | 		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if textResponse.Usage.TotalTokens == 0 { | 	if textResponse.Usage.TotalTokens == 0 { | ||||||
| 		completionTokens := 0 | 		completionTokens := 0 | ||||||
| 		for _, choice := range textResponse.Choices { | 		for _, choice := range textResponse.Choices { | ||||||
| 			completionTokens += countTokenText(choice.Message.StringContent(), model) | 			completionTokens += CountTokenText(choice.Message.StringContent(), model) | ||||||
| 		} | 		} | ||||||
| 		textResponse.Usage = Usage{ | 		textResponse.Usage = Usage{ | ||||||
| 			PromptTokens:     promptTokens, | 			PromptTokens:     promptTokens, | ||||||
							
								
								
									
										283
									
								
								relay/channel/openai/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										283
									
								
								relay/channel/openai/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,283 @@ | |||||||
|  | package openai | ||||||
|  |  | ||||||
|  | type Message struct { | ||||||
|  | 	Role    string  `json:"role"` | ||||||
|  | 	Content any     `json:"content"` | ||||||
|  | 	Name    *string `json:"name,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ImageURL struct { | ||||||
|  | 	Url    string `json:"url,omitempty"` | ||||||
|  | 	Detail string `json:"detail,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type TextContent struct { | ||||||
|  | 	Type string `json:"type,omitempty"` | ||||||
|  | 	Text string `json:"text,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ImageContent struct { | ||||||
|  | 	Type     string    `json:"type,omitempty"` | ||||||
|  | 	ImageURL *ImageURL `json:"image_url,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type OpenAIMessageContent struct { | ||||||
|  | 	Type     string    `json:"type,omitempty"` | ||||||
|  | 	Text     string    `json:"text"` | ||||||
|  | 	ImageURL *ImageURL `json:"image_url,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m Message) IsStringContent() bool { | ||||||
|  | 	_, ok := m.Content.(string) | ||||||
|  | 	return ok | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m Message) StringContent() string { | ||||||
|  | 	content, ok := m.Content.(string) | ||||||
|  | 	if ok { | ||||||
|  | 		return content | ||||||
|  | 	} | ||||||
|  | 	contentList, ok := m.Content.([]any) | ||||||
|  | 	if ok { | ||||||
|  | 		var contentStr string | ||||||
|  | 		for _, contentItem := range contentList { | ||||||
|  | 			contentMap, ok := contentItem.(map[string]any) | ||||||
|  | 			if !ok { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			if contentMap["type"] == ContentTypeText { | ||||||
|  | 				if subStr, ok := contentMap["text"].(string); ok { | ||||||
|  | 					contentStr += subStr | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return contentStr | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m Message) ParseContent() []OpenAIMessageContent { | ||||||
|  | 	var contentList []OpenAIMessageContent | ||||||
|  | 	content, ok := m.Content.(string) | ||||||
|  | 	if ok { | ||||||
|  | 		contentList = append(contentList, OpenAIMessageContent{ | ||||||
|  | 			Type: ContentTypeText, | ||||||
|  | 			Text: content, | ||||||
|  | 		}) | ||||||
|  | 		return contentList | ||||||
|  | 	} | ||||||
|  | 	anyList, ok := m.Content.([]any) | ||||||
|  | 	if ok { | ||||||
|  | 		for _, contentItem := range anyList { | ||||||
|  | 			contentMap, ok := contentItem.(map[string]any) | ||||||
|  | 			if !ok { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			switch contentMap["type"] { | ||||||
|  | 			case ContentTypeText: | ||||||
|  | 				if subStr, ok := contentMap["text"].(string); ok { | ||||||
|  | 					contentList = append(contentList, OpenAIMessageContent{ | ||||||
|  | 						Type: ContentTypeText, | ||||||
|  | 						Text: subStr, | ||||||
|  | 					}) | ||||||
|  | 				} | ||||||
|  | 			case ContentTypeImageURL: | ||||||
|  | 				if subObj, ok := contentMap["image_url"].(map[string]any); ok { | ||||||
|  | 					contentList = append(contentList, OpenAIMessageContent{ | ||||||
|  | 						Type: ContentTypeImageURL, | ||||||
|  | 						ImageURL: &ImageURL{ | ||||||
|  | 							Url: subObj["url"].(string), | ||||||
|  | 						}, | ||||||
|  | 					}) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return contentList | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ResponseFormat struct { | ||||||
|  | 	Type string `json:"type,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type GeneralOpenAIRequest struct { | ||||||
|  | 	Model            string          `json:"model,omitempty"` | ||||||
|  | 	Messages         []Message       `json:"messages,omitempty"` | ||||||
|  | 	Prompt           any             `json:"prompt,omitempty"` | ||||||
|  | 	Stream           bool            `json:"stream,omitempty"` | ||||||
|  | 	MaxTokens        int             `json:"max_tokens,omitempty"` | ||||||
|  | 	Temperature      float64         `json:"temperature,omitempty"` | ||||||
|  | 	TopP             float64         `json:"top_p,omitempty"` | ||||||
|  | 	N                int             `json:"n,omitempty"` | ||||||
|  | 	Input            any             `json:"input,omitempty"` | ||||||
|  | 	Instruction      string          `json:"instruction,omitempty"` | ||||||
|  | 	Size             string          `json:"size,omitempty"` | ||||||
|  | 	Functions        any             `json:"functions,omitempty"` | ||||||
|  | 	FrequencyPenalty float64         `json:"frequency_penalty,omitempty"` | ||||||
|  | 	PresencePenalty  float64         `json:"presence_penalty,omitempty"` | ||||||
|  | 	ResponseFormat   *ResponseFormat `json:"response_format,omitempty"` | ||||||
|  | 	Seed             float64         `json:"seed,omitempty"` | ||||||
|  | 	Tools            any             `json:"tools,omitempty"` | ||||||
|  | 	ToolChoice       any             `json:"tool_choice,omitempty"` | ||||||
|  | 	User             string          `json:"user,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r GeneralOpenAIRequest) ParseInput() []string { | ||||||
|  | 	if r.Input == nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	var input []string | ||||||
|  | 	switch r.Input.(type) { | ||||||
|  | 	case string: | ||||||
|  | 		input = []string{r.Input.(string)} | ||||||
|  | 	case []any: | ||||||
|  | 		input = make([]string, 0, len(r.Input.([]any))) | ||||||
|  | 		for _, item := range r.Input.([]any) { | ||||||
|  | 			if str, ok := item.(string); ok { | ||||||
|  | 				input = append(input, str) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return input | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ChatRequest struct { | ||||||
|  | 	Model     string    `json:"model"` | ||||||
|  | 	Messages  []Message `json:"messages"` | ||||||
|  | 	MaxTokens int       `json:"max_tokens"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type TextRequest struct { | ||||||
|  | 	Model     string    `json:"model"` | ||||||
|  | 	Messages  []Message `json:"messages"` | ||||||
|  | 	Prompt    string    `json:"prompt"` | ||||||
|  | 	MaxTokens int       `json:"max_tokens"` | ||||||
|  | 	//Stream   bool      `json:"stream"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create | ||||||
|  | type ImageRequest struct { | ||||||
|  | 	Model          string `json:"model"` | ||||||
|  | 	Prompt         string `json:"prompt" binding:"required"` | ||||||
|  | 	N              int    `json:"n,omitempty"` | ||||||
|  | 	Size           string `json:"size,omitempty"` | ||||||
|  | 	Quality        string `json:"quality,omitempty"` | ||||||
|  | 	ResponseFormat string `json:"response_format,omitempty"` | ||||||
|  | 	Style          string `json:"style,omitempty"` | ||||||
|  | 	User           string `json:"user,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type WhisperJSONResponse struct { | ||||||
|  | 	Text string `json:"text,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type WhisperVerboseJSONResponse struct { | ||||||
|  | 	Task     string    `json:"task,omitempty"` | ||||||
|  | 	Language string    `json:"language,omitempty"` | ||||||
|  | 	Duration float64   `json:"duration,omitempty"` | ||||||
|  | 	Text     string    `json:"text,omitempty"` | ||||||
|  | 	Segments []Segment `json:"segments,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Segment struct { | ||||||
|  | 	Id               int     `json:"id"` | ||||||
|  | 	Seek             int     `json:"seek"` | ||||||
|  | 	Start            float64 `json:"start"` | ||||||
|  | 	End              float64 `json:"end"` | ||||||
|  | 	Text             string  `json:"text"` | ||||||
|  | 	Tokens           []int   `json:"tokens"` | ||||||
|  | 	Temperature      float64 `json:"temperature"` | ||||||
|  | 	AvgLogprob       float64 `json:"avg_logprob"` | ||||||
|  | 	CompressionRatio float64 `json:"compression_ratio"` | ||||||
|  | 	NoSpeechProb     float64 `json:"no_speech_prob"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type TextToSpeechRequest struct { | ||||||
|  | 	Model          string  `json:"model" binding:"required"` | ||||||
|  | 	Input          string  `json:"input" binding:"required"` | ||||||
|  | 	Voice          string  `json:"voice" binding:"required"` | ||||||
|  | 	Speed          float64 `json:"speed"` | ||||||
|  | 	ResponseFormat string  `json:"response_format"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Usage struct { | ||||||
|  | 	PromptTokens     int `json:"prompt_tokens"` | ||||||
|  | 	CompletionTokens int `json:"completion_tokens"` | ||||||
|  | 	TotalTokens      int `json:"total_tokens"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Error struct { | ||||||
|  | 	Message string `json:"message"` | ||||||
|  | 	Type    string `json:"type"` | ||||||
|  | 	Param   string `json:"param"` | ||||||
|  | 	Code    any    `json:"code"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ErrorWithStatusCode struct { | ||||||
|  | 	Error | ||||||
|  | 	StatusCode int `json:"status_code"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type SlimTextResponse struct { | ||||||
|  | 	Choices []TextResponseChoice `json:"choices"` | ||||||
|  | 	Usage   `json:"usage"` | ||||||
|  | 	Error   Error `json:"error"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type TextResponseChoice struct { | ||||||
|  | 	Index        int `json:"index"` | ||||||
|  | 	Message      `json:"message"` | ||||||
|  | 	FinishReason string `json:"finish_reason"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type TextResponse struct { | ||||||
|  | 	Id      string               `json:"id"` | ||||||
|  | 	Model   string               `json:"model,omitempty"` | ||||||
|  | 	Object  string               `json:"object"` | ||||||
|  | 	Created int64                `json:"created"` | ||||||
|  | 	Choices []TextResponseChoice `json:"choices"` | ||||||
|  | 	Usage   `json:"usage"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type EmbeddingResponseItem struct { | ||||||
|  | 	Object    string    `json:"object"` | ||||||
|  | 	Index     int       `json:"index"` | ||||||
|  | 	Embedding []float64 `json:"embedding"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type EmbeddingResponse struct { | ||||||
|  | 	Object string                  `json:"object"` | ||||||
|  | 	Data   []EmbeddingResponseItem `json:"data"` | ||||||
|  | 	Model  string                  `json:"model"` | ||||||
|  | 	Usage  `json:"usage"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ImageResponse struct { | ||||||
|  | 	Created int `json:"created"` | ||||||
|  | 	Data    []struct { | ||||||
|  | 		Url string `json:"url"` | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ChatCompletionsStreamResponseChoice struct { | ||||||
|  | 	Delta struct { | ||||||
|  | 		Content string `json:"content"` | ||||||
|  | 	} `json:"delta"` | ||||||
|  | 	FinishReason *string `json:"finish_reason,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ChatCompletionsStreamResponse struct { | ||||||
|  | 	Id      string                                `json:"id"` | ||||||
|  | 	Object  string                                `json:"object"` | ||||||
|  | 	Created int64                                 `json:"created"` | ||||||
|  | 	Model   string                                `json:"model"` | ||||||
|  | 	Choices []ChatCompletionsStreamResponseChoice `json:"choices"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type CompletionsStreamResponse struct { | ||||||
|  | 	Choices []struct { | ||||||
|  | 		Text         string `json:"text"` | ||||||
|  | 		FinishReason string `json:"finish_reason"` | ||||||
|  | 	} `json:"choices"` | ||||||
|  | } | ||||||
| @@ -1,25 +1,15 @@ | |||||||
| package controller | package openai | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"github.com/pkoukk/tiktoken-go" | ||||||
| 	"math" | 	"math" | ||||||
| 	"net/http" |  | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/common/image" | 	"one-api/common/image" | ||||||
| 	"one-api/model" |  | ||||||
| 	"strconv" |  | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"github.com/pkoukk/tiktoken-go" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var stopFinishReason = "stop" |  | ||||||
| 
 |  | ||||||
| // tokenEncoderMap won't grow after initialization | // tokenEncoderMap won't grow after initialization | ||||||
| var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} | var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} | ||||||
| var defaultTokenEncoder *tiktoken.Tiktoken | var defaultTokenEncoder *tiktoken.Tiktoken | ||||||
| @@ -71,7 +61,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { | |||||||
| 	return len(tokenEncoder.Encode(text, nil, nil)) | 	return len(tokenEncoder.Encode(text, nil, nil)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func countTokenMessages(messages []Message, model string) int { | func CountTokenMessages(messages []Message, model string) int { | ||||||
| 	tokenEncoder := getTokenEncoder(model) | 	tokenEncoder := getTokenEncoder(model) | ||||||
| 	// Reference: | 	// Reference: | ||||||
| 	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | 	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | ||||||
| @@ -195,191 +185,21 @@ func countImageTokens(url string, detail string) (_ int, err error) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func countTokenInput(input any, model string) int { | func CountTokenInput(input any, model string) int { | ||||||
| 	switch v := input.(type) { | 	switch v := input.(type) { | ||||||
| 	case string: | 	case string: | ||||||
| 		return countTokenText(v, model) | 		return CountTokenText(v, model) | ||||||
| 	case []string: | 	case []string: | ||||||
| 		text := "" | 		text := "" | ||||||
| 		for _, s := range v { | 		for _, s := range v { | ||||||
| 			text += s | 			text += s | ||||||
| 		} | 		} | ||||||
| 		return countTokenText(text, model) | 		return CountTokenText(text, model) | ||||||
| 	} | 	} | ||||||
| 	return 0 | 	return 0 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func countTokenText(text string, model string) int { | func CountTokenText(text string, model string) int { | ||||||
| 	tokenEncoder := getTokenEncoder(model) | 	tokenEncoder := getTokenEncoder(model) | ||||||
| 	return getTokenNum(tokenEncoder, text) | 	return getTokenNum(tokenEncoder, text) | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode { |  | ||||||
| 	openAIError := OpenAIError{ |  | ||||||
| 		Message: err.Error(), |  | ||||||
| 		Type:    "one_api_error", |  | ||||||
| 		Code:    code, |  | ||||||
| 	} |  | ||||||
| 	return &OpenAIErrorWithStatusCode{ |  | ||||||
| 		OpenAIError: openAIError, |  | ||||||
| 		StatusCode:  statusCode, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func shouldDisableChannel(err *OpenAIError, statusCode int) bool { |  | ||||||
| 	if !common.AutomaticDisableChannelEnabled { |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 	if err == nil { |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 	if statusCode == http.StatusUnauthorized { |  | ||||||
| 		return true |  | ||||||
| 	} |  | ||||||
| 	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { |  | ||||||
| 		return true |  | ||||||
| 	} |  | ||||||
| 	return false |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func shouldEnableChannel(err error, openAIErr *OpenAIError) bool { |  | ||||||
| 	if !common.AutomaticEnableChannelEnabled { |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 	if err != nil { |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 	if openAIErr != nil { |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 	return true |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func setEventStreamHeaders(c *gin.Context) { |  | ||||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") |  | ||||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") |  | ||||||
| 	c.Writer.Header().Set("Connection", "keep-alive") |  | ||||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") |  | ||||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type GeneralErrorResponse struct { |  | ||||||
| 	Error    OpenAIError `json:"error"` |  | ||||||
| 	Message  string      `json:"message"` |  | ||||||
| 	Msg      string      `json:"msg"` |  | ||||||
| 	Err      string      `json:"err"` |  | ||||||
| 	ErrorMsg string      `json:"error_msg"` |  | ||||||
| 	Header   struct { |  | ||||||
| 		Message string `json:"message"` |  | ||||||
| 	} `json:"header"` |  | ||||||
| 	Response struct { |  | ||||||
| 		Error struct { |  | ||||||
| 			Message string `json:"message"` |  | ||||||
| 		} `json:"error"` |  | ||||||
| 	} `json:"response"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (e GeneralErrorResponse) ToMessage() string { |  | ||||||
| 	if e.Error.Message != "" { |  | ||||||
| 		return e.Error.Message |  | ||||||
| 	} |  | ||||||
| 	if e.Message != "" { |  | ||||||
| 		return e.Message |  | ||||||
| 	} |  | ||||||
| 	if e.Msg != "" { |  | ||||||
| 		return e.Msg |  | ||||||
| 	} |  | ||||||
| 	if e.Err != "" { |  | ||||||
| 		return e.Err |  | ||||||
| 	} |  | ||||||
| 	if e.ErrorMsg != "" { |  | ||||||
| 		return e.ErrorMsg |  | ||||||
| 	} |  | ||||||
| 	if e.Header.Message != "" { |  | ||||||
| 		return e.Header.Message |  | ||||||
| 	} |  | ||||||
| 	if e.Response.Error.Message != "" { |  | ||||||
| 		return e.Response.Error.Message |  | ||||||
| 	} |  | ||||||
| 	return "" |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) { |  | ||||||
| 	openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{ |  | ||||||
| 		StatusCode: resp.StatusCode, |  | ||||||
| 		OpenAIError: OpenAIError{ |  | ||||||
| 			Message: "", |  | ||||||
| 			Type:    "upstream_error", |  | ||||||
| 			Code:    "bad_response_status_code", |  | ||||||
| 			Param:   strconv.Itoa(resp.StatusCode), |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	err = resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	var errResponse GeneralErrorResponse |  | ||||||
| 	err = json.Unmarshal(responseBody, &errResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	if errResponse.Error.Message != "" { |  | ||||||
| 		// OpenAI format error, so we override the default one |  | ||||||
| 		openAIErrorWithStatusCode.OpenAIError = errResponse.Error |  | ||||||
| 	} else { |  | ||||||
| 		openAIErrorWithStatusCode.OpenAIError.Message = errResponse.ToMessage() |  | ||||||
| 	} |  | ||||||
| 	if openAIErrorWithStatusCode.OpenAIError.Message == "" { |  | ||||||
| 		openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) |  | ||||||
| 	} |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func getFullRequestURL(baseURL string, requestURL string, channelType int) string { |  | ||||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) |  | ||||||
| 
 |  | ||||||
| 	if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { |  | ||||||
| 		switch channelType { |  | ||||||
| 		case common.ChannelTypeOpenAI: |  | ||||||
| 			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) |  | ||||||
| 		case common.ChannelTypeAzure: |  | ||||||
| 			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return fullRequestURL |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { |  | ||||||
| 	// quotaDelta is remaining quota to be consumed |  | ||||||
| 	err := model.PostConsumeTokenQuota(tokenId, quotaDelta) |  | ||||||
| 	if err != nil { |  | ||||||
| 		common.SysError("error consuming token remain quota: " + err.Error()) |  | ||||||
| 	} |  | ||||||
| 	err = model.CacheUpdateUserQuota(userId) |  | ||||||
| 	if err != nil { |  | ||||||
| 		common.SysError("error update user quota cache: " + err.Error()) |  | ||||||
| 	} |  | ||||||
| 	// totalQuota is total quota consumed |  | ||||||
| 	if totalQuota != 0 { |  | ||||||
| 		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) |  | ||||||
| 		model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) |  | ||||||
| 		model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) |  | ||||||
| 		model.UpdateChannelUsedQuota(channelId, totalQuota) |  | ||||||
| 	} |  | ||||||
| 	if totalQuota <= 0 { |  | ||||||
| 		common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func GetAPIVersion(c *gin.Context) string { |  | ||||||
| 	query := c.Request.URL.Query() |  | ||||||
| 	apiVersion := query.Get("api-version") |  | ||||||
| 	if apiVersion == "" { |  | ||||||
| 		apiVersion = c.GetString("api_version") |  | ||||||
| 	} |  | ||||||
| 	return apiVersion |  | ||||||
| } |  | ||||||
							
								
								
									
										13
									
								
								relay/channel/openai/util.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								relay/channel/openai/util.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | |||||||
|  | package openai | ||||||
|  |  | ||||||
|  | func ErrorWrapper(err error, code string, statusCode int) *ErrorWithStatusCode { | ||||||
|  | 	Error := Error{ | ||||||
|  | 		Message: err.Error(), | ||||||
|  | 		Type:    "one_api_error", | ||||||
|  | 		Code:    code, | ||||||
|  | 	} | ||||||
|  | 	return &ErrorWithStatusCode{ | ||||||
|  | 		Error:      Error, | ||||||
|  | 		StatusCode: statusCode, | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -1,4 +1,4 @@ | |||||||
| package controller | package tencent | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"bufio" | 	"bufio" | ||||||
| @@ -12,6 +12,8 @@ import ( | |||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
|  | 	"one-api/relay/constant" | ||||||
| 	"sort" | 	"sort" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| @@ -19,80 +21,22 @@ import ( | |||||||
| 
 | 
 | ||||||
| // https://cloud.tencent.com/document/product/1729/97732 | // https://cloud.tencent.com/document/product/1729/97732 | ||||||
| 
 | 
 | ||||||
| type TencentMessage struct { | func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest { | ||||||
| 	Role    string `json:"role"` | 	messages := make([]Message, 0, len(request.Messages)) | ||||||
| 	Content string `json:"content"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type TencentChatRequest struct { |  | ||||||
| 	AppId    int64  `json:"app_id"`    // 腾讯云账号的 APPID |  | ||||||
| 	SecretId string `json:"secret_id"` // 官网 SecretId |  | ||||||
| 	// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 |  | ||||||
| 	// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 |  | ||||||
| 	Timestamp int64 `json:"timestamp"` |  | ||||||
| 	// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, |  | ||||||
| 	// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 |  | ||||||
| 	Expired int64  `json:"expired"` |  | ||||||
| 	QueryID string `json:"query_id"` //请求 Id,用于问题排查 |  | ||||||
| 	// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 |  | ||||||
| 	// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 |  | ||||||
| 	// 建议该参数和 top_p 只设置1个,不要同时更改 top_p |  | ||||||
| 	Temperature float64 `json:"temperature"` |  | ||||||
| 	// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 |  | ||||||
| 	// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 |  | ||||||
| 	// 建议该参数和 temperature 只设置1个,不要同时更改 |  | ||||||
| 	TopP float64 `json:"top_p"` |  | ||||||
| 	// Stream 0:同步,1:流式 (默认,协议:SSE) |  | ||||||
| 	// 同步请求超时:60s,如果内容较长建议使用流式 |  | ||||||
| 	Stream int `json:"stream"` |  | ||||||
| 	// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 |  | ||||||
| 	// 输入 content 总数最大支持 3000 token。 |  | ||||||
| 	Messages []TencentMessage `json:"messages"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type TencentError struct { |  | ||||||
| 	Code    int    `json:"code"` |  | ||||||
| 	Message string `json:"message"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type TencentUsage struct { |  | ||||||
| 	InputTokens  int `json:"input_tokens"` |  | ||||||
| 	OutputTokens int `json:"output_tokens"` |  | ||||||
| 	TotalTokens  int `json:"total_tokens"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type TencentResponseChoices struct { |  | ||||||
| 	FinishReason string         `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 |  | ||||||
| 	Messages     TencentMessage `json:"messages,omitempty"`      // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 |  | ||||||
| 	Delta        TencentMessage `json:"delta,omitempty"`         // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type TencentChatResponse struct { |  | ||||||
| 	Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 |  | ||||||
| 	Created string                   `json:"created,omitempty"` // unix 时间戳的字符串 |  | ||||||
| 	Id      string                   `json:"id,omitempty"`      // 会话 id |  | ||||||
| 	Usage   Usage                    `json:"usage,omitempty"`   // token 数量 |  | ||||||
| 	Error   TencentError             `json:"error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值 |  | ||||||
| 	Note    string                   `json:"note,omitempty"`    // 注释 |  | ||||||
| 	ReqID   string                   `json:"req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { |  | ||||||
| 	messages := make([]TencentMessage, 0, len(request.Messages)) |  | ||||||
| 	for i := 0; i < len(request.Messages); i++ { | 	for i := 0; i < len(request.Messages); i++ { | ||||||
| 		message := request.Messages[i] | 		message := request.Messages[i] | ||||||
| 		if message.Role == "system" { | 		if message.Role == "system" { | ||||||
| 			messages = append(messages, TencentMessage{ | 			messages = append(messages, Message{ | ||||||
| 				Role:    "user", | 				Role:    "user", | ||||||
| 				Content: message.StringContent(), | 				Content: message.StringContent(), | ||||||
| 			}) | 			}) | ||||||
| 			messages = append(messages, TencentMessage{ | 			messages = append(messages, Message{ | ||||||
| 				Role:    "assistant", | 				Role:    "assistant", | ||||||
| 				Content: "Okay", | 				Content: "Okay", | ||||||
| 			}) | 			}) | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		messages = append(messages, TencentMessage{ | 		messages = append(messages, Message{ | ||||||
| 			Content: message.StringContent(), | 			Content: message.StringContent(), | ||||||
| 			Role:    message.Role, | 			Role:    message.Role, | ||||||
| 		}) | 		}) | ||||||
| @@ -101,7 +45,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { | |||||||
| 	if request.Stream { | 	if request.Stream { | ||||||
| 		stream = 1 | 		stream = 1 | ||||||
| 	} | 	} | ||||||
| 	return &TencentChatRequest{ | 	return &ChatRequest{ | ||||||
| 		Timestamp:   common.GetTimestamp(), | 		Timestamp:   common.GetTimestamp(), | ||||||
| 		Expired:     common.GetTimestamp() + 24*60*60, | 		Expired:     common.GetTimestamp() + 24*60*60, | ||||||
| 		QueryID:     common.GetUUID(), | 		QueryID:     common.GetUUID(), | ||||||
| @@ -112,16 +56,16 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { | func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||||
| 	fullTextResponse := OpenAITextResponse{ | 	fullTextResponse := openai.TextResponse{ | ||||||
| 		Object:  "chat.completion", | 		Object:  "chat.completion", | ||||||
| 		Created: common.GetTimestamp(), | 		Created: common.GetTimestamp(), | ||||||
| 		Usage:   response.Usage, | 		Usage:   response.Usage, | ||||||
| 	} | 	} | ||||||
| 	if len(response.Choices) > 0 { | 	if len(response.Choices) > 0 { | ||||||
| 		choice := OpenAITextResponseChoice{ | 		choice := openai.TextResponseChoice{ | ||||||
| 			Index: 0, | 			Index: 0, | ||||||
| 			Message: Message{ | 			Message: openai.Message{ | ||||||
| 				Role:    "assistant", | 				Role:    "assistant", | ||||||
| 				Content: response.Choices[0].Messages.Content, | 				Content: response.Choices[0].Messages.Content, | ||||||
| 			}, | 			}, | ||||||
| @@ -132,24 +76,24 @@ func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { | |||||||
| 	return &fullTextResponse | 	return &fullTextResponse | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse { | func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { | ||||||
| 	response := ChatCompletionsStreamResponse{ | 	response := openai.ChatCompletionsStreamResponse{ | ||||||
| 		Object:  "chat.completion.chunk", | 		Object:  "chat.completion.chunk", | ||||||
| 		Created: common.GetTimestamp(), | 		Created: common.GetTimestamp(), | ||||||
| 		Model:   "tencent-hunyuan", | 		Model:   "tencent-hunyuan", | ||||||
| 	} | 	} | ||||||
| 	if len(TencentResponse.Choices) > 0 { | 	if len(TencentResponse.Choices) > 0 { | ||||||
| 		var choice ChatCompletionsStreamResponseChoice | 		var choice openai.ChatCompletionsStreamResponseChoice | ||||||
| 		choice.Delta.Content = TencentResponse.Choices[0].Delta.Content | 		choice.Delta.Content = TencentResponse.Choices[0].Delta.Content | ||||||
| 		if TencentResponse.Choices[0].FinishReason == "stop" { | 		if TencentResponse.Choices[0].FinishReason == "stop" { | ||||||
| 			choice.FinishReason = &stopFinishReason | 			choice.FinishReason = &constant.StopFinishReason | ||||||
| 		} | 		} | ||||||
| 		response.Choices = append(response.Choices, choice) | 		response.Choices = append(response.Choices, choice) | ||||||
| 	} | 	} | ||||||
| 	return &response | 	return &response | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { | ||||||
| 	var responseText string | 	var responseText string | ||||||
| 	scanner := bufio.NewScanner(resp.Body) | 	scanner := bufio.NewScanner(resp.Body) | ||||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
| @@ -180,11 +124,11 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith | |||||||
| 		} | 		} | ||||||
| 		stopChan <- true | 		stopChan <- true | ||||||
| 	}() | 	}() | ||||||
| 	setEventStreamHeaders(c) | 	common.SetEventStreamHeaders(c) | ||||||
| 	c.Stream(func(w io.Writer) bool { | 	c.Stream(func(w io.Writer) bool { | ||||||
| 		select { | 		select { | ||||||
| 		case data := <-dataChan: | 		case data := <-dataChan: | ||||||
| 			var TencentResponse TencentChatResponse | 			var TencentResponse ChatResponse | ||||||
| 			err := json.Unmarshal([]byte(data), &TencentResponse) | 			err := json.Unmarshal([]byte(data), &TencentResponse) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
| @@ -208,28 +152,28 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith | |||||||
| 	}) | 	}) | ||||||
| 	err := resp.Body.Close() | 	err := resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||||
| 	} | 	} | ||||||
| 	return nil, responseText | 	return nil, responseText | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||||
| 	var TencentResponse TencentChatResponse | 	var TencentResponse ChatResponse | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	err = resp.Body.Close() | 	err = resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(responseBody, &TencentResponse) | 	err = json.Unmarshal(responseBody, &TencentResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	if TencentResponse.Error.Code != 0 { | 	if TencentResponse.Error.Code != 0 { | ||||||
| 		return &OpenAIErrorWithStatusCode{ | 		return &openai.ErrorWithStatusCode{ | ||||||
| 			OpenAIError: OpenAIError{ | 			Error: openai.Error{ | ||||||
| 				Message: TencentResponse.Error.Message, | 				Message: TencentResponse.Error.Message, | ||||||
| 				Code:    TencentResponse.Error.Code, | 				Code:    TencentResponse.Error.Code, | ||||||
| 			}, | 			}, | ||||||
| @@ -240,7 +184,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatus | |||||||
| 	fullTextResponse.Model = "hunyuan" | 	fullTextResponse.Model = "hunyuan" | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
| @@ -248,7 +192,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatus | |||||||
| 	return nil, &fullTextResponse.Usage | 	return nil, &fullTextResponse.Usage | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) { | func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) { | ||||||
| 	parts := strings.Split(config, "|") | 	parts := strings.Split(config, "|") | ||||||
| 	if len(parts) != 3 { | 	if len(parts) != 3 { | ||||||
| 		err = errors.New("invalid tencent config") | 		err = errors.New("invalid tencent config") | ||||||
| @@ -260,7 +204,7 @@ func parseTencentConfig(config string) (appId int64, secretId string, secretKey | |||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getTencentSign(req TencentChatRequest, secretKey string) string { | func GetSign(req ChatRequest, secretKey string) string { | ||||||
| 	params := make([]string, 0) | 	params := make([]string, 0) | ||||||
| 	params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) | 	params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) | ||||||
| 	params = append(params, "secret_id="+req.SecretId) | 	params = append(params, "secret_id="+req.SecretId) | ||||||
							
								
								
									
										63
									
								
								relay/channel/tencent/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								relay/channel/tencent/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | |||||||
|  | package tencent | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Message struct { | ||||||
|  | 	Role    string `json:"role"` | ||||||
|  | 	Content string `json:"content"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ChatRequest struct { | ||||||
|  | 	AppId    int64  `json:"app_id"`    // 腾讯云账号的 APPID | ||||||
|  | 	SecretId string `json:"secret_id"` // 官网 SecretId | ||||||
|  | 	// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 | ||||||
|  | 	// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 | ||||||
|  | 	Timestamp int64 `json:"timestamp"` | ||||||
|  | 	// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, | ||||||
|  | 	// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 | ||||||
|  | 	Expired int64  `json:"expired"` | ||||||
|  | 	QueryID string `json:"query_id"` //请求 Id,用于问题排查 | ||||||
|  | 	// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 | ||||||
|  | 	// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 | ||||||
|  | 	// 建议该参数和 top_p 只设置1个,不要同时更改 top_p | ||||||
|  | 	Temperature float64 `json:"temperature"` | ||||||
|  | 	// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 | ||||||
|  | 	// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 | ||||||
|  | 	// 建议该参数和 temperature 只设置1个,不要同时更改 | ||||||
|  | 	TopP float64 `json:"top_p"` | ||||||
|  | 	// Stream 0:同步,1:流式 (默认,协议:SSE) | ||||||
|  | 	// 同步请求超时:60s,如果内容较长建议使用流式 | ||||||
|  | 	Stream int `json:"stream"` | ||||||
|  | 	// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 | ||||||
|  | 	// 输入 content 总数最大支持 3000 token。 | ||||||
|  | 	Messages []Message `json:"messages"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Error struct { | ||||||
|  | 	Code    int    `json:"code"` | ||||||
|  | 	Message string `json:"message"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Usage struct { | ||||||
|  | 	InputTokens  int `json:"input_tokens"` | ||||||
|  | 	OutputTokens int `json:"output_tokens"` | ||||||
|  | 	TotalTokens  int `json:"total_tokens"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ResponseChoices struct { | ||||||
|  | 	FinishReason string  `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 | ||||||
|  | 	Messages     Message `json:"messages,omitempty"`      // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 | ||||||
|  | 	Delta        Message `json:"delta,omitempty"`         // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ChatResponse struct { | ||||||
|  | 	Choices []ResponseChoices `json:"choices,omitempty"` // 结果 | ||||||
|  | 	Created string            `json:"created,omitempty"` // unix 时间戳的字符串 | ||||||
|  | 	Id      string            `json:"id,omitempty"`      // 会话 id | ||||||
|  | 	Usage   openai.Usage      `json:"usage,omitempty"`   // token 数量 | ||||||
|  | 	Error   Error             `json:"error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值 | ||||||
|  | 	Note    string            `json:"note,omitempty"`    // 注释 | ||||||
|  | 	ReqID   string            `json:"req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 | ||||||
|  | } | ||||||
| @@ -1,4 +1,4 @@ | |||||||
| package controller | package xunfei | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"crypto/hmac" | 	"crypto/hmac" | ||||||
| @@ -12,6 +12,8 @@ import ( | |||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
|  | 	"one-api/relay/constant" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
| @@ -19,82 +21,26 @@ import ( | |||||||
| // https://console.xfyun.cn/services/cbm | // https://console.xfyun.cn/services/cbm | ||||||
| // https://www.xfyun.cn/doc/spark/Web.html | // https://www.xfyun.cn/doc/spark/Web.html | ||||||
| 
 | 
 | ||||||
| type XunfeiMessage struct { | func requestOpenAI2Xunfei(request openai.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { | ||||||
| 	Role    string `json:"role"` | 	messages := make([]Message, 0, len(request.Messages)) | ||||||
| 	Content string `json:"content"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type XunfeiChatRequest struct { |  | ||||||
| 	Header struct { |  | ||||||
| 		AppId string `json:"app_id"` |  | ||||||
| 	} `json:"header"` |  | ||||||
| 	Parameter struct { |  | ||||||
| 		Chat struct { |  | ||||||
| 			Domain      string  `json:"domain,omitempty"` |  | ||||||
| 			Temperature float64 `json:"temperature,omitempty"` |  | ||||||
| 			TopK        int     `json:"top_k,omitempty"` |  | ||||||
| 			MaxTokens   int     `json:"max_tokens,omitempty"` |  | ||||||
| 			Auditing    bool    `json:"auditing,omitempty"` |  | ||||||
| 		} `json:"chat"` |  | ||||||
| 	} `json:"parameter"` |  | ||||||
| 	Payload struct { |  | ||||||
| 		Message struct { |  | ||||||
| 			Text []XunfeiMessage `json:"text"` |  | ||||||
| 		} `json:"message"` |  | ||||||
| 	} `json:"payload"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type XunfeiChatResponseTextItem struct { |  | ||||||
| 	Content string `json:"content"` |  | ||||||
| 	Role    string `json:"role"` |  | ||||||
| 	Index   int    `json:"index"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type XunfeiChatResponse struct { |  | ||||||
| 	Header struct { |  | ||||||
| 		Code    int    `json:"code"` |  | ||||||
| 		Message string `json:"message"` |  | ||||||
| 		Sid     string `json:"sid"` |  | ||||||
| 		Status  int    `json:"status"` |  | ||||||
| 	} `json:"header"` |  | ||||||
| 	Payload struct { |  | ||||||
| 		Choices struct { |  | ||||||
| 			Status int                          `json:"status"` |  | ||||||
| 			Seq    int                          `json:"seq"` |  | ||||||
| 			Text   []XunfeiChatResponseTextItem `json:"text"` |  | ||||||
| 		} `json:"choices"` |  | ||||||
| 		Usage struct { |  | ||||||
| 			//Text struct { |  | ||||||
| 			//	QuestionTokens   string `json:"question_tokens"` |  | ||||||
| 			//	PromptTokens     string `json:"prompt_tokens"` |  | ||||||
| 			//	CompletionTokens string `json:"completion_tokens"` |  | ||||||
| 			//	TotalTokens      string `json:"total_tokens"` |  | ||||||
| 			//} `json:"text"` |  | ||||||
| 			Text Usage `json:"text"` |  | ||||||
| 		} `json:"usage"` |  | ||||||
| 	} `json:"payload"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest { |  | ||||||
| 	messages := make([]XunfeiMessage, 0, len(request.Messages)) |  | ||||||
| 	for _, message := range request.Messages { | 	for _, message := range request.Messages { | ||||||
| 		if message.Role == "system" { | 		if message.Role == "system" { | ||||||
| 			messages = append(messages, XunfeiMessage{ | 			messages = append(messages, Message{ | ||||||
| 				Role:    "user", | 				Role:    "user", | ||||||
| 				Content: message.StringContent(), | 				Content: message.StringContent(), | ||||||
| 			}) | 			}) | ||||||
| 			messages = append(messages, XunfeiMessage{ | 			messages = append(messages, Message{ | ||||||
| 				Role:    "assistant", | 				Role:    "assistant", | ||||||
| 				Content: "Okay", | 				Content: "Okay", | ||||||
| 			}) | 			}) | ||||||
| 		} else { | 		} else { | ||||||
| 			messages = append(messages, XunfeiMessage{ | 			messages = append(messages, Message{ | ||||||
| 				Role:    message.Role, | 				Role:    message.Role, | ||||||
| 				Content: message.StringContent(), | 				Content: message.StringContent(), | ||||||
| 			}) | 			}) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	xunfeiRequest := XunfeiChatRequest{} | 	xunfeiRequest := ChatRequest{} | ||||||
| 	xunfeiRequest.Header.AppId = xunfeiAppId | 	xunfeiRequest.Header.AppId = xunfeiAppId | ||||||
| 	xunfeiRequest.Parameter.Chat.Domain = domain | 	xunfeiRequest.Parameter.Chat.Domain = domain | ||||||
| 	xunfeiRequest.Parameter.Chat.Temperature = request.Temperature | 	xunfeiRequest.Parameter.Chat.Temperature = request.Temperature | ||||||
| @@ -104,49 +50,49 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma | |||||||
| 	return &xunfeiRequest | 	return &xunfeiRequest | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { | func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||||
| 	if len(response.Payload.Choices.Text) == 0 { | 	if len(response.Payload.Choices.Text) == 0 { | ||||||
| 		response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ | 		response.Payload.Choices.Text = []ChatResponseTextItem{ | ||||||
| 			{ | 			{ | ||||||
| 				Content: "", | 				Content: "", | ||||||
| 			}, | 			}, | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	choice := OpenAITextResponseChoice{ | 	choice := openai.TextResponseChoice{ | ||||||
| 		Index: 0, | 		Index: 0, | ||||||
| 		Message: Message{ | 		Message: openai.Message{ | ||||||
| 			Role:    "assistant", | 			Role:    "assistant", | ||||||
| 			Content: response.Payload.Choices.Text[0].Content, | 			Content: response.Payload.Choices.Text[0].Content, | ||||||
| 		}, | 		}, | ||||||
| 		FinishReason: stopFinishReason, | 		FinishReason: constant.StopFinishReason, | ||||||
| 	} | 	} | ||||||
| 	fullTextResponse := OpenAITextResponse{ | 	fullTextResponse := openai.TextResponse{ | ||||||
| 		Object:  "chat.completion", | 		Object:  "chat.completion", | ||||||
| 		Created: common.GetTimestamp(), | 		Created: common.GetTimestamp(), | ||||||
| 		Choices: []OpenAITextResponseChoice{choice}, | 		Choices: []openai.TextResponseChoice{choice}, | ||||||
| 		Usage:   response.Payload.Usage.Text, | 		Usage:   response.Payload.Usage.Text, | ||||||
| 	} | 	} | ||||||
| 	return &fullTextResponse | 	return &fullTextResponse | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse { | func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { | ||||||
| 	if len(xunfeiResponse.Payload.Choices.Text) == 0 { | 	if len(xunfeiResponse.Payload.Choices.Text) == 0 { | ||||||
| 		xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ | 		xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{ | ||||||
| 			{ | 			{ | ||||||
| 				Content: "", | 				Content: "", | ||||||
| 			}, | 			}, | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	var choice ChatCompletionsStreamResponseChoice | 	var choice openai.ChatCompletionsStreamResponseChoice | ||||||
| 	choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content | 	choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content | ||||||
| 	if xunfeiResponse.Payload.Choices.Status == 2 { | 	if xunfeiResponse.Payload.Choices.Status == 2 { | ||||||
| 		choice.FinishReason = &stopFinishReason | 		choice.FinishReason = &constant.StopFinishReason | ||||||
| 	} | 	} | ||||||
| 	response := ChatCompletionsStreamResponse{ | 	response := openai.ChatCompletionsStreamResponse{ | ||||||
| 		Object:  "chat.completion.chunk", | 		Object:  "chat.completion.chunk", | ||||||
| 		Created: common.GetTimestamp(), | 		Created: common.GetTimestamp(), | ||||||
| 		Model:   "SparkDesk", | 		Model:   "SparkDesk", | ||||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||||
| 	} | 	} | ||||||
| 	return &response | 	return &response | ||||||
| } | } | ||||||
| @@ -177,14 +123,14 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { | |||||||
| 	return callUrl | 	return callUrl | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { | func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||||
| 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) | 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) | ||||||
| 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) | 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	setEventStreamHeaders(c) | 	common.SetEventStreamHeaders(c) | ||||||
| 	var usage Usage | 	var usage openai.Usage | ||||||
| 	c.Stream(func(w io.Writer) bool { | 	c.Stream(func(w io.Writer) bool { | ||||||
| 		select { | 		select { | ||||||
| 		case xunfeiResponse := <-dataChan: | 		case xunfeiResponse := <-dataChan: | ||||||
| @@ -207,15 +153,15 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId | |||||||
| 	return nil, &usage | 	return nil, &usage | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { | func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||||
| 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) | 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) | ||||||
| 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) | 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	var usage Usage | 	var usage openai.Usage | ||||||
| 	var content string | 	var content string | ||||||
| 	var xunfeiResponse XunfeiChatResponse | 	var xunfeiResponse ChatResponse | ||||||
| 	stop := false | 	stop := false | ||||||
| 	for !stop { | 	for !stop { | ||||||
| 		select { | 		select { | ||||||
| @@ -231,7 +177,7 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	if len(xunfeiResponse.Payload.Choices.Text) == 0 { | 	if len(xunfeiResponse.Payload.Choices.Text) == 0 { | ||||||
| 		xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ | 		xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{ | ||||||
| 			{ | 			{ | ||||||
| 				Content: "", | 				Content: "", | ||||||
| 			}, | 			}, | ||||||
| @@ -242,14 +188,14 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin | |||||||
| 	response := responseXunfei2OpenAI(&xunfeiResponse) | 	response := responseXunfei2OpenAI(&xunfeiResponse) | ||||||
| 	jsonResponse, err := json.Marshal(response) | 	jsonResponse, err := json.Marshal(response) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
| 	_, _ = c.Writer.Write(jsonResponse) | 	_, _ = c.Writer.Write(jsonResponse) | ||||||
| 	return nil, &usage | 	return nil, &usage | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) { | func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) { | ||||||
| 	d := websocket.Dialer{ | 	d := websocket.Dialer{ | ||||||
| 		HandshakeTimeout: 5 * time.Second, | 		HandshakeTimeout: 5 * time.Second, | ||||||
| 	} | 	} | ||||||
| @@ -263,7 +209,7 @@ func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId | |||||||
| 		return nil, nil, err | 		return nil, nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	dataChan := make(chan XunfeiChatResponse) | 	dataChan := make(chan ChatResponse) | ||||||
| 	stopChan := make(chan bool) | 	stopChan := make(chan bool) | ||||||
| 	go func() { | 	go func() { | ||||||
| 		for { | 		for { | ||||||
| @@ -272,7 +218,7 @@ func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId | |||||||
| 				common.SysError("error reading stream response: " + err.Error()) | 				common.SysError("error reading stream response: " + err.Error()) | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			var response XunfeiChatResponse | 			var response ChatResponse | ||||||
| 			err = json.Unmarshal(msg, &response) | 			err = json.Unmarshal(msg, &response) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
							
								
								
									
										61
									
								
								relay/channel/xunfei/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								relay/channel/xunfei/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | |||||||
|  | package xunfei | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Message struct { | ||||||
|  | 	Role    string `json:"role"` | ||||||
|  | 	Content string `json:"content"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ChatRequest struct { | ||||||
|  | 	Header struct { | ||||||
|  | 		AppId string `json:"app_id"` | ||||||
|  | 	} `json:"header"` | ||||||
|  | 	Parameter struct { | ||||||
|  | 		Chat struct { | ||||||
|  | 			Domain      string  `json:"domain,omitempty"` | ||||||
|  | 			Temperature float64 `json:"temperature,omitempty"` | ||||||
|  | 			TopK        int     `json:"top_k,omitempty"` | ||||||
|  | 			MaxTokens   int     `json:"max_tokens,omitempty"` | ||||||
|  | 			Auditing    bool    `json:"auditing,omitempty"` | ||||||
|  | 		} `json:"chat"` | ||||||
|  | 	} `json:"parameter"` | ||||||
|  | 	Payload struct { | ||||||
|  | 		Message struct { | ||||||
|  | 			Text []Message `json:"text"` | ||||||
|  | 		} `json:"message"` | ||||||
|  | 	} `json:"payload"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ChatResponseTextItem struct { | ||||||
|  | 	Content string `json:"content"` | ||||||
|  | 	Role    string `json:"role"` | ||||||
|  | 	Index   int    `json:"index"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ChatResponse struct { | ||||||
|  | 	Header struct { | ||||||
|  | 		Code    int    `json:"code"` | ||||||
|  | 		Message string `json:"message"` | ||||||
|  | 		Sid     string `json:"sid"` | ||||||
|  | 		Status  int    `json:"status"` | ||||||
|  | 	} `json:"header"` | ||||||
|  | 	Payload struct { | ||||||
|  | 		Choices struct { | ||||||
|  | 			Status int                    `json:"status"` | ||||||
|  | 			Seq    int                    `json:"seq"` | ||||||
|  | 			Text   []ChatResponseTextItem `json:"text"` | ||||||
|  | 		} `json:"choices"` | ||||||
|  | 		Usage struct { | ||||||
|  | 			//Text struct { | ||||||
|  | 			//	QuestionTokens   string `json:"question_tokens"` | ||||||
|  | 			//	PromptTokens     string `json:"prompt_tokens"` | ||||||
|  | 			//	CompletionTokens string `json:"completion_tokens"` | ||||||
|  | 			//	TotalTokens      string `json:"total_tokens"` | ||||||
|  | 			//} `json:"text"` | ||||||
|  | 			Text openai.Usage `json:"text"` | ||||||
|  | 		} `json:"usage"` | ||||||
|  | 	} `json:"payload"` | ||||||
|  | } | ||||||
| @@ -1,4 +1,4 @@ | |||||||
| package controller | package zhipu | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"bufio" | 	"bufio" | ||||||
| @@ -8,6 +8,8 @@ import ( | |||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
|  | 	"one-api/relay/constant" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -18,53 +20,13 @@ import ( | |||||||
| // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke | // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke | ||||||
| // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke | // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke | ||||||
| 
 | 
 | ||||||
| type ZhipuMessage struct { |  | ||||||
| 	Role    string `json:"role"` |  | ||||||
| 	Content string `json:"content"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type ZhipuRequest struct { |  | ||||||
| 	Prompt      []ZhipuMessage `json:"prompt"` |  | ||||||
| 	Temperature float64        `json:"temperature,omitempty"` |  | ||||||
| 	TopP        float64        `json:"top_p,omitempty"` |  | ||||||
| 	RequestId   string         `json:"request_id,omitempty"` |  | ||||||
| 	Incremental bool           `json:"incremental,omitempty"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type ZhipuResponseData struct { |  | ||||||
| 	TaskId     string         `json:"task_id"` |  | ||||||
| 	RequestId  string         `json:"request_id"` |  | ||||||
| 	TaskStatus string         `json:"task_status"` |  | ||||||
| 	Choices    []ZhipuMessage `json:"choices"` |  | ||||||
| 	Usage      `json:"usage"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type ZhipuResponse struct { |  | ||||||
| 	Code    int               `json:"code"` |  | ||||||
| 	Msg     string            `json:"msg"` |  | ||||||
| 	Success bool              `json:"success"` |  | ||||||
| 	Data    ZhipuResponseData `json:"data"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type ZhipuStreamMetaResponse struct { |  | ||||||
| 	RequestId  string `json:"request_id"` |  | ||||||
| 	TaskId     string `json:"task_id"` |  | ||||||
| 	TaskStatus string `json:"task_status"` |  | ||||||
| 	Usage      `json:"usage"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type zhipuTokenData struct { |  | ||||||
| 	Token      string |  | ||||||
| 	ExpiryTime time.Time |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| var zhipuTokens sync.Map | var zhipuTokens sync.Map | ||||||
| var expSeconds int64 = 24 * 3600 | var expSeconds int64 = 24 * 3600 | ||||||
| 
 | 
 | ||||||
| func getZhipuToken(apikey string) string { | func GetToken(apikey string) string { | ||||||
| 	data, ok := zhipuTokens.Load(apikey) | 	data, ok := zhipuTokens.Load(apikey) | ||||||
| 	if ok { | 	if ok { | ||||||
| 		tokenData := data.(zhipuTokenData) | 		tokenData := data.(tokenData) | ||||||
| 		if time.Now().Before(tokenData.ExpiryTime) { | 		if time.Now().Before(tokenData.ExpiryTime) { | ||||||
| 			return tokenData.Token | 			return tokenData.Token | ||||||
| 		} | 		} | ||||||
| @@ -100,7 +62,7 @@ func getZhipuToken(apikey string) string { | |||||||
| 		return "" | 		return "" | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	zhipuTokens.Store(apikey, zhipuTokenData{ | 	zhipuTokens.Store(apikey, tokenData{ | ||||||
| 		Token:      tokenString, | 		Token:      tokenString, | ||||||
| 		ExpiryTime: expiryTime, | 		ExpiryTime: expiryTime, | ||||||
| 	}) | 	}) | ||||||
| @@ -108,26 +70,26 @@ func getZhipuToken(apikey string) string { | |||||||
| 	return tokenString | 	return tokenString | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { | func ConvertRequest(request openai.GeneralOpenAIRequest) *Request { | ||||||
| 	messages := make([]ZhipuMessage, 0, len(request.Messages)) | 	messages := make([]Message, 0, len(request.Messages)) | ||||||
| 	for _, message := range request.Messages { | 	for _, message := range request.Messages { | ||||||
| 		if message.Role == "system" { | 		if message.Role == "system" { | ||||||
| 			messages = append(messages, ZhipuMessage{ | 			messages = append(messages, Message{ | ||||||
| 				Role:    "system", | 				Role:    "system", | ||||||
| 				Content: message.StringContent(), | 				Content: message.StringContent(), | ||||||
| 			}) | 			}) | ||||||
| 			messages = append(messages, ZhipuMessage{ | 			messages = append(messages, Message{ | ||||||
| 				Role:    "user", | 				Role:    "user", | ||||||
| 				Content: "Okay", | 				Content: "Okay", | ||||||
| 			}) | 			}) | ||||||
| 		} else { | 		} else { | ||||||
| 			messages = append(messages, ZhipuMessage{ | 			messages = append(messages, Message{ | ||||||
| 				Role:    message.Role, | 				Role:    message.Role, | ||||||
| 				Content: message.StringContent(), | 				Content: message.StringContent(), | ||||||
| 			}) | 			}) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	return &ZhipuRequest{ | 	return &Request{ | ||||||
| 		Prompt:      messages, | 		Prompt:      messages, | ||||||
| 		Temperature: request.Temperature, | 		Temperature: request.Temperature, | ||||||
| 		TopP:        request.TopP, | 		TopP:        request.TopP, | ||||||
| @@ -135,18 +97,18 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { | func responseZhipu2OpenAI(response *Response) *openai.TextResponse { | ||||||
| 	fullTextResponse := OpenAITextResponse{ | 	fullTextResponse := openai.TextResponse{ | ||||||
| 		Id:      response.Data.TaskId, | 		Id:      response.Data.TaskId, | ||||||
| 		Object:  "chat.completion", | 		Object:  "chat.completion", | ||||||
| 		Created: common.GetTimestamp(), | 		Created: common.GetTimestamp(), | ||||||
| 		Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)), | 		Choices: make([]openai.TextResponseChoice, 0, len(response.Data.Choices)), | ||||||
| 		Usage:   response.Data.Usage, | 		Usage:   response.Data.Usage, | ||||||
| 	} | 	} | ||||||
| 	for i, choice := range response.Data.Choices { | 	for i, choice := range response.Data.Choices { | ||||||
| 		openaiChoice := OpenAITextResponseChoice{ | 		openaiChoice := openai.TextResponseChoice{ | ||||||
| 			Index: i, | 			Index: i, | ||||||
| 			Message: Message{ | 			Message: openai.Message{ | ||||||
| 				Role:    choice.Role, | 				Role:    choice.Role, | ||||||
| 				Content: strings.Trim(choice.Content, "\""), | 				Content: strings.Trim(choice.Content, "\""), | ||||||
| 			}, | 			}, | ||||||
| @@ -160,34 +122,34 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { | |||||||
| 	return &fullTextResponse | 	return &fullTextResponse | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse { | func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStreamResponse { | ||||||
| 	var choice ChatCompletionsStreamResponseChoice | 	var choice openai.ChatCompletionsStreamResponseChoice | ||||||
| 	choice.Delta.Content = zhipuResponse | 	choice.Delta.Content = zhipuResponse | ||||||
| 	response := ChatCompletionsStreamResponse{ | 	response := openai.ChatCompletionsStreamResponse{ | ||||||
| 		Object:  "chat.completion.chunk", | 		Object:  "chat.completion.chunk", | ||||||
| 		Created: common.GetTimestamp(), | 		Created: common.GetTimestamp(), | ||||||
| 		Model:   "chatglm", | 		Model:   "chatglm", | ||||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||||
| 	} | 	} | ||||||
| 	return &response | 	return &response | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) { | func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *openai.Usage) { | ||||||
| 	var choice ChatCompletionsStreamResponseChoice | 	var choice openai.ChatCompletionsStreamResponseChoice | ||||||
| 	choice.Delta.Content = "" | 	choice.Delta.Content = "" | ||||||
| 	choice.FinishReason = &stopFinishReason | 	choice.FinishReason = &constant.StopFinishReason | ||||||
| 	response := ChatCompletionsStreamResponse{ | 	response := openai.ChatCompletionsStreamResponse{ | ||||||
| 		Id:      zhipuResponse.RequestId, | 		Id:      zhipuResponse.RequestId, | ||||||
| 		Object:  "chat.completion.chunk", | 		Object:  "chat.completion.chunk", | ||||||
| 		Created: common.GetTimestamp(), | 		Created: common.GetTimestamp(), | ||||||
| 		Model:   "chatglm", | 		Model:   "chatglm", | ||||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||||
| 	} | 	} | ||||||
| 	return &response, &zhipuResponse.Usage | 	return &response, &zhipuResponse.Usage | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||||
| 	var usage *Usage | 	var usage *openai.Usage | ||||||
| 	scanner := bufio.NewScanner(resp.Body) | 	scanner := bufio.NewScanner(resp.Body) | ||||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
| 		if atEOF && len(data) == 0 { | 		if atEOF && len(data) == 0 { | ||||||
| @@ -224,7 +186,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt | |||||||
| 		} | 		} | ||||||
| 		stopChan <- true | 		stopChan <- true | ||||||
| 	}() | 	}() | ||||||
| 	setEventStreamHeaders(c) | 	common.SetEventStreamHeaders(c) | ||||||
| 	c.Stream(func(w io.Writer) bool { | 	c.Stream(func(w io.Writer) bool { | ||||||
| 		select { | 		select { | ||||||
| 		case data := <-dataChan: | 		case data := <-dataChan: | ||||||
| @@ -237,7 +199,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt | |||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
| 			return true | 			return true | ||||||
| 		case data := <-metaChan: | 		case data := <-metaChan: | ||||||
| 			var zhipuResponse ZhipuStreamMetaResponse | 			var zhipuResponse StreamMetaResponse | ||||||
| 			err := json.Unmarshal([]byte(data), &zhipuResponse) | 			err := json.Unmarshal([]byte(data), &zhipuResponse) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
| @@ -259,28 +221,28 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt | |||||||
| 	}) | 	}) | ||||||
| 	err := resp.Body.Close() | 	err := resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	return nil, usage | 	return nil, usage | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) { | ||||||
| 	var zhipuResponse ZhipuResponse | 	var zhipuResponse Response | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	err = resp.Body.Close() | 	err = resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(responseBody, &zhipuResponse) | 	err = json.Unmarshal(responseBody, &zhipuResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	if !zhipuResponse.Success { | 	if !zhipuResponse.Success { | ||||||
| 		return &OpenAIErrorWithStatusCode{ | 		return &openai.ErrorWithStatusCode{ | ||||||
| 			OpenAIError: OpenAIError{ | 			Error: openai.Error{ | ||||||
| 				Message: zhipuResponse.Msg, | 				Message: zhipuResponse.Msg, | ||||||
| 				Type:    "zhipu_error", | 				Type:    "zhipu_error", | ||||||
| 				Param:   "", | 				Param:   "", | ||||||
| @@ -293,7 +255,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo | |||||||
| 	fullTextResponse.Model = "chatglm" | 	fullTextResponse.Model = "chatglm" | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
							
								
								
									
										46
									
								
								relay/channel/zhipu/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								relay/channel/zhipu/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,46 @@ | |||||||
|  | package zhipu | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Message struct { | ||||||
|  | 	Role    string `json:"role"` | ||||||
|  | 	Content string `json:"content"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Request struct { | ||||||
|  | 	Prompt      []Message `json:"prompt"` | ||||||
|  | 	Temperature float64   `json:"temperature,omitempty"` | ||||||
|  | 	TopP        float64   `json:"top_p,omitempty"` | ||||||
|  | 	RequestId   string    `json:"request_id,omitempty"` | ||||||
|  | 	Incremental bool      `json:"incremental,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ResponseData struct { | ||||||
|  | 	TaskId       string    `json:"task_id"` | ||||||
|  | 	RequestId    string    `json:"request_id"` | ||||||
|  | 	TaskStatus   string    `json:"task_status"` | ||||||
|  | 	Choices      []Message `json:"choices"` | ||||||
|  | 	openai.Usage `json:"usage"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Response struct { | ||||||
|  | 	Code    int          `json:"code"` | ||||||
|  | 	Msg     string       `json:"msg"` | ||||||
|  | 	Success bool         `json:"success"` | ||||||
|  | 	Data    ResponseData `json:"data"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type StreamMetaResponse struct { | ||||||
|  | 	RequestId    string `json:"request_id"` | ||||||
|  | 	TaskId       string `json:"task_id"` | ||||||
|  | 	TaskStatus   string `json:"task_status"` | ||||||
|  | 	openai.Usage `json:"usage"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type tokenData struct { | ||||||
|  | 	Token      string | ||||||
|  | 	ExpiryTime time.Time | ||||||
|  | } | ||||||
							
								
								
									
										16
									
								
								relay/constant/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								relay/constant/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | |||||||
|  | package constant | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	RelayModeUnknown = iota | ||||||
|  | 	RelayModeChatCompletions | ||||||
|  | 	RelayModeCompletions | ||||||
|  | 	RelayModeEmbeddings | ||||||
|  | 	RelayModeModerations | ||||||
|  | 	RelayModeImagesGenerations | ||||||
|  | 	RelayModeEdits | ||||||
|  | 	RelayModeAudioSpeech | ||||||
|  | 	RelayModeAudioTranscription | ||||||
|  | 	RelayModeAudioTranslation | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var StopFinishReason = "stop" | ||||||
| @@ -12,10 +12,13 @@ import ( | |||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/model" | 	"one-api/model" | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
|  | 	"one-api/relay/constant" | ||||||
|  | 	"one-api/relay/util" | ||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode { | ||||||
| 	audioModel := "whisper-1" | 	audioModel := "whisper-1" | ||||||
| 
 | 
 | ||||||
| 	tokenId := c.GetInt("token_id") | 	tokenId := c.GetInt("token_id") | ||||||
| @@ -25,18 +28,18 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	group := c.GetString("group") | 	group := c.GetString("group") | ||||||
| 	tokenName := c.GetString("token_name") | 	tokenName := c.GetString("token_name") | ||||||
| 
 | 
 | ||||||
| 	var ttsRequest TextToSpeechRequest | 	var ttsRequest openai.TextToSpeechRequest | ||||||
| 	if relayMode == RelayModeAudioSpeech { | 	if relayMode == constant.RelayModeAudioSpeech { | ||||||
| 		// Read JSON | 		// Read JSON | ||||||
| 		err := common.UnmarshalBodyReusable(c, &ttsRequest) | 		err := common.UnmarshalBodyReusable(c, &ttsRequest) | ||||||
| 		// Check if JSON is valid | 		// Check if JSON is valid | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "invalid_json", http.StatusBadRequest) | 			return openai.ErrorWrapper(err, "invalid_json", http.StatusBadRequest) | ||||||
| 		} | 		} | ||||||
| 		audioModel = ttsRequest.Model | 		audioModel = ttsRequest.Model | ||||||
| 		// Check if text is too long 4096 | 		// Check if text is too long 4096 | ||||||
| 		if len(ttsRequest.Input) > 4096 { | 		if len(ttsRequest.Input) > 4096 { | ||||||
| 			return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) | 			return openai.ErrorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @@ -46,7 +49,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	var quota int | 	var quota int | ||||||
| 	var preConsumedQuota int | 	var preConsumedQuota int | ||||||
| 	switch relayMode { | 	switch relayMode { | ||||||
| 	case RelayModeAudioSpeech: | 	case constant.RelayModeAudioSpeech: | ||||||
| 		preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) | 		preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) | ||||||
| 		quota = preConsumedQuota | 		quota = preConsumedQuota | ||||||
| 	default: | 	default: | ||||||
| @@ -54,16 +57,16 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	} | 	} | ||||||
| 	userQuota, err := model.CacheGetUserQuota(userId) | 	userQuota, err := model.CacheGetUserQuota(userId) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Check if user quota is enough | 	// Check if user quota is enough | ||||||
| 	if userQuota-preConsumedQuota < 0 { | 	if userQuota-preConsumedQuota < 0 { | ||||||
| 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | 		return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||||
| 	} | 	} | ||||||
| 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	if userQuota > 100*preConsumedQuota { | 	if userQuota > 100*preConsumedQuota { | ||||||
| 		// in this case, we do not pre-consume quota | 		// in this case, we do not pre-consume quota | ||||||
| @@ -73,7 +76,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	if preConsumedQuota > 0 { | 	if preConsumedQuota > 0 { | ||||||
| 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | 			return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @@ -83,7 +86,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 		modelMap := make(map[string]string) | 		modelMap := make(map[string]string) | ||||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		if modelMap[audioModel] != "" { | 		if modelMap[audioModel] != "" { | ||||||
| 			audioModel = modelMap[audioModel] | 			audioModel = modelMap[audioModel] | ||||||
| @@ -96,27 +99,27 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 		baseURL = c.GetString("base_url") | 		baseURL = c.GetString("base_url") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) | 	fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) | ||||||
| 	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { | 	if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { | ||||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api | 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api | ||||||
| 		apiVersion := GetAPIVersion(c) | 		apiVersion := util.GetAPIVersion(c) | ||||||
| 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) | 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	requestBody := &bytes.Buffer{} | 	requestBody := &bytes.Buffer{} | ||||||
| 	_, err = io.Copy(requestBody, c.Request.Body) | 	_, err = io.Copy(requestBody, c.Request.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) | 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) | ||||||
| 	responseFormat := c.DefaultPostForm("response_format", "json") | 	responseFormat := c.DefaultPostForm("response_format", "json") | ||||||
| 
 | 
 | ||||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { | 	if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { | ||||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api | 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api | ||||||
| 		apiKey := c.Request.Header.Get("Authorization") | 		apiKey := c.Request.Header.Get("Authorization") | ||||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||||
| @@ -128,34 +131,34 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||||
| 
 | 
 | ||||||
| 	resp, err := httpClient.Do(req) | 	resp, err := util.HTTPClient.Do(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = req.Body.Close() | 	err = req.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	err = c.Request.Body.Close() | 	err = c.Request.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if relayMode != RelayModeAudioSpeech { | 	if relayMode != constant.RelayModeAudioSpeech { | ||||||
| 		responseBody, err := io.ReadAll(resp.Body) | 		responseBody, err := io.ReadAll(resp.Body) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		err = resp.Body.Close() | 		err = resp.Body.Close() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		var openAIErr TextResponse | 		var openAIErr openai.SlimTextResponse | ||||||
| 		if err = json.Unmarshal(responseBody, &openAIErr); err == nil { | 		if err = json.Unmarshal(responseBody, &openAIErr); err == nil { | ||||||
| 			if openAIErr.Error.Message != "" { | 			if openAIErr.Error.Message != "" { | ||||||
| 				return errorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) | 				return openai.ErrorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| @@ -172,12 +175,12 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 		case "vtt": | 		case "vtt": | ||||||
| 			text, err = getTextFromVTT(responseBody) | 			text, err = getTextFromVTT(responseBody) | ||||||
| 		default: | 		default: | ||||||
| 			return errorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) | 			return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		quota = countTokenText(text, audioModel) | 		quota = openai.CountTokenText(text, audioModel) | ||||||
| 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||||
| 	} | 	} | ||||||
| 	if resp.StatusCode != http.StatusOK { | 	if resp.StatusCode != http.StatusOK { | ||||||
| @@ -193,11 +196,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 				}() | 				}() | ||||||
| 			}(c.Request.Context()) | 			}(c.Request.Context()) | ||||||
| 		} | 		} | ||||||
| 		return relayErrorHandler(resp) | 		return util.RelayErrorHandler(resp) | ||||||
| 	} | 	} | ||||||
| 	quotaDelta := quota - preConsumedQuota | 	quotaDelta := quota - preConsumedQuota | ||||||
| 	defer func(ctx context.Context) { | 	defer func(ctx context.Context) { | ||||||
| 		go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) | 		go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) | ||||||
| 	}(c.Request.Context()) | 	}(c.Request.Context()) | ||||||
| 
 | 
 | ||||||
| 	for k, v := range resp.Header { | 	for k, v := range resp.Header { | ||||||
| @@ -207,11 +210,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 
 | 
 | ||||||
| 	_, err = io.Copy(c.Writer, resp.Body) | 	_, err = io.Copy(c.Writer, resp.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	err = resp.Body.Close() | 	err = resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| @@ -221,7 +224,7 @@ func getTextFromVTT(body []byte) (string, error) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getTextFromVerboseJSON(body []byte) (string, error) { | func getTextFromVerboseJSON(body []byte) (string, error) { | ||||||
| 	var whisperResponse WhisperVerboseJSONResponse | 	var whisperResponse openai.WhisperVerboseJSONResponse | ||||||
| 	if err := json.Unmarshal(body, &whisperResponse); err != nil { | 	if err := json.Unmarshal(body, &whisperResponse); err != nil { | ||||||
| 		return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) | 		return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) | ||||||
| 	} | 	} | ||||||
| @@ -254,7 +257,7 @@ func getTextFromText(body []byte) (string, error) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getTextFromJSON(body []byte) (string, error) { | func getTextFromJSON(body []byte) (string, error) { | ||||||
| 	var whisperResponse WhisperJSONResponse | 	var whisperResponse openai.WhisperJSONResponse | ||||||
| 	if err := json.Unmarshal(body, &whisperResponse); err != nil { | 	if err := json.Unmarshal(body, &whisperResponse); err != nil { | ||||||
| 		return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) | 		return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) | ||||||
| 	} | 	} | ||||||
| @@ -10,6 +10,8 @@ import ( | |||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/model" | 	"one-api/model" | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
|  | 	"one-api/relay/util" | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| @@ -25,7 +27,7 @@ func isWithinRange(element string, value int) bool { | |||||||
| 	return value >= min && value <= max | 	return value >= min && value <= max | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode { | ||||||
| 	imageModel := "dall-e-2" | 	imageModel := "dall-e-2" | ||||||
| 	imageSize := "1024x1024" | 	imageSize := "1024x1024" | ||||||
| 
 | 
 | ||||||
| @@ -35,10 +37,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	userId := c.GetInt("id") | 	userId := c.GetInt("id") | ||||||
| 	group := c.GetString("group") | 	group := c.GetString("group") | ||||||
| 
 | 
 | ||||||
| 	var imageRequest ImageRequest | 	var imageRequest openai.ImageRequest | ||||||
| 	err := common.UnmarshalBodyReusable(c, &imageRequest) | 	err := common.UnmarshalBodyReusable(c, &imageRequest) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | 		return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if imageRequest.N == 0 { | 	if imageRequest.N == 0 { | ||||||
| @@ -67,24 +69,24 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} else { | 	} else { | ||||||
| 		return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) | 		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Prompt validation | 	// Prompt validation | ||||||
| 	if imageRequest.Prompt == "" { | 	if imageRequest.Prompt == "" { | ||||||
| 		return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | 		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Check prompt length | 	// Check prompt length | ||||||
| 	if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { | 	if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { | ||||||
| 		return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) | 		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Number of generated images validation | 	// Number of generated images validation | ||||||
| 	if isWithinRange(imageModel, imageRequest.N) == false { | 	if isWithinRange(imageModel, imageRequest.N) == false { | ||||||
| 		// channel not azure | 		// channel not azure | ||||||
| 		if channelType != common.ChannelTypeAzure { | 		if channelType != common.ChannelTypeAzure { | ||||||
| 			return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) | 			return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @@ -95,7 +97,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 		modelMap := make(map[string]string) | 		modelMap := make(map[string]string) | ||||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		if modelMap[imageModel] != "" { | 		if modelMap[imageModel] != "" { | ||||||
| 			imageModel = modelMap[imageModel] | 			imageModel = modelMap[imageModel] | ||||||
| @@ -107,10 +109,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	if c.GetString("base_url") != "" { | 	if c.GetString("base_url") != "" { | ||||||
| 		baseURL = c.GetString("base_url") | 		baseURL = c.GetString("base_url") | ||||||
| 	} | 	} | ||||||
| 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) | 	fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) | ||||||
| 	if channelType == common.ChannelTypeAzure { | 	if channelType == common.ChannelTypeAzure { | ||||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api | 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api | ||||||
| 		apiVersion := GetAPIVersion(c) | 		apiVersion := util.GetAPIVersion(c) | ||||||
| 		// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview | 		// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview | ||||||
| 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) | 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) | ||||||
| 	} | 	} | ||||||
| @@ -119,7 +121,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body | 	if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body | ||||||
| 		jsonStr, err := json.Marshal(imageRequest) | 		jsonStr, err := json.Marshal(imageRequest) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
| 	} else { | 	} else { | ||||||
| @@ -134,12 +136,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	quota := int(ratio*imageCostRatio*1000) * imageRequest.N | 	quota := int(ratio*imageCostRatio*1000) * imageRequest.N | ||||||
| 
 | 
 | ||||||
| 	if userQuota-quota < 0 { | 	if userQuota-quota < 0 { | ||||||
| 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | 		return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	token := c.Request.Header.Get("Authorization") | 	token := c.Request.Header.Get("Authorization") | ||||||
| 	if channelType == common.ChannelTypeAzure { // Azure authentication | 	if channelType == common.ChannelTypeAzure { // Azure authentication | ||||||
| @@ -152,20 +154,20 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||||
| 
 | 
 | ||||||
| 	resp, err := httpClient.Do(req) | 	resp, err := util.HTTPClient.Do(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = req.Body.Close() | 	err = req.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	err = c.Request.Body.Close() | 	err = c.Request.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	var textResponse ImageResponse | 	var textResponse openai.ImageResponse | ||||||
| 
 | 
 | ||||||
| 	defer func(ctx context.Context) { | 	defer func(ctx context.Context) { | ||||||
| 		if resp.StatusCode != http.StatusOK { | 		if resp.StatusCode != http.StatusOK { | ||||||
| @@ -192,15 +194,15 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	responseBody, err := io.ReadAll(resp.Body) | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
| 
 | 
 | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	err = resp.Body.Close() | 	err = resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(responseBody, &textResponse) | 	err = json.Unmarshal(responseBody, &textResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | 	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||||
| @@ -212,11 +214,11 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 
 | 
 | ||||||
| 	_, err = io.Copy(c.Writer, resp.Body) | 	_, err = io.Copy(c.Writer, resp.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	err = resp.Body.Close() | 	err = resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| @@ -6,15 +6,24 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| 	"io" | 	"io" | ||||||
| 	"math" | 	"math" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/model" | 	"one-api/model" | ||||||
|  | 	"one-api/relay/channel/aiproxy" | ||||||
|  | 	"one-api/relay/channel/ali" | ||||||
|  | 	"one-api/relay/channel/anthropic" | ||||||
|  | 	"one-api/relay/channel/baidu" | ||||||
|  | 	"one-api/relay/channel/google" | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
|  | 	"one-api/relay/channel/tencent" | ||||||
|  | 	"one-api/relay/channel/xunfei" | ||||||
|  | 	"one-api/relay/channel/zhipu" | ||||||
|  | 	"one-api/relay/constant" | ||||||
|  | 	"one-api/relay/util" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" |  | ||||||
| 
 |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| @@ -30,64 +39,47 @@ const ( | |||||||
| 	APITypeGemini | 	APITypeGemini | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var httpClient *http.Client | func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode { | ||||||
| var impatientHTTPClient *http.Client |  | ||||||
| 
 |  | ||||||
| func init() { |  | ||||||
| 	if common.RelayTimeout == 0 { |  | ||||||
| 		httpClient = &http.Client{} |  | ||||||
| 	} else { |  | ||||||
| 		httpClient = &http.Client{ |  | ||||||
| 			Timeout: time.Duration(common.RelayTimeout) * time.Second, |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	impatientHTTPClient = &http.Client{ |  | ||||||
| 		Timeout: 5 * time.Second, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { |  | ||||||
| 	channelType := c.GetInt("channel") | 	channelType := c.GetInt("channel") | ||||||
| 	channelId := c.GetInt("channel_id") | 	channelId := c.GetInt("channel_id") | ||||||
| 	tokenId := c.GetInt("token_id") | 	tokenId := c.GetInt("token_id") | ||||||
| 	userId := c.GetInt("id") | 	userId := c.GetInt("id") | ||||||
| 	group := c.GetString("group") | 	group := c.GetString("group") | ||||||
| 	var textRequest GeneralOpenAIRequest | 	var textRequest openai.GeneralOpenAIRequest | ||||||
| 	err := common.UnmarshalBodyReusable(c, &textRequest) | 	err := common.UnmarshalBodyReusable(c, &textRequest) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | 		return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||||
| 	} | 	} | ||||||
| 	if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 { | 	if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 { | ||||||
| 		return errorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest) | 		return openai.ErrorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest) | ||||||
| 	} | 	} | ||||||
| 	if relayMode == RelayModeModerations && textRequest.Model == "" { | 	if relayMode == constant.RelayModeModerations && textRequest.Model == "" { | ||||||
| 		textRequest.Model = "text-moderation-latest" | 		textRequest.Model = "text-moderation-latest" | ||||||
| 	} | 	} | ||||||
| 	if relayMode == RelayModeEmbeddings && textRequest.Model == "" { | 	if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" { | ||||||
| 		textRequest.Model = c.Param("model") | 		textRequest.Model = c.Param("model") | ||||||
| 	} | 	} | ||||||
| 	// request validation | 	// request validation | ||||||
| 	if textRequest.Model == "" { | 	if textRequest.Model == "" { | ||||||
| 		return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) | 		return openai.ErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) | ||||||
| 	} | 	} | ||||||
| 	switch relayMode { | 	switch relayMode { | ||||||
| 	case RelayModeCompletions: | 	case constant.RelayModeCompletions: | ||||||
| 		if textRequest.Prompt == "" { | 		if textRequest.Prompt == "" { | ||||||
| 			return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest) | 			return openai.ErrorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest) | ||||||
| 		} | 		} | ||||||
| 	case RelayModeChatCompletions: | 	case constant.RelayModeChatCompletions: | ||||||
| 		if textRequest.Messages == nil || len(textRequest.Messages) == 0 { | 		if textRequest.Messages == nil || len(textRequest.Messages) == 0 { | ||||||
| 			return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest) | 			return openai.ErrorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest) | ||||||
| 		} | 		} | ||||||
| 	case RelayModeEmbeddings: | 	case constant.RelayModeEmbeddings: | ||||||
| 	case RelayModeModerations: | 	case constant.RelayModeModerations: | ||||||
| 		if textRequest.Input == "" { | 		if textRequest.Input == "" { | ||||||
| 			return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) | 			return openai.ErrorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) | ||||||
| 		} | 		} | ||||||
| 	case RelayModeEdits: | 	case constant.RelayModeEdits: | ||||||
| 		if textRequest.Instruction == "" { | 		if textRequest.Instruction == "" { | ||||||
| 			return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest) | 			return openai.ErrorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	// map model name | 	// map model name | ||||||
| @@ -97,7 +89,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		modelMap := make(map[string]string) | 		modelMap := make(map[string]string) | ||||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		if modelMap[textRequest.Model] != "" { | 		if modelMap[textRequest.Model] != "" { | ||||||
| 			textRequest.Model = modelMap[textRequest.Model] | 			textRequest.Model = modelMap[textRequest.Model] | ||||||
| @@ -130,12 +122,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 	if c.GetString("base_url") != "" { | 	if c.GetString("base_url") != "" { | ||||||
| 		baseURL = c.GetString("base_url") | 		baseURL = c.GetString("base_url") | ||||||
| 	} | 	} | ||||||
| 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) | 	fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) | ||||||
| 	switch apiType { | 	switch apiType { | ||||||
| 	case APITypeOpenAI: | 	case APITypeOpenAI: | ||||||
| 		if channelType == common.ChannelTypeAzure { | 		if channelType == common.ChannelTypeAzure { | ||||||
| 			// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | 			// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | ||||||
| 			apiVersion := GetAPIVersion(c) | 			apiVersion := util.GetAPIVersion(c) | ||||||
| 			requestURL := strings.Split(requestURL, "?")[0] | 			requestURL := strings.Split(requestURL, "?")[0] | ||||||
| 			requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) | 			requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) | ||||||
| 			baseURL = c.GetString("base_url") | 			baseURL = c.GetString("base_url") | ||||||
| @@ -148,7 +140,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			model_ = strings.TrimSuffix(model_, "-0613") | 			model_ = strings.TrimSuffix(model_, "-0613") | ||||||
| 
 | 
 | ||||||
| 			requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) | 			requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) | ||||||
| 			fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType) | 			fullRequestURL = util.GetFullRequestURL(baseURL, requestURL, channelType) | ||||||
| 		} | 		} | ||||||
| 	case APITypeClaude: | 	case APITypeClaude: | ||||||
| 		fullRequestURL = "https://api.anthropic.com/v1/complete" | 		fullRequestURL = "https://api.anthropic.com/v1/complete" | ||||||
| @@ -171,8 +163,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		apiKey := c.Request.Header.Get("Authorization") | 		apiKey := c.Request.Header.Get("Authorization") | ||||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||||
| 		var err error | 		var err error | ||||||
| 		if apiKey, err = getBaiduAccessToken(apiKey); err != nil { | 		if apiKey, err = baidu.GetAccessToken(apiKey); err != nil { | ||||||
| 			return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		fullRequestURL += "?access_token=" + apiKey | 		fullRequestURL += "?access_token=" + apiKey | ||||||
| 	case APITypePaLM: | 	case APITypePaLM: | ||||||
| @@ -202,7 +194,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) | 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) | ||||||
| 	case APITypeAli: | 	case APITypeAli: | ||||||
| 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" | 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" | ||||||
| 		if relayMode == RelayModeEmbeddings { | 		if relayMode == constant.RelayModeEmbeddings { | ||||||
| 			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" | 			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" | ||||||
| 		} | 		} | ||||||
| 	case APITypeTencent: | 	case APITypeTencent: | ||||||
| @@ -213,12 +205,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 	var promptTokens int | 	var promptTokens int | ||||||
| 	var completionTokens int | 	var completionTokens int | ||||||
| 	switch relayMode { | 	switch relayMode { | ||||||
| 	case RelayModeChatCompletions: | 	case constant.RelayModeChatCompletions: | ||||||
| 		promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) | 		promptTokens = openai.CountTokenMessages(textRequest.Messages, textRequest.Model) | ||||||
| 	case RelayModeCompletions: | 	case constant.RelayModeCompletions: | ||||||
| 		promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) | 		promptTokens = openai.CountTokenInput(textRequest.Prompt, textRequest.Model) | ||||||
| 	case RelayModeModerations: | 	case constant.RelayModeModerations: | ||||||
| 		promptTokens = countTokenInput(textRequest.Input, textRequest.Model) | 		promptTokens = openai.CountTokenInput(textRequest.Input, textRequest.Model) | ||||||
| 	} | 	} | ||||||
| 	preConsumedTokens := common.PreConsumedQuota | 	preConsumedTokens := common.PreConsumedQuota | ||||||
| 	if textRequest.MaxTokens != 0 { | 	if textRequest.MaxTokens != 0 { | ||||||
| @@ -230,14 +222,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 	preConsumedQuota := int(float64(preConsumedTokens) * ratio) | 	preConsumedQuota := int(float64(preConsumedTokens) * ratio) | ||||||
| 	userQuota, err := model.CacheGetUserQuota(userId) | 	userQuota, err := model.CacheGetUserQuota(userId) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	if userQuota-preConsumedQuota < 0 { | 	if userQuota-preConsumedQuota < 0 { | ||||||
| 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | 		return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||||
| 	} | 	} | ||||||
| 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	if userQuota > 100*preConsumedQuota { | 	if userQuota > 100*preConsumedQuota { | ||||||
| 		// in this case, we do not pre-consume quota | 		// in this case, we do not pre-consume quota | ||||||
| @@ -248,14 +240,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 	if preConsumedQuota > 0 { | 	if preConsumedQuota > 0 { | ||||||
| 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | 			return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	var requestBody io.Reader | 	var requestBody io.Reader | ||||||
| 	if isModelMapped { | 	if isModelMapped { | ||||||
| 		jsonStr, err := json.Marshal(textRequest) | 		jsonStr, err := json.Marshal(textRequest) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
| 	} else { | 	} else { | ||||||
| @@ -263,86 +255,86 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 	} | 	} | ||||||
| 	switch apiType { | 	switch apiType { | ||||||
| 	case APITypeClaude: | 	case APITypeClaude: | ||||||
| 		claudeRequest := requestOpenAI2Claude(textRequest) | 		claudeRequest := anthropic.ConvertRequest(textRequest) | ||||||
| 		jsonStr, err := json.Marshal(claudeRequest) | 		jsonStr, err := json.Marshal(claudeRequest) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
| 	case APITypeBaidu: | 	case APITypeBaidu: | ||||||
| 		var jsonData []byte | 		var jsonData []byte | ||||||
| 		var err error | 		var err error | ||||||
| 		switch relayMode { | 		switch relayMode { | ||||||
| 		case RelayModeEmbeddings: | 		case constant.RelayModeEmbeddings: | ||||||
| 			baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest) | 			baiduEmbeddingRequest := baidu.ConvertEmbeddingRequest(textRequest) | ||||||
| 			jsonData, err = json.Marshal(baiduEmbeddingRequest) | 			jsonData, err = json.Marshal(baiduEmbeddingRequest) | ||||||
| 		default: | 		default: | ||||||
| 			baiduRequest := requestOpenAI2Baidu(textRequest) | 			baiduRequest := baidu.ConvertRequest(textRequest) | ||||||
| 			jsonData, err = json.Marshal(baiduRequest) | 			jsonData, err = json.Marshal(baiduRequest) | ||||||
| 		} | 		} | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		requestBody = bytes.NewBuffer(jsonData) | 		requestBody = bytes.NewBuffer(jsonData) | ||||||
| 	case APITypePaLM: | 	case APITypePaLM: | ||||||
| 		palmRequest := requestOpenAI2PaLM(textRequest) | 		palmRequest := google.ConvertPaLMRequest(textRequest) | ||||||
| 		jsonStr, err := json.Marshal(palmRequest) | 		jsonStr, err := json.Marshal(palmRequest) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
| 	case APITypeGemini: | 	case APITypeGemini: | ||||||
| 		geminiChatRequest := requestOpenAI2Gemini(textRequest) | 		geminiChatRequest := google.ConvertGeminiRequest(textRequest) | ||||||
| 		jsonStr, err := json.Marshal(geminiChatRequest) | 		jsonStr, err := json.Marshal(geminiChatRequest) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
| 	case APITypeZhipu: | 	case APITypeZhipu: | ||||||
| 		zhipuRequest := requestOpenAI2Zhipu(textRequest) | 		zhipuRequest := zhipu.ConvertRequest(textRequest) | ||||||
| 		jsonStr, err := json.Marshal(zhipuRequest) | 		jsonStr, err := json.Marshal(zhipuRequest) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
| 	case APITypeAli: | 	case APITypeAli: | ||||||
| 		var jsonStr []byte | 		var jsonStr []byte | ||||||
| 		var err error | 		var err error | ||||||
| 		switch relayMode { | 		switch relayMode { | ||||||
| 		case RelayModeEmbeddings: | 		case constant.RelayModeEmbeddings: | ||||||
| 			aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest) | 			aliEmbeddingRequest := ali.ConvertEmbeddingRequest(textRequest) | ||||||
| 			jsonStr, err = json.Marshal(aliEmbeddingRequest) | 			jsonStr, err = json.Marshal(aliEmbeddingRequest) | ||||||
| 		default: | 		default: | ||||||
| 			aliRequest := requestOpenAI2Ali(textRequest) | 			aliRequest := ali.ConvertRequest(textRequest) | ||||||
| 			jsonStr, err = json.Marshal(aliRequest) | 			jsonStr, err = json.Marshal(aliRequest) | ||||||
| 		} | 		} | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
| 	case APITypeTencent: | 	case APITypeTencent: | ||||||
| 		apiKey := c.Request.Header.Get("Authorization") | 		apiKey := c.Request.Header.Get("Authorization") | ||||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||||
| 		appId, secretId, secretKey, err := parseTencentConfig(apiKey) | 		appId, secretId, secretKey, err := tencent.ParseConfig(apiKey) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		tencentRequest := requestOpenAI2Tencent(textRequest) | 		tencentRequest := tencent.ConvertRequest(textRequest) | ||||||
| 		tencentRequest.AppId = appId | 		tencentRequest.AppId = appId | ||||||
| 		tencentRequest.SecretId = secretId | 		tencentRequest.SecretId = secretId | ||||||
| 		jsonStr, err := json.Marshal(tencentRequest) | 		jsonStr, err := json.Marshal(tencentRequest) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		sign := getTencentSign(*tencentRequest, secretKey) | 		sign := tencent.GetSign(*tencentRequest, secretKey) | ||||||
| 		c.Request.Header.Set("Authorization", sign) | 		c.Request.Header.Set("Authorization", sign) | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
| 	case APITypeAIProxyLibrary: | 	case APITypeAIProxyLibrary: | ||||||
| 		aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) | 		aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest) | ||||||
| 		aiProxyLibraryRequest.LibraryId = c.GetString("library_id") | 		aiProxyLibraryRequest.LibraryId = c.GetString("library_id") | ||||||
| 		jsonStr, err := json.Marshal(aiProxyLibraryRequest) | 		jsonStr, err := json.Marshal(aiProxyLibraryRequest) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
| 	} | 	} | ||||||
| @@ -354,7 +346,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 	if apiType != APITypeXunfei { // cause xunfei use websocket | 	if apiType != APITypeXunfei { // cause xunfei use websocket | ||||||
| 		req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | 		req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		apiKey := c.Request.Header.Get("Authorization") | 		apiKey := c.Request.Header.Get("Authorization") | ||||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||||
| @@ -377,7 +369,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			} | 			} | ||||||
| 			req.Header.Set("anthropic-version", anthropicVersion) | 			req.Header.Set("anthropic-version", anthropicVersion) | ||||||
| 		case APITypeZhipu: | 		case APITypeZhipu: | ||||||
| 			token := getZhipuToken(apiKey) | 			token := zhipu.GetToken(apiKey) | ||||||
| 			req.Header.Set("Authorization", token) | 			req.Header.Set("Authorization", token) | ||||||
| 		case APITypeAli: | 		case APITypeAli: | ||||||
| 			req.Header.Set("Authorization", "Bearer "+apiKey) | 			req.Header.Set("Authorization", "Bearer "+apiKey) | ||||||
| @@ -402,17 +394,17 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			req.Header.Set("Accept", "text/event-stream") | 			req.Header.Set("Accept", "text/event-stream") | ||||||
| 		} | 		} | ||||||
| 		//req.Header.Set("Connection", c.Request.Header.Get("Connection")) | 		//req.Header.Set("Connection", c.Request.Header.Get("Connection")) | ||||||
| 		resp, err = httpClient.Do(req) | 		resp, err = util.HTTPClient.Do(req) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		err = req.Body.Close() | 		err = req.Body.Close() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		err = c.Request.Body.Close() | 		err = c.Request.Body.Close() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | 		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||||
| 
 | 
 | ||||||
| @@ -426,11 +418,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 					} | 					} | ||||||
| 				}(c.Request.Context()) | 				}(c.Request.Context()) | ||||||
| 			} | 			} | ||||||
| 			return relayErrorHandler(resp) | 			return util.RelayErrorHandler(resp) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var textResponse TextResponse | 	var textResponse openai.SlimTextResponse | ||||||
| 	tokenName := c.GetString("token_name") | 	tokenName := c.GetString("token_name") | ||||||
| 
 | 
 | ||||||
| 	defer func(ctx context.Context) { | 	defer func(ctx context.Context) { | ||||||
| @@ -471,15 +463,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 	switch apiType { | 	switch apiType { | ||||||
| 	case APITypeOpenAI: | 	case APITypeOpenAI: | ||||||
| 		if isStream { | 		if isStream { | ||||||
| 			err, responseText := openaiStreamHandler(c, resp, relayMode) | 			err, responseText := openai.StreamHandler(c, resp, relayMode) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 			textResponse.Usage.PromptTokens = promptTokens | 			textResponse.Usage.PromptTokens = promptTokens | ||||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | 			textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) | ||||||
| 			return nil | 			return nil | ||||||
| 		} else { | 		} else { | ||||||
| 			err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model) | 			err, usage := openai.Handler(c, resp, promptTokens, textRequest.Model) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| @@ -490,15 +482,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		} | 		} | ||||||
| 	case APITypeClaude: | 	case APITypeClaude: | ||||||
| 		if isStream { | 		if isStream { | ||||||
| 			err, responseText := claudeStreamHandler(c, resp) | 			err, responseText := anthropic.StreamHandler(c, resp) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 			textResponse.Usage.PromptTokens = promptTokens | 			textResponse.Usage.PromptTokens = promptTokens | ||||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | 			textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) | ||||||
| 			return nil | 			return nil | ||||||
| 		} else { | 		} else { | ||||||
| 			err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) | 			err, usage := anthropic.Handler(c, resp, promptTokens, textRequest.Model) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| @@ -509,7 +501,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		} | 		} | ||||||
| 	case APITypeBaidu: | 	case APITypeBaidu: | ||||||
| 		if isStream { | 		if isStream { | ||||||
| 			err, usage := baiduStreamHandler(c, resp) | 			err, usage := baidu.StreamHandler(c, resp) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| @@ -518,13 +510,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			} | 			} | ||||||
| 			return nil | 			return nil | ||||||
| 		} else { | 		} else { | ||||||
| 			var err *OpenAIErrorWithStatusCode | 			var err *openai.ErrorWithStatusCode | ||||||
| 			var usage *Usage | 			var usage *openai.Usage | ||||||
| 			switch relayMode { | 			switch relayMode { | ||||||
| 			case RelayModeEmbeddings: | 			case constant.RelayModeEmbeddings: | ||||||
| 				err, usage = baiduEmbeddingHandler(c, resp) | 				err, usage = baidu.EmbeddingHandler(c, resp) | ||||||
| 			default: | 			default: | ||||||
| 				err, usage = baiduHandler(c, resp) | 				err, usage = baidu.Handler(c, resp) | ||||||
| 			} | 			} | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| @@ -536,15 +528,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		} | 		} | ||||||
| 	case APITypePaLM: | 	case APITypePaLM: | ||||||
| 		if textRequest.Stream { // PaLM2 API does not support stream | 		if textRequest.Stream { // PaLM2 API does not support stream | ||||||
| 			err, responseText := palmStreamHandler(c, resp) | 			err, responseText := google.PaLMStreamHandler(c, resp) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 			textResponse.Usage.PromptTokens = promptTokens | 			textResponse.Usage.PromptTokens = promptTokens | ||||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | 			textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) | ||||||
| 			return nil | 			return nil | ||||||
| 		} else { | 		} else { | ||||||
| 			err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) | 			err, usage := google.PaLMHandler(c, resp, promptTokens, textRequest.Model) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| @@ -555,15 +547,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		} | 		} | ||||||
| 	case APITypeGemini: | 	case APITypeGemini: | ||||||
| 		if textRequest.Stream { | 		if textRequest.Stream { | ||||||
| 			err, responseText := geminiChatStreamHandler(c, resp) | 			err, responseText := google.StreamHandler(c, resp) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 			textResponse.Usage.PromptTokens = promptTokens | 			textResponse.Usage.PromptTokens = promptTokens | ||||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | 			textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) | ||||||
| 			return nil | 			return nil | ||||||
| 		} else { | 		} else { | ||||||
| 			err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model) | 			err, usage := google.GeminiHandler(c, resp, promptTokens, textRequest.Model) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| @@ -574,7 +566,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		} | 		} | ||||||
| 	case APITypeZhipu: | 	case APITypeZhipu: | ||||||
| 		if isStream { | 		if isStream { | ||||||
| 			err, usage := zhipuStreamHandler(c, resp) | 			err, usage := zhipu.StreamHandler(c, resp) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| @@ -585,7 +577,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens | 			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens | ||||||
| 			return nil | 			return nil | ||||||
| 		} else { | 		} else { | ||||||
| 			err, usage := zhipuHandler(c, resp) | 			err, usage := zhipu.Handler(c, resp) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| @@ -598,7 +590,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		} | 		} | ||||||
| 	case APITypeAli: | 	case APITypeAli: | ||||||
| 		if isStream { | 		if isStream { | ||||||
| 			err, usage := aliStreamHandler(c, resp) | 			err, usage := ali.StreamHandler(c, resp) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| @@ -607,13 +599,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			} | 			} | ||||||
| 			return nil | 			return nil | ||||||
| 		} else { | 		} else { | ||||||
| 			var err *OpenAIErrorWithStatusCode | 			var err *openai.ErrorWithStatusCode | ||||||
| 			var usage *Usage | 			var usage *openai.Usage | ||||||
| 			switch relayMode { | 			switch relayMode { | ||||||
| 			case RelayModeEmbeddings: | 			case constant.RelayModeEmbeddings: | ||||||
| 				err, usage = aliEmbeddingHandler(c, resp) | 				err, usage = ali.EmbeddingHandler(c, resp) | ||||||
| 			default: | 			default: | ||||||
| 				err, usage = aliHandler(c, resp) | 				err, usage = ali.Handler(c, resp) | ||||||
| 			} | 			} | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| @@ -628,14 +620,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		auth = strings.TrimPrefix(auth, "Bearer ") | 		auth = strings.TrimPrefix(auth, "Bearer ") | ||||||
| 		splits := strings.Split(auth, "|") | 		splits := strings.Split(auth, "|") | ||||||
| 		if len(splits) != 3 { | 		if len(splits) != 3 { | ||||||
| 			return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) | 			return openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) | ||||||
| 		} | 		} | ||||||
| 		var err *OpenAIErrorWithStatusCode | 		var err *openai.ErrorWithStatusCode | ||||||
| 		var usage *Usage | 		var usage *openai.Usage | ||||||
| 		if isStream { | 		if isStream { | ||||||
| 			err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) | 			err, usage = xunfei.StreamHandler(c, textRequest, splits[0], splits[1], splits[2]) | ||||||
| 		} else { | 		} else { | ||||||
| 			err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2]) | 			err, usage = xunfei.Handler(c, textRequest, splits[0], splits[1], splits[2]) | ||||||
| 		} | 		} | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| @@ -646,7 +638,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		return nil | 		return nil | ||||||
| 	case APITypeAIProxyLibrary: | 	case APITypeAIProxyLibrary: | ||||||
| 		if isStream { | 		if isStream { | ||||||
| 			err, usage := aiProxyLibraryStreamHandler(c, resp) | 			err, usage := aiproxy.StreamHandler(c, resp) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| @@ -655,7 +647,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			} | 			} | ||||||
| 			return nil | 			return nil | ||||||
| 		} else { | 		} else { | ||||||
| 			err, usage := aiProxyLibraryHandler(c, resp) | 			err, usage := aiproxy.Handler(c, resp) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| @@ -666,15 +658,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		} | 		} | ||||||
| 	case APITypeTencent: | 	case APITypeTencent: | ||||||
| 		if isStream { | 		if isStream { | ||||||
| 			err, responseText := tencentStreamHandler(c, resp) | 			err, responseText := tencent.StreamHandler(c, resp) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 			textResponse.Usage.PromptTokens = promptTokens | 			textResponse.Usage.PromptTokens = promptTokens | ||||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | 			textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) | ||||||
| 			return nil | 			return nil | ||||||
| 		} else { | 		} else { | ||||||
| 			err, usage := tencentHandler(c, resp) | 			err, usage := tencent.Handler(c, resp) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| @@ -684,6 +676,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			return nil | 			return nil | ||||||
| 		} | 		} | ||||||
| 	default: | 	default: | ||||||
| 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | 		return openai.ErrorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
							
								
								
									
										166
									
								
								relay/util/common.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										166
									
								
								relay/util/common.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,166 @@ | |||||||
|  | package util | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	"one-api/relay/channel/openai" | ||||||
|  | 	"strconv" | ||||||
|  | 	"strings" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func ShouldDisableChannel(err *openai.Error, statusCode int) bool { | ||||||
|  | 	if !common.AutomaticDisableChannelEnabled { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if err == nil { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if statusCode == http.StatusUnauthorized { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ShouldEnableChannel(err error, openAIErr *openai.Error) bool { | ||||||
|  | 	if !common.AutomaticEnableChannelEnabled { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if openAIErr != nil { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	return true | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type GeneralErrorResponse struct { | ||||||
|  | 	Error    openai.Error `json:"error"` | ||||||
|  | 	Message  string       `json:"message"` | ||||||
|  | 	Msg      string       `json:"msg"` | ||||||
|  | 	Err      string       `json:"err"` | ||||||
|  | 	ErrorMsg string       `json:"error_msg"` | ||||||
|  | 	Header   struct { | ||||||
|  | 		Message string `json:"message"` | ||||||
|  | 	} `json:"header"` | ||||||
|  | 	Response struct { | ||||||
|  | 		Error struct { | ||||||
|  | 			Message string `json:"message"` | ||||||
|  | 		} `json:"error"` | ||||||
|  | 	} `json:"response"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (e GeneralErrorResponse) ToMessage() string { | ||||||
|  | 	if e.Error.Message != "" { | ||||||
|  | 		return e.Error.Message | ||||||
|  | 	} | ||||||
|  | 	if e.Message != "" { | ||||||
|  | 		return e.Message | ||||||
|  | 	} | ||||||
|  | 	if e.Msg != "" { | ||||||
|  | 		return e.Msg | ||||||
|  | 	} | ||||||
|  | 	if e.Err != "" { | ||||||
|  | 		return e.Err | ||||||
|  | 	} | ||||||
|  | 	if e.ErrorMsg != "" { | ||||||
|  | 		return e.ErrorMsg | ||||||
|  | 	} | ||||||
|  | 	if e.Header.Message != "" { | ||||||
|  | 		return e.Header.Message | ||||||
|  | 	} | ||||||
|  | 	if e.Response.Error.Message != "" { | ||||||
|  | 		return e.Response.Error.Message | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *openai.ErrorWithStatusCode) { | ||||||
|  | 	ErrorWithStatusCode = &openai.ErrorWithStatusCode{ | ||||||
|  | 		StatusCode: resp.StatusCode, | ||||||
|  | 		Error: openai.Error{ | ||||||
|  | 			Message: "", | ||||||
|  | 			Type:    "upstream_error", | ||||||
|  | 			Code:    "bad_response_status_code", | ||||||
|  | 			Param:   strconv.Itoa(resp.StatusCode), | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	var errResponse GeneralErrorResponse | ||||||
|  | 	err = json.Unmarshal(responseBody, &errResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if errResponse.Error.Message != "" { | ||||||
|  | 		// OpenAI format error, so we override the default one | ||||||
|  | 		ErrorWithStatusCode.Error = errResponse.Error | ||||||
|  | 	} else { | ||||||
|  | 		ErrorWithStatusCode.Error.Message = errResponse.ToMessage() | ||||||
|  | 	} | ||||||
|  | 	if ErrorWithStatusCode.Error.Message == "" { | ||||||
|  | 		ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { | ||||||
|  | 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||||
|  |  | ||||||
|  | 	if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { | ||||||
|  | 		switch channelType { | ||||||
|  | 		case common.ChannelTypeOpenAI: | ||||||
|  | 			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) | ||||||
|  | 		case common.ChannelTypeAzure: | ||||||
|  | 			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return fullRequestURL | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { | ||||||
|  | 	// quotaDelta is remaining quota to be consumed | ||||||
|  | 	err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.SysError("error consuming token remain quota: " + err.Error()) | ||||||
|  | 	} | ||||||
|  | 	err = model.CacheUpdateUserQuota(userId) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.SysError("error update user quota cache: " + err.Error()) | ||||||
|  | 	} | ||||||
|  | 	// totalQuota is total quota consumed | ||||||
|  | 	if totalQuota != 0 { | ||||||
|  | 		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||||
|  | 		model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) | ||||||
|  | 		model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) | ||||||
|  | 		model.UpdateChannelUsedQuota(channelId, totalQuota) | ||||||
|  | 	} | ||||||
|  | 	if totalQuota <= 0 { | ||||||
|  | 		common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetAPIVersion(c *gin.Context) string { | ||||||
|  | 	query := c.Request.URL.Query() | ||||||
|  | 	apiVersion := query.Get("api-version") | ||||||
|  | 	if apiVersion == "" { | ||||||
|  | 		apiVersion = c.GetString("api_version") | ||||||
|  | 	} | ||||||
|  | 	return apiVersion | ||||||
|  | } | ||||||
							
								
								
									
										24
									
								
								relay/util/init.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								relay/util/init.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | |||||||
|  | package util | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var HTTPClient *http.Client | ||||||
|  | var ImpatientHTTPClient *http.Client | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	if common.RelayTimeout == 0 { | ||||||
|  | 		HTTPClient = &http.Client{} | ||||||
|  | 	} else { | ||||||
|  | 		HTTPClient = &http.Client{ | ||||||
|  | 			Timeout: time.Duration(common.RelayTimeout) * time.Second, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	ImpatientHTTPClient = &http.Client{ | ||||||
|  | 		Timeout: 5 * time.Second, | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user