mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	Compare commits
	
		
			14 Commits
		
	
	
		
			v0.5.11-al
			...
			refactor
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					e12b0c7aa8 | ||
| 
						 | 
					f2c51a494c | ||
| 
						 | 
					8a4d6f3327 | ||
| 
						 | 
					cf4e33cb12 | ||
| 
						 | 
					5d60305570 | ||
| 
						 | 
					d062bc60e4 | ||
| 
						 | 
					39c1882970 | ||
| 
						 | 
					9c42c7dfd9 | ||
| 
						 | 
					903aaeded0 | ||
| 
						 | 
					bdd4be562d | ||
| 
						 | 
					37afb313b5 | ||
| 
						 | 
					c9ebcab8b8 | ||
| 
						 | 
					86261cc656 | ||
| 
						 | 
					8491785c9d | 
@@ -414,6 +414,9 @@ https://openai.justsong.cn
 | 
				
			|||||||
8. 升级之前数据库需要做变更吗?
 | 
					8. 升级之前数据库需要做变更吗?
 | 
				
			||||||
   + 一般情况下不需要,系统将在初始化的时候自动调整。
 | 
					   + 一般情况下不需要,系统将在初始化的时候自动调整。
 | 
				
			||||||
   + 如果需要的话,我会在更新日志中说明,并给出脚本。
 | 
					   + 如果需要的话,我会在更新日志中说明,并给出脚本。
 | 
				
			||||||
 | 
					9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`?
 | 
				
			||||||
 | 
					   + 这是检测到 ability 表里有些记录的通道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的通道。
 | 
				
			||||||
 | 
					   + 对于每一个通道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该通道支持该模型。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## 相关项目
 | 
					## 相关项目
 | 
				
			||||||
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
 | 
					* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -101,6 +101,10 @@ var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
 | 
				
			|||||||
var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
 | 
					var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var Theme = GetOrDefaultString("THEME", "default")
 | 
					var Theme = GetOrDefaultString("THEME", "default")
 | 
				
			||||||
 | 
					var ValidThemes = map[string]bool{
 | 
				
			||||||
 | 
						"default": true,
 | 
				
			||||||
 | 
						"berry":   true,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	RequestIdKey = "X-Oneapi-Request-Id"
 | 
						RequestIdKey = "X-Oneapi-Request-Id"
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -31,3 +31,11 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
 | 
				
			|||||||
	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
						c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func SetEventStreamHeaders(c *gin.Context) {
 | 
				
			||||||
 | 
						c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
				
			||||||
 | 
						c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
				
			||||||
 | 
						c.Writer.Header().Set("Connection", "keep-alive")
 | 
				
			||||||
 | 
						c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
				
			||||||
 | 
						c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,6 +4,7 @@ import (
 | 
				
			|||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
	"one-api/model"
 | 
						"one-api/model"
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetSubscription(c *gin.Context) {
 | 
					func GetSubscription(c *gin.Context) {
 | 
				
			||||||
@@ -27,12 +28,12 @@ func GetSubscription(c *gin.Context) {
 | 
				
			|||||||
		expiredTime = 0
 | 
							expiredTime = 0
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		openAIError := OpenAIError{
 | 
							Error := openai.Error{
 | 
				
			||||||
			Message: err.Error(),
 | 
								Message: err.Error(),
 | 
				
			||||||
			Type:    "upstream_error",
 | 
								Type:    "upstream_error",
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		c.JSON(200, gin.H{
 | 
							c.JSON(200, gin.H{
 | 
				
			||||||
			"error": openAIError,
 | 
								"error": Error,
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -69,12 +70,12 @@ func GetUsage(c *gin.Context) {
 | 
				
			|||||||
		quota, err = model.GetUserUsedQuota(userId)
 | 
							quota, err = model.GetUserUsedQuota(userId)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		openAIError := OpenAIError{
 | 
							Error := openai.Error{
 | 
				
			||||||
			Message: err.Error(),
 | 
								Message: err.Error(),
 | 
				
			||||||
			Type:    "one_api_error",
 | 
								Type:    "one_api_error",
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		c.JSON(200, gin.H{
 | 
							c.JSON(200, gin.H{
 | 
				
			||||||
			"error": openAIError,
 | 
								"error": Error,
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,6 +8,7 @@ import (
 | 
				
			|||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
	"one-api/model"
 | 
						"one-api/model"
 | 
				
			||||||
 | 
						"one-api/relay/util"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -92,7 +93,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
 | 
				
			|||||||
	for k := range headers {
 | 
						for k := range headers {
 | 
				
			||||||
		req.Header.Add(k, headers.Get(k))
 | 
							req.Header.Add(k, headers.Get(k))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	res, err := httpClient.Do(req)
 | 
						res, err := util.HTTPClient.Do(req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,6 +9,8 @@ import (
 | 
				
			|||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
	"one-api/model"
 | 
						"one-api/model"
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
 | 
						"one-api/relay/util"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
@@ -16,7 +18,7 @@ import (
 | 
				
			|||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
 | 
					func testChannel(channel *model.Channel, request openai.ChatRequest) (err error, openaiErr *openai.Error) {
 | 
				
			||||||
	switch channel.Type {
 | 
						switch channel.Type {
 | 
				
			||||||
	case common.ChannelTypePaLM:
 | 
						case common.ChannelTypePaLM:
 | 
				
			||||||
		fallthrough
 | 
							fallthrough
 | 
				
			||||||
@@ -46,13 +48,13 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	requestURL := common.ChannelBaseURLs[channel.Type]
 | 
						requestURL := common.ChannelBaseURLs[channel.Type]
 | 
				
			||||||
	if channel.Type == common.ChannelTypeAzure {
 | 
						if channel.Type == common.ChannelTypeAzure {
 | 
				
			||||||
		requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
 | 
							requestURL = util.GetFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
 | 
							if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
 | 
				
			||||||
			requestURL = baseURL
 | 
								requestURL = baseURL
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
 | 
							requestURL = util.GetFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	jsonData, err := json.Marshal(request)
 | 
						jsonData, err := json.Marshal(request)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@@ -68,12 +70,12 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
 | 
				
			|||||||
		req.Header.Set("Authorization", "Bearer "+channel.Key)
 | 
							req.Header.Set("Authorization", "Bearer "+channel.Key)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	req.Header.Set("Content-Type", "application/json")
 | 
						req.Header.Set("Content-Type", "application/json")
 | 
				
			||||||
	resp, err := httpClient.Do(req)
 | 
						resp, err := util.HTTPClient.Do(req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err, nil
 | 
							return err, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	defer resp.Body.Close()
 | 
						defer resp.Body.Close()
 | 
				
			||||||
	var response TextResponse
 | 
						var response openai.SlimTextResponse
 | 
				
			||||||
	body, err := io.ReadAll(resp.Body)
 | 
						body, err := io.ReadAll(resp.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err, nil
 | 
							return err, nil
 | 
				
			||||||
@@ -91,12 +93,12 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
 | 
				
			|||||||
	return nil, nil
 | 
						return nil, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func buildTestRequest() *ChatRequest {
 | 
					func buildTestRequest() *openai.ChatRequest {
 | 
				
			||||||
	testRequest := &ChatRequest{
 | 
						testRequest := &openai.ChatRequest{
 | 
				
			||||||
		Model:     "", // this will be set later
 | 
							Model:     "", // this will be set later
 | 
				
			||||||
		MaxTokens: 1,
 | 
							MaxTokens: 1,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	testMessage := Message{
 | 
						testMessage := openai.Message{
 | 
				
			||||||
		Role:    "user",
 | 
							Role:    "user",
 | 
				
			||||||
		Content: "hi",
 | 
							Content: "hi",
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -204,10 +206,10 @@ func testAllChannels(notify bool) error {
 | 
				
			|||||||
				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
 | 
									err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
 | 
				
			||||||
				disableChannel(channel.Id, channel.Name, err.Error())
 | 
									disableChannel(channel.Id, channel.Name, err.Error())
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if isChannelEnabled && shouldDisableChannel(openaiErr, -1) {
 | 
								if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) {
 | 
				
			||||||
				disableChannel(channel.Id, channel.Name, err.Error())
 | 
									disableChannel(channel.Id, channel.Name, err.Error())
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
 | 
								if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) {
 | 
				
			||||||
				enableChannel(channel.Id, channel.Name)
 | 
									enableChannel(channel.Id, channel.Name)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			channel.UpdateResponseTime(milliseconds)
 | 
								channel.UpdateResponseTime(milliseconds)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,8 +2,8 @@ package controller
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
					 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// https://platform.openai.com/docs/api-reference/models/list
 | 
					// https://platform.openai.com/docs/api-reference/models/list
 | 
				
			||||||
@@ -613,14 +613,14 @@ func RetrieveModel(c *gin.Context) {
 | 
				
			|||||||
	if model, ok := openAIModelsMap[modelId]; ok {
 | 
						if model, ok := openAIModelsMap[modelId]; ok {
 | 
				
			||||||
		c.JSON(200, model)
 | 
							c.JSON(200, model)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		openAIError := OpenAIError{
 | 
							Error := openai.Error{
 | 
				
			||||||
			Message: fmt.Sprintf("The model '%s' does not exist", modelId),
 | 
								Message: fmt.Sprintf("The model '%s' does not exist", modelId),
 | 
				
			||||||
			Type:    "invalid_request_error",
 | 
								Type:    "invalid_request_error",
 | 
				
			||||||
			Param:   "model",
 | 
								Param:   "model",
 | 
				
			||||||
			Code:    "model_not_found",
 | 
								Code:    "model_not_found",
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		c.JSON(200, gin.H{
 | 
							c.JSON(200, gin.H{
 | 
				
			||||||
			"error": openAIError,
 | 
								"error": Error,
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -42,6 +42,14 @@ func UpdateOption(c *gin.Context) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	switch option.Key {
 | 
						switch option.Key {
 | 
				
			||||||
 | 
						case "Theme":
 | 
				
			||||||
 | 
							if !common.ValidThemes[option.Value] {
 | 
				
			||||||
 | 
								c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
 | 
									"success": false,
 | 
				
			||||||
 | 
									"message": "无效的主题",
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	case "GitHubOAuthEnabled":
 | 
						case "GitHubOAuthEnabled":
 | 
				
			||||||
		if option.Value == "true" && common.GitHubClientId == "" {
 | 
							if option.Value == "true" && common.GitHubClientId == "" {
 | 
				
			||||||
			c.JSON(http.StatusOK, gin.H{
 | 
								c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,349 +4,53 @@ import (
 | 
				
			|||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
 | 
						"one-api/relay/constant"
 | 
				
			||||||
 | 
						"one-api/relay/controller"
 | 
				
			||||||
 | 
						"one-api/relay/util"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Message struct {
 | 
					 | 
				
			||||||
	Role    string  `json:"role"`
 | 
					 | 
				
			||||||
	Content any     `json:"content"`
 | 
					 | 
				
			||||||
	Name    *string `json:"name,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ImageURL struct {
 | 
					 | 
				
			||||||
	Url    string `json:"url,omitempty"`
 | 
					 | 
				
			||||||
	Detail string `json:"detail,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type TextContent struct {
 | 
					 | 
				
			||||||
	Type string `json:"type,omitempty"`
 | 
					 | 
				
			||||||
	Text string `json:"text,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ImageContent struct {
 | 
					 | 
				
			||||||
	Type     string    `json:"type,omitempty"`
 | 
					 | 
				
			||||||
	ImageURL *ImageURL `json:"image_url,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
const (
 | 
					 | 
				
			||||||
	ContentTypeText     = "text"
 | 
					 | 
				
			||||||
	ContentTypeImageURL = "image_url"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type OpenAIMessageContent struct {
 | 
					 | 
				
			||||||
	Type     string    `json:"type,omitempty"`
 | 
					 | 
				
			||||||
	Text     string    `json:"text"`
 | 
					 | 
				
			||||||
	ImageURL *ImageURL `json:"image_url,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (m Message) IsStringContent() bool {
 | 
					 | 
				
			||||||
	_, ok := m.Content.(string)
 | 
					 | 
				
			||||||
	return ok
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (m Message) StringContent() string {
 | 
					 | 
				
			||||||
	content, ok := m.Content.(string)
 | 
					 | 
				
			||||||
	if ok {
 | 
					 | 
				
			||||||
		return content
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	contentList, ok := m.Content.([]any)
 | 
					 | 
				
			||||||
	if ok {
 | 
					 | 
				
			||||||
		var contentStr string
 | 
					 | 
				
			||||||
		for _, contentItem := range contentList {
 | 
					 | 
				
			||||||
			contentMap, ok := contentItem.(map[string]any)
 | 
					 | 
				
			||||||
			if !ok {
 | 
					 | 
				
			||||||
				continue
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			if contentMap["type"] == ContentTypeText {
 | 
					 | 
				
			||||||
				if subStr, ok := contentMap["text"].(string); ok {
 | 
					 | 
				
			||||||
					contentStr += subStr
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return contentStr
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return ""
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (m Message) ParseContent() []OpenAIMessageContent {
 | 
					 | 
				
			||||||
	var contentList []OpenAIMessageContent
 | 
					 | 
				
			||||||
	content, ok := m.Content.(string)
 | 
					 | 
				
			||||||
	if ok {
 | 
					 | 
				
			||||||
		contentList = append(contentList, OpenAIMessageContent{
 | 
					 | 
				
			||||||
			Type: ContentTypeText,
 | 
					 | 
				
			||||||
			Text: content,
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
		return contentList
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	anyList, ok := m.Content.([]any)
 | 
					 | 
				
			||||||
	if ok {
 | 
					 | 
				
			||||||
		for _, contentItem := range anyList {
 | 
					 | 
				
			||||||
			contentMap, ok := contentItem.(map[string]any)
 | 
					 | 
				
			||||||
			if !ok {
 | 
					 | 
				
			||||||
				continue
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			switch contentMap["type"] {
 | 
					 | 
				
			||||||
			case ContentTypeText:
 | 
					 | 
				
			||||||
				if subStr, ok := contentMap["text"].(string); ok {
 | 
					 | 
				
			||||||
					contentList = append(contentList, OpenAIMessageContent{
 | 
					 | 
				
			||||||
						Type: ContentTypeText,
 | 
					 | 
				
			||||||
						Text: subStr,
 | 
					 | 
				
			||||||
					})
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			case ContentTypeImageURL:
 | 
					 | 
				
			||||||
				if subObj, ok := contentMap["image_url"].(map[string]any); ok {
 | 
					 | 
				
			||||||
					contentList = append(contentList, OpenAIMessageContent{
 | 
					 | 
				
			||||||
						Type: ContentTypeImageURL,
 | 
					 | 
				
			||||||
						ImageURL: &ImageURL{
 | 
					 | 
				
			||||||
							Url: subObj["url"].(string),
 | 
					 | 
				
			||||||
						},
 | 
					 | 
				
			||||||
					})
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return contentList
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
const (
 | 
					 | 
				
			||||||
	RelayModeUnknown = iota
 | 
					 | 
				
			||||||
	RelayModeChatCompletions
 | 
					 | 
				
			||||||
	RelayModeCompletions
 | 
					 | 
				
			||||||
	RelayModeEmbeddings
 | 
					 | 
				
			||||||
	RelayModeModerations
 | 
					 | 
				
			||||||
	RelayModeImagesGenerations
 | 
					 | 
				
			||||||
	RelayModeEdits
 | 
					 | 
				
			||||||
	RelayModeAudioSpeech
 | 
					 | 
				
			||||||
	RelayModeAudioTranscription
 | 
					 | 
				
			||||||
	RelayModeAudioTranslation
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// https://platform.openai.com/docs/api-reference/chat
 | 
					// https://platform.openai.com/docs/api-reference/chat
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ResponseFormat struct {
 | 
					 | 
				
			||||||
	Type string `json:"type,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type GeneralOpenAIRequest struct {
 | 
					 | 
				
			||||||
	Model            string          `json:"model,omitempty"`
 | 
					 | 
				
			||||||
	Messages         []Message       `json:"messages,omitempty"`
 | 
					 | 
				
			||||||
	Prompt           any             `json:"prompt,omitempty"`
 | 
					 | 
				
			||||||
	Stream           bool            `json:"stream,omitempty"`
 | 
					 | 
				
			||||||
	MaxTokens        int             `json:"max_tokens,omitempty"`
 | 
					 | 
				
			||||||
	Temperature      float64         `json:"temperature,omitempty"`
 | 
					 | 
				
			||||||
	TopP             float64         `json:"top_p,omitempty"`
 | 
					 | 
				
			||||||
	N                int             `json:"n,omitempty"`
 | 
					 | 
				
			||||||
	Input            any             `json:"input,omitempty"`
 | 
					 | 
				
			||||||
	Instruction      string          `json:"instruction,omitempty"`
 | 
					 | 
				
			||||||
	Size             string          `json:"size,omitempty"`
 | 
					 | 
				
			||||||
	Functions        any             `json:"functions,omitempty"`
 | 
					 | 
				
			||||||
	FrequencyPenalty float64         `json:"frequency_penalty,omitempty"`
 | 
					 | 
				
			||||||
	PresencePenalty  float64         `json:"presence_penalty,omitempty"`
 | 
					 | 
				
			||||||
	ResponseFormat   *ResponseFormat `json:"response_format,omitempty"`
 | 
					 | 
				
			||||||
	Seed             float64         `json:"seed,omitempty"`
 | 
					 | 
				
			||||||
	Tools            any             `json:"tools,omitempty"`
 | 
					 | 
				
			||||||
	ToolChoice       any             `json:"tool_choice,omitempty"`
 | 
					 | 
				
			||||||
	User             string          `json:"user,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (r GeneralOpenAIRequest) ParseInput() []string {
 | 
					 | 
				
			||||||
	if r.Input == nil {
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	var input []string
 | 
					 | 
				
			||||||
	switch r.Input.(type) {
 | 
					 | 
				
			||||||
	case string:
 | 
					 | 
				
			||||||
		input = []string{r.Input.(string)}
 | 
					 | 
				
			||||||
	case []any:
 | 
					 | 
				
			||||||
		input = make([]string, 0, len(r.Input.([]any)))
 | 
					 | 
				
			||||||
		for _, item := range r.Input.([]any) {
 | 
					 | 
				
			||||||
			if str, ok := item.(string); ok {
 | 
					 | 
				
			||||||
				input = append(input, str)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return input
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ChatRequest struct {
 | 
					 | 
				
			||||||
	Model     string    `json:"model"`
 | 
					 | 
				
			||||||
	Messages  []Message `json:"messages"`
 | 
					 | 
				
			||||||
	MaxTokens int       `json:"max_tokens"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type TextRequest struct {
 | 
					 | 
				
			||||||
	Model     string    `json:"model"`
 | 
					 | 
				
			||||||
	Messages  []Message `json:"messages"`
 | 
					 | 
				
			||||||
	Prompt    string    `json:"prompt"`
 | 
					 | 
				
			||||||
	MaxTokens int       `json:"max_tokens"`
 | 
					 | 
				
			||||||
	//Stream   bool      `json:"stream"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
 | 
					 | 
				
			||||||
type ImageRequest struct {
 | 
					 | 
				
			||||||
	Model          string `json:"model"`
 | 
					 | 
				
			||||||
	Prompt         string `json:"prompt" binding:"required"`
 | 
					 | 
				
			||||||
	N              int    `json:"n,omitempty"`
 | 
					 | 
				
			||||||
	Size           string `json:"size,omitempty"`
 | 
					 | 
				
			||||||
	Quality        string `json:"quality,omitempty"`
 | 
					 | 
				
			||||||
	ResponseFormat string `json:"response_format,omitempty"`
 | 
					 | 
				
			||||||
	Style          string `json:"style,omitempty"`
 | 
					 | 
				
			||||||
	User           string `json:"user,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type WhisperJSONResponse struct {
 | 
					 | 
				
			||||||
	Text string `json:"text,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type WhisperVerboseJSONResponse struct {
 | 
					 | 
				
			||||||
	Task     string    `json:"task,omitempty"`
 | 
					 | 
				
			||||||
	Language string    `json:"language,omitempty"`
 | 
					 | 
				
			||||||
	Duration float64   `json:"duration,omitempty"`
 | 
					 | 
				
			||||||
	Text     string    `json:"text,omitempty"`
 | 
					 | 
				
			||||||
	Segments []Segment `json:"segments,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type Segment struct {
 | 
					 | 
				
			||||||
	Id               int     `json:"id"`
 | 
					 | 
				
			||||||
	Seek             int     `json:"seek"`
 | 
					 | 
				
			||||||
	Start            float64 `json:"start"`
 | 
					 | 
				
			||||||
	End              float64 `json:"end"`
 | 
					 | 
				
			||||||
	Text             string  `json:"text"`
 | 
					 | 
				
			||||||
	Tokens           []int   `json:"tokens"`
 | 
					 | 
				
			||||||
	Temperature      float64 `json:"temperature"`
 | 
					 | 
				
			||||||
	AvgLogprob       float64 `json:"avg_logprob"`
 | 
					 | 
				
			||||||
	CompressionRatio float64 `json:"compression_ratio"`
 | 
					 | 
				
			||||||
	NoSpeechProb     float64 `json:"no_speech_prob"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type TextToSpeechRequest struct {
 | 
					 | 
				
			||||||
	Model          string  `json:"model" binding:"required"`
 | 
					 | 
				
			||||||
	Input          string  `json:"input" binding:"required"`
 | 
					 | 
				
			||||||
	Voice          string  `json:"voice" binding:"required"`
 | 
					 | 
				
			||||||
	Speed          float64 `json:"speed"`
 | 
					 | 
				
			||||||
	ResponseFormat string  `json:"response_format"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type Usage struct {
 | 
					 | 
				
			||||||
	PromptTokens     int `json:"prompt_tokens"`
 | 
					 | 
				
			||||||
	CompletionTokens int `json:"completion_tokens"`
 | 
					 | 
				
			||||||
	TotalTokens      int `json:"total_tokens"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type OpenAIError struct {
 | 
					 | 
				
			||||||
	Message string `json:"message"`
 | 
					 | 
				
			||||||
	Type    string `json:"type"`
 | 
					 | 
				
			||||||
	Param   string `json:"param"`
 | 
					 | 
				
			||||||
	Code    any    `json:"code"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type OpenAIErrorWithStatusCode struct {
 | 
					 | 
				
			||||||
	OpenAIError
 | 
					 | 
				
			||||||
	StatusCode int `json:"status_code"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type TextResponse struct {
 | 
					 | 
				
			||||||
	Choices []OpenAITextResponseChoice `json:"choices"`
 | 
					 | 
				
			||||||
	Usage   `json:"usage"`
 | 
					 | 
				
			||||||
	Error   OpenAIError `json:"error"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type OpenAITextResponseChoice struct {
 | 
					 | 
				
			||||||
	Index        int `json:"index"`
 | 
					 | 
				
			||||||
	Message      `json:"message"`
 | 
					 | 
				
			||||||
	FinishReason string `json:"finish_reason"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type OpenAITextResponse struct {
 | 
					 | 
				
			||||||
	Id      string                     `json:"id"`
 | 
					 | 
				
			||||||
	Model   string                     `json:"model,omitempty"`
 | 
					 | 
				
			||||||
	Object  string                     `json:"object"`
 | 
					 | 
				
			||||||
	Created int64                      `json:"created"`
 | 
					 | 
				
			||||||
	Choices []OpenAITextResponseChoice `json:"choices"`
 | 
					 | 
				
			||||||
	Usage   `json:"usage"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type OpenAIEmbeddingResponseItem struct {
 | 
					 | 
				
			||||||
	Object    string    `json:"object"`
 | 
					 | 
				
			||||||
	Index     int       `json:"index"`
 | 
					 | 
				
			||||||
	Embedding []float64 `json:"embedding"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type OpenAIEmbeddingResponse struct {
 | 
					 | 
				
			||||||
	Object string                        `json:"object"`
 | 
					 | 
				
			||||||
	Data   []OpenAIEmbeddingResponseItem `json:"data"`
 | 
					 | 
				
			||||||
	Model  string                        `json:"model"`
 | 
					 | 
				
			||||||
	Usage  `json:"usage"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ImageResponse struct {
 | 
					 | 
				
			||||||
	Created int `json:"created"`
 | 
					 | 
				
			||||||
	Data    []struct {
 | 
					 | 
				
			||||||
		Url string `json:"url"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ChatCompletionsStreamResponseChoice struct {
 | 
					 | 
				
			||||||
	Delta struct {
 | 
					 | 
				
			||||||
		Content string `json:"content"`
 | 
					 | 
				
			||||||
	} `json:"delta"`
 | 
					 | 
				
			||||||
	FinishReason *string `json:"finish_reason,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ChatCompletionsStreamResponse struct {
 | 
					 | 
				
			||||||
	Id      string                                `json:"id"`
 | 
					 | 
				
			||||||
	Object  string                                `json:"object"`
 | 
					 | 
				
			||||||
	Created int64                                 `json:"created"`
 | 
					 | 
				
			||||||
	Model   string                                `json:"model"`
 | 
					 | 
				
			||||||
	Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type CompletionsStreamResponse struct {
 | 
					 | 
				
			||||||
	Choices []struct {
 | 
					 | 
				
			||||||
		Text         string `json:"text"`
 | 
					 | 
				
			||||||
		FinishReason string `json:"finish_reason"`
 | 
					 | 
				
			||||||
	} `json:"choices"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func Relay(c *gin.Context) {
 | 
					func Relay(c *gin.Context) {
 | 
				
			||||||
	relayMode := RelayModeUnknown
 | 
						relayMode := constant.RelayModeUnknown
 | 
				
			||||||
	if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
 | 
						if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
 | 
				
			||||||
		relayMode = RelayModeChatCompletions
 | 
							relayMode = constant.RelayModeChatCompletions
 | 
				
			||||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
 | 
						} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
 | 
				
			||||||
		relayMode = RelayModeCompletions
 | 
							relayMode = constant.RelayModeCompletions
 | 
				
			||||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
 | 
						} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
 | 
				
			||||||
		relayMode = RelayModeEmbeddings
 | 
							relayMode = constant.RelayModeEmbeddings
 | 
				
			||||||
	} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
 | 
						} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
 | 
				
			||||||
		relayMode = RelayModeEmbeddings
 | 
							relayMode = constant.RelayModeEmbeddings
 | 
				
			||||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 | 
						} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 | 
				
			||||||
		relayMode = RelayModeModerations
 | 
							relayMode = constant.RelayModeModerations
 | 
				
			||||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
 | 
						} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
 | 
				
			||||||
		relayMode = RelayModeImagesGenerations
 | 
							relayMode = constant.RelayModeImagesGenerations
 | 
				
			||||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
 | 
						} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
 | 
				
			||||||
		relayMode = RelayModeEdits
 | 
							relayMode = constant.RelayModeEdits
 | 
				
			||||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
 | 
						} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
 | 
				
			||||||
		relayMode = RelayModeAudioSpeech
 | 
							relayMode = constant.RelayModeAudioSpeech
 | 
				
			||||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
 | 
						} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
 | 
				
			||||||
		relayMode = RelayModeAudioTranscription
 | 
							relayMode = constant.RelayModeAudioTranscription
 | 
				
			||||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
 | 
						} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
 | 
				
			||||||
		relayMode = RelayModeAudioTranslation
 | 
							relayMode = constant.RelayModeAudioTranslation
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var err *OpenAIErrorWithStatusCode
 | 
						var err *openai.ErrorWithStatusCode
 | 
				
			||||||
	switch relayMode {
 | 
						switch relayMode {
 | 
				
			||||||
	case RelayModeImagesGenerations:
 | 
						case constant.RelayModeImagesGenerations:
 | 
				
			||||||
		err = relayImageHelper(c, relayMode)
 | 
							err = controller.RelayImageHelper(c, relayMode)
 | 
				
			||||||
	case RelayModeAudioSpeech:
 | 
						case constant.RelayModeAudioSpeech:
 | 
				
			||||||
		fallthrough
 | 
							fallthrough
 | 
				
			||||||
	case RelayModeAudioTranslation:
 | 
						case constant.RelayModeAudioTranslation:
 | 
				
			||||||
		fallthrough
 | 
							fallthrough
 | 
				
			||||||
	case RelayModeAudioTranscription:
 | 
						case constant.RelayModeAudioTranscription:
 | 
				
			||||||
		err = relayAudioHelper(c, relayMode)
 | 
							err = controller.RelayAudioHelper(c, relayMode)
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		err = relayTextHelper(c, relayMode)
 | 
							err = controller.RelayTextHelper(c, relayMode)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		requestId := c.GetString(common.RequestIdKey)
 | 
							requestId := c.GetString(common.RequestIdKey)
 | 
				
			||||||
@@ -359,17 +63,17 @@ func Relay(c *gin.Context) {
 | 
				
			|||||||
			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
 | 
								c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			if err.StatusCode == http.StatusTooManyRequests {
 | 
								if err.StatusCode == http.StatusTooManyRequests {
 | 
				
			||||||
				err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
 | 
									err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
 | 
								err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId)
 | 
				
			||||||
			c.JSON(err.StatusCode, gin.H{
 | 
								c.JSON(err.StatusCode, gin.H{
 | 
				
			||||||
				"error": err.OpenAIError,
 | 
									"error": err.Error,
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		channelId := c.GetInt("channel_id")
 | 
							channelId := c.GetInt("channel_id")
 | 
				
			||||||
		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
 | 
							common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
 | 
				
			||||||
		// https://platform.openai.com/docs/guides/error-codes/api-errors
 | 
							// https://platform.openai.com/docs/guides/error-codes/api-errors
 | 
				
			||||||
		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
 | 
							if util.ShouldDisableChannel(&err.Error, err.StatusCode) {
 | 
				
			||||||
			channelId := c.GetInt("channel_id")
 | 
								channelId := c.GetInt("channel_id")
 | 
				
			||||||
			channelName := c.GetString("channel_name")
 | 
								channelName := c.GetString("channel_name")
 | 
				
			||||||
			disableChannel(channelId, channelName, err.Message)
 | 
								disableChannel(channelId, channelName, err.Message)
 | 
				
			||||||
@@ -378,7 +82,7 @@ func Relay(c *gin.Context) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func RelayNotImplemented(c *gin.Context) {
 | 
					func RelayNotImplemented(c *gin.Context) {
 | 
				
			||||||
	err := OpenAIError{
 | 
						err := openai.Error{
 | 
				
			||||||
		Message: "API not implemented",
 | 
							Message: "API not implemented",
 | 
				
			||||||
		Type:    "one_api_error",
 | 
							Type:    "one_api_error",
 | 
				
			||||||
		Param:   "",
 | 
							Param:   "",
 | 
				
			||||||
@@ -390,7 +94,7 @@ func RelayNotImplemented(c *gin.Context) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func RelayNotFound(c *gin.Context) {
 | 
					func RelayNotFound(c *gin.Context) {
 | 
				
			||||||
	err := OpenAIError{
 | 
						err := openai.Error{
 | 
				
			||||||
		Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
 | 
							Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
 | 
				
			||||||
		Type:    "invalid_request_error",
 | 
							Type:    "invalid_request_error",
 | 
				
			||||||
		Param:   "",
 | 
							Param:   "",
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										6
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								main.go
									
									
									
									
									
								
							@@ -10,6 +10,7 @@ import (
 | 
				
			|||||||
	"one-api/controller"
 | 
						"one-api/controller"
 | 
				
			||||||
	"one-api/middleware"
 | 
						"one-api/middleware"
 | 
				
			||||||
	"one-api/model"
 | 
						"one-api/model"
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
	"one-api/router"
 | 
						"one-api/router"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
@@ -20,7 +21,7 @@ var buildFS embed.FS
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func main() {
 | 
					func main() {
 | 
				
			||||||
	common.SetupLogger()
 | 
						common.SetupLogger()
 | 
				
			||||||
	common.SysLog(fmt.Sprintf("One API %s started with theme %s", common.Version, common.Theme))
 | 
						common.SysLog(fmt.Sprintf("One API %s started", common.Version))
 | 
				
			||||||
	if os.Getenv("GIN_MODE") != "debug" {
 | 
						if os.Getenv("GIN_MODE") != "debug" {
 | 
				
			||||||
		gin.SetMode(gin.ReleaseMode)
 | 
							gin.SetMode(gin.ReleaseMode)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -47,6 +48,7 @@ func main() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// Initialize options
 | 
						// Initialize options
 | 
				
			||||||
	model.InitOptionMap()
 | 
						model.InitOptionMap()
 | 
				
			||||||
 | 
						common.SysLog(fmt.Sprintf("using theme %s", common.Theme))
 | 
				
			||||||
	if common.RedisEnabled {
 | 
						if common.RedisEnabled {
 | 
				
			||||||
		// for compatibility with old versions
 | 
							// for compatibility with old versions
 | 
				
			||||||
		common.MemoryCacheEnabled = true
 | 
							common.MemoryCacheEnabled = true
 | 
				
			||||||
@@ -79,7 +81,7 @@ func main() {
 | 
				
			|||||||
		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
 | 
							common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
 | 
				
			||||||
		model.InitBatchUpdater()
 | 
							model.InitBatchUpdater()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	controller.InitTokenEncoders()
 | 
						openai.InitTokenEncoders()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Initialize HTTP server
 | 
						// Initialize HTTP server
 | 
				
			||||||
	server := gin.New()
 | 
						server := gin.New()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,9 +9,9 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Log struct {
 | 
					type Log struct {
 | 
				
			||||||
	Id               int    `json:"id;index:idx_created_at_id,priority:1"`
 | 
						Id               int    `json:"id"`
 | 
				
			||||||
	UserId           int    `json:"user_id" gorm:"index"`
 | 
						UserId           int    `json:"user_id" gorm:"index"`
 | 
				
			||||||
	CreatedAt        int64  `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
 | 
						CreatedAt        int64  `json:"created_at" gorm:"bigint;index:idx_created_at_type"`
 | 
				
			||||||
	Type             int    `json:"type" gorm:"index:idx_created_at_type"`
 | 
						Type             int    `json:"type" gorm:"index:idx_created_at_type"`
 | 
				
			||||||
	Content          string `json:"content"`
 | 
						Content          string `json:"content"`
 | 
				
			||||||
	Username         string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
 | 
						Username         string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
 | 
				
			||||||
@@ -218,7 +218,5 @@ func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatis
 | 
				
			|||||||
		ORDER BY day, model_name
 | 
							ORDER BY day, model_name
 | 
				
			||||||
	`, userId, start, end).Scan(&LogStatistics).Error
 | 
						`, userId, start, end).Scan(&LogStatistics).Error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	fmt.Println(userId, start, end)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return LogStatistics, err
 | 
						return LogStatistics, err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -16,7 +16,7 @@ var DB *gorm.DB
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func createRootAccountIfNeed() error {
 | 
					func createRootAccountIfNeed() error {
 | 
				
			||||||
	var user User
 | 
						var user User
 | 
				
			||||||
	//if user.Status != common.UserStatusEnabled {
 | 
						//if user.Status != util.UserStatusEnabled {
 | 
				
			||||||
	if err := DB.First(&user).Error; err != nil {
 | 
						if err := DB.First(&user).Error; err != nil {
 | 
				
			||||||
		common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
 | 
							common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
 | 
				
			||||||
		hashedPassword, err := common.Password2Hash("123456")
 | 
							hashedPassword, err := common.Password2Hash("123456")
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -72,6 +72,7 @@ func InitOptionMap() {
 | 
				
			|||||||
	common.OptionMap["ChatLink"] = common.ChatLink
 | 
						common.OptionMap["ChatLink"] = common.ChatLink
 | 
				
			||||||
	common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
 | 
						common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
 | 
				
			||||||
	common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
 | 
						common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
 | 
				
			||||||
 | 
						common.OptionMap["Theme"] = common.Theme
 | 
				
			||||||
	common.OptionMapRWMutex.Unlock()
 | 
						common.OptionMapRWMutex.Unlock()
 | 
				
			||||||
	loadOptionsFromDatabase()
 | 
						loadOptionsFromDatabase()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -220,6 +221,8 @@ func updateOptionMap(key string, value string) (err error) {
 | 
				
			|||||||
		common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
 | 
							common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
 | 
				
			||||||
	case "QuotaPerUnit":
 | 
						case "QuotaPerUnit":
 | 
				
			||||||
		common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
 | 
							common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
 | 
				
			||||||
 | 
						case "Theme":
 | 
				
			||||||
 | 
							common.Theme = value
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return err
 | 
						return err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -15,7 +15,7 @@ type User struct {
 | 
				
			|||||||
	Username         string `json:"username" gorm:"unique;index" validate:"max=12"`
 | 
						Username         string `json:"username" gorm:"unique;index" validate:"max=12"`
 | 
				
			||||||
	Password         string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
 | 
						Password         string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
 | 
				
			||||||
	DisplayName      string `json:"display_name" gorm:"index" validate:"max=20"`
 | 
						DisplayName      string `json:"display_name" gorm:"index" validate:"max=20"`
 | 
				
			||||||
	Role             int    `json:"role" gorm:"type:int;default:1"`   // admin, common
 | 
						Role             int    `json:"role" gorm:"type:int;default:1"`   // admin, util
 | 
				
			||||||
	Status           int    `json:"status" gorm:"type:int;default:1"` // enabled, disabled
 | 
						Status           int    `json:"status" gorm:"type:int;default:1"` // enabled, disabled
 | 
				
			||||||
	Email            string `json:"email" gorm:"index" validate:"max=50"`
 | 
						Email            string `json:"email" gorm:"index" validate:"max=50"`
 | 
				
			||||||
	GitHubId         string `json:"github_id" gorm:"column:github_id;index"`
 | 
						GitHubId         string `json:"github_id" gorm:"column:github_id;index"`
 | 
				
			||||||
@@ -141,7 +141,15 @@ func (user *User) ValidateAndFill() (err error) {
 | 
				
			|||||||
	if user.Username == "" || password == "" {
 | 
						if user.Username == "" || password == "" {
 | 
				
			||||||
		return errors.New("用户名或密码为空")
 | 
							return errors.New("用户名或密码为空")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	DB.Where(User{Username: user.Username}).First(user)
 | 
						err = DB.Where("username = ?", user.Username).First(user).Error
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							// we must make sure check username firstly
 | 
				
			||||||
 | 
							// consider this case: a malicious user set his username as other's email
 | 
				
			||||||
 | 
							err := DB.Where("email = ?", user.Username).First(user).Error
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return errors.New("用户名或密码错误,或用户已被封禁")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	okay := common.ValidatePasswordAndHash(password, user.Password)
 | 
						okay := common.ValidatePasswordAndHash(password, user.Password)
 | 
				
			||||||
	if !okay || user.Status != common.UserStatusEnabled {
 | 
						if !okay || user.Status != common.UserStatusEnabled {
 | 
				
			||||||
		return errors.New("用户名或密码错误,或用户已被封禁")
 | 
							return errors.New("用户名或密码错误,或用户已被封禁")
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
package controller
 | 
					package aiproxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bufio"
 | 
						"bufio"
 | 
				
			||||||
@@ -8,56 +8,27 @@ import (
 | 
				
			|||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
 | 
						"one-api/relay/constant"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
 | 
					// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type AIProxyLibraryRequest struct {
 | 
					func ConvertRequest(request openai.GeneralOpenAIRequest) *LibraryRequest {
 | 
				
			||||||
	Model     string `json:"model"`
 | 
					 | 
				
			||||||
	Query     string `json:"query"`
 | 
					 | 
				
			||||||
	LibraryId string `json:"libraryId"`
 | 
					 | 
				
			||||||
	Stream    bool   `json:"stream"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type AIProxyLibraryError struct {
 | 
					 | 
				
			||||||
	ErrCode int    `json:"errCode"`
 | 
					 | 
				
			||||||
	Message string `json:"message"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type AIProxyLibraryDocument struct {
 | 
					 | 
				
			||||||
	Title string `json:"title"`
 | 
					 | 
				
			||||||
	URL   string `json:"url"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type AIProxyLibraryResponse struct {
 | 
					 | 
				
			||||||
	Success   bool                     `json:"success"`
 | 
					 | 
				
			||||||
	Answer    string                   `json:"answer"`
 | 
					 | 
				
			||||||
	Documents []AIProxyLibraryDocument `json:"documents"`
 | 
					 | 
				
			||||||
	AIProxyLibraryError
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type AIProxyLibraryStreamResponse struct {
 | 
					 | 
				
			||||||
	Content   string                   `json:"content"`
 | 
					 | 
				
			||||||
	Finish    bool                     `json:"finish"`
 | 
					 | 
				
			||||||
	Model     string                   `json:"model"`
 | 
					 | 
				
			||||||
	Documents []AIProxyLibraryDocument `json:"documents"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
 | 
					 | 
				
			||||||
	query := ""
 | 
						query := ""
 | 
				
			||||||
	if len(request.Messages) != 0 {
 | 
						if len(request.Messages) != 0 {
 | 
				
			||||||
		query = request.Messages[len(request.Messages)-1].StringContent()
 | 
							query = request.Messages[len(request.Messages)-1].StringContent()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &AIProxyLibraryRequest{
 | 
						return &LibraryRequest{
 | 
				
			||||||
		Model:  request.Model,
 | 
							Model:  request.Model,
 | 
				
			||||||
		Stream: request.Stream,
 | 
							Stream: request.Stream,
 | 
				
			||||||
		Query:  query,
 | 
							Query:  query,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
 | 
					func aiProxyDocuments2Markdown(documents []LibraryDocument) string {
 | 
				
			||||||
	if len(documents) == 0 {
 | 
						if len(documents) == 0 {
 | 
				
			||||||
		return ""
 | 
							return ""
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -68,52 +39,52 @@ func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
 | 
				
			|||||||
	return content
 | 
						return content
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
 | 
					func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextResponse {
 | 
				
			||||||
	content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
 | 
						content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
 | 
				
			||||||
	choice := OpenAITextResponseChoice{
 | 
						choice := openai.TextResponseChoice{
 | 
				
			||||||
		Index: 0,
 | 
							Index: 0,
 | 
				
			||||||
		Message: Message{
 | 
							Message: openai.Message{
 | 
				
			||||||
			Role:    "assistant",
 | 
								Role:    "assistant",
 | 
				
			||||||
			Content: content,
 | 
								Content: content,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		FinishReason: "stop",
 | 
							FinishReason: "stop",
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	fullTextResponse := OpenAITextResponse{
 | 
						fullTextResponse := openai.TextResponse{
 | 
				
			||||||
		Id:      common.GetUUID(),
 | 
							Id:      common.GetUUID(),
 | 
				
			||||||
		Object:  "chat.completion",
 | 
							Object:  "chat.completion",
 | 
				
			||||||
		Created: common.GetTimestamp(),
 | 
							Created: common.GetTimestamp(),
 | 
				
			||||||
		Choices: []OpenAITextResponseChoice{choice},
 | 
							Choices: []openai.TextResponseChoice{choice},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &fullTextResponse
 | 
						return &fullTextResponse
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
 | 
					func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletionsStreamResponse {
 | 
				
			||||||
	var choice ChatCompletionsStreamResponseChoice
 | 
						var choice openai.ChatCompletionsStreamResponseChoice
 | 
				
			||||||
	choice.Delta.Content = aiProxyDocuments2Markdown(documents)
 | 
						choice.Delta.Content = aiProxyDocuments2Markdown(documents)
 | 
				
			||||||
	choice.FinishReason = &stopFinishReason
 | 
						choice.FinishReason = &constant.StopFinishReason
 | 
				
			||||||
	return &ChatCompletionsStreamResponse{
 | 
						return &openai.ChatCompletionsStreamResponse{
 | 
				
			||||||
		Id:      common.GetUUID(),
 | 
							Id:      common.GetUUID(),
 | 
				
			||||||
		Object:  "chat.completion.chunk",
 | 
							Object:  "chat.completion.chunk",
 | 
				
			||||||
		Created: common.GetTimestamp(),
 | 
							Created: common.GetTimestamp(),
 | 
				
			||||||
		Model:   "",
 | 
							Model:   "",
 | 
				
			||||||
		Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
							Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
 | 
					func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *openai.ChatCompletionsStreamResponse {
 | 
				
			||||||
	var choice ChatCompletionsStreamResponseChoice
 | 
						var choice openai.ChatCompletionsStreamResponseChoice
 | 
				
			||||||
	choice.Delta.Content = response.Content
 | 
						choice.Delta.Content = response.Content
 | 
				
			||||||
	return &ChatCompletionsStreamResponse{
 | 
						return &openai.ChatCompletionsStreamResponse{
 | 
				
			||||||
		Id:      common.GetUUID(),
 | 
							Id:      common.GetUUID(),
 | 
				
			||||||
		Object:  "chat.completion.chunk",
 | 
							Object:  "chat.completion.chunk",
 | 
				
			||||||
		Created: common.GetTimestamp(),
 | 
							Created: common.GetTimestamp(),
 | 
				
			||||||
		Model:   response.Model,
 | 
							Model:   response.Model,
 | 
				
			||||||
		Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
							Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
 | 
				
			||||||
	var usage Usage
 | 
						var usage openai.Usage
 | 
				
			||||||
	scanner := bufio.NewScanner(resp.Body)
 | 
						scanner := bufio.NewScanner(resp.Body)
 | 
				
			||||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
						scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
				
			||||||
		if atEOF && len(data) == 0 {
 | 
							if atEOF && len(data) == 0 {
 | 
				
			||||||
@@ -143,12 +114,12 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		stopChan <- true
 | 
							stopChan <- true
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
	setEventStreamHeaders(c)
 | 
						common.SetEventStreamHeaders(c)
 | 
				
			||||||
	var documents []AIProxyLibraryDocument
 | 
						var documents []LibraryDocument
 | 
				
			||||||
	c.Stream(func(w io.Writer) bool {
 | 
						c.Stream(func(w io.Writer) bool {
 | 
				
			||||||
		select {
 | 
							select {
 | 
				
			||||||
		case data := <-dataChan:
 | 
							case data := <-dataChan:
 | 
				
			||||||
			var AIProxyLibraryResponse AIProxyLibraryStreamResponse
 | 
								var AIProxyLibraryResponse LibraryStreamResponse
 | 
				
			||||||
			err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
 | 
								err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				common.SysError("error unmarshalling stream response: " + err.Error())
 | 
									common.SysError("error unmarshalling stream response: " + err.Error())
 | 
				
			||||||
@@ -179,28 +150,28 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
	err := resp.Body.Close()
 | 
						err := resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil, &usage
 | 
						return nil, &usage
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
 | 
				
			||||||
	var AIProxyLibraryResponse AIProxyLibraryResponse
 | 
						var AIProxyLibraryResponse LibraryResponse
 | 
				
			||||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
						responseBody, err := io.ReadAll(resp.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = resp.Body.Close()
 | 
						err = resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
 | 
						err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if AIProxyLibraryResponse.ErrCode != 0 {
 | 
						if AIProxyLibraryResponse.ErrCode != 0 {
 | 
				
			||||||
		return &OpenAIErrorWithStatusCode{
 | 
							return &openai.ErrorWithStatusCode{
 | 
				
			||||||
			OpenAIError: OpenAIError{
 | 
								Error: openai.Error{
 | 
				
			||||||
				Message: AIProxyLibraryResponse.Message,
 | 
									Message: AIProxyLibraryResponse.Message,
 | 
				
			||||||
				Type:    strconv.Itoa(AIProxyLibraryResponse.ErrCode),
 | 
									Type:    strconv.Itoa(AIProxyLibraryResponse.ErrCode),
 | 
				
			||||||
				Code:    AIProxyLibraryResponse.ErrCode,
 | 
									Code:    AIProxyLibraryResponse.ErrCode,
 | 
				
			||||||
@@ -211,7 +182,7 @@ func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
 | 
				
			|||||||
	fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
 | 
						fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
 | 
				
			||||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
						jsonResponse, err := json.Marshal(fullTextResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
						c.Writer.Header().Set("Content-Type", "application/json")
 | 
				
			||||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
						c.Writer.WriteHeader(resp.StatusCode)
 | 
				
			||||||
							
								
								
									
										32
									
								
								relay/channel/aiproxy/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								relay/channel/aiproxy/model.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,32 @@
 | 
				
			|||||||
 | 
					package aiproxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type LibraryRequest struct {
 | 
				
			||||||
 | 
						Model     string `json:"model"`
 | 
				
			||||||
 | 
						Query     string `json:"query"`
 | 
				
			||||||
 | 
						LibraryId string `json:"libraryId"`
 | 
				
			||||||
 | 
						Stream    bool   `json:"stream"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type LibraryError struct {
 | 
				
			||||||
 | 
						ErrCode int    `json:"errCode"`
 | 
				
			||||||
 | 
						Message string `json:"message"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type LibraryDocument struct {
 | 
				
			||||||
 | 
						Title string `json:"title"`
 | 
				
			||||||
 | 
						URL   string `json:"url"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type LibraryResponse struct {
 | 
				
			||||||
 | 
						Success   bool              `json:"success"`
 | 
				
			||||||
 | 
						Answer    string            `json:"answer"`
 | 
				
			||||||
 | 
						Documents []LibraryDocument `json:"documents"`
 | 
				
			||||||
 | 
						LibraryError
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type LibraryStreamResponse struct {
 | 
				
			||||||
 | 
						Content   string            `json:"content"`
 | 
				
			||||||
 | 
						Finish    bool              `json:"finish"`
 | 
				
			||||||
 | 
						Model     string            `json:"model"`
 | 
				
			||||||
 | 
						Documents []LibraryDocument `json:"documents"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
package controller
 | 
					package ali
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bufio"
 | 
						"bufio"
 | 
				
			||||||
@@ -7,112 +7,43 @@ import (
 | 
				
			|||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
 | 
					// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type AliMessage struct {
 | 
					const EnableSearchModelSuffix = "-internet"
 | 
				
			||||||
	Content string `json:"content"`
 | 
					 | 
				
			||||||
	Role    string `json:"role"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
type AliInput struct {
 | 
					func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
 | 
				
			||||||
	//Prompt   string       `json:"prompt"`
 | 
						messages := make([]Message, 0, len(request.Messages))
 | 
				
			||||||
	Messages []AliMessage `json:"messages"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type AliParameters struct {
 | 
					 | 
				
			||||||
	TopP              float64 `json:"top_p,omitempty"`
 | 
					 | 
				
			||||||
	TopK              int     `json:"top_k,omitempty"`
 | 
					 | 
				
			||||||
	Seed              uint64  `json:"seed,omitempty"`
 | 
					 | 
				
			||||||
	EnableSearch      bool    `json:"enable_search,omitempty"`
 | 
					 | 
				
			||||||
	IncrementalOutput bool    `json:"incremental_output,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type AliChatRequest struct {
 | 
					 | 
				
			||||||
	Model      string        `json:"model"`
 | 
					 | 
				
			||||||
	Input      AliInput      `json:"input"`
 | 
					 | 
				
			||||||
	Parameters AliParameters `json:"parameters,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type AliEmbeddingRequest struct {
 | 
					 | 
				
			||||||
	Model string `json:"model"`
 | 
					 | 
				
			||||||
	Input struct {
 | 
					 | 
				
			||||||
		Texts []string `json:"texts"`
 | 
					 | 
				
			||||||
	} `json:"input"`
 | 
					 | 
				
			||||||
	Parameters *struct {
 | 
					 | 
				
			||||||
		TextType string `json:"text_type,omitempty"`
 | 
					 | 
				
			||||||
	} `json:"parameters,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type AliEmbedding struct {
 | 
					 | 
				
			||||||
	Embedding []float64 `json:"embedding"`
 | 
					 | 
				
			||||||
	TextIndex int       `json:"text_index"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type AliEmbeddingResponse struct {
 | 
					 | 
				
			||||||
	Output struct {
 | 
					 | 
				
			||||||
		Embeddings []AliEmbedding `json:"embeddings"`
 | 
					 | 
				
			||||||
	} `json:"output"`
 | 
					 | 
				
			||||||
	Usage AliUsage `json:"usage"`
 | 
					 | 
				
			||||||
	AliError
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type AliError struct {
 | 
					 | 
				
			||||||
	Code      string `json:"code"`
 | 
					 | 
				
			||||||
	Message   string `json:"message"`
 | 
					 | 
				
			||||||
	RequestId string `json:"request_id"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type AliUsage struct {
 | 
					 | 
				
			||||||
	InputTokens  int `json:"input_tokens"`
 | 
					 | 
				
			||||||
	OutputTokens int `json:"output_tokens"`
 | 
					 | 
				
			||||||
	TotalTokens  int `json:"total_tokens"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type AliOutput struct {
 | 
					 | 
				
			||||||
	Text         string `json:"text"`
 | 
					 | 
				
			||||||
	FinishReason string `json:"finish_reason"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type AliChatResponse struct {
 | 
					 | 
				
			||||||
	Output AliOutput `json:"output"`
 | 
					 | 
				
			||||||
	Usage  AliUsage  `json:"usage"`
 | 
					 | 
				
			||||||
	AliError
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
const AliEnableSearchModelSuffix = "-internet"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
 | 
					 | 
				
			||||||
	messages := make([]AliMessage, 0, len(request.Messages))
 | 
					 | 
				
			||||||
	for i := 0; i < len(request.Messages); i++ {
 | 
						for i := 0; i < len(request.Messages); i++ {
 | 
				
			||||||
		message := request.Messages[i]
 | 
							message := request.Messages[i]
 | 
				
			||||||
		messages = append(messages, AliMessage{
 | 
							messages = append(messages, Message{
 | 
				
			||||||
			Content: message.StringContent(),
 | 
								Content: message.StringContent(),
 | 
				
			||||||
			Role:    strings.ToLower(message.Role),
 | 
								Role:    strings.ToLower(message.Role),
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	enableSearch := false
 | 
						enableSearch := false
 | 
				
			||||||
	aliModel := request.Model
 | 
						aliModel := request.Model
 | 
				
			||||||
	if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) {
 | 
						if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
 | 
				
			||||||
		enableSearch = true
 | 
							enableSearch = true
 | 
				
			||||||
		aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix)
 | 
							aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &AliChatRequest{
 | 
						return &ChatRequest{
 | 
				
			||||||
		Model: aliModel,
 | 
							Model: aliModel,
 | 
				
			||||||
		Input: AliInput{
 | 
							Input: Input{
 | 
				
			||||||
			Messages: messages,
 | 
								Messages: messages,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		Parameters: AliParameters{
 | 
							Parameters: Parameters{
 | 
				
			||||||
			EnableSearch:      enableSearch,
 | 
								EnableSearch:      enableSearch,
 | 
				
			||||||
			IncrementalOutput: request.Stream,
 | 
								IncrementalOutput: request.Stream,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
 | 
					func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest {
 | 
				
			||||||
	return &AliEmbeddingRequest{
 | 
						return &EmbeddingRequest{
 | 
				
			||||||
		Model: "text-embedding-v1",
 | 
							Model: "text-embedding-v1",
 | 
				
			||||||
		Input: struct {
 | 
							Input: struct {
 | 
				
			||||||
			Texts []string `json:"texts"`
 | 
								Texts []string `json:"texts"`
 | 
				
			||||||
@@ -122,21 +53,21 @@ func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingReque
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
 | 
				
			||||||
	var aliResponse AliEmbeddingResponse
 | 
						var aliResponse EmbeddingResponse
 | 
				
			||||||
	err := json.NewDecoder(resp.Body).Decode(&aliResponse)
 | 
						err := json.NewDecoder(resp.Body).Decode(&aliResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = resp.Body.Close()
 | 
						err = resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if aliResponse.Code != "" {
 | 
						if aliResponse.Code != "" {
 | 
				
			||||||
		return &OpenAIErrorWithStatusCode{
 | 
							return &openai.ErrorWithStatusCode{
 | 
				
			||||||
			OpenAIError: OpenAIError{
 | 
								Error: openai.Error{
 | 
				
			||||||
				Message: aliResponse.Message,
 | 
									Message: aliResponse.Message,
 | 
				
			||||||
				Type:    aliResponse.Code,
 | 
									Type:    aliResponse.Code,
 | 
				
			||||||
				Param:   aliResponse.RequestId,
 | 
									Param:   aliResponse.RequestId,
 | 
				
			||||||
@@ -149,7 +80,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
 | 
				
			|||||||
	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
 | 
						fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
 | 
				
			||||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
						jsonResponse, err := json.Marshal(fullTextResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
						c.Writer.Header().Set("Content-Type", "application/json")
 | 
				
			||||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
						c.Writer.WriteHeader(resp.StatusCode)
 | 
				
			||||||
@@ -157,16 +88,16 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
 | 
				
			|||||||
	return nil, &fullTextResponse.Usage
 | 
						return nil, &fullTextResponse.Usage
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
 | 
					func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
 | 
				
			||||||
	openAIEmbeddingResponse := OpenAIEmbeddingResponse{
 | 
						openAIEmbeddingResponse := openai.EmbeddingResponse{
 | 
				
			||||||
		Object: "list",
 | 
							Object: "list",
 | 
				
			||||||
		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
 | 
							Data:   make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)),
 | 
				
			||||||
		Model:  "text-embedding-v1",
 | 
							Model:  "text-embedding-v1",
 | 
				
			||||||
		Usage:  Usage{TotalTokens: response.Usage.TotalTokens},
 | 
							Usage:  openai.Usage{TotalTokens: response.Usage.TotalTokens},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, item := range response.Output.Embeddings {
 | 
						for _, item := range response.Output.Embeddings {
 | 
				
			||||||
		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
 | 
							openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
 | 
				
			||||||
			Object:    `embedding`,
 | 
								Object:    `embedding`,
 | 
				
			||||||
			Index:     item.TextIndex,
 | 
								Index:     item.TextIndex,
 | 
				
			||||||
			Embedding: item.Embedding,
 | 
								Embedding: item.Embedding,
 | 
				
			||||||
@@ -175,21 +106,21 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddin
 | 
				
			|||||||
	return &openAIEmbeddingResponse
 | 
						return &openAIEmbeddingResponse
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
 | 
					func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
 | 
				
			||||||
	choice := OpenAITextResponseChoice{
 | 
						choice := openai.TextResponseChoice{
 | 
				
			||||||
		Index: 0,
 | 
							Index: 0,
 | 
				
			||||||
		Message: Message{
 | 
							Message: openai.Message{
 | 
				
			||||||
			Role:    "assistant",
 | 
								Role:    "assistant",
 | 
				
			||||||
			Content: response.Output.Text,
 | 
								Content: response.Output.Text,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		FinishReason: response.Output.FinishReason,
 | 
							FinishReason: response.Output.FinishReason,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	fullTextResponse := OpenAITextResponse{
 | 
						fullTextResponse := openai.TextResponse{
 | 
				
			||||||
		Id:      response.RequestId,
 | 
							Id:      response.RequestId,
 | 
				
			||||||
		Object:  "chat.completion",
 | 
							Object:  "chat.completion",
 | 
				
			||||||
		Created: common.GetTimestamp(),
 | 
							Created: common.GetTimestamp(),
 | 
				
			||||||
		Choices: []OpenAITextResponseChoice{choice},
 | 
							Choices: []openai.TextResponseChoice{choice},
 | 
				
			||||||
		Usage: Usage{
 | 
							Usage: openai.Usage{
 | 
				
			||||||
			PromptTokens:     response.Usage.InputTokens,
 | 
								PromptTokens:     response.Usage.InputTokens,
 | 
				
			||||||
			CompletionTokens: response.Usage.OutputTokens,
 | 
								CompletionTokens: response.Usage.OutputTokens,
 | 
				
			||||||
			TotalTokens:      response.Usage.InputTokens + response.Usage.OutputTokens,
 | 
								TotalTokens:      response.Usage.InputTokens + response.Usage.OutputTokens,
 | 
				
			||||||
@@ -198,25 +129,25 @@ func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
 | 
				
			|||||||
	return &fullTextResponse
 | 
						return &fullTextResponse
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
 | 
					func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
 | 
				
			||||||
	var choice ChatCompletionsStreamResponseChoice
 | 
						var choice openai.ChatCompletionsStreamResponseChoice
 | 
				
			||||||
	choice.Delta.Content = aliResponse.Output.Text
 | 
						choice.Delta.Content = aliResponse.Output.Text
 | 
				
			||||||
	if aliResponse.Output.FinishReason != "null" {
 | 
						if aliResponse.Output.FinishReason != "null" {
 | 
				
			||||||
		finishReason := aliResponse.Output.FinishReason
 | 
							finishReason := aliResponse.Output.FinishReason
 | 
				
			||||||
		choice.FinishReason = &finishReason
 | 
							choice.FinishReason = &finishReason
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	response := ChatCompletionsStreamResponse{
 | 
						response := openai.ChatCompletionsStreamResponse{
 | 
				
			||||||
		Id:      aliResponse.RequestId,
 | 
							Id:      aliResponse.RequestId,
 | 
				
			||||||
		Object:  "chat.completion.chunk",
 | 
							Object:  "chat.completion.chunk",
 | 
				
			||||||
		Created: common.GetTimestamp(),
 | 
							Created: common.GetTimestamp(),
 | 
				
			||||||
		Model:   "qwen",
 | 
							Model:   "qwen",
 | 
				
			||||||
		Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
							Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &response
 | 
						return &response
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
 | 
				
			||||||
	var usage Usage
 | 
						var usage openai.Usage
 | 
				
			||||||
	scanner := bufio.NewScanner(resp.Body)
 | 
						scanner := bufio.NewScanner(resp.Body)
 | 
				
			||||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
						scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
				
			||||||
		if atEOF && len(data) == 0 {
 | 
							if atEOF && len(data) == 0 {
 | 
				
			||||||
@@ -246,12 +177,12 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		stopChan <- true
 | 
							stopChan <- true
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
	setEventStreamHeaders(c)
 | 
						common.SetEventStreamHeaders(c)
 | 
				
			||||||
	//lastResponseText := ""
 | 
						//lastResponseText := ""
 | 
				
			||||||
	c.Stream(func(w io.Writer) bool {
 | 
						c.Stream(func(w io.Writer) bool {
 | 
				
			||||||
		select {
 | 
							select {
 | 
				
			||||||
		case data := <-dataChan:
 | 
							case data := <-dataChan:
 | 
				
			||||||
			var aliResponse AliChatResponse
 | 
								var aliResponse ChatResponse
 | 
				
			||||||
			err := json.Unmarshal([]byte(data), &aliResponse)
 | 
								err := json.Unmarshal([]byte(data), &aliResponse)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				common.SysError("error unmarshalling stream response: " + err.Error())
 | 
									common.SysError("error unmarshalling stream response: " + err.Error())
 | 
				
			||||||
@@ -279,28 +210,28 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
	err := resp.Body.Close()
 | 
						err := resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil, &usage
 | 
						return nil, &usage
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
 | 
				
			||||||
	var aliResponse AliChatResponse
 | 
						var aliResponse ChatResponse
 | 
				
			||||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
						responseBody, err := io.ReadAll(resp.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = resp.Body.Close()
 | 
						err = resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = json.Unmarshal(responseBody, &aliResponse)
 | 
						err = json.Unmarshal(responseBody, &aliResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if aliResponse.Code != "" {
 | 
						if aliResponse.Code != "" {
 | 
				
			||||||
		return &OpenAIErrorWithStatusCode{
 | 
							return &openai.ErrorWithStatusCode{
 | 
				
			||||||
			OpenAIError: OpenAIError{
 | 
								Error: openai.Error{
 | 
				
			||||||
				Message: aliResponse.Message,
 | 
									Message: aliResponse.Message,
 | 
				
			||||||
				Type:    aliResponse.Code,
 | 
									Type:    aliResponse.Code,
 | 
				
			||||||
				Param:   aliResponse.RequestId,
 | 
									Param:   aliResponse.RequestId,
 | 
				
			||||||
@@ -313,7 +244,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode
 | 
				
			|||||||
	fullTextResponse.Model = "qwen"
 | 
						fullTextResponse.Model = "qwen"
 | 
				
			||||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
						jsonResponse, err := json.Marshal(fullTextResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
						c.Writer.Header().Set("Content-Type", "application/json")
 | 
				
			||||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
						c.Writer.WriteHeader(resp.StatusCode)
 | 
				
			||||||
							
								
								
									
										71
									
								
								relay/channel/ali/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								relay/channel/ali/model.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,71 @@
 | 
				
			|||||||
 | 
					package ali
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Message struct {
 | 
				
			||||||
 | 
						Content string `json:"content"`
 | 
				
			||||||
 | 
						Role    string `json:"role"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Input struct {
 | 
				
			||||||
 | 
						//Prompt   string       `json:"prompt"`
 | 
				
			||||||
 | 
						Messages []Message `json:"messages"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Parameters struct {
 | 
				
			||||||
 | 
						TopP              float64 `json:"top_p,omitempty"`
 | 
				
			||||||
 | 
						TopK              int     `json:"top_k,omitempty"`
 | 
				
			||||||
 | 
						Seed              uint64  `json:"seed,omitempty"`
 | 
				
			||||||
 | 
						EnableSearch      bool    `json:"enable_search,omitempty"`
 | 
				
			||||||
 | 
						IncrementalOutput bool    `json:"incremental_output,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ChatRequest struct {
 | 
				
			||||||
 | 
						Model      string     `json:"model"`
 | 
				
			||||||
 | 
						Input      Input      `json:"input"`
 | 
				
			||||||
 | 
						Parameters Parameters `json:"parameters,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type EmbeddingRequest struct {
 | 
				
			||||||
 | 
						Model string `json:"model"`
 | 
				
			||||||
 | 
						Input struct {
 | 
				
			||||||
 | 
							Texts []string `json:"texts"`
 | 
				
			||||||
 | 
						} `json:"input"`
 | 
				
			||||||
 | 
						Parameters *struct {
 | 
				
			||||||
 | 
							TextType string `json:"text_type,omitempty"`
 | 
				
			||||||
 | 
						} `json:"parameters,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Embedding struct {
 | 
				
			||||||
 | 
						Embedding []float64 `json:"embedding"`
 | 
				
			||||||
 | 
						TextIndex int       `json:"text_index"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type EmbeddingResponse struct {
 | 
				
			||||||
 | 
						Output struct {
 | 
				
			||||||
 | 
							Embeddings []Embedding `json:"embeddings"`
 | 
				
			||||||
 | 
						} `json:"output"`
 | 
				
			||||||
 | 
						Usage Usage `json:"usage"`
 | 
				
			||||||
 | 
						Error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Error struct {
 | 
				
			||||||
 | 
						Code      string `json:"code"`
 | 
				
			||||||
 | 
						Message   string `json:"message"`
 | 
				
			||||||
 | 
						RequestId string `json:"request_id"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Usage struct {
 | 
				
			||||||
 | 
						InputTokens  int `json:"input_tokens"`
 | 
				
			||||||
 | 
						OutputTokens int `json:"output_tokens"`
 | 
				
			||||||
 | 
						TotalTokens  int `json:"total_tokens"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Output struct {
 | 
				
			||||||
 | 
						Text         string `json:"text"`
 | 
				
			||||||
 | 
						FinishReason string `json:"finish_reason"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ChatResponse struct {
 | 
				
			||||||
 | 
						Output Output `json:"output"`
 | 
				
			||||||
 | 
						Usage  Usage  `json:"usage"`
 | 
				
			||||||
 | 
						Error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
package controller
 | 
					package anthropic
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bufio"
 | 
						"bufio"
 | 
				
			||||||
@@ -8,37 +8,10 @@ import (
 | 
				
			|||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ClaudeMetadata struct {
 | 
					 | 
				
			||||||
	UserId string `json:"user_id"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ClaudeRequest struct {
 | 
					 | 
				
			||||||
	Model             string   `json:"model"`
 | 
					 | 
				
			||||||
	Prompt            string   `json:"prompt"`
 | 
					 | 
				
			||||||
	MaxTokensToSample int      `json:"max_tokens_to_sample"`
 | 
					 | 
				
			||||||
	StopSequences     []string `json:"stop_sequences,omitempty"`
 | 
					 | 
				
			||||||
	Temperature       float64  `json:"temperature,omitempty"`
 | 
					 | 
				
			||||||
	TopP              float64  `json:"top_p,omitempty"`
 | 
					 | 
				
			||||||
	TopK              int      `json:"top_k,omitempty"`
 | 
					 | 
				
			||||||
	//ClaudeMetadata    `json:"metadata,omitempty"`
 | 
					 | 
				
			||||||
	Stream bool `json:"stream,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ClaudeError struct {
 | 
					 | 
				
			||||||
	Type    string `json:"type"`
 | 
					 | 
				
			||||||
	Message string `json:"message"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ClaudeResponse struct {
 | 
					 | 
				
			||||||
	Completion string      `json:"completion"`
 | 
					 | 
				
			||||||
	StopReason string      `json:"stop_reason"`
 | 
					 | 
				
			||||||
	Model      string      `json:"model"`
 | 
					 | 
				
			||||||
	Error      ClaudeError `json:"error"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func stopReasonClaude2OpenAI(reason string) string {
 | 
					func stopReasonClaude2OpenAI(reason string) string {
 | 
				
			||||||
	switch reason {
 | 
						switch reason {
 | 
				
			||||||
	case "stop_sequence":
 | 
						case "stop_sequence":
 | 
				
			||||||
@@ -50,8 +23,8 @@ func stopReasonClaude2OpenAI(reason string) string {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
 | 
					func ConvertRequest(textRequest openai.GeneralOpenAIRequest) *Request {
 | 
				
			||||||
	claudeRequest := ClaudeRequest{
 | 
						claudeRequest := Request{
 | 
				
			||||||
		Model:             textRequest.Model,
 | 
							Model:             textRequest.Model,
 | 
				
			||||||
		Prompt:            "",
 | 
							Prompt:            "",
 | 
				
			||||||
		MaxTokensToSample: textRequest.MaxTokens,
 | 
							MaxTokensToSample: textRequest.MaxTokens,
 | 
				
			||||||
@@ -80,40 +53,40 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
 | 
				
			|||||||
	return &claudeRequest
 | 
						return &claudeRequest
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
 | 
					func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletionsStreamResponse {
 | 
				
			||||||
	var choice ChatCompletionsStreamResponseChoice
 | 
						var choice openai.ChatCompletionsStreamResponseChoice
 | 
				
			||||||
	choice.Delta.Content = claudeResponse.Completion
 | 
						choice.Delta.Content = claudeResponse.Completion
 | 
				
			||||||
	finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
 | 
						finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
 | 
				
			||||||
	if finishReason != "null" {
 | 
						if finishReason != "null" {
 | 
				
			||||||
		choice.FinishReason = &finishReason
 | 
							choice.FinishReason = &finishReason
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var response ChatCompletionsStreamResponse
 | 
						var response openai.ChatCompletionsStreamResponse
 | 
				
			||||||
	response.Object = "chat.completion.chunk"
 | 
						response.Object = "chat.completion.chunk"
 | 
				
			||||||
	response.Model = claudeResponse.Model
 | 
						response.Model = claudeResponse.Model
 | 
				
			||||||
	response.Choices = []ChatCompletionsStreamResponseChoice{choice}
 | 
						response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
 | 
				
			||||||
	return &response
 | 
						return &response
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
 | 
					func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
 | 
				
			||||||
	choice := OpenAITextResponseChoice{
 | 
						choice := openai.TextResponseChoice{
 | 
				
			||||||
		Index: 0,
 | 
							Index: 0,
 | 
				
			||||||
		Message: Message{
 | 
							Message: openai.Message{
 | 
				
			||||||
			Role:    "assistant",
 | 
								Role:    "assistant",
 | 
				
			||||||
			Content: strings.TrimPrefix(claudeResponse.Completion, " "),
 | 
								Content: strings.TrimPrefix(claudeResponse.Completion, " "),
 | 
				
			||||||
			Name:    nil,
 | 
								Name:    nil,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
 | 
							FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	fullTextResponse := OpenAITextResponse{
 | 
						fullTextResponse := openai.TextResponse{
 | 
				
			||||||
		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
 | 
							Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
 | 
				
			||||||
		Object:  "chat.completion",
 | 
							Object:  "chat.completion",
 | 
				
			||||||
		Created: common.GetTimestamp(),
 | 
							Created: common.GetTimestamp(),
 | 
				
			||||||
		Choices: []OpenAITextResponseChoice{choice},
 | 
							Choices: []openai.TextResponseChoice{choice},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &fullTextResponse
 | 
						return &fullTextResponse
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
 | 
					func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
 | 
				
			||||||
	responseText := ""
 | 
						responseText := ""
 | 
				
			||||||
	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
 | 
						responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
 | 
				
			||||||
	createdTime := common.GetTimestamp()
 | 
						createdTime := common.GetTimestamp()
 | 
				
			||||||
@@ -143,13 +116,13 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		stopChan <- true
 | 
							stopChan <- true
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
	setEventStreamHeaders(c)
 | 
						common.SetEventStreamHeaders(c)
 | 
				
			||||||
	c.Stream(func(w io.Writer) bool {
 | 
						c.Stream(func(w io.Writer) bool {
 | 
				
			||||||
		select {
 | 
							select {
 | 
				
			||||||
		case data := <-dataChan:
 | 
							case data := <-dataChan:
 | 
				
			||||||
			// some implementations may add \r at the end of data
 | 
								// some implementations may add \r at the end of data
 | 
				
			||||||
			data = strings.TrimSuffix(data, "\r")
 | 
								data = strings.TrimSuffix(data, "\r")
 | 
				
			||||||
			var claudeResponse ClaudeResponse
 | 
								var claudeResponse Response
 | 
				
			||||||
			err := json.Unmarshal([]byte(data), &claudeResponse)
 | 
								err := json.Unmarshal([]byte(data), &claudeResponse)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				common.SysError("error unmarshalling stream response: " + err.Error())
 | 
									common.SysError("error unmarshalling stream response: " + err.Error())
 | 
				
			||||||
@@ -173,28 +146,28 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
	err := resp.Body.Close()
 | 
						err := resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil, responseText
 | 
						return nil, responseText
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
 | 
				
			||||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
						responseBody, err := io.ReadAll(resp.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = resp.Body.Close()
 | 
						err = resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var claudeResponse ClaudeResponse
 | 
						var claudeResponse Response
 | 
				
			||||||
	err = json.Unmarshal(responseBody, &claudeResponse)
 | 
						err = json.Unmarshal(responseBody, &claudeResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if claudeResponse.Error.Type != "" {
 | 
						if claudeResponse.Error.Type != "" {
 | 
				
			||||||
		return &OpenAIErrorWithStatusCode{
 | 
							return &openai.ErrorWithStatusCode{
 | 
				
			||||||
			OpenAIError: OpenAIError{
 | 
								Error: openai.Error{
 | 
				
			||||||
				Message: claudeResponse.Error.Message,
 | 
									Message: claudeResponse.Error.Message,
 | 
				
			||||||
				Type:    claudeResponse.Error.Type,
 | 
									Type:    claudeResponse.Error.Type,
 | 
				
			||||||
				Param:   "",
 | 
									Param:   "",
 | 
				
			||||||
@@ -205,8 +178,8 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	fullTextResponse := responseClaude2OpenAI(&claudeResponse)
 | 
						fullTextResponse := responseClaude2OpenAI(&claudeResponse)
 | 
				
			||||||
	fullTextResponse.Model = model
 | 
						fullTextResponse.Model = model
 | 
				
			||||||
	completionTokens := countTokenText(claudeResponse.Completion, model)
 | 
						completionTokens := openai.CountTokenText(claudeResponse.Completion, model)
 | 
				
			||||||
	usage := Usage{
 | 
						usage := openai.Usage{
 | 
				
			||||||
		PromptTokens:     promptTokens,
 | 
							PromptTokens:     promptTokens,
 | 
				
			||||||
		CompletionTokens: completionTokens,
 | 
							CompletionTokens: completionTokens,
 | 
				
			||||||
		TotalTokens:      promptTokens + completionTokens,
 | 
							TotalTokens:      promptTokens + completionTokens,
 | 
				
			||||||
@@ -214,7 +187,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 | 
				
			|||||||
	fullTextResponse.Usage = usage
 | 
						fullTextResponse.Usage = usage
 | 
				
			||||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
						jsonResponse, err := json.Marshal(fullTextResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
						c.Writer.Header().Set("Content-Type", "application/json")
 | 
				
			||||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
						c.Writer.WriteHeader(resp.StatusCode)
 | 
				
			||||||
							
								
								
									
										29
									
								
								relay/channel/anthropic/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								relay/channel/anthropic/model.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,29 @@
 | 
				
			|||||||
 | 
					package anthropic
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Metadata struct {
 | 
				
			||||||
 | 
						UserId string `json:"user_id"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Request struct {
 | 
				
			||||||
 | 
						Model             string   `json:"model"`
 | 
				
			||||||
 | 
						Prompt            string   `json:"prompt"`
 | 
				
			||||||
 | 
						MaxTokensToSample int      `json:"max_tokens_to_sample"`
 | 
				
			||||||
 | 
						StopSequences     []string `json:"stop_sequences,omitempty"`
 | 
				
			||||||
 | 
						Temperature       float64  `json:"temperature,omitempty"`
 | 
				
			||||||
 | 
						TopP              float64  `json:"top_p,omitempty"`
 | 
				
			||||||
 | 
						TopK              int      `json:"top_k,omitempty"`
 | 
				
			||||||
 | 
						//Metadata    `json:"metadata,omitempty"`
 | 
				
			||||||
 | 
						Stream bool `json:"stream,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Error struct {
 | 
				
			||||||
 | 
						Type    string `json:"type"`
 | 
				
			||||||
 | 
						Message string `json:"message"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Response struct {
 | 
				
			||||||
 | 
						Completion string `json:"completion"`
 | 
				
			||||||
 | 
						StopReason string `json:"stop_reason"`
 | 
				
			||||||
 | 
						Model      string `json:"model"`
 | 
				
			||||||
 | 
						Error      Error  `json:"error"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
package controller
 | 
					package baidu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bufio"
 | 
						"bufio"
 | 
				
			||||||
@@ -9,6 +9,9 @@ import (
 | 
				
			|||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
 | 
						"one-api/relay/constant"
 | 
				
			||||||
 | 
						"one-api/relay/util"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
@@ -37,53 +40,9 @@ type BaiduError struct {
 | 
				
			|||||||
	ErrorMsg  string `json:"error_msg"`
 | 
						ErrorMsg  string `json:"error_msg"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type BaiduChatResponse struct {
 | 
					 | 
				
			||||||
	Id               string `json:"id"`
 | 
					 | 
				
			||||||
	Object           string `json:"object"`
 | 
					 | 
				
			||||||
	Created          int64  `json:"created"`
 | 
					 | 
				
			||||||
	Result           string `json:"result"`
 | 
					 | 
				
			||||||
	IsTruncated      bool   `json:"is_truncated"`
 | 
					 | 
				
			||||||
	NeedClearHistory bool   `json:"need_clear_history"`
 | 
					 | 
				
			||||||
	Usage            Usage  `json:"usage"`
 | 
					 | 
				
			||||||
	BaiduError
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type BaiduChatStreamResponse struct {
 | 
					 | 
				
			||||||
	BaiduChatResponse
 | 
					 | 
				
			||||||
	SentenceId int  `json:"sentence_id"`
 | 
					 | 
				
			||||||
	IsEnd      bool `json:"is_end"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type BaiduEmbeddingRequest struct {
 | 
					 | 
				
			||||||
	Input []string `json:"input"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type BaiduEmbeddingData struct {
 | 
					 | 
				
			||||||
	Object    string    `json:"object"`
 | 
					 | 
				
			||||||
	Embedding []float64 `json:"embedding"`
 | 
					 | 
				
			||||||
	Index     int       `json:"index"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type BaiduEmbeddingResponse struct {
 | 
					 | 
				
			||||||
	Id      string               `json:"id"`
 | 
					 | 
				
			||||||
	Object  string               `json:"object"`
 | 
					 | 
				
			||||||
	Created int64                `json:"created"`
 | 
					 | 
				
			||||||
	Data    []BaiduEmbeddingData `json:"data"`
 | 
					 | 
				
			||||||
	Usage   Usage                `json:"usage"`
 | 
					 | 
				
			||||||
	BaiduError
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type BaiduAccessToken struct {
 | 
					 | 
				
			||||||
	AccessToken      string    `json:"access_token"`
 | 
					 | 
				
			||||||
	Error            string    `json:"error,omitempty"`
 | 
					 | 
				
			||||||
	ErrorDescription string    `json:"error_description,omitempty"`
 | 
					 | 
				
			||||||
	ExpiresIn        int64     `json:"expires_in,omitempty"`
 | 
					 | 
				
			||||||
	ExpiresAt        time.Time `json:"-"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
var baiduTokenStore sync.Map
 | 
					var baiduTokenStore sync.Map
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
 | 
					func ConvertRequest(request openai.GeneralOpenAIRequest) *BaiduChatRequest {
 | 
				
			||||||
	messages := make([]BaiduMessage, 0, len(request.Messages))
 | 
						messages := make([]BaiduMessage, 0, len(request.Messages))
 | 
				
			||||||
	for _, message := range request.Messages {
 | 
						for _, message := range request.Messages {
 | 
				
			||||||
		if message.Role == "system" {
 | 
							if message.Role == "system" {
 | 
				
			||||||
@@ -108,56 +67,56 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
 | 
					func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
 | 
				
			||||||
	choice := OpenAITextResponseChoice{
 | 
						choice := openai.TextResponseChoice{
 | 
				
			||||||
		Index: 0,
 | 
							Index: 0,
 | 
				
			||||||
		Message: Message{
 | 
							Message: openai.Message{
 | 
				
			||||||
			Role:    "assistant",
 | 
								Role:    "assistant",
 | 
				
			||||||
			Content: response.Result,
 | 
								Content: response.Result,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		FinishReason: "stop",
 | 
							FinishReason: "stop",
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	fullTextResponse := OpenAITextResponse{
 | 
						fullTextResponse := openai.TextResponse{
 | 
				
			||||||
		Id:      response.Id,
 | 
							Id:      response.Id,
 | 
				
			||||||
		Object:  "chat.completion",
 | 
							Object:  "chat.completion",
 | 
				
			||||||
		Created: response.Created,
 | 
							Created: response.Created,
 | 
				
			||||||
		Choices: []OpenAITextResponseChoice{choice},
 | 
							Choices: []openai.TextResponseChoice{choice},
 | 
				
			||||||
		Usage:   response.Usage,
 | 
							Usage:   response.Usage,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &fullTextResponse
 | 
						return &fullTextResponse
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
 | 
					func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatCompletionsStreamResponse {
 | 
				
			||||||
	var choice ChatCompletionsStreamResponseChoice
 | 
						var choice openai.ChatCompletionsStreamResponseChoice
 | 
				
			||||||
	choice.Delta.Content = baiduResponse.Result
 | 
						choice.Delta.Content = baiduResponse.Result
 | 
				
			||||||
	if baiduResponse.IsEnd {
 | 
						if baiduResponse.IsEnd {
 | 
				
			||||||
		choice.FinishReason = &stopFinishReason
 | 
							choice.FinishReason = &constant.StopFinishReason
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	response := ChatCompletionsStreamResponse{
 | 
						response := openai.ChatCompletionsStreamResponse{
 | 
				
			||||||
		Id:      baiduResponse.Id,
 | 
							Id:      baiduResponse.Id,
 | 
				
			||||||
		Object:  "chat.completion.chunk",
 | 
							Object:  "chat.completion.chunk",
 | 
				
			||||||
		Created: baiduResponse.Created,
 | 
							Created: baiduResponse.Created,
 | 
				
			||||||
		Model:   "ernie-bot",
 | 
							Model:   "ernie-bot",
 | 
				
			||||||
		Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
							Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &response
 | 
						return &response
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
 | 
					func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest {
 | 
				
			||||||
	return &BaiduEmbeddingRequest{
 | 
						return &EmbeddingRequest{
 | 
				
			||||||
		Input: request.ParseInput(),
 | 
							Input: request.ParseInput(),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
 | 
					func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
 | 
				
			||||||
	openAIEmbeddingResponse := OpenAIEmbeddingResponse{
 | 
						openAIEmbeddingResponse := openai.EmbeddingResponse{
 | 
				
			||||||
		Object: "list",
 | 
							Object: "list",
 | 
				
			||||||
		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
 | 
							Data:   make([]openai.EmbeddingResponseItem, 0, len(response.Data)),
 | 
				
			||||||
		Model:  "baidu-embedding",
 | 
							Model:  "baidu-embedding",
 | 
				
			||||||
		Usage:  response.Usage,
 | 
							Usage:  response.Usage,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	for _, item := range response.Data {
 | 
						for _, item := range response.Data {
 | 
				
			||||||
		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
 | 
							openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
 | 
				
			||||||
			Object:    item.Object,
 | 
								Object:    item.Object,
 | 
				
			||||||
			Index:     item.Index,
 | 
								Index:     item.Index,
 | 
				
			||||||
			Embedding: item.Embedding,
 | 
								Embedding: item.Embedding,
 | 
				
			||||||
@@ -166,8 +125,8 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbe
 | 
				
			|||||||
	return &openAIEmbeddingResponse
 | 
						return &openAIEmbeddingResponse
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
 | 
				
			||||||
	var usage Usage
 | 
						var usage openai.Usage
 | 
				
			||||||
	scanner := bufio.NewScanner(resp.Body)
 | 
						scanner := bufio.NewScanner(resp.Body)
 | 
				
			||||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
						scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
				
			||||||
		if atEOF && len(data) == 0 {
 | 
							if atEOF && len(data) == 0 {
 | 
				
			||||||
@@ -194,11 +153,11 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		stopChan <- true
 | 
							stopChan <- true
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
	setEventStreamHeaders(c)
 | 
						common.SetEventStreamHeaders(c)
 | 
				
			||||||
	c.Stream(func(w io.Writer) bool {
 | 
						c.Stream(func(w io.Writer) bool {
 | 
				
			||||||
		select {
 | 
							select {
 | 
				
			||||||
		case data := <-dataChan:
 | 
							case data := <-dataChan:
 | 
				
			||||||
			var baiduResponse BaiduChatStreamResponse
 | 
								var baiduResponse ChatStreamResponse
 | 
				
			||||||
			err := json.Unmarshal([]byte(data), &baiduResponse)
 | 
								err := json.Unmarshal([]byte(data), &baiduResponse)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				common.SysError("error unmarshalling stream response: " + err.Error())
 | 
									common.SysError("error unmarshalling stream response: " + err.Error())
 | 
				
			||||||
@@ -224,28 +183,28 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
	err := resp.Body.Close()
 | 
						err := resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil, &usage
 | 
						return nil, &usage
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
 | 
				
			||||||
	var baiduResponse BaiduChatResponse
 | 
						var baiduResponse ChatResponse
 | 
				
			||||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
						responseBody, err := io.ReadAll(resp.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = resp.Body.Close()
 | 
						err = resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = json.Unmarshal(responseBody, &baiduResponse)
 | 
						err = json.Unmarshal(responseBody, &baiduResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if baiduResponse.ErrorMsg != "" {
 | 
						if baiduResponse.ErrorMsg != "" {
 | 
				
			||||||
		return &OpenAIErrorWithStatusCode{
 | 
							return &openai.ErrorWithStatusCode{
 | 
				
			||||||
			OpenAIError: OpenAIError{
 | 
								Error: openai.Error{
 | 
				
			||||||
				Message: baiduResponse.ErrorMsg,
 | 
									Message: baiduResponse.ErrorMsg,
 | 
				
			||||||
				Type:    "baidu_error",
 | 
									Type:    "baidu_error",
 | 
				
			||||||
				Param:   "",
 | 
									Param:   "",
 | 
				
			||||||
@@ -258,7 +217,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
 | 
				
			|||||||
	fullTextResponse.Model = "ernie-bot"
 | 
						fullTextResponse.Model = "ernie-bot"
 | 
				
			||||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
						jsonResponse, err := json.Marshal(fullTextResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
						c.Writer.Header().Set("Content-Type", "application/json")
 | 
				
			||||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
						c.Writer.WriteHeader(resp.StatusCode)
 | 
				
			||||||
@@ -266,23 +225,23 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
 | 
				
			|||||||
	return nil, &fullTextResponse.Usage
 | 
						return nil, &fullTextResponse.Usage
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
 | 
				
			||||||
	var baiduResponse BaiduEmbeddingResponse
 | 
						var baiduResponse EmbeddingResponse
 | 
				
			||||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
						responseBody, err := io.ReadAll(resp.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = resp.Body.Close()
 | 
						err = resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = json.Unmarshal(responseBody, &baiduResponse)
 | 
						err = json.Unmarshal(responseBody, &baiduResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if baiduResponse.ErrorMsg != "" {
 | 
						if baiduResponse.ErrorMsg != "" {
 | 
				
			||||||
		return &OpenAIErrorWithStatusCode{
 | 
							return &openai.ErrorWithStatusCode{
 | 
				
			||||||
			OpenAIError: OpenAIError{
 | 
								Error: openai.Error{
 | 
				
			||||||
				Message: baiduResponse.ErrorMsg,
 | 
									Message: baiduResponse.ErrorMsg,
 | 
				
			||||||
				Type:    "baidu_error",
 | 
									Type:    "baidu_error",
 | 
				
			||||||
				Param:   "",
 | 
									Param:   "",
 | 
				
			||||||
@@ -294,7 +253,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
 | 
				
			|||||||
	fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
 | 
						fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
 | 
				
			||||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
						jsonResponse, err := json.Marshal(fullTextResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
						c.Writer.Header().Set("Content-Type", "application/json")
 | 
				
			||||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
						c.Writer.WriteHeader(resp.StatusCode)
 | 
				
			||||||
@@ -302,10 +261,10 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
 | 
				
			|||||||
	return nil, &fullTextResponse.Usage
 | 
						return nil, &fullTextResponse.Usage
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func getBaiduAccessToken(apiKey string) (string, error) {
 | 
					func GetAccessToken(apiKey string) (string, error) {
 | 
				
			||||||
	if val, ok := baiduTokenStore.Load(apiKey); ok {
 | 
						if val, ok := baiduTokenStore.Load(apiKey); ok {
 | 
				
			||||||
		var accessToken BaiduAccessToken
 | 
							var accessToken AccessToken
 | 
				
			||||||
		if accessToken, ok = val.(BaiduAccessToken); ok {
 | 
							if accessToken, ok = val.(AccessToken); ok {
 | 
				
			||||||
			// soon this will expire
 | 
								// soon this will expire
 | 
				
			||||||
			if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
 | 
								if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
 | 
				
			||||||
				go func() {
 | 
									go func() {
 | 
				
			||||||
@@ -320,12 +279,12 @@ func getBaiduAccessToken(apiKey string) (string, error) {
 | 
				
			|||||||
		return "", err
 | 
							return "", err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if accessToken == nil {
 | 
						if accessToken == nil {
 | 
				
			||||||
		return "", errors.New("getBaiduAccessToken return a nil token")
 | 
							return "", errors.New("GetAccessToken return a nil token")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return (*accessToken).AccessToken, nil
 | 
						return (*accessToken).AccessToken, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
 | 
					func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) {
 | 
				
			||||||
	parts := strings.Split(apiKey, "|")
 | 
						parts := strings.Split(apiKey, "|")
 | 
				
			||||||
	if len(parts) != 2 {
 | 
						if len(parts) != 2 {
 | 
				
			||||||
		return nil, errors.New("invalid baidu apikey")
 | 
							return nil, errors.New("invalid baidu apikey")
 | 
				
			||||||
@@ -337,13 +296,13 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	req.Header.Add("Content-Type", "application/json")
 | 
						req.Header.Add("Content-Type", "application/json")
 | 
				
			||||||
	req.Header.Add("Accept", "application/json")
 | 
						req.Header.Add("Accept", "application/json")
 | 
				
			||||||
	res, err := impatientHTTPClient.Do(req)
 | 
						res, err := util.ImpatientHTTPClient.Do(req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	defer res.Body.Close()
 | 
						defer res.Body.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var accessToken BaiduAccessToken
 | 
						var accessToken AccessToken
 | 
				
			||||||
	err = json.NewDecoder(res.Body).Decode(&accessToken)
 | 
						err = json.NewDecoder(res.Body).Decode(&accessToken)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
							
								
								
									
										50
									
								
								relay/channel/baidu/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								relay/channel/baidu/model.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,50 @@
 | 
				
			|||||||
 | 
					package baidu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ChatResponse struct {
 | 
				
			||||||
 | 
						Id               string       `json:"id"`
 | 
				
			||||||
 | 
						Object           string       `json:"object"`
 | 
				
			||||||
 | 
						Created          int64        `json:"created"`
 | 
				
			||||||
 | 
						Result           string       `json:"result"`
 | 
				
			||||||
 | 
						IsTruncated      bool         `json:"is_truncated"`
 | 
				
			||||||
 | 
						NeedClearHistory bool         `json:"need_clear_history"`
 | 
				
			||||||
 | 
						Usage            openai.Usage `json:"usage"`
 | 
				
			||||||
 | 
						BaiduError
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ChatStreamResponse struct {
 | 
				
			||||||
 | 
						ChatResponse
 | 
				
			||||||
 | 
						SentenceId int  `json:"sentence_id"`
 | 
				
			||||||
 | 
						IsEnd      bool `json:"is_end"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type EmbeddingRequest struct {
 | 
				
			||||||
 | 
						Input []string `json:"input"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type EmbeddingData struct {
 | 
				
			||||||
 | 
						Object    string    `json:"object"`
 | 
				
			||||||
 | 
						Embedding []float64 `json:"embedding"`
 | 
				
			||||||
 | 
						Index     int       `json:"index"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type EmbeddingResponse struct {
 | 
				
			||||||
 | 
						Id      string          `json:"id"`
 | 
				
			||||||
 | 
						Object  string          `json:"object"`
 | 
				
			||||||
 | 
						Created int64           `json:"created"`
 | 
				
			||||||
 | 
						Data    []EmbeddingData `json:"data"`
 | 
				
			||||||
 | 
						Usage   openai.Usage    `json:"usage"`
 | 
				
			||||||
 | 
						BaiduError
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type AccessToken struct {
 | 
				
			||||||
 | 
						AccessToken      string    `json:"access_token"`
 | 
				
			||||||
 | 
						Error            string    `json:"error,omitempty"`
 | 
				
			||||||
 | 
						ErrorDescription string    `json:"error_description,omitempty"`
 | 
				
			||||||
 | 
						ExpiresIn        int64     `json:"expires_in,omitempty"`
 | 
				
			||||||
 | 
						ExpiresAt        time.Time `json:"-"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
package controller
 | 
					package google
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bufio"
 | 
						"bufio"
 | 
				
			||||||
@@ -8,6 +8,8 @@ import (
 | 
				
			|||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
	"one-api/common/image"
 | 
						"one-api/common/image"
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
 | 
						"one-api/relay/constant"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
@@ -19,48 +21,8 @@ const (
 | 
				
			|||||||
	GeminiVisionMaxImageNum = 16
 | 
						GeminiVisionMaxImageNum = 16
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type GeminiChatRequest struct {
 | 
					 | 
				
			||||||
	Contents         []GeminiChatContent        `json:"contents"`
 | 
					 | 
				
			||||||
	SafetySettings   []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
 | 
					 | 
				
			||||||
	GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
 | 
					 | 
				
			||||||
	Tools            []GeminiChatTools          `json:"tools,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type GeminiInlineData struct {
 | 
					 | 
				
			||||||
	MimeType string `json:"mimeType"`
 | 
					 | 
				
			||||||
	Data     string `json:"data"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type GeminiPart struct {
 | 
					 | 
				
			||||||
	Text       string            `json:"text,omitempty"`
 | 
					 | 
				
			||||||
	InlineData *GeminiInlineData `json:"inlineData,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type GeminiChatContent struct {
 | 
					 | 
				
			||||||
	Role  string       `json:"role,omitempty"`
 | 
					 | 
				
			||||||
	Parts []GeminiPart `json:"parts"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type GeminiChatSafetySettings struct {
 | 
					 | 
				
			||||||
	Category  string `json:"category"`
 | 
					 | 
				
			||||||
	Threshold string `json:"threshold"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type GeminiChatTools struct {
 | 
					 | 
				
			||||||
	FunctionDeclarations any `json:"functionDeclarations,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type GeminiChatGenerationConfig struct {
 | 
					 | 
				
			||||||
	Temperature     float64  `json:"temperature,omitempty"`
 | 
					 | 
				
			||||||
	TopP            float64  `json:"topP,omitempty"`
 | 
					 | 
				
			||||||
	TopK            float64  `json:"topK,omitempty"`
 | 
					 | 
				
			||||||
	MaxOutputTokens int      `json:"maxOutputTokens,omitempty"`
 | 
					 | 
				
			||||||
	CandidateCount  int      `json:"candidateCount,omitempty"`
 | 
					 | 
				
			||||||
	StopSequences   []string `json:"stopSequences,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
 | 
					// Setting safety to the lowest possible values since Gemini is already powerless enough
 | 
				
			||||||
func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
 | 
					func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRequest {
 | 
				
			||||||
	geminiRequest := GeminiChatRequest{
 | 
						geminiRequest := GeminiChatRequest{
 | 
				
			||||||
		Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
 | 
							Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
 | 
				
			||||||
		SafetySettings: []GeminiChatSafetySettings{
 | 
							SafetySettings: []GeminiChatSafetySettings{
 | 
				
			||||||
@@ -108,11 +70,11 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
 | 
				
			|||||||
		var parts []GeminiPart
 | 
							var parts []GeminiPart
 | 
				
			||||||
		imageNum := 0
 | 
							imageNum := 0
 | 
				
			||||||
		for _, part := range openaiContent {
 | 
							for _, part := range openaiContent {
 | 
				
			||||||
			if part.Type == ContentTypeText {
 | 
								if part.Type == openai.ContentTypeText {
 | 
				
			||||||
				parts = append(parts, GeminiPart{
 | 
									parts = append(parts, GeminiPart{
 | 
				
			||||||
					Text: part.Text,
 | 
										Text: part.Text,
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
			} else if part.Type == ContentTypeImageURL {
 | 
								} else if part.Type == openai.ContentTypeImageURL {
 | 
				
			||||||
				imageNum += 1
 | 
									imageNum += 1
 | 
				
			||||||
				if imageNum > GeminiVisionMaxImageNum {
 | 
									if imageNum > GeminiVisionMaxImageNum {
 | 
				
			||||||
					continue
 | 
										continue
 | 
				
			||||||
@@ -187,21 +149,21 @@ type GeminiChatPromptFeedback struct {
 | 
				
			|||||||
	SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
 | 
						SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
 | 
					func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextResponse {
 | 
				
			||||||
	fullTextResponse := OpenAITextResponse{
 | 
						fullTextResponse := openai.TextResponse{
 | 
				
			||||||
		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
 | 
							Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
 | 
				
			||||||
		Object:  "chat.completion",
 | 
							Object:  "chat.completion",
 | 
				
			||||||
		Created: common.GetTimestamp(),
 | 
							Created: common.GetTimestamp(),
 | 
				
			||||||
		Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
 | 
							Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	for i, candidate := range response.Candidates {
 | 
						for i, candidate := range response.Candidates {
 | 
				
			||||||
		choice := OpenAITextResponseChoice{
 | 
							choice := openai.TextResponseChoice{
 | 
				
			||||||
			Index: i,
 | 
								Index: i,
 | 
				
			||||||
			Message: Message{
 | 
								Message: openai.Message{
 | 
				
			||||||
				Role:    "assistant",
 | 
									Role:    "assistant",
 | 
				
			||||||
				Content: "",
 | 
									Content: "",
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			FinishReason: stopFinishReason,
 | 
								FinishReason: constant.StopFinishReason,
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if len(candidate.Content.Parts) > 0 {
 | 
							if len(candidate.Content.Parts) > 0 {
 | 
				
			||||||
			choice.Message.Content = candidate.Content.Parts[0].Text
 | 
								choice.Message.Content = candidate.Content.Parts[0].Text
 | 
				
			||||||
@@ -211,18 +173,18 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse
 | 
				
			|||||||
	return &fullTextResponse
 | 
						return &fullTextResponse
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
 | 
					func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai.ChatCompletionsStreamResponse {
 | 
				
			||||||
	var choice ChatCompletionsStreamResponseChoice
 | 
						var choice openai.ChatCompletionsStreamResponseChoice
 | 
				
			||||||
	choice.Delta.Content = geminiResponse.GetResponseText()
 | 
						choice.Delta.Content = geminiResponse.GetResponseText()
 | 
				
			||||||
	choice.FinishReason = &stopFinishReason
 | 
						choice.FinishReason = &constant.StopFinishReason
 | 
				
			||||||
	var response ChatCompletionsStreamResponse
 | 
						var response openai.ChatCompletionsStreamResponse
 | 
				
			||||||
	response.Object = "chat.completion.chunk"
 | 
						response.Object = "chat.completion.chunk"
 | 
				
			||||||
	response.Model = "gemini"
 | 
						response.Model = "gemini"
 | 
				
			||||||
	response.Choices = []ChatCompletionsStreamResponseChoice{choice}
 | 
						response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
 | 
				
			||||||
	return &response
 | 
						return &response
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
 | 
					func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
 | 
				
			||||||
	responseText := ""
 | 
						responseText := ""
 | 
				
			||||||
	dataChan := make(chan string)
 | 
						dataChan := make(chan string)
 | 
				
			||||||
	stopChan := make(chan bool)
 | 
						stopChan := make(chan bool)
 | 
				
			||||||
@@ -252,7 +214,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		stopChan <- true
 | 
							stopChan <- true
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
	setEventStreamHeaders(c)
 | 
						common.SetEventStreamHeaders(c)
 | 
				
			||||||
	c.Stream(func(w io.Writer) bool {
 | 
						c.Stream(func(w io.Writer) bool {
 | 
				
			||||||
		select {
 | 
							select {
 | 
				
			||||||
		case data := <-dataChan:
 | 
							case data := <-dataChan:
 | 
				
			||||||
@@ -264,14 +226,14 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
 | 
				
			|||||||
			var dummy dummyStruct
 | 
								var dummy dummyStruct
 | 
				
			||||||
			err := json.Unmarshal([]byte(data), &dummy)
 | 
								err := json.Unmarshal([]byte(data), &dummy)
 | 
				
			||||||
			responseText += dummy.Content
 | 
								responseText += dummy.Content
 | 
				
			||||||
			var choice ChatCompletionsStreamResponseChoice
 | 
								var choice openai.ChatCompletionsStreamResponseChoice
 | 
				
			||||||
			choice.Delta.Content = dummy.Content
 | 
								choice.Delta.Content = dummy.Content
 | 
				
			||||||
			response := ChatCompletionsStreamResponse{
 | 
								response := openai.ChatCompletionsStreamResponse{
 | 
				
			||||||
				Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
 | 
									Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
 | 
				
			||||||
				Object:  "chat.completion.chunk",
 | 
									Object:  "chat.completion.chunk",
 | 
				
			||||||
				Created: common.GetTimestamp(),
 | 
									Created: common.GetTimestamp(),
 | 
				
			||||||
				Model:   "gemini-pro",
 | 
									Model:   "gemini-pro",
 | 
				
			||||||
				Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
									Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			jsonResponse, err := json.Marshal(response)
 | 
								jsonResponse, err := json.Marshal(response)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
@@ -287,28 +249,28 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
	err := resp.Body.Close()
 | 
						err := resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil, responseText
 | 
						return nil, responseText
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
 | 
				
			||||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
						responseBody, err := io.ReadAll(resp.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = resp.Body.Close()
 | 
						err = resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var geminiResponse GeminiChatResponse
 | 
						var geminiResponse GeminiChatResponse
 | 
				
			||||||
	err = json.Unmarshal(responseBody, &geminiResponse)
 | 
						err = json.Unmarshal(responseBody, &geminiResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if len(geminiResponse.Candidates) == 0 {
 | 
						if len(geminiResponse.Candidates) == 0 {
 | 
				
			||||||
		return &OpenAIErrorWithStatusCode{
 | 
							return &openai.ErrorWithStatusCode{
 | 
				
			||||||
			OpenAIError: OpenAIError{
 | 
								Error: openai.Error{
 | 
				
			||||||
				Message: "No candidates returned",
 | 
									Message: "No candidates returned",
 | 
				
			||||||
				Type:    "server_error",
 | 
									Type:    "server_error",
 | 
				
			||||||
				Param:   "",
 | 
									Param:   "",
 | 
				
			||||||
@@ -319,8 +281,8 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
 | 
						fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
 | 
				
			||||||
	fullTextResponse.Model = model
 | 
						fullTextResponse.Model = model
 | 
				
			||||||
	completionTokens := countTokenText(geminiResponse.GetResponseText(), model)
 | 
						completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), model)
 | 
				
			||||||
	usage := Usage{
 | 
						usage := openai.Usage{
 | 
				
			||||||
		PromptTokens:     promptTokens,
 | 
							PromptTokens:     promptTokens,
 | 
				
			||||||
		CompletionTokens: completionTokens,
 | 
							CompletionTokens: completionTokens,
 | 
				
			||||||
		TotalTokens:      promptTokens + completionTokens,
 | 
							TotalTokens:      promptTokens + completionTokens,
 | 
				
			||||||
@@ -328,7 +290,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
 | 
				
			|||||||
	fullTextResponse.Usage = usage
 | 
						fullTextResponse.Usage = usage
 | 
				
			||||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
						jsonResponse, err := json.Marshal(fullTextResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
						c.Writer.Header().Set("Content-Type", "application/json")
 | 
				
			||||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
						c.Writer.WriteHeader(resp.StatusCode)
 | 
				
			||||||
							
								
								
									
										80
									
								
								relay/channel/google/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								relay/channel/google/model.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,80 @@
 | 
				
			|||||||
 | 
					package google
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type GeminiChatRequest struct {
 | 
				
			||||||
 | 
						Contents         []GeminiChatContent        `json:"contents"`
 | 
				
			||||||
 | 
						SafetySettings   []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
 | 
				
			||||||
 | 
						GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
 | 
				
			||||||
 | 
						Tools            []GeminiChatTools          `json:"tools,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type GeminiInlineData struct {
 | 
				
			||||||
 | 
						MimeType string `json:"mimeType"`
 | 
				
			||||||
 | 
						Data     string `json:"data"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type GeminiPart struct {
 | 
				
			||||||
 | 
						Text       string            `json:"text,omitempty"`
 | 
				
			||||||
 | 
						InlineData *GeminiInlineData `json:"inlineData,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type GeminiChatContent struct {
 | 
				
			||||||
 | 
						Role  string       `json:"role,omitempty"`
 | 
				
			||||||
 | 
						Parts []GeminiPart `json:"parts"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type GeminiChatSafetySettings struct {
 | 
				
			||||||
 | 
						Category  string `json:"category"`
 | 
				
			||||||
 | 
						Threshold string `json:"threshold"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type GeminiChatTools struct {
 | 
				
			||||||
 | 
						FunctionDeclarations any `json:"functionDeclarations,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type GeminiChatGenerationConfig struct {
 | 
				
			||||||
 | 
						Temperature     float64  `json:"temperature,omitempty"`
 | 
				
			||||||
 | 
						TopP            float64  `json:"topP,omitempty"`
 | 
				
			||||||
 | 
						TopK            float64  `json:"topK,omitempty"`
 | 
				
			||||||
 | 
						MaxOutputTokens int      `json:"maxOutputTokens,omitempty"`
 | 
				
			||||||
 | 
						CandidateCount  int      `json:"candidateCount,omitempty"`
 | 
				
			||||||
 | 
						StopSequences   []string `json:"stopSequences,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type PaLMChatMessage struct {
 | 
				
			||||||
 | 
						Author  string `json:"author"`
 | 
				
			||||||
 | 
						Content string `json:"content"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type PaLMFilter struct {
 | 
				
			||||||
 | 
						Reason  string `json:"reason"`
 | 
				
			||||||
 | 
						Message string `json:"message"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type PaLMPrompt struct {
 | 
				
			||||||
 | 
						Messages []PaLMChatMessage `json:"messages"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type PaLMChatRequest struct {
 | 
				
			||||||
 | 
						Prompt         PaLMPrompt `json:"prompt"`
 | 
				
			||||||
 | 
						Temperature    float64    `json:"temperature,omitempty"`
 | 
				
			||||||
 | 
						CandidateCount int        `json:"candidateCount,omitempty"`
 | 
				
			||||||
 | 
						TopP           float64    `json:"topP,omitempty"`
 | 
				
			||||||
 | 
						TopK           int        `json:"topK,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type PaLMError struct {
 | 
				
			||||||
 | 
						Code    int    `json:"code"`
 | 
				
			||||||
 | 
						Message string `json:"message"`
 | 
				
			||||||
 | 
						Status  string `json:"status"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type PaLMChatResponse struct {
 | 
				
			||||||
 | 
						Candidates []PaLMChatMessage `json:"candidates"`
 | 
				
			||||||
 | 
						Messages   []openai.Message  `json:"messages"`
 | 
				
			||||||
 | 
						Filters    []PaLMFilter      `json:"filters"`
 | 
				
			||||||
 | 
						Error      PaLMError         `json:"error"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
package controller
 | 
					package google
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
@@ -7,47 +7,14 @@ import (
 | 
				
			|||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
 | 
						"one-api/relay/constant"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
 | 
					// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
 | 
				
			||||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
 | 
					// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type PaLMChatMessage struct {
 | 
					func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatRequest {
 | 
				
			||||||
	Author  string `json:"author"`
 | 
					 | 
				
			||||||
	Content string `json:"content"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type PaLMFilter struct {
 | 
					 | 
				
			||||||
	Reason  string `json:"reason"`
 | 
					 | 
				
			||||||
	Message string `json:"message"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type PaLMPrompt struct {
 | 
					 | 
				
			||||||
	Messages []PaLMChatMessage `json:"messages"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type PaLMChatRequest struct {
 | 
					 | 
				
			||||||
	Prompt         PaLMPrompt `json:"prompt"`
 | 
					 | 
				
			||||||
	Temperature    float64    `json:"temperature,omitempty"`
 | 
					 | 
				
			||||||
	CandidateCount int        `json:"candidateCount,omitempty"`
 | 
					 | 
				
			||||||
	TopP           float64    `json:"topP,omitempty"`
 | 
					 | 
				
			||||||
	TopK           int        `json:"topK,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type PaLMError struct {
 | 
					 | 
				
			||||||
	Code    int    `json:"code"`
 | 
					 | 
				
			||||||
	Message string `json:"message"`
 | 
					 | 
				
			||||||
	Status  string `json:"status"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type PaLMChatResponse struct {
 | 
					 | 
				
			||||||
	Candidates []PaLMChatMessage `json:"candidates"`
 | 
					 | 
				
			||||||
	Messages   []Message         `json:"messages"`
 | 
					 | 
				
			||||||
	Filters    []PaLMFilter      `json:"filters"`
 | 
					 | 
				
			||||||
	Error      PaLMError         `json:"error"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
 | 
					 | 
				
			||||||
	palmRequest := PaLMChatRequest{
 | 
						palmRequest := PaLMChatRequest{
 | 
				
			||||||
		Prompt: PaLMPrompt{
 | 
							Prompt: PaLMPrompt{
 | 
				
			||||||
			Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
 | 
								Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
 | 
				
			||||||
@@ -71,14 +38,14 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
 | 
				
			|||||||
	return &palmRequest
 | 
						return &palmRequest
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
 | 
					func responsePaLM2OpenAI(response *PaLMChatResponse) *openai.TextResponse {
 | 
				
			||||||
	fullTextResponse := OpenAITextResponse{
 | 
						fullTextResponse := openai.TextResponse{
 | 
				
			||||||
		Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
 | 
							Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	for i, candidate := range response.Candidates {
 | 
						for i, candidate := range response.Candidates {
 | 
				
			||||||
		choice := OpenAITextResponseChoice{
 | 
							choice := openai.TextResponseChoice{
 | 
				
			||||||
			Index: i,
 | 
								Index: i,
 | 
				
			||||||
			Message: Message{
 | 
								Message: openai.Message{
 | 
				
			||||||
				Role:    "assistant",
 | 
									Role:    "assistant",
 | 
				
			||||||
				Content: candidate.Content,
 | 
									Content: candidate.Content,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
@@ -89,20 +56,20 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
 | 
				
			|||||||
	return &fullTextResponse
 | 
						return &fullTextResponse
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse {
 | 
					func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompletionsStreamResponse {
 | 
				
			||||||
	var choice ChatCompletionsStreamResponseChoice
 | 
						var choice openai.ChatCompletionsStreamResponseChoice
 | 
				
			||||||
	if len(palmResponse.Candidates) > 0 {
 | 
						if len(palmResponse.Candidates) > 0 {
 | 
				
			||||||
		choice.Delta.Content = palmResponse.Candidates[0].Content
 | 
							choice.Delta.Content = palmResponse.Candidates[0].Content
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	choice.FinishReason = &stopFinishReason
 | 
						choice.FinishReason = &constant.StopFinishReason
 | 
				
			||||||
	var response ChatCompletionsStreamResponse
 | 
						var response openai.ChatCompletionsStreamResponse
 | 
				
			||||||
	response.Object = "chat.completion.chunk"
 | 
						response.Object = "chat.completion.chunk"
 | 
				
			||||||
	response.Model = "palm2"
 | 
						response.Model = "palm2"
 | 
				
			||||||
	response.Choices = []ChatCompletionsStreamResponseChoice{choice}
 | 
						response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
 | 
				
			||||||
	return &response
 | 
						return &response
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
 | 
					func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
 | 
				
			||||||
	responseText := ""
 | 
						responseText := ""
 | 
				
			||||||
	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
 | 
						responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
 | 
				
			||||||
	createdTime := common.GetTimestamp()
 | 
						createdTime := common.GetTimestamp()
 | 
				
			||||||
@@ -143,7 +110,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta
 | 
				
			|||||||
		dataChan <- string(jsonResponse)
 | 
							dataChan <- string(jsonResponse)
 | 
				
			||||||
		stopChan <- true
 | 
							stopChan <- true
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
	setEventStreamHeaders(c)
 | 
						common.SetEventStreamHeaders(c)
 | 
				
			||||||
	c.Stream(func(w io.Writer) bool {
 | 
						c.Stream(func(w io.Writer) bool {
 | 
				
			||||||
		select {
 | 
							select {
 | 
				
			||||||
		case data := <-dataChan:
 | 
							case data := <-dataChan:
 | 
				
			||||||
@@ -156,28 +123,28 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
	err := resp.Body.Close()
 | 
						err := resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil, responseText
 | 
						return nil, responseText
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
 | 
				
			||||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
						responseBody, err := io.ReadAll(resp.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = resp.Body.Close()
 | 
						err = resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var palmResponse PaLMChatResponse
 | 
						var palmResponse PaLMChatResponse
 | 
				
			||||||
	err = json.Unmarshal(responseBody, &palmResponse)
 | 
						err = json.Unmarshal(responseBody, &palmResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
 | 
						if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
 | 
				
			||||||
		return &OpenAIErrorWithStatusCode{
 | 
							return &openai.ErrorWithStatusCode{
 | 
				
			||||||
			OpenAIError: OpenAIError{
 | 
								Error: openai.Error{
 | 
				
			||||||
				Message: palmResponse.Error.Message,
 | 
									Message: palmResponse.Error.Message,
 | 
				
			||||||
				Type:    palmResponse.Error.Status,
 | 
									Type:    palmResponse.Error.Status,
 | 
				
			||||||
				Param:   "",
 | 
									Param:   "",
 | 
				
			||||||
@@ -188,8 +155,8 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	fullTextResponse := responsePaLM2OpenAI(&palmResponse)
 | 
						fullTextResponse := responsePaLM2OpenAI(&palmResponse)
 | 
				
			||||||
	fullTextResponse.Model = model
 | 
						fullTextResponse.Model = model
 | 
				
			||||||
	completionTokens := countTokenText(palmResponse.Candidates[0].Content, model)
 | 
						completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, model)
 | 
				
			||||||
	usage := Usage{
 | 
						usage := openai.Usage{
 | 
				
			||||||
		PromptTokens:     promptTokens,
 | 
							PromptTokens:     promptTokens,
 | 
				
			||||||
		CompletionTokens: completionTokens,
 | 
							CompletionTokens: completionTokens,
 | 
				
			||||||
		TotalTokens:      promptTokens + completionTokens,
 | 
							TotalTokens:      promptTokens + completionTokens,
 | 
				
			||||||
@@ -197,7 +164,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
 | 
				
			|||||||
	fullTextResponse.Usage = usage
 | 
						fullTextResponse.Usage = usage
 | 
				
			||||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
						jsonResponse, err := json.Marshal(fullTextResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
						c.Writer.Header().Set("Content-Type", "application/json")
 | 
				
			||||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
						c.Writer.WriteHeader(resp.StatusCode)
 | 
				
			||||||
							
								
								
									
										6
									
								
								relay/channel/openai/constant.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								relay/channel/openai/constant.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,6 @@
 | 
				
			|||||||
 | 
					package openai
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						ContentTypeText     = "text"
 | 
				
			||||||
 | 
						ContentTypeImageURL = "image_url"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
package controller
 | 
					package openai
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bufio"
 | 
						"bufio"
 | 
				
			||||||
@@ -8,10 +8,11 @@ import (
 | 
				
			|||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
 | 
						"one-api/relay/constant"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
 | 
					func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWithStatusCode, string) {
 | 
				
			||||||
	responseText := ""
 | 
						responseText := ""
 | 
				
			||||||
	scanner := bufio.NewScanner(resp.Body)
 | 
						scanner := bufio.NewScanner(resp.Body)
 | 
				
			||||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
						scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
				
			||||||
@@ -41,7 +42,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 | 
				
			|||||||
			data = data[6:]
 | 
								data = data[6:]
 | 
				
			||||||
			if !strings.HasPrefix(data, "[DONE]") {
 | 
								if !strings.HasPrefix(data, "[DONE]") {
 | 
				
			||||||
				switch relayMode {
 | 
									switch relayMode {
 | 
				
			||||||
				case RelayModeChatCompletions:
 | 
									case constant.RelayModeChatCompletions:
 | 
				
			||||||
					var streamResponse ChatCompletionsStreamResponse
 | 
										var streamResponse ChatCompletionsStreamResponse
 | 
				
			||||||
					err := json.Unmarshal([]byte(data), &streamResponse)
 | 
										err := json.Unmarshal([]byte(data), &streamResponse)
 | 
				
			||||||
					if err != nil {
 | 
										if err != nil {
 | 
				
			||||||
@@ -51,7 +52,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 | 
				
			|||||||
					for _, choice := range streamResponse.Choices {
 | 
										for _, choice := range streamResponse.Choices {
 | 
				
			||||||
						responseText += choice.Delta.Content
 | 
											responseText += choice.Delta.Content
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				case RelayModeCompletions:
 | 
									case constant.RelayModeCompletions:
 | 
				
			||||||
					var streamResponse CompletionsStreamResponse
 | 
										var streamResponse CompletionsStreamResponse
 | 
				
			||||||
					err := json.Unmarshal([]byte(data), &streamResponse)
 | 
										err := json.Unmarshal([]byte(data), &streamResponse)
 | 
				
			||||||
					if err != nil {
 | 
										if err != nil {
 | 
				
			||||||
@@ -66,7 +67,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		stopChan <- true
 | 
							stopChan <- true
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
	setEventStreamHeaders(c)
 | 
						common.SetEventStreamHeaders(c)
 | 
				
			||||||
	c.Stream(func(w io.Writer) bool {
 | 
						c.Stream(func(w io.Writer) bool {
 | 
				
			||||||
		select {
 | 
							select {
 | 
				
			||||||
		case data := <-dataChan:
 | 
							case data := <-dataChan:
 | 
				
			||||||
@@ -83,28 +84,28 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
	err := resp.Body.Close()
 | 
						err := resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
							return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil, responseText
 | 
						return nil, responseText
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*ErrorWithStatusCode, *Usage) {
 | 
				
			||||||
	var textResponse TextResponse
 | 
						var textResponse SlimTextResponse
 | 
				
			||||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
						responseBody, err := io.ReadAll(resp.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
							return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = resp.Body.Close()
 | 
						err = resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
							return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = json.Unmarshal(responseBody, &textResponse)
 | 
						err = json.Unmarshal(responseBody, &textResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if textResponse.Error.Type != "" {
 | 
						if textResponse.Error.Type != "" {
 | 
				
			||||||
		return &OpenAIErrorWithStatusCode{
 | 
							return &ErrorWithStatusCode{
 | 
				
			||||||
			OpenAIError: textResponse.Error,
 | 
								Error:      textResponse.Error,
 | 
				
			||||||
			StatusCode: resp.StatusCode,
 | 
								StatusCode: resp.StatusCode,
 | 
				
			||||||
		}, nil
 | 
							}, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -113,7 +114,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// We shouldn't set the header before we parse the response body, because the parse part may fail.
 | 
						// We shouldn't set the header before we parse the response body, because the parse part may fail.
 | 
				
			||||||
	// And then we will have to send an error response, but in this case, the header has already been set.
 | 
						// And then we will have to send an error response, but in this case, the header has already been set.
 | 
				
			||||||
	// So the httpClient will be confused by the response.
 | 
						// So the HTTPClient will be confused by the response.
 | 
				
			||||||
	// For example, Postman will report error, and we cannot check the response at all.
 | 
						// For example, Postman will report error, and we cannot check the response at all.
 | 
				
			||||||
	for k, v := range resp.Header {
 | 
						for k, v := range resp.Header {
 | 
				
			||||||
		c.Writer.Header().Set(k, v[0])
 | 
							c.Writer.Header().Set(k, v[0])
 | 
				
			||||||
@@ -121,17 +122,17 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 | 
				
			|||||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
						c.Writer.WriteHeader(resp.StatusCode)
 | 
				
			||||||
	_, err = io.Copy(c.Writer, resp.Body)
 | 
						_, err = io.Copy(c.Writer, resp.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 | 
							return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = resp.Body.Close()
 | 
						err = resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
							return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if textResponse.Usage.TotalTokens == 0 {
 | 
						if textResponse.Usage.TotalTokens == 0 {
 | 
				
			||||||
		completionTokens := 0
 | 
							completionTokens := 0
 | 
				
			||||||
		for _, choice := range textResponse.Choices {
 | 
							for _, choice := range textResponse.Choices {
 | 
				
			||||||
			completionTokens += countTokenText(choice.Message.StringContent(), model)
 | 
								completionTokens += CountTokenText(choice.Message.StringContent(), model)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		textResponse.Usage = Usage{
 | 
							textResponse.Usage = Usage{
 | 
				
			||||||
			PromptTokens:     promptTokens,
 | 
								PromptTokens:     promptTokens,
 | 
				
			||||||
							
								
								
									
										283
									
								
								relay/channel/openai/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										283
									
								
								relay/channel/openai/model.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,283 @@
 | 
				
			|||||||
 | 
					package openai
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Message struct {
 | 
				
			||||||
 | 
						Role    string  `json:"role"`
 | 
				
			||||||
 | 
						Content any     `json:"content"`
 | 
				
			||||||
 | 
						Name    *string `json:"name,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ImageURL struct {
 | 
				
			||||||
 | 
						Url    string `json:"url,omitempty"`
 | 
				
			||||||
 | 
						Detail string `json:"detail,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type TextContent struct {
 | 
				
			||||||
 | 
						Type string `json:"type,omitempty"`
 | 
				
			||||||
 | 
						Text string `json:"text,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ImageContent struct {
 | 
				
			||||||
 | 
						Type     string    `json:"type,omitempty"`
 | 
				
			||||||
 | 
						ImageURL *ImageURL `json:"image_url,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type OpenAIMessageContent struct {
 | 
				
			||||||
 | 
						Type     string    `json:"type,omitempty"`
 | 
				
			||||||
 | 
						Text     string    `json:"text"`
 | 
				
			||||||
 | 
						ImageURL *ImageURL `json:"image_url,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (m Message) IsStringContent() bool {
 | 
				
			||||||
 | 
						_, ok := m.Content.(string)
 | 
				
			||||||
 | 
						return ok
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (m Message) StringContent() string {
 | 
				
			||||||
 | 
						content, ok := m.Content.(string)
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							return content
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						contentList, ok := m.Content.([]any)
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							var contentStr string
 | 
				
			||||||
 | 
							for _, contentItem := range contentList {
 | 
				
			||||||
 | 
								contentMap, ok := contentItem.(map[string]any)
 | 
				
			||||||
 | 
								if !ok {
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if contentMap["type"] == ContentTypeText {
 | 
				
			||||||
 | 
									if subStr, ok := contentMap["text"].(string); ok {
 | 
				
			||||||
 | 
										contentStr += subStr
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return contentStr
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ""
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (m Message) ParseContent() []OpenAIMessageContent {
 | 
				
			||||||
 | 
						var contentList []OpenAIMessageContent
 | 
				
			||||||
 | 
						content, ok := m.Content.(string)
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							contentList = append(contentList, OpenAIMessageContent{
 | 
				
			||||||
 | 
								Type: ContentTypeText,
 | 
				
			||||||
 | 
								Text: content,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							return contentList
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						anyList, ok := m.Content.([]any)
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							for _, contentItem := range anyList {
 | 
				
			||||||
 | 
								contentMap, ok := contentItem.(map[string]any)
 | 
				
			||||||
 | 
								if !ok {
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								switch contentMap["type"] {
 | 
				
			||||||
 | 
								case ContentTypeText:
 | 
				
			||||||
 | 
									if subStr, ok := contentMap["text"].(string); ok {
 | 
				
			||||||
 | 
										contentList = append(contentList, OpenAIMessageContent{
 | 
				
			||||||
 | 
											Type: ContentTypeText,
 | 
				
			||||||
 | 
											Text: subStr,
 | 
				
			||||||
 | 
										})
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								case ContentTypeImageURL:
 | 
				
			||||||
 | 
									if subObj, ok := contentMap["image_url"].(map[string]any); ok {
 | 
				
			||||||
 | 
										contentList = append(contentList, OpenAIMessageContent{
 | 
				
			||||||
 | 
											Type: ContentTypeImageURL,
 | 
				
			||||||
 | 
											ImageURL: &ImageURL{
 | 
				
			||||||
 | 
												Url: subObj["url"].(string),
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										})
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return contentList
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ResponseFormat struct {
 | 
				
			||||||
 | 
						Type string `json:"type,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type GeneralOpenAIRequest struct {
 | 
				
			||||||
 | 
						Model            string          `json:"model,omitempty"`
 | 
				
			||||||
 | 
						Messages         []Message       `json:"messages,omitempty"`
 | 
				
			||||||
 | 
						Prompt           any             `json:"prompt,omitempty"`
 | 
				
			||||||
 | 
						Stream           bool            `json:"stream,omitempty"`
 | 
				
			||||||
 | 
						MaxTokens        int             `json:"max_tokens,omitempty"`
 | 
				
			||||||
 | 
						Temperature      float64         `json:"temperature,omitempty"`
 | 
				
			||||||
 | 
						TopP             float64         `json:"top_p,omitempty"`
 | 
				
			||||||
 | 
						N                int             `json:"n,omitempty"`
 | 
				
			||||||
 | 
						Input            any             `json:"input,omitempty"`
 | 
				
			||||||
 | 
						Instruction      string          `json:"instruction,omitempty"`
 | 
				
			||||||
 | 
						Size             string          `json:"size,omitempty"`
 | 
				
			||||||
 | 
						Functions        any             `json:"functions,omitempty"`
 | 
				
			||||||
 | 
						FrequencyPenalty float64         `json:"frequency_penalty,omitempty"`
 | 
				
			||||||
 | 
						PresencePenalty  float64         `json:"presence_penalty,omitempty"`
 | 
				
			||||||
 | 
						ResponseFormat   *ResponseFormat `json:"response_format,omitempty"`
 | 
				
			||||||
 | 
						Seed             float64         `json:"seed,omitempty"`
 | 
				
			||||||
 | 
						Tools            any             `json:"tools,omitempty"`
 | 
				
			||||||
 | 
						ToolChoice       any             `json:"tool_choice,omitempty"`
 | 
				
			||||||
 | 
						User             string          `json:"user,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r GeneralOpenAIRequest) ParseInput() []string {
 | 
				
			||||||
 | 
						if r.Input == nil {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var input []string
 | 
				
			||||||
 | 
						switch r.Input.(type) {
 | 
				
			||||||
 | 
						case string:
 | 
				
			||||||
 | 
							input = []string{r.Input.(string)}
 | 
				
			||||||
 | 
						case []any:
 | 
				
			||||||
 | 
							input = make([]string, 0, len(r.Input.([]any)))
 | 
				
			||||||
 | 
							for _, item := range r.Input.([]any) {
 | 
				
			||||||
 | 
								if str, ok := item.(string); ok {
 | 
				
			||||||
 | 
									input = append(input, str)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return input
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ChatRequest struct {
 | 
				
			||||||
 | 
						Model     string    `json:"model"`
 | 
				
			||||||
 | 
						Messages  []Message `json:"messages"`
 | 
				
			||||||
 | 
						MaxTokens int       `json:"max_tokens"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type TextRequest struct {
 | 
				
			||||||
 | 
						Model     string    `json:"model"`
 | 
				
			||||||
 | 
						Messages  []Message `json:"messages"`
 | 
				
			||||||
 | 
						Prompt    string    `json:"prompt"`
 | 
				
			||||||
 | 
						MaxTokens int       `json:"max_tokens"`
 | 
				
			||||||
 | 
						//Stream   bool      `json:"stream"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
 | 
				
			||||||
 | 
					type ImageRequest struct {
 | 
				
			||||||
 | 
						Model          string `json:"model"`
 | 
				
			||||||
 | 
						Prompt         string `json:"prompt" binding:"required"`
 | 
				
			||||||
 | 
						N              int    `json:"n,omitempty"`
 | 
				
			||||||
 | 
						Size           string `json:"size,omitempty"`
 | 
				
			||||||
 | 
						Quality        string `json:"quality,omitempty"`
 | 
				
			||||||
 | 
						ResponseFormat string `json:"response_format,omitempty"`
 | 
				
			||||||
 | 
						Style          string `json:"style,omitempty"`
 | 
				
			||||||
 | 
						User           string `json:"user,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type WhisperJSONResponse struct {
 | 
				
			||||||
 | 
						Text string `json:"text,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type WhisperVerboseJSONResponse struct {
 | 
				
			||||||
 | 
						Task     string    `json:"task,omitempty"`
 | 
				
			||||||
 | 
						Language string    `json:"language,omitempty"`
 | 
				
			||||||
 | 
						Duration float64   `json:"duration,omitempty"`
 | 
				
			||||||
 | 
						Text     string    `json:"text,omitempty"`
 | 
				
			||||||
 | 
						Segments []Segment `json:"segments,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Segment struct {
 | 
				
			||||||
 | 
						Id               int     `json:"id"`
 | 
				
			||||||
 | 
						Seek             int     `json:"seek"`
 | 
				
			||||||
 | 
						Start            float64 `json:"start"`
 | 
				
			||||||
 | 
						End              float64 `json:"end"`
 | 
				
			||||||
 | 
						Text             string  `json:"text"`
 | 
				
			||||||
 | 
						Tokens           []int   `json:"tokens"`
 | 
				
			||||||
 | 
						Temperature      float64 `json:"temperature"`
 | 
				
			||||||
 | 
						AvgLogprob       float64 `json:"avg_logprob"`
 | 
				
			||||||
 | 
						CompressionRatio float64 `json:"compression_ratio"`
 | 
				
			||||||
 | 
						NoSpeechProb     float64 `json:"no_speech_prob"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type TextToSpeechRequest struct {
 | 
				
			||||||
 | 
						Model          string  `json:"model" binding:"required"`
 | 
				
			||||||
 | 
						Input          string  `json:"input" binding:"required"`
 | 
				
			||||||
 | 
						Voice          string  `json:"voice" binding:"required"`
 | 
				
			||||||
 | 
						Speed          float64 `json:"speed"`
 | 
				
			||||||
 | 
						ResponseFormat string  `json:"response_format"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Usage struct {
 | 
				
			||||||
 | 
						PromptTokens     int `json:"prompt_tokens"`
 | 
				
			||||||
 | 
						CompletionTokens int `json:"completion_tokens"`
 | 
				
			||||||
 | 
						TotalTokens      int `json:"total_tokens"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Error struct {
 | 
				
			||||||
 | 
						Message string `json:"message"`
 | 
				
			||||||
 | 
						Type    string `json:"type"`
 | 
				
			||||||
 | 
						Param   string `json:"param"`
 | 
				
			||||||
 | 
						Code    any    `json:"code"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ErrorWithStatusCode struct {
 | 
				
			||||||
 | 
						Error
 | 
				
			||||||
 | 
						StatusCode int `json:"status_code"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type SlimTextResponse struct {
 | 
				
			||||||
 | 
						Choices []TextResponseChoice `json:"choices"`
 | 
				
			||||||
 | 
						Usage   `json:"usage"`
 | 
				
			||||||
 | 
						Error   Error `json:"error"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type TextResponseChoice struct {
 | 
				
			||||||
 | 
						Index        int `json:"index"`
 | 
				
			||||||
 | 
						Message      `json:"message"`
 | 
				
			||||||
 | 
						FinishReason string `json:"finish_reason"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type TextResponse struct {
 | 
				
			||||||
 | 
						Id      string               `json:"id"`
 | 
				
			||||||
 | 
						Model   string               `json:"model,omitempty"`
 | 
				
			||||||
 | 
						Object  string               `json:"object"`
 | 
				
			||||||
 | 
						Created int64                `json:"created"`
 | 
				
			||||||
 | 
						Choices []TextResponseChoice `json:"choices"`
 | 
				
			||||||
 | 
						Usage   `json:"usage"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type EmbeddingResponseItem struct {
 | 
				
			||||||
 | 
						Object    string    `json:"object"`
 | 
				
			||||||
 | 
						Index     int       `json:"index"`
 | 
				
			||||||
 | 
						Embedding []float64 `json:"embedding"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type EmbeddingResponse struct {
 | 
				
			||||||
 | 
						Object string                  `json:"object"`
 | 
				
			||||||
 | 
						Data   []EmbeddingResponseItem `json:"data"`
 | 
				
			||||||
 | 
						Model  string                  `json:"model"`
 | 
				
			||||||
 | 
						Usage  `json:"usage"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ImageResponse struct {
 | 
				
			||||||
 | 
						Created int `json:"created"`
 | 
				
			||||||
 | 
						Data    []struct {
 | 
				
			||||||
 | 
							Url string `json:"url"`
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ChatCompletionsStreamResponseChoice struct {
 | 
				
			||||||
 | 
						Delta struct {
 | 
				
			||||||
 | 
							Content string `json:"content"`
 | 
				
			||||||
 | 
						} `json:"delta"`
 | 
				
			||||||
 | 
						FinishReason *string `json:"finish_reason,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ChatCompletionsStreamResponse struct {
 | 
				
			||||||
 | 
						Id      string                                `json:"id"`
 | 
				
			||||||
 | 
						Object  string                                `json:"object"`
 | 
				
			||||||
 | 
						Created int64                                 `json:"created"`
 | 
				
			||||||
 | 
						Model   string                                `json:"model"`
 | 
				
			||||||
 | 
						Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type CompletionsStreamResponse struct {
 | 
				
			||||||
 | 
						Choices []struct {
 | 
				
			||||||
 | 
							Text         string `json:"text"`
 | 
				
			||||||
 | 
							FinishReason string `json:"finish_reason"`
 | 
				
			||||||
 | 
						} `json:"choices"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -1,25 +1,15 @@
 | 
				
			|||||||
package controller
 | 
					package openai
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
					 | 
				
			||||||
	"encoding/json"
 | 
					 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"io"
 | 
						"github.com/pkoukk/tiktoken-go"
 | 
				
			||||||
	"math"
 | 
						"math"
 | 
				
			||||||
	"net/http"
 | 
					 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
	"one-api/common/image"
 | 
						"one-api/common/image"
 | 
				
			||||||
	"one-api/model"
 | 
					 | 
				
			||||||
	"strconv"
 | 
					 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
					 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
					 | 
				
			||||||
	"github.com/pkoukk/tiktoken-go"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var stopFinishReason = "stop"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// tokenEncoderMap won't grow after initialization
 | 
					// tokenEncoderMap won't grow after initialization
 | 
				
			||||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
 | 
					var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
 | 
				
			||||||
var defaultTokenEncoder *tiktoken.Tiktoken
 | 
					var defaultTokenEncoder *tiktoken.Tiktoken
 | 
				
			||||||
@@ -71,7 +61,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
 | 
				
			|||||||
	return len(tokenEncoder.Encode(text, nil, nil))
 | 
						return len(tokenEncoder.Encode(text, nil, nil))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func countTokenMessages(messages []Message, model string) int {
 | 
					func CountTokenMessages(messages []Message, model string) int {
 | 
				
			||||||
	tokenEncoder := getTokenEncoder(model)
 | 
						tokenEncoder := getTokenEncoder(model)
 | 
				
			||||||
	// Reference:
 | 
						// Reference:
 | 
				
			||||||
	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
 | 
						// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
 | 
				
			||||||
@@ -195,191 +185,21 @@ func countImageTokens(url string, detail string) (_ int, err error) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func countTokenInput(input any, model string) int {
 | 
					func CountTokenInput(input any, model string) int {
 | 
				
			||||||
	switch v := input.(type) {
 | 
						switch v := input.(type) {
 | 
				
			||||||
	case string:
 | 
						case string:
 | 
				
			||||||
		return countTokenText(v, model)
 | 
							return CountTokenText(v, model)
 | 
				
			||||||
	case []string:
 | 
						case []string:
 | 
				
			||||||
		text := ""
 | 
							text := ""
 | 
				
			||||||
		for _, s := range v {
 | 
							for _, s := range v {
 | 
				
			||||||
			text += s
 | 
								text += s
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return countTokenText(text, model)
 | 
							return CountTokenText(text, model)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return 0
 | 
						return 0
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func countTokenText(text string, model string) int {
 | 
					func CountTokenText(text string, model string) int {
 | 
				
			||||||
	tokenEncoder := getTokenEncoder(model)
 | 
						tokenEncoder := getTokenEncoder(model)
 | 
				
			||||||
	return getTokenNum(tokenEncoder, text)
 | 
						return getTokenNum(tokenEncoder, text)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
 | 
					 | 
				
			||||||
	openAIError := OpenAIError{
 | 
					 | 
				
			||||||
		Message: err.Error(),
 | 
					 | 
				
			||||||
		Type:    "one_api_error",
 | 
					 | 
				
			||||||
		Code:    code,
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return &OpenAIErrorWithStatusCode{
 | 
					 | 
				
			||||||
		OpenAIError: openAIError,
 | 
					 | 
				
			||||||
		StatusCode:  statusCode,
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
 | 
					 | 
				
			||||||
	if !common.AutomaticDisableChannelEnabled {
 | 
					 | 
				
			||||||
		return false
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if err == nil {
 | 
					 | 
				
			||||||
		return false
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if statusCode == http.StatusUnauthorized {
 | 
					 | 
				
			||||||
		return true
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
 | 
					 | 
				
			||||||
		return true
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return false
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func shouldEnableChannel(err error, openAIErr *OpenAIError) bool {
 | 
					 | 
				
			||||||
	if !common.AutomaticEnableChannelEnabled {
 | 
					 | 
				
			||||||
		return false
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return false
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if openAIErr != nil {
 | 
					 | 
				
			||||||
		return false
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return true
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func setEventStreamHeaders(c *gin.Context) {
 | 
					 | 
				
			||||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
					 | 
				
			||||||
	c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
					 | 
				
			||||||
	c.Writer.Header().Set("Connection", "keep-alive")
 | 
					 | 
				
			||||||
	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
					 | 
				
			||||||
	c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type GeneralErrorResponse struct {
 | 
					 | 
				
			||||||
	Error    OpenAIError `json:"error"`
 | 
					 | 
				
			||||||
	Message  string      `json:"message"`
 | 
					 | 
				
			||||||
	Msg      string      `json:"msg"`
 | 
					 | 
				
			||||||
	Err      string      `json:"err"`
 | 
					 | 
				
			||||||
	ErrorMsg string      `json:"error_msg"`
 | 
					 | 
				
			||||||
	Header   struct {
 | 
					 | 
				
			||||||
		Message string `json:"message"`
 | 
					 | 
				
			||||||
	} `json:"header"`
 | 
					 | 
				
			||||||
	Response struct {
 | 
					 | 
				
			||||||
		Error struct {
 | 
					 | 
				
			||||||
			Message string `json:"message"`
 | 
					 | 
				
			||||||
		} `json:"error"`
 | 
					 | 
				
			||||||
	} `json:"response"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (e GeneralErrorResponse) ToMessage() string {
 | 
					 | 
				
			||||||
	if e.Error.Message != "" {
 | 
					 | 
				
			||||||
		return e.Error.Message
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if e.Message != "" {
 | 
					 | 
				
			||||||
		return e.Message
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if e.Msg != "" {
 | 
					 | 
				
			||||||
		return e.Msg
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if e.Err != "" {
 | 
					 | 
				
			||||||
		return e.Err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if e.ErrorMsg != "" {
 | 
					 | 
				
			||||||
		return e.ErrorMsg
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if e.Header.Message != "" {
 | 
					 | 
				
			||||||
		return e.Header.Message
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if e.Response.Error.Message != "" {
 | 
					 | 
				
			||||||
		return e.Response.Error.Message
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return ""
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) {
 | 
					 | 
				
			||||||
	openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{
 | 
					 | 
				
			||||||
		StatusCode: resp.StatusCode,
 | 
					 | 
				
			||||||
		OpenAIError: OpenAIError{
 | 
					 | 
				
			||||||
			Message: "",
 | 
					 | 
				
			||||||
			Type:    "upstream_error",
 | 
					 | 
				
			||||||
			Code:    "bad_response_status_code",
 | 
					 | 
				
			||||||
			Param:   strconv.Itoa(resp.StatusCode),
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	err = resp.Body.Close()
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	var errResponse GeneralErrorResponse
 | 
					 | 
				
			||||||
	err = json.Unmarshal(responseBody, &errResponse)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if errResponse.Error.Message != "" {
 | 
					 | 
				
			||||||
		// OpenAI format error, so we override the default one
 | 
					 | 
				
			||||||
		openAIErrorWithStatusCode.OpenAIError = errResponse.Error
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		openAIErrorWithStatusCode.OpenAIError.Message = errResponse.ToMessage()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if openAIErrorWithStatusCode.OpenAIError.Message == "" {
 | 
					 | 
				
			||||||
		openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
 | 
					 | 
				
			||||||
	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
 | 
					 | 
				
			||||||
		switch channelType {
 | 
					 | 
				
			||||||
		case common.ChannelTypeOpenAI:
 | 
					 | 
				
			||||||
			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
 | 
					 | 
				
			||||||
		case common.ChannelTypeAzure:
 | 
					 | 
				
			||||||
			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return fullRequestURL
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
 | 
					 | 
				
			||||||
	// quotaDelta is remaining quota to be consumed
 | 
					 | 
				
			||||||
	err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		common.SysError("error consuming token remain quota: " + err.Error())
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	err = model.CacheUpdateUserQuota(userId)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		common.SysError("error update user quota cache: " + err.Error())
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	// totalQuota is total quota consumed
 | 
					 | 
				
			||||||
	if totalQuota != 0 {
 | 
					 | 
				
			||||||
		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
					 | 
				
			||||||
		model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
 | 
					 | 
				
			||||||
		model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
 | 
					 | 
				
			||||||
		model.UpdateChannelUsedQuota(channelId, totalQuota)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if totalQuota <= 0 {
 | 
					 | 
				
			||||||
		common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func GetAPIVersion(c *gin.Context) string {
 | 
					 | 
				
			||||||
	query := c.Request.URL.Query()
 | 
					 | 
				
			||||||
	apiVersion := query.Get("api-version")
 | 
					 | 
				
			||||||
	if apiVersion == "" {
 | 
					 | 
				
			||||||
		apiVersion = c.GetString("api_version")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return apiVersion
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
							
								
								
									
										13
									
								
								relay/channel/openai/util.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								relay/channel/openai/util.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,13 @@
 | 
				
			|||||||
 | 
					package openai
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func ErrorWrapper(err error, code string, statusCode int) *ErrorWithStatusCode {
 | 
				
			||||||
 | 
						Error := Error{
 | 
				
			||||||
 | 
							Message: err.Error(),
 | 
				
			||||||
 | 
							Type:    "one_api_error",
 | 
				
			||||||
 | 
							Code:    code,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &ErrorWithStatusCode{
 | 
				
			||||||
 | 
							Error:      Error,
 | 
				
			||||||
 | 
							StatusCode: statusCode,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
package controller
 | 
					package tencent
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bufio"
 | 
						"bufio"
 | 
				
			||||||
@@ -12,6 +12,8 @@ import (
 | 
				
			|||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
 | 
						"one-api/relay/constant"
 | 
				
			||||||
	"sort"
 | 
						"sort"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
@@ -19,80 +21,22 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// https://cloud.tencent.com/document/product/1729/97732
 | 
					// https://cloud.tencent.com/document/product/1729/97732
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type TencentMessage struct {
 | 
					func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
 | 
				
			||||||
	Role    string `json:"role"`
 | 
						messages := make([]Message, 0, len(request.Messages))
 | 
				
			||||||
	Content string `json:"content"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type TencentChatRequest struct {
 | 
					 | 
				
			||||||
	AppId    int64  `json:"app_id"`    // 腾讯云账号的 APPID
 | 
					 | 
				
			||||||
	SecretId string `json:"secret_id"` // 官网 SecretId
 | 
					 | 
				
			||||||
	// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
 | 
					 | 
				
			||||||
	// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
 | 
					 | 
				
			||||||
	Timestamp int64 `json:"timestamp"`
 | 
					 | 
				
			||||||
	// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
 | 
					 | 
				
			||||||
	// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
 | 
					 | 
				
			||||||
	Expired int64  `json:"expired"`
 | 
					 | 
				
			||||||
	QueryID string `json:"query_id"` //请求 Id,用于问题排查
 | 
					 | 
				
			||||||
	// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
 | 
					 | 
				
			||||||
	// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
 | 
					 | 
				
			||||||
	// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
 | 
					 | 
				
			||||||
	Temperature float64 `json:"temperature"`
 | 
					 | 
				
			||||||
	// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
 | 
					 | 
				
			||||||
	// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
 | 
					 | 
				
			||||||
	// 建议该参数和 temperature 只设置1个,不要同时更改
 | 
					 | 
				
			||||||
	TopP float64 `json:"top_p"`
 | 
					 | 
				
			||||||
	// Stream 0:同步,1:流式 (默认,协议:SSE)
 | 
					 | 
				
			||||||
	// 同步请求超时:60s,如果内容较长建议使用流式
 | 
					 | 
				
			||||||
	Stream int `json:"stream"`
 | 
					 | 
				
			||||||
	// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
 | 
					 | 
				
			||||||
	// 输入 content 总数最大支持 3000 token。
 | 
					 | 
				
			||||||
	Messages []TencentMessage `json:"messages"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type TencentError struct {
 | 
					 | 
				
			||||||
	Code    int    `json:"code"`
 | 
					 | 
				
			||||||
	Message string `json:"message"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type TencentUsage struct {
 | 
					 | 
				
			||||||
	InputTokens  int `json:"input_tokens"`
 | 
					 | 
				
			||||||
	OutputTokens int `json:"output_tokens"`
 | 
					 | 
				
			||||||
	TotalTokens  int `json:"total_tokens"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type TencentResponseChoices struct {
 | 
					 | 
				
			||||||
	FinishReason string         `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
 | 
					 | 
				
			||||||
	Messages     TencentMessage `json:"messages,omitempty"`      // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
 | 
					 | 
				
			||||||
	Delta        TencentMessage `json:"delta,omitempty"`         // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type TencentChatResponse struct {
 | 
					 | 
				
			||||||
	Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
 | 
					 | 
				
			||||||
	Created string                   `json:"created,omitempty"` // unix 时间戳的字符串
 | 
					 | 
				
			||||||
	Id      string                   `json:"id,omitempty"`      // 会话 id
 | 
					 | 
				
			||||||
	Usage   Usage                    `json:"usage,omitempty"`   // token 数量
 | 
					 | 
				
			||||||
	Error   TencentError             `json:"error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值
 | 
					 | 
				
			||||||
	Note    string                   `json:"note,omitempty"`    // 注释
 | 
					 | 
				
			||||||
	ReqID   string                   `json:"req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
 | 
					 | 
				
			||||||
	messages := make([]TencentMessage, 0, len(request.Messages))
 | 
					 | 
				
			||||||
	for i := 0; i < len(request.Messages); i++ {
 | 
						for i := 0; i < len(request.Messages); i++ {
 | 
				
			||||||
		message := request.Messages[i]
 | 
							message := request.Messages[i]
 | 
				
			||||||
		if message.Role == "system" {
 | 
							if message.Role == "system" {
 | 
				
			||||||
			messages = append(messages, TencentMessage{
 | 
								messages = append(messages, Message{
 | 
				
			||||||
				Role:    "user",
 | 
									Role:    "user",
 | 
				
			||||||
				Content: message.StringContent(),
 | 
									Content: message.StringContent(),
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
			messages = append(messages, TencentMessage{
 | 
								messages = append(messages, Message{
 | 
				
			||||||
				Role:    "assistant",
 | 
									Role:    "assistant",
 | 
				
			||||||
				Content: "Okay",
 | 
									Content: "Okay",
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		messages = append(messages, TencentMessage{
 | 
							messages = append(messages, Message{
 | 
				
			||||||
			Content: message.StringContent(),
 | 
								Content: message.StringContent(),
 | 
				
			||||||
			Role:    message.Role,
 | 
								Role:    message.Role,
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
@@ -101,7 +45,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
 | 
				
			|||||||
	if request.Stream {
 | 
						if request.Stream {
 | 
				
			||||||
		stream = 1
 | 
							stream = 1
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &TencentChatRequest{
 | 
						return &ChatRequest{
 | 
				
			||||||
		Timestamp:   common.GetTimestamp(),
 | 
							Timestamp:   common.GetTimestamp(),
 | 
				
			||||||
		Expired:     common.GetTimestamp() + 24*60*60,
 | 
							Expired:     common.GetTimestamp() + 24*60*60,
 | 
				
			||||||
		QueryID:     common.GetUUID(),
 | 
							QueryID:     common.GetUUID(),
 | 
				
			||||||
@@ -112,16 +56,16 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
 | 
					func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
 | 
				
			||||||
	fullTextResponse := OpenAITextResponse{
 | 
						fullTextResponse := openai.TextResponse{
 | 
				
			||||||
		Object:  "chat.completion",
 | 
							Object:  "chat.completion",
 | 
				
			||||||
		Created: common.GetTimestamp(),
 | 
							Created: common.GetTimestamp(),
 | 
				
			||||||
		Usage:   response.Usage,
 | 
							Usage:   response.Usage,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if len(response.Choices) > 0 {
 | 
						if len(response.Choices) > 0 {
 | 
				
			||||||
		choice := OpenAITextResponseChoice{
 | 
							choice := openai.TextResponseChoice{
 | 
				
			||||||
			Index: 0,
 | 
								Index: 0,
 | 
				
			||||||
			Message: Message{
 | 
								Message: openai.Message{
 | 
				
			||||||
				Role:    "assistant",
 | 
									Role:    "assistant",
 | 
				
			||||||
				Content: response.Choices[0].Messages.Content,
 | 
									Content: response.Choices[0].Messages.Content,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
@@ -132,24 +76,24 @@ func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
 | 
				
			|||||||
	return &fullTextResponse
 | 
						return &fullTextResponse
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse {
 | 
					func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
 | 
				
			||||||
	response := ChatCompletionsStreamResponse{
 | 
						response := openai.ChatCompletionsStreamResponse{
 | 
				
			||||||
		Object:  "chat.completion.chunk",
 | 
							Object:  "chat.completion.chunk",
 | 
				
			||||||
		Created: common.GetTimestamp(),
 | 
							Created: common.GetTimestamp(),
 | 
				
			||||||
		Model:   "tencent-hunyuan",
 | 
							Model:   "tencent-hunyuan",
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if len(TencentResponse.Choices) > 0 {
 | 
						if len(TencentResponse.Choices) > 0 {
 | 
				
			||||||
		var choice ChatCompletionsStreamResponseChoice
 | 
							var choice openai.ChatCompletionsStreamResponseChoice
 | 
				
			||||||
		choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
 | 
							choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
 | 
				
			||||||
		if TencentResponse.Choices[0].FinishReason == "stop" {
 | 
							if TencentResponse.Choices[0].FinishReason == "stop" {
 | 
				
			||||||
			choice.FinishReason = &stopFinishReason
 | 
								choice.FinishReason = &constant.StopFinishReason
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		response.Choices = append(response.Choices, choice)
 | 
							response.Choices = append(response.Choices, choice)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &response
 | 
						return &response
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
 | 
					func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
 | 
				
			||||||
	var responseText string
 | 
						var responseText string
 | 
				
			||||||
	scanner := bufio.NewScanner(resp.Body)
 | 
						scanner := bufio.NewScanner(resp.Body)
 | 
				
			||||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
						scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
				
			||||||
@@ -180,11 +124,11 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		stopChan <- true
 | 
							stopChan <- true
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
	setEventStreamHeaders(c)
 | 
						common.SetEventStreamHeaders(c)
 | 
				
			||||||
	c.Stream(func(w io.Writer) bool {
 | 
						c.Stream(func(w io.Writer) bool {
 | 
				
			||||||
		select {
 | 
							select {
 | 
				
			||||||
		case data := <-dataChan:
 | 
							case data := <-dataChan:
 | 
				
			||||||
			var TencentResponse TencentChatResponse
 | 
								var TencentResponse ChatResponse
 | 
				
			||||||
			err := json.Unmarshal([]byte(data), &TencentResponse)
 | 
								err := json.Unmarshal([]byte(data), &TencentResponse)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				common.SysError("error unmarshalling stream response: " + err.Error())
 | 
									common.SysError("error unmarshalling stream response: " + err.Error())
 | 
				
			||||||
@@ -208,28 +152,28 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
	err := resp.Body.Close()
 | 
						err := resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil, responseText
 | 
						return nil, responseText
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
 | 
				
			||||||
	var TencentResponse TencentChatResponse
 | 
						var TencentResponse ChatResponse
 | 
				
			||||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
						responseBody, err := io.ReadAll(resp.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = resp.Body.Close()
 | 
						err = resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = json.Unmarshal(responseBody, &TencentResponse)
 | 
						err = json.Unmarshal(responseBody, &TencentResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if TencentResponse.Error.Code != 0 {
 | 
						if TencentResponse.Error.Code != 0 {
 | 
				
			||||||
		return &OpenAIErrorWithStatusCode{
 | 
							return &openai.ErrorWithStatusCode{
 | 
				
			||||||
			OpenAIError: OpenAIError{
 | 
								Error: openai.Error{
 | 
				
			||||||
				Message: TencentResponse.Error.Message,
 | 
									Message: TencentResponse.Error.Message,
 | 
				
			||||||
				Code:    TencentResponse.Error.Code,
 | 
									Code:    TencentResponse.Error.Code,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
@@ -240,7 +184,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatus
 | 
				
			|||||||
	fullTextResponse.Model = "hunyuan"
 | 
						fullTextResponse.Model = "hunyuan"
 | 
				
			||||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
						jsonResponse, err := json.Marshal(fullTextResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
						c.Writer.Header().Set("Content-Type", "application/json")
 | 
				
			||||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
						c.Writer.WriteHeader(resp.StatusCode)
 | 
				
			||||||
@@ -248,7 +192,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatus
 | 
				
			|||||||
	return nil, &fullTextResponse.Usage
 | 
						return nil, &fullTextResponse.Usage
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
 | 
					func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) {
 | 
				
			||||||
	parts := strings.Split(config, "|")
 | 
						parts := strings.Split(config, "|")
 | 
				
			||||||
	if len(parts) != 3 {
 | 
						if len(parts) != 3 {
 | 
				
			||||||
		err = errors.New("invalid tencent config")
 | 
							err = errors.New("invalid tencent config")
 | 
				
			||||||
@@ -260,7 +204,7 @@ func parseTencentConfig(config string) (appId int64, secretId string, secretKey
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func getTencentSign(req TencentChatRequest, secretKey string) string {
 | 
					func GetSign(req ChatRequest, secretKey string) string {
 | 
				
			||||||
	params := make([]string, 0)
 | 
						params := make([]string, 0)
 | 
				
			||||||
	params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
 | 
						params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
 | 
				
			||||||
	params = append(params, "secret_id="+req.SecretId)
 | 
						params = append(params, "secret_id="+req.SecretId)
 | 
				
			||||||
							
								
								
									
										63
									
								
								relay/channel/tencent/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								relay/channel/tencent/model.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,63 @@
 | 
				
			|||||||
 | 
					package tencent
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Message struct {
 | 
				
			||||||
 | 
						Role    string `json:"role"`
 | 
				
			||||||
 | 
						Content string `json:"content"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ChatRequest struct {
 | 
				
			||||||
 | 
						AppId    int64  `json:"app_id"`    // 腾讯云账号的 APPID
 | 
				
			||||||
 | 
						SecretId string `json:"secret_id"` // 官网 SecretId
 | 
				
			||||||
 | 
						// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
 | 
				
			||||||
 | 
						// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
 | 
				
			||||||
 | 
						Timestamp int64 `json:"timestamp"`
 | 
				
			||||||
 | 
						// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
 | 
				
			||||||
 | 
						// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
 | 
				
			||||||
 | 
						Expired int64  `json:"expired"`
 | 
				
			||||||
 | 
						QueryID string `json:"query_id"` //请求 Id,用于问题排查
 | 
				
			||||||
 | 
						// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
 | 
				
			||||||
 | 
						// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
 | 
				
			||||||
 | 
						// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
 | 
				
			||||||
 | 
						Temperature float64 `json:"temperature"`
 | 
				
			||||||
 | 
						// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
 | 
				
			||||||
 | 
						// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
 | 
				
			||||||
 | 
						// 建议该参数和 temperature 只设置1个,不要同时更改
 | 
				
			||||||
 | 
						TopP float64 `json:"top_p"`
 | 
				
			||||||
 | 
						// Stream 0:同步,1:流式 (默认,协议:SSE)
 | 
				
			||||||
 | 
						// 同步请求超时:60s,如果内容较长建议使用流式
 | 
				
			||||||
 | 
						Stream int `json:"stream"`
 | 
				
			||||||
 | 
						// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
 | 
				
			||||||
 | 
						// 输入 content 总数最大支持 3000 token。
 | 
				
			||||||
 | 
						Messages []Message `json:"messages"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Error struct {
 | 
				
			||||||
 | 
						Code    int    `json:"code"`
 | 
				
			||||||
 | 
						Message string `json:"message"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Usage struct {
 | 
				
			||||||
 | 
						InputTokens  int `json:"input_tokens"`
 | 
				
			||||||
 | 
						OutputTokens int `json:"output_tokens"`
 | 
				
			||||||
 | 
						TotalTokens  int `json:"total_tokens"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ResponseChoices struct {
 | 
				
			||||||
 | 
						FinishReason string  `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
 | 
				
			||||||
 | 
						Messages     Message `json:"messages,omitempty"`      // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
 | 
				
			||||||
 | 
						Delta        Message `json:"delta,omitempty"`         // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ChatResponse struct {
 | 
				
			||||||
 | 
						Choices []ResponseChoices `json:"choices,omitempty"` // 结果
 | 
				
			||||||
 | 
						Created string            `json:"created,omitempty"` // unix 时间戳的字符串
 | 
				
			||||||
 | 
						Id      string            `json:"id,omitempty"`      // 会话 id
 | 
				
			||||||
 | 
						Usage   openai.Usage      `json:"usage,omitempty"`   // token 数量
 | 
				
			||||||
 | 
						Error   Error             `json:"error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值
 | 
				
			||||||
 | 
						Note    string            `json:"note,omitempty"`    // 注释
 | 
				
			||||||
 | 
						ReqID   string            `json:"req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
package controller
 | 
					package xunfei
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"crypto/hmac"
 | 
						"crypto/hmac"
 | 
				
			||||||
@@ -12,6 +12,8 @@ import (
 | 
				
			|||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
 | 
						"one-api/relay/constant"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -19,82 +21,26 @@ import (
 | 
				
			|||||||
// https://console.xfyun.cn/services/cbm
 | 
					// https://console.xfyun.cn/services/cbm
 | 
				
			||||||
// https://www.xfyun.cn/doc/spark/Web.html
 | 
					// https://www.xfyun.cn/doc/spark/Web.html
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type XunfeiMessage struct {
 | 
					func requestOpenAI2Xunfei(request openai.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
 | 
				
			||||||
	Role    string `json:"role"`
 | 
						messages := make([]Message, 0, len(request.Messages))
 | 
				
			||||||
	Content string `json:"content"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type XunfeiChatRequest struct {
 | 
					 | 
				
			||||||
	Header struct {
 | 
					 | 
				
			||||||
		AppId string `json:"app_id"`
 | 
					 | 
				
			||||||
	} `json:"header"`
 | 
					 | 
				
			||||||
	Parameter struct {
 | 
					 | 
				
			||||||
		Chat struct {
 | 
					 | 
				
			||||||
			Domain      string  `json:"domain,omitempty"`
 | 
					 | 
				
			||||||
			Temperature float64 `json:"temperature,omitempty"`
 | 
					 | 
				
			||||||
			TopK        int     `json:"top_k,omitempty"`
 | 
					 | 
				
			||||||
			MaxTokens   int     `json:"max_tokens,omitempty"`
 | 
					 | 
				
			||||||
			Auditing    bool    `json:"auditing,omitempty"`
 | 
					 | 
				
			||||||
		} `json:"chat"`
 | 
					 | 
				
			||||||
	} `json:"parameter"`
 | 
					 | 
				
			||||||
	Payload struct {
 | 
					 | 
				
			||||||
		Message struct {
 | 
					 | 
				
			||||||
			Text []XunfeiMessage `json:"text"`
 | 
					 | 
				
			||||||
		} `json:"message"`
 | 
					 | 
				
			||||||
	} `json:"payload"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type XunfeiChatResponseTextItem struct {
 | 
					 | 
				
			||||||
	Content string `json:"content"`
 | 
					 | 
				
			||||||
	Role    string `json:"role"`
 | 
					 | 
				
			||||||
	Index   int    `json:"index"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type XunfeiChatResponse struct {
 | 
					 | 
				
			||||||
	Header struct {
 | 
					 | 
				
			||||||
		Code    int    `json:"code"`
 | 
					 | 
				
			||||||
		Message string `json:"message"`
 | 
					 | 
				
			||||||
		Sid     string `json:"sid"`
 | 
					 | 
				
			||||||
		Status  int    `json:"status"`
 | 
					 | 
				
			||||||
	} `json:"header"`
 | 
					 | 
				
			||||||
	Payload struct {
 | 
					 | 
				
			||||||
		Choices struct {
 | 
					 | 
				
			||||||
			Status int                          `json:"status"`
 | 
					 | 
				
			||||||
			Seq    int                          `json:"seq"`
 | 
					 | 
				
			||||||
			Text   []XunfeiChatResponseTextItem `json:"text"`
 | 
					 | 
				
			||||||
		} `json:"choices"`
 | 
					 | 
				
			||||||
		Usage struct {
 | 
					 | 
				
			||||||
			//Text struct {
 | 
					 | 
				
			||||||
			//	QuestionTokens   string `json:"question_tokens"`
 | 
					 | 
				
			||||||
			//	PromptTokens     string `json:"prompt_tokens"`
 | 
					 | 
				
			||||||
			//	CompletionTokens string `json:"completion_tokens"`
 | 
					 | 
				
			||||||
			//	TotalTokens      string `json:"total_tokens"`
 | 
					 | 
				
			||||||
			//} `json:"text"`
 | 
					 | 
				
			||||||
			Text Usage `json:"text"`
 | 
					 | 
				
			||||||
		} `json:"usage"`
 | 
					 | 
				
			||||||
	} `json:"payload"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
 | 
					 | 
				
			||||||
	messages := make([]XunfeiMessage, 0, len(request.Messages))
 | 
					 | 
				
			||||||
	for _, message := range request.Messages {
 | 
						for _, message := range request.Messages {
 | 
				
			||||||
		if message.Role == "system" {
 | 
							if message.Role == "system" {
 | 
				
			||||||
			messages = append(messages, XunfeiMessage{
 | 
								messages = append(messages, Message{
 | 
				
			||||||
				Role:    "user",
 | 
									Role:    "user",
 | 
				
			||||||
				Content: message.StringContent(),
 | 
									Content: message.StringContent(),
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
			messages = append(messages, XunfeiMessage{
 | 
								messages = append(messages, Message{
 | 
				
			||||||
				Role:    "assistant",
 | 
									Role:    "assistant",
 | 
				
			||||||
				Content: "Okay",
 | 
									Content: "Okay",
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			messages = append(messages, XunfeiMessage{
 | 
								messages = append(messages, Message{
 | 
				
			||||||
				Role:    message.Role,
 | 
									Role:    message.Role,
 | 
				
			||||||
				Content: message.StringContent(),
 | 
									Content: message.StringContent(),
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	xunfeiRequest := XunfeiChatRequest{}
 | 
						xunfeiRequest := ChatRequest{}
 | 
				
			||||||
	xunfeiRequest.Header.AppId = xunfeiAppId
 | 
						xunfeiRequest.Header.AppId = xunfeiAppId
 | 
				
			||||||
	xunfeiRequest.Parameter.Chat.Domain = domain
 | 
						xunfeiRequest.Parameter.Chat.Domain = domain
 | 
				
			||||||
	xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
 | 
						xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
 | 
				
			||||||
@@ -104,49 +50,49 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
 | 
				
			|||||||
	return &xunfeiRequest
 | 
						return &xunfeiRequest
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
 | 
					func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
 | 
				
			||||||
	if len(response.Payload.Choices.Text) == 0 {
 | 
						if len(response.Payload.Choices.Text) == 0 {
 | 
				
			||||||
		response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
 | 
							response.Payload.Choices.Text = []ChatResponseTextItem{
 | 
				
			||||||
			{
 | 
								{
 | 
				
			||||||
				Content: "",
 | 
									Content: "",
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	choice := OpenAITextResponseChoice{
 | 
						choice := openai.TextResponseChoice{
 | 
				
			||||||
		Index: 0,
 | 
							Index: 0,
 | 
				
			||||||
		Message: Message{
 | 
							Message: openai.Message{
 | 
				
			||||||
			Role:    "assistant",
 | 
								Role:    "assistant",
 | 
				
			||||||
			Content: response.Payload.Choices.Text[0].Content,
 | 
								Content: response.Payload.Choices.Text[0].Content,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		FinishReason: stopFinishReason,
 | 
							FinishReason: constant.StopFinishReason,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	fullTextResponse := OpenAITextResponse{
 | 
						fullTextResponse := openai.TextResponse{
 | 
				
			||||||
		Object:  "chat.completion",
 | 
							Object:  "chat.completion",
 | 
				
			||||||
		Created: common.GetTimestamp(),
 | 
							Created: common.GetTimestamp(),
 | 
				
			||||||
		Choices: []OpenAITextResponseChoice{choice},
 | 
							Choices: []openai.TextResponseChoice{choice},
 | 
				
			||||||
		Usage:   response.Payload.Usage.Text,
 | 
							Usage:   response.Payload.Usage.Text,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &fullTextResponse
 | 
						return &fullTextResponse
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse {
 | 
					func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
 | 
				
			||||||
	if len(xunfeiResponse.Payload.Choices.Text) == 0 {
 | 
						if len(xunfeiResponse.Payload.Choices.Text) == 0 {
 | 
				
			||||||
		xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
 | 
							xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{
 | 
				
			||||||
			{
 | 
								{
 | 
				
			||||||
				Content: "",
 | 
									Content: "",
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var choice ChatCompletionsStreamResponseChoice
 | 
						var choice openai.ChatCompletionsStreamResponseChoice
 | 
				
			||||||
	choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
 | 
						choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
 | 
				
			||||||
	if xunfeiResponse.Payload.Choices.Status == 2 {
 | 
						if xunfeiResponse.Payload.Choices.Status == 2 {
 | 
				
			||||||
		choice.FinishReason = &stopFinishReason
 | 
							choice.FinishReason = &constant.StopFinishReason
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	response := ChatCompletionsStreamResponse{
 | 
						response := openai.ChatCompletionsStreamResponse{
 | 
				
			||||||
		Object:  "chat.completion.chunk",
 | 
							Object:  "chat.completion.chunk",
 | 
				
			||||||
		Created: common.GetTimestamp(),
 | 
							Created: common.GetTimestamp(),
 | 
				
			||||||
		Model:   "SparkDesk",
 | 
							Model:   "SparkDesk",
 | 
				
			||||||
		Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
							Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &response
 | 
						return &response
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -177,14 +123,14 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
 | 
				
			|||||||
	return callUrl
 | 
						return callUrl
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) {
 | 
				
			||||||
	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
 | 
						domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
 | 
				
			||||||
	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
 | 
						dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	setEventStreamHeaders(c)
 | 
						common.SetEventStreamHeaders(c)
 | 
				
			||||||
	var usage Usage
 | 
						var usage openai.Usage
 | 
				
			||||||
	c.Stream(func(w io.Writer) bool {
 | 
						c.Stream(func(w io.Writer) bool {
 | 
				
			||||||
		select {
 | 
							select {
 | 
				
			||||||
		case xunfeiResponse := <-dataChan:
 | 
							case xunfeiResponse := <-dataChan:
 | 
				
			||||||
@@ -207,15 +153,15 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
 | 
				
			|||||||
	return nil, &usage
 | 
						return nil, &usage
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) {
 | 
				
			||||||
	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
 | 
						domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
 | 
				
			||||||
	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
 | 
						dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var usage Usage
 | 
						var usage openai.Usage
 | 
				
			||||||
	var content string
 | 
						var content string
 | 
				
			||||||
	var xunfeiResponse XunfeiChatResponse
 | 
						var xunfeiResponse ChatResponse
 | 
				
			||||||
	stop := false
 | 
						stop := false
 | 
				
			||||||
	for !stop {
 | 
						for !stop {
 | 
				
			||||||
		select {
 | 
							select {
 | 
				
			||||||
@@ -231,7 +177,7 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if len(xunfeiResponse.Payload.Choices.Text) == 0 {
 | 
						if len(xunfeiResponse.Payload.Choices.Text) == 0 {
 | 
				
			||||||
		xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
 | 
							xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{
 | 
				
			||||||
			{
 | 
								{
 | 
				
			||||||
				Content: "",
 | 
									Content: "",
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
@@ -242,14 +188,14 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin
 | 
				
			|||||||
	response := responseXunfei2OpenAI(&xunfeiResponse)
 | 
						response := responseXunfei2OpenAI(&xunfeiResponse)
 | 
				
			||||||
	jsonResponse, err := json.Marshal(response)
 | 
						jsonResponse, err := json.Marshal(response)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
						c.Writer.Header().Set("Content-Type", "application/json")
 | 
				
			||||||
	_, _ = c.Writer.Write(jsonResponse)
 | 
						_, _ = c.Writer.Write(jsonResponse)
 | 
				
			||||||
	return nil, &usage
 | 
						return nil, &usage
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
 | 
					func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) {
 | 
				
			||||||
	d := websocket.Dialer{
 | 
						d := websocket.Dialer{
 | 
				
			||||||
		HandshakeTimeout: 5 * time.Second,
 | 
							HandshakeTimeout: 5 * time.Second,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -263,7 +209,7 @@ func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId
 | 
				
			|||||||
		return nil, nil, err
 | 
							return nil, nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	dataChan := make(chan XunfeiChatResponse)
 | 
						dataChan := make(chan ChatResponse)
 | 
				
			||||||
	stopChan := make(chan bool)
 | 
						stopChan := make(chan bool)
 | 
				
			||||||
	go func() {
 | 
						go func() {
 | 
				
			||||||
		for {
 | 
							for {
 | 
				
			||||||
@@ -272,7 +218,7 @@ func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId
 | 
				
			|||||||
				common.SysError("error reading stream response: " + err.Error())
 | 
									common.SysError("error reading stream response: " + err.Error())
 | 
				
			||||||
				break
 | 
									break
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			var response XunfeiChatResponse
 | 
								var response ChatResponse
 | 
				
			||||||
			err = json.Unmarshal(msg, &response)
 | 
								err = json.Unmarshal(msg, &response)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				common.SysError("error unmarshalling stream response: " + err.Error())
 | 
									common.SysError("error unmarshalling stream response: " + err.Error())
 | 
				
			||||||
							
								
								
									
										61
									
								
								relay/channel/xunfei/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								relay/channel/xunfei/model.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,61 @@
 | 
				
			|||||||
 | 
					package xunfei
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Message struct {
 | 
				
			||||||
 | 
						Role    string `json:"role"`
 | 
				
			||||||
 | 
						Content string `json:"content"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ChatRequest struct {
 | 
				
			||||||
 | 
						Header struct {
 | 
				
			||||||
 | 
							AppId string `json:"app_id"`
 | 
				
			||||||
 | 
						} `json:"header"`
 | 
				
			||||||
 | 
						Parameter struct {
 | 
				
			||||||
 | 
							Chat struct {
 | 
				
			||||||
 | 
								Domain      string  `json:"domain,omitempty"`
 | 
				
			||||||
 | 
								Temperature float64 `json:"temperature,omitempty"`
 | 
				
			||||||
 | 
								TopK        int     `json:"top_k,omitempty"`
 | 
				
			||||||
 | 
								MaxTokens   int     `json:"max_tokens,omitempty"`
 | 
				
			||||||
 | 
								Auditing    bool    `json:"auditing,omitempty"`
 | 
				
			||||||
 | 
							} `json:"chat"`
 | 
				
			||||||
 | 
						} `json:"parameter"`
 | 
				
			||||||
 | 
						Payload struct {
 | 
				
			||||||
 | 
							Message struct {
 | 
				
			||||||
 | 
								Text []Message `json:"text"`
 | 
				
			||||||
 | 
							} `json:"message"`
 | 
				
			||||||
 | 
						} `json:"payload"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ChatResponseTextItem struct {
 | 
				
			||||||
 | 
						Content string `json:"content"`
 | 
				
			||||||
 | 
						Role    string `json:"role"`
 | 
				
			||||||
 | 
						Index   int    `json:"index"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ChatResponse struct {
 | 
				
			||||||
 | 
						Header struct {
 | 
				
			||||||
 | 
							Code    int    `json:"code"`
 | 
				
			||||||
 | 
							Message string `json:"message"`
 | 
				
			||||||
 | 
							Sid     string `json:"sid"`
 | 
				
			||||||
 | 
							Status  int    `json:"status"`
 | 
				
			||||||
 | 
						} `json:"header"`
 | 
				
			||||||
 | 
						Payload struct {
 | 
				
			||||||
 | 
							Choices struct {
 | 
				
			||||||
 | 
								Status int                    `json:"status"`
 | 
				
			||||||
 | 
								Seq    int                    `json:"seq"`
 | 
				
			||||||
 | 
								Text   []ChatResponseTextItem `json:"text"`
 | 
				
			||||||
 | 
							} `json:"choices"`
 | 
				
			||||||
 | 
							Usage struct {
 | 
				
			||||||
 | 
								//Text struct {
 | 
				
			||||||
 | 
								//	QuestionTokens   string `json:"question_tokens"`
 | 
				
			||||||
 | 
								//	PromptTokens     string `json:"prompt_tokens"`
 | 
				
			||||||
 | 
								//	CompletionTokens string `json:"completion_tokens"`
 | 
				
			||||||
 | 
								//	TotalTokens      string `json:"total_tokens"`
 | 
				
			||||||
 | 
								//} `json:"text"`
 | 
				
			||||||
 | 
								Text openai.Usage `json:"text"`
 | 
				
			||||||
 | 
							} `json:"usage"`
 | 
				
			||||||
 | 
						} `json:"payload"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
package controller
 | 
					package zhipu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bufio"
 | 
						"bufio"
 | 
				
			||||||
@@ -8,6 +8,8 @@ import (
 | 
				
			|||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
 | 
						"one-api/relay/constant"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
@@ -18,53 +20,13 @@ import (
 | 
				
			|||||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
 | 
					// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
 | 
				
			||||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
 | 
					// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ZhipuMessage struct {
 | 
					 | 
				
			||||||
	Role    string `json:"role"`
 | 
					 | 
				
			||||||
	Content string `json:"content"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ZhipuRequest struct {
 | 
					 | 
				
			||||||
	Prompt      []ZhipuMessage `json:"prompt"`
 | 
					 | 
				
			||||||
	Temperature float64        `json:"temperature,omitempty"`
 | 
					 | 
				
			||||||
	TopP        float64        `json:"top_p,omitempty"`
 | 
					 | 
				
			||||||
	RequestId   string         `json:"request_id,omitempty"`
 | 
					 | 
				
			||||||
	Incremental bool           `json:"incremental,omitempty"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ZhipuResponseData struct {
 | 
					 | 
				
			||||||
	TaskId     string         `json:"task_id"`
 | 
					 | 
				
			||||||
	RequestId  string         `json:"request_id"`
 | 
					 | 
				
			||||||
	TaskStatus string         `json:"task_status"`
 | 
					 | 
				
			||||||
	Choices    []ZhipuMessage `json:"choices"`
 | 
					 | 
				
			||||||
	Usage      `json:"usage"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ZhipuResponse struct {
 | 
					 | 
				
			||||||
	Code    int               `json:"code"`
 | 
					 | 
				
			||||||
	Msg     string            `json:"msg"`
 | 
					 | 
				
			||||||
	Success bool              `json:"success"`
 | 
					 | 
				
			||||||
	Data    ZhipuResponseData `json:"data"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ZhipuStreamMetaResponse struct {
 | 
					 | 
				
			||||||
	RequestId  string `json:"request_id"`
 | 
					 | 
				
			||||||
	TaskId     string `json:"task_id"`
 | 
					 | 
				
			||||||
	TaskStatus string `json:"task_status"`
 | 
					 | 
				
			||||||
	Usage      `json:"usage"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type zhipuTokenData struct {
 | 
					 | 
				
			||||||
	Token      string
 | 
					 | 
				
			||||||
	ExpiryTime time.Time
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
var zhipuTokens sync.Map
 | 
					var zhipuTokens sync.Map
 | 
				
			||||||
var expSeconds int64 = 24 * 3600
 | 
					var expSeconds int64 = 24 * 3600
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func getZhipuToken(apikey string) string {
 | 
					func GetToken(apikey string) string {
 | 
				
			||||||
	data, ok := zhipuTokens.Load(apikey)
 | 
						data, ok := zhipuTokens.Load(apikey)
 | 
				
			||||||
	if ok {
 | 
						if ok {
 | 
				
			||||||
		tokenData := data.(zhipuTokenData)
 | 
							tokenData := data.(tokenData)
 | 
				
			||||||
		if time.Now().Before(tokenData.ExpiryTime) {
 | 
							if time.Now().Before(tokenData.ExpiryTime) {
 | 
				
			||||||
			return tokenData.Token
 | 
								return tokenData.Token
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -100,7 +62,7 @@ func getZhipuToken(apikey string) string {
 | 
				
			|||||||
		return ""
 | 
							return ""
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	zhipuTokens.Store(apikey, zhipuTokenData{
 | 
						zhipuTokens.Store(apikey, tokenData{
 | 
				
			||||||
		Token:      tokenString,
 | 
							Token:      tokenString,
 | 
				
			||||||
		ExpiryTime: expiryTime,
 | 
							ExpiryTime: expiryTime,
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
@@ -108,26 +70,26 @@ func getZhipuToken(apikey string) string {
 | 
				
			|||||||
	return tokenString
 | 
						return tokenString
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
 | 
					func ConvertRequest(request openai.GeneralOpenAIRequest) *Request {
 | 
				
			||||||
	messages := make([]ZhipuMessage, 0, len(request.Messages))
 | 
						messages := make([]Message, 0, len(request.Messages))
 | 
				
			||||||
	for _, message := range request.Messages {
 | 
						for _, message := range request.Messages {
 | 
				
			||||||
		if message.Role == "system" {
 | 
							if message.Role == "system" {
 | 
				
			||||||
			messages = append(messages, ZhipuMessage{
 | 
								messages = append(messages, Message{
 | 
				
			||||||
				Role:    "system",
 | 
									Role:    "system",
 | 
				
			||||||
				Content: message.StringContent(),
 | 
									Content: message.StringContent(),
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
			messages = append(messages, ZhipuMessage{
 | 
								messages = append(messages, Message{
 | 
				
			||||||
				Role:    "user",
 | 
									Role:    "user",
 | 
				
			||||||
				Content: "Okay",
 | 
									Content: "Okay",
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			messages = append(messages, ZhipuMessage{
 | 
								messages = append(messages, Message{
 | 
				
			||||||
				Role:    message.Role,
 | 
									Role:    message.Role,
 | 
				
			||||||
				Content: message.StringContent(),
 | 
									Content: message.StringContent(),
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &ZhipuRequest{
 | 
						return &Request{
 | 
				
			||||||
		Prompt:      messages,
 | 
							Prompt:      messages,
 | 
				
			||||||
		Temperature: request.Temperature,
 | 
							Temperature: request.Temperature,
 | 
				
			||||||
		TopP:        request.TopP,
 | 
							TopP:        request.TopP,
 | 
				
			||||||
@@ -135,18 +97,18 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
 | 
					func responseZhipu2OpenAI(response *Response) *openai.TextResponse {
 | 
				
			||||||
	fullTextResponse := OpenAITextResponse{
 | 
						fullTextResponse := openai.TextResponse{
 | 
				
			||||||
		Id:      response.Data.TaskId,
 | 
							Id:      response.Data.TaskId,
 | 
				
			||||||
		Object:  "chat.completion",
 | 
							Object:  "chat.completion",
 | 
				
			||||||
		Created: common.GetTimestamp(),
 | 
							Created: common.GetTimestamp(),
 | 
				
			||||||
		Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)),
 | 
							Choices: make([]openai.TextResponseChoice, 0, len(response.Data.Choices)),
 | 
				
			||||||
		Usage:   response.Data.Usage,
 | 
							Usage:   response.Data.Usage,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	for i, choice := range response.Data.Choices {
 | 
						for i, choice := range response.Data.Choices {
 | 
				
			||||||
		openaiChoice := OpenAITextResponseChoice{
 | 
							openaiChoice := openai.TextResponseChoice{
 | 
				
			||||||
			Index: i,
 | 
								Index: i,
 | 
				
			||||||
			Message: Message{
 | 
								Message: openai.Message{
 | 
				
			||||||
				Role:    choice.Role,
 | 
									Role:    choice.Role,
 | 
				
			||||||
				Content: strings.Trim(choice.Content, "\""),
 | 
									Content: strings.Trim(choice.Content, "\""),
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
@@ -160,34 +122,34 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
 | 
				
			|||||||
	return &fullTextResponse
 | 
						return &fullTextResponse
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse {
 | 
					func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStreamResponse {
 | 
				
			||||||
	var choice ChatCompletionsStreamResponseChoice
 | 
						var choice openai.ChatCompletionsStreamResponseChoice
 | 
				
			||||||
	choice.Delta.Content = zhipuResponse
 | 
						choice.Delta.Content = zhipuResponse
 | 
				
			||||||
	response := ChatCompletionsStreamResponse{
 | 
						response := openai.ChatCompletionsStreamResponse{
 | 
				
			||||||
		Object:  "chat.completion.chunk",
 | 
							Object:  "chat.completion.chunk",
 | 
				
			||||||
		Created: common.GetTimestamp(),
 | 
							Created: common.GetTimestamp(),
 | 
				
			||||||
		Model:   "chatglm",
 | 
							Model:   "chatglm",
 | 
				
			||||||
		Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
							Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &response
 | 
						return &response
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) {
 | 
					func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *openai.Usage) {
 | 
				
			||||||
	var choice ChatCompletionsStreamResponseChoice
 | 
						var choice openai.ChatCompletionsStreamResponseChoice
 | 
				
			||||||
	choice.Delta.Content = ""
 | 
						choice.Delta.Content = ""
 | 
				
			||||||
	choice.FinishReason = &stopFinishReason
 | 
						choice.FinishReason = &constant.StopFinishReason
 | 
				
			||||||
	response := ChatCompletionsStreamResponse{
 | 
						response := openai.ChatCompletionsStreamResponse{
 | 
				
			||||||
		Id:      zhipuResponse.RequestId,
 | 
							Id:      zhipuResponse.RequestId,
 | 
				
			||||||
		Object:  "chat.completion.chunk",
 | 
							Object:  "chat.completion.chunk",
 | 
				
			||||||
		Created: common.GetTimestamp(),
 | 
							Created: common.GetTimestamp(),
 | 
				
			||||||
		Model:   "chatglm",
 | 
							Model:   "chatglm",
 | 
				
			||||||
		Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
							Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &response, &zhipuResponse.Usage
 | 
						return &response, &zhipuResponse.Usage
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
 | 
				
			||||||
	var usage *Usage
 | 
						var usage *openai.Usage
 | 
				
			||||||
	scanner := bufio.NewScanner(resp.Body)
 | 
						scanner := bufio.NewScanner(resp.Body)
 | 
				
			||||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
						scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
				
			||||||
		if atEOF && len(data) == 0 {
 | 
							if atEOF && len(data) == 0 {
 | 
				
			||||||
@@ -224,7 +186,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		stopChan <- true
 | 
							stopChan <- true
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
	setEventStreamHeaders(c)
 | 
						common.SetEventStreamHeaders(c)
 | 
				
			||||||
	c.Stream(func(w io.Writer) bool {
 | 
						c.Stream(func(w io.Writer) bool {
 | 
				
			||||||
		select {
 | 
							select {
 | 
				
			||||||
		case data := <-dataChan:
 | 
							case data := <-dataChan:
 | 
				
			||||||
@@ -237,7 +199,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
 | 
				
			|||||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
								c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
				
			||||||
			return true
 | 
								return true
 | 
				
			||||||
		case data := <-metaChan:
 | 
							case data := <-metaChan:
 | 
				
			||||||
			var zhipuResponse ZhipuStreamMetaResponse
 | 
								var zhipuResponse StreamMetaResponse
 | 
				
			||||||
			err := json.Unmarshal([]byte(data), &zhipuResponse)
 | 
								err := json.Unmarshal([]byte(data), &zhipuResponse)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				common.SysError("error unmarshalling stream response: " + err.Error())
 | 
									common.SysError("error unmarshalling stream response: " + err.Error())
 | 
				
			||||||
@@ -259,28 +221,28 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
	err := resp.Body.Close()
 | 
						err := resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil, usage
 | 
						return nil, usage
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
 | 
				
			||||||
	var zhipuResponse ZhipuResponse
 | 
						var zhipuResponse Response
 | 
				
			||||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
						responseBody, err := io.ReadAll(resp.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = resp.Body.Close()
 | 
						err = resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = json.Unmarshal(responseBody, &zhipuResponse)
 | 
						err = json.Unmarshal(responseBody, &zhipuResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if !zhipuResponse.Success {
 | 
						if !zhipuResponse.Success {
 | 
				
			||||||
		return &OpenAIErrorWithStatusCode{
 | 
							return &openai.ErrorWithStatusCode{
 | 
				
			||||||
			OpenAIError: OpenAIError{
 | 
								Error: openai.Error{
 | 
				
			||||||
				Message: zhipuResponse.Msg,
 | 
									Message: zhipuResponse.Msg,
 | 
				
			||||||
				Type:    "zhipu_error",
 | 
									Type:    "zhipu_error",
 | 
				
			||||||
				Param:   "",
 | 
									Param:   "",
 | 
				
			||||||
@@ -293,7 +255,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
 | 
				
			|||||||
	fullTextResponse.Model = "chatglm"
 | 
						fullTextResponse.Model = "chatglm"
 | 
				
			||||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
						jsonResponse, err := json.Marshal(fullTextResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
						c.Writer.Header().Set("Content-Type", "application/json")
 | 
				
			||||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
						c.Writer.WriteHeader(resp.StatusCode)
 | 
				
			||||||
							
								
								
									
										46
									
								
								relay/channel/zhipu/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								relay/channel/zhipu/model.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,46 @@
 | 
				
			|||||||
 | 
					package zhipu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Message struct {
 | 
				
			||||||
 | 
						Role    string `json:"role"`
 | 
				
			||||||
 | 
						Content string `json:"content"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Request struct {
 | 
				
			||||||
 | 
						Prompt      []Message `json:"prompt"`
 | 
				
			||||||
 | 
						Temperature float64   `json:"temperature,omitempty"`
 | 
				
			||||||
 | 
						TopP        float64   `json:"top_p,omitempty"`
 | 
				
			||||||
 | 
						RequestId   string    `json:"request_id,omitempty"`
 | 
				
			||||||
 | 
						Incremental bool      `json:"incremental,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ResponseData struct {
 | 
				
			||||||
 | 
						TaskId       string    `json:"task_id"`
 | 
				
			||||||
 | 
						RequestId    string    `json:"request_id"`
 | 
				
			||||||
 | 
						TaskStatus   string    `json:"task_status"`
 | 
				
			||||||
 | 
						Choices      []Message `json:"choices"`
 | 
				
			||||||
 | 
						openai.Usage `json:"usage"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Response struct {
 | 
				
			||||||
 | 
						Code    int          `json:"code"`
 | 
				
			||||||
 | 
						Msg     string       `json:"msg"`
 | 
				
			||||||
 | 
						Success bool         `json:"success"`
 | 
				
			||||||
 | 
						Data    ResponseData `json:"data"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type StreamMetaResponse struct {
 | 
				
			||||||
 | 
						RequestId    string `json:"request_id"`
 | 
				
			||||||
 | 
						TaskId       string `json:"task_id"`
 | 
				
			||||||
 | 
						TaskStatus   string `json:"task_status"`
 | 
				
			||||||
 | 
						openai.Usage `json:"usage"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type tokenData struct {
 | 
				
			||||||
 | 
						Token      string
 | 
				
			||||||
 | 
						ExpiryTime time.Time
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										16
									
								
								relay/constant/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								relay/constant/main.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,16 @@
 | 
				
			|||||||
 | 
					package constant
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						RelayModeUnknown = iota
 | 
				
			||||||
 | 
						RelayModeChatCompletions
 | 
				
			||||||
 | 
						RelayModeCompletions
 | 
				
			||||||
 | 
						RelayModeEmbeddings
 | 
				
			||||||
 | 
						RelayModeModerations
 | 
				
			||||||
 | 
						RelayModeImagesGenerations
 | 
				
			||||||
 | 
						RelayModeEdits
 | 
				
			||||||
 | 
						RelayModeAudioSpeech
 | 
				
			||||||
 | 
						RelayModeAudioTranscription
 | 
				
			||||||
 | 
						RelayModeAudioTranslation
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var StopFinishReason = "stop"
 | 
				
			||||||
@@ -12,10 +12,13 @@ import (
 | 
				
			|||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
	"one-api/model"
 | 
						"one-api/model"
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
 | 
						"one-api/relay/constant"
 | 
				
			||||||
 | 
						"one-api/relay/util"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
					func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
 | 
				
			||||||
	audioModel := "whisper-1"
 | 
						audioModel := "whisper-1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	tokenId := c.GetInt("token_id")
 | 
						tokenId := c.GetInt("token_id")
 | 
				
			||||||
@@ -25,18 +28,18 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
	group := c.GetString("group")
 | 
						group := c.GetString("group")
 | 
				
			||||||
	tokenName := c.GetString("token_name")
 | 
						tokenName := c.GetString("token_name")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var ttsRequest TextToSpeechRequest
 | 
						var ttsRequest openai.TextToSpeechRequest
 | 
				
			||||||
	if relayMode == RelayModeAudioSpeech {
 | 
						if relayMode == constant.RelayModeAudioSpeech {
 | 
				
			||||||
		// Read JSON
 | 
							// Read JSON
 | 
				
			||||||
		err := common.UnmarshalBodyReusable(c, &ttsRequest)
 | 
							err := common.UnmarshalBodyReusable(c, &ttsRequest)
 | 
				
			||||||
		// Check if JSON is valid
 | 
							// Check if JSON is valid
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "invalid_json", http.StatusBadRequest)
 | 
								return openai.ErrorWrapper(err, "invalid_json", http.StatusBadRequest)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		audioModel = ttsRequest.Model
 | 
							audioModel = ttsRequest.Model
 | 
				
			||||||
		// Check if text is too long 4096
 | 
							// Check if text is too long 4096
 | 
				
			||||||
		if len(ttsRequest.Input) > 4096 {
 | 
							if len(ttsRequest.Input) > 4096 {
 | 
				
			||||||
			return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
 | 
								return openai.ErrorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -46,7 +49,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
	var quota int
 | 
						var quota int
 | 
				
			||||||
	var preConsumedQuota int
 | 
						var preConsumedQuota int
 | 
				
			||||||
	switch relayMode {
 | 
						switch relayMode {
 | 
				
			||||||
	case RelayModeAudioSpeech:
 | 
						case constant.RelayModeAudioSpeech:
 | 
				
			||||||
		preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
 | 
							preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
 | 
				
			||||||
		quota = preConsumedQuota
 | 
							quota = preConsumedQuota
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
@@ -54,16 +57,16 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	userQuota, err := model.CacheGetUserQuota(userId)
 | 
						userQuota, err := model.CacheGetUserQuota(userId)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if user quota is enough
 | 
						// Check if user quota is enough
 | 
				
			||||||
	if userQuota-preConsumedQuota < 0 {
 | 
						if userQuota-preConsumedQuota < 0 {
 | 
				
			||||||
		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 | 
							return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
 | 
						err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if userQuota > 100*preConsumedQuota {
 | 
						if userQuota > 100*preConsumedQuota {
 | 
				
			||||||
		// in this case, we do not pre-consume quota
 | 
							// in this case, we do not pre-consume quota
 | 
				
			||||||
@@ -73,7 +76,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
	if preConsumedQuota > 0 {
 | 
						if preConsumedQuota > 0 {
 | 
				
			||||||
		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 | 
							err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 | 
								return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -83,7 +86,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
		modelMap := make(map[string]string)
 | 
							modelMap := make(map[string]string)
 | 
				
			||||||
		err := json.Unmarshal([]byte(modelMapping), &modelMap)
 | 
							err := json.Unmarshal([]byte(modelMapping), &modelMap)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if modelMap[audioModel] != "" {
 | 
							if modelMap[audioModel] != "" {
 | 
				
			||||||
			audioModel = modelMap[audioModel]
 | 
								audioModel = modelMap[audioModel]
 | 
				
			||||||
@@ -96,27 +99,27 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
		baseURL = c.GetString("base_url")
 | 
							baseURL = c.GetString("base_url")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
 | 
						fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
 | 
				
			||||||
	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
 | 
						if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
 | 
				
			||||||
		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
 | 
							// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
 | 
				
			||||||
		apiVersion := GetAPIVersion(c)
 | 
							apiVersion := util.GetAPIVersion(c)
 | 
				
			||||||
		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
 | 
							fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	requestBody := &bytes.Buffer{}
 | 
						requestBody := &bytes.Buffer{}
 | 
				
			||||||
	_, err = io.Copy(requestBody, c.Request.Body)
 | 
						_, err = io.Copy(requestBody, c.Request.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes()))
 | 
						c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes()))
 | 
				
			||||||
	responseFormat := c.DefaultPostForm("response_format", "json")
 | 
						responseFormat := c.DefaultPostForm("response_format", "json")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
						req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
 | 
						if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
 | 
				
			||||||
		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
 | 
							// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
 | 
				
			||||||
		apiKey := c.Request.Header.Get("Authorization")
 | 
							apiKey := c.Request.Header.Get("Authorization")
 | 
				
			||||||
		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
							apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
				
			||||||
@@ -128,34 +131,34 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
						req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
				
			||||||
	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
						req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	resp, err := httpClient.Do(req)
 | 
						resp, err := util.HTTPClient.Do(req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = req.Body.Close()
 | 
						err = req.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = c.Request.Body.Close()
 | 
						err = c.Request.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if relayMode != RelayModeAudioSpeech {
 | 
						if relayMode != constant.RelayModeAudioSpeech {
 | 
				
			||||||
		responseBody, err := io.ReadAll(resp.Body)
 | 
							responseBody, err := io.ReadAll(resp.Body)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		err = resp.Body.Close()
 | 
							err = resp.Body.Close()
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		var openAIErr TextResponse
 | 
							var openAIErr openai.SlimTextResponse
 | 
				
			||||||
		if err = json.Unmarshal(responseBody, &openAIErr); err == nil {
 | 
							if err = json.Unmarshal(responseBody, &openAIErr); err == nil {
 | 
				
			||||||
			if openAIErr.Error.Message != "" {
 | 
								if openAIErr.Error.Message != "" {
 | 
				
			||||||
				return errorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError)
 | 
									return openai.ErrorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -172,12 +175,12 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
		case "vtt":
 | 
							case "vtt":
 | 
				
			||||||
			text, err = getTextFromVTT(responseBody)
 | 
								text, err = getTextFromVTT(responseBody)
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
			return errorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		quota = countTokenText(text, audioModel)
 | 
							quota = openai.CountTokenText(text, audioModel)
 | 
				
			||||||
		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
							resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if resp.StatusCode != http.StatusOK {
 | 
						if resp.StatusCode != http.StatusOK {
 | 
				
			||||||
@@ -193,11 +196,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
				}()
 | 
									}()
 | 
				
			||||||
			}(c.Request.Context())
 | 
								}(c.Request.Context())
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return relayErrorHandler(resp)
 | 
							return util.RelayErrorHandler(resp)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	quotaDelta := quota - preConsumedQuota
 | 
						quotaDelta := quota - preConsumedQuota
 | 
				
			||||||
	defer func(ctx context.Context) {
 | 
						defer func(ctx context.Context) {
 | 
				
			||||||
		go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
 | 
							go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
 | 
				
			||||||
	}(c.Request.Context())
 | 
						}(c.Request.Context())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for k, v := range resp.Header {
 | 
						for k, v := range resp.Header {
 | 
				
			||||||
@@ -207,11 +210,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	_, err = io.Copy(c.Writer, resp.Body)
 | 
						_, err = io.Copy(c.Writer, resp.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = resp.Body.Close()
 | 
						err = resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -221,7 +224,7 @@ func getTextFromVTT(body []byte) (string, error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func getTextFromVerboseJSON(body []byte) (string, error) {
 | 
					func getTextFromVerboseJSON(body []byte) (string, error) {
 | 
				
			||||||
	var whisperResponse WhisperVerboseJSONResponse
 | 
						var whisperResponse openai.WhisperVerboseJSONResponse
 | 
				
			||||||
	if err := json.Unmarshal(body, &whisperResponse); err != nil {
 | 
						if err := json.Unmarshal(body, &whisperResponse); err != nil {
 | 
				
			||||||
		return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
 | 
							return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -254,7 +257,7 @@ func getTextFromText(body []byte) (string, error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func getTextFromJSON(body []byte) (string, error) {
 | 
					func getTextFromJSON(body []byte) (string, error) {
 | 
				
			||||||
	var whisperResponse WhisperJSONResponse
 | 
						var whisperResponse openai.WhisperJSONResponse
 | 
				
			||||||
	if err := json.Unmarshal(body, &whisperResponse); err != nil {
 | 
						if err := json.Unmarshal(body, &whisperResponse); err != nil {
 | 
				
			||||||
		return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
 | 
							return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -10,6 +10,8 @@ import (
 | 
				
			|||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
	"one-api/model"
 | 
						"one-api/model"
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
 | 
						"one-api/relay/util"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
@@ -25,7 +27,7 @@ func isWithinRange(element string, value int) bool {
 | 
				
			|||||||
	return value >= min && value <= max
 | 
						return value >= min && value <= max
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
					func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
 | 
				
			||||||
	imageModel := "dall-e-2"
 | 
						imageModel := "dall-e-2"
 | 
				
			||||||
	imageSize := "1024x1024"
 | 
						imageSize := "1024x1024"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -35,10 +37,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
	userId := c.GetInt("id")
 | 
						userId := c.GetInt("id")
 | 
				
			||||||
	group := c.GetString("group")
 | 
						group := c.GetString("group")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var imageRequest ImageRequest
 | 
						var imageRequest openai.ImageRequest
 | 
				
			||||||
	err := common.UnmarshalBodyReusable(c, &imageRequest)
 | 
						err := common.UnmarshalBodyReusable(c, &imageRequest)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 | 
							return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if imageRequest.N == 0 {
 | 
						if imageRequest.N == 0 {
 | 
				
			||||||
@@ -67,24 +69,24 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
 | 
							return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Prompt validation
 | 
						// Prompt validation
 | 
				
			||||||
	if imageRequest.Prompt == "" {
 | 
						if imageRequest.Prompt == "" {
 | 
				
			||||||
		return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
 | 
							return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check prompt length
 | 
						// Check prompt length
 | 
				
			||||||
	if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] {
 | 
						if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] {
 | 
				
			||||||
		return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
 | 
							return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Number of generated images validation
 | 
						// Number of generated images validation
 | 
				
			||||||
	if isWithinRange(imageModel, imageRequest.N) == false {
 | 
						if isWithinRange(imageModel, imageRequest.N) == false {
 | 
				
			||||||
		// channel not azure
 | 
							// channel not azure
 | 
				
			||||||
		if channelType != common.ChannelTypeAzure {
 | 
							if channelType != common.ChannelTypeAzure {
 | 
				
			||||||
			return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
 | 
								return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -95,7 +97,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
		modelMap := make(map[string]string)
 | 
							modelMap := make(map[string]string)
 | 
				
			||||||
		err := json.Unmarshal([]byte(modelMapping), &modelMap)
 | 
							err := json.Unmarshal([]byte(modelMapping), &modelMap)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if modelMap[imageModel] != "" {
 | 
							if modelMap[imageModel] != "" {
 | 
				
			||||||
			imageModel = modelMap[imageModel]
 | 
								imageModel = modelMap[imageModel]
 | 
				
			||||||
@@ -107,10 +109,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
	if c.GetString("base_url") != "" {
 | 
						if c.GetString("base_url") != "" {
 | 
				
			||||||
		baseURL = c.GetString("base_url")
 | 
							baseURL = c.GetString("base_url")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
 | 
						fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
 | 
				
			||||||
	if channelType == common.ChannelTypeAzure {
 | 
						if channelType == common.ChannelTypeAzure {
 | 
				
			||||||
		// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
 | 
							// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
 | 
				
			||||||
		apiVersion := GetAPIVersion(c)
 | 
							apiVersion := util.GetAPIVersion(c)
 | 
				
			||||||
		// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
 | 
							// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
 | 
				
			||||||
		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion)
 | 
							fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -119,7 +121,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
	if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
 | 
						if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
 | 
				
			||||||
		jsonStr, err := json.Marshal(imageRequest)
 | 
							jsonStr, err := json.Marshal(imageRequest)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
							requestBody = bytes.NewBuffer(jsonStr)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
@@ -134,12 +136,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
	quota := int(ratio*imageCostRatio*1000) * imageRequest.N
 | 
						quota := int(ratio*imageCostRatio*1000) * imageRequest.N
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if userQuota-quota < 0 {
 | 
						if userQuota-quota < 0 {
 | 
				
			||||||
		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 | 
							return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
						req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	token := c.Request.Header.Get("Authorization")
 | 
						token := c.Request.Header.Get("Authorization")
 | 
				
			||||||
	if channelType == common.ChannelTypeAzure { // Azure authentication
 | 
						if channelType == common.ChannelTypeAzure { // Azure authentication
 | 
				
			||||||
@@ -152,20 +154,20 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
						req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
				
			||||||
	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
						req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	resp, err := httpClient.Do(req)
 | 
						resp, err := util.HTTPClient.Do(req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = req.Body.Close()
 | 
						err = req.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = c.Request.Body.Close()
 | 
						err = c.Request.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var textResponse ImageResponse
 | 
						var textResponse openai.ImageResponse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	defer func(ctx context.Context) {
 | 
						defer func(ctx context.Context) {
 | 
				
			||||||
		if resp.StatusCode != http.StatusOK {
 | 
							if resp.StatusCode != http.StatusOK {
 | 
				
			||||||
@@ -192,15 +194,15 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
						responseBody, err := io.ReadAll(resp.Body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = resp.Body.Close()
 | 
						err = resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = json.Unmarshal(responseBody, &textResponse)
 | 
						err = json.Unmarshal(responseBody, &textResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
						resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
				
			||||||
@@ -212,11 +214,11 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	_, err = io.Copy(c.Writer, resp.Body)
 | 
						_, err = io.Copy(c.Writer, resp.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = resp.Body.Close()
 | 
						err = resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -6,15 +6,24 @@ import (
 | 
				
			|||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"math"
 | 
						"math"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
	"one-api/model"
 | 
						"one-api/model"
 | 
				
			||||||
 | 
						"one-api/relay/channel/aiproxy"
 | 
				
			||||||
 | 
						"one-api/relay/channel/ali"
 | 
				
			||||||
 | 
						"one-api/relay/channel/anthropic"
 | 
				
			||||||
 | 
						"one-api/relay/channel/baidu"
 | 
				
			||||||
 | 
						"one-api/relay/channel/google"
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
 | 
						"one-api/relay/channel/tencent"
 | 
				
			||||||
 | 
						"one-api/relay/channel/xunfei"
 | 
				
			||||||
 | 
						"one-api/relay/channel/zhipu"
 | 
				
			||||||
 | 
						"one-api/relay/constant"
 | 
				
			||||||
 | 
						"one-api/relay/util"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
@@ -30,64 +39,47 @@ const (
 | 
				
			|||||||
	APITypeGemini
 | 
						APITypeGemini
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var httpClient *http.Client
 | 
					func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
 | 
				
			||||||
var impatientHTTPClient *http.Client
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func init() {
 | 
					 | 
				
			||||||
	if common.RelayTimeout == 0 {
 | 
					 | 
				
			||||||
		httpClient = &http.Client{}
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		httpClient = &http.Client{
 | 
					 | 
				
			||||||
			Timeout: time.Duration(common.RelayTimeout) * time.Second,
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	impatientHTTPClient = &http.Client{
 | 
					 | 
				
			||||||
		Timeout: 5 * time.Second,
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
					 | 
				
			||||||
	channelType := c.GetInt("channel")
 | 
						channelType := c.GetInt("channel")
 | 
				
			||||||
	channelId := c.GetInt("channel_id")
 | 
						channelId := c.GetInt("channel_id")
 | 
				
			||||||
	tokenId := c.GetInt("token_id")
 | 
						tokenId := c.GetInt("token_id")
 | 
				
			||||||
	userId := c.GetInt("id")
 | 
						userId := c.GetInt("id")
 | 
				
			||||||
	group := c.GetString("group")
 | 
						group := c.GetString("group")
 | 
				
			||||||
	var textRequest GeneralOpenAIRequest
 | 
						var textRequest openai.GeneralOpenAIRequest
 | 
				
			||||||
	err := common.UnmarshalBodyReusable(c, &textRequest)
 | 
						err := common.UnmarshalBodyReusable(c, &textRequest)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 | 
							return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
 | 
						if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
 | 
				
			||||||
		return errorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest)
 | 
							return openai.ErrorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if relayMode == RelayModeModerations && textRequest.Model == "" {
 | 
						if relayMode == constant.RelayModeModerations && textRequest.Model == "" {
 | 
				
			||||||
		textRequest.Model = "text-moderation-latest"
 | 
							textRequest.Model = "text-moderation-latest"
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
 | 
						if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" {
 | 
				
			||||||
		textRequest.Model = c.Param("model")
 | 
							textRequest.Model = c.Param("model")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	// request validation
 | 
						// request validation
 | 
				
			||||||
	if textRequest.Model == "" {
 | 
						if textRequest.Model == "" {
 | 
				
			||||||
		return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
 | 
							return openai.ErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	switch relayMode {
 | 
						switch relayMode {
 | 
				
			||||||
	case RelayModeCompletions:
 | 
						case constant.RelayModeCompletions:
 | 
				
			||||||
		if textRequest.Prompt == "" {
 | 
							if textRequest.Prompt == "" {
 | 
				
			||||||
			return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
 | 
								return openai.ErrorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	case RelayModeChatCompletions:
 | 
						case constant.RelayModeChatCompletions:
 | 
				
			||||||
		if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
 | 
							if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
 | 
				
			||||||
			return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
 | 
								return openai.ErrorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	case RelayModeEmbeddings:
 | 
						case constant.RelayModeEmbeddings:
 | 
				
			||||||
	case RelayModeModerations:
 | 
						case constant.RelayModeModerations:
 | 
				
			||||||
		if textRequest.Input == "" {
 | 
							if textRequest.Input == "" {
 | 
				
			||||||
			return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
 | 
								return openai.ErrorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	case RelayModeEdits:
 | 
						case constant.RelayModeEdits:
 | 
				
			||||||
		if textRequest.Instruction == "" {
 | 
							if textRequest.Instruction == "" {
 | 
				
			||||||
			return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
 | 
								return openai.ErrorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	// map model name
 | 
						// map model name
 | 
				
			||||||
@@ -97,7 +89,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		modelMap := make(map[string]string)
 | 
							modelMap := make(map[string]string)
 | 
				
			||||||
		err := json.Unmarshal([]byte(modelMapping), &modelMap)
 | 
							err := json.Unmarshal([]byte(modelMapping), &modelMap)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if modelMap[textRequest.Model] != "" {
 | 
							if modelMap[textRequest.Model] != "" {
 | 
				
			||||||
			textRequest.Model = modelMap[textRequest.Model]
 | 
								textRequest.Model = modelMap[textRequest.Model]
 | 
				
			||||||
@@ -130,12 +122,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
	if c.GetString("base_url") != "" {
 | 
						if c.GetString("base_url") != "" {
 | 
				
			||||||
		baseURL = c.GetString("base_url")
 | 
							baseURL = c.GetString("base_url")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
 | 
						fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
 | 
				
			||||||
	switch apiType {
 | 
						switch apiType {
 | 
				
			||||||
	case APITypeOpenAI:
 | 
						case APITypeOpenAI:
 | 
				
			||||||
		if channelType == common.ChannelTypeAzure {
 | 
							if channelType == common.ChannelTypeAzure {
 | 
				
			||||||
			// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
 | 
								// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
 | 
				
			||||||
			apiVersion := GetAPIVersion(c)
 | 
								apiVersion := util.GetAPIVersion(c)
 | 
				
			||||||
			requestURL := strings.Split(requestURL, "?")[0]
 | 
								requestURL := strings.Split(requestURL, "?")[0]
 | 
				
			||||||
			requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
 | 
								requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
 | 
				
			||||||
			baseURL = c.GetString("base_url")
 | 
								baseURL = c.GetString("base_url")
 | 
				
			||||||
@@ -148,7 +140,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
			model_ = strings.TrimSuffix(model_, "-0613")
 | 
								model_ = strings.TrimSuffix(model_, "-0613")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
 | 
								requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
 | 
				
			||||||
			fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
 | 
								fullRequestURL = util.GetFullRequestURL(baseURL, requestURL, channelType)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	case APITypeClaude:
 | 
						case APITypeClaude:
 | 
				
			||||||
		fullRequestURL = "https://api.anthropic.com/v1/complete"
 | 
							fullRequestURL = "https://api.anthropic.com/v1/complete"
 | 
				
			||||||
@@ -171,8 +163,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		apiKey := c.Request.Header.Get("Authorization")
 | 
							apiKey := c.Request.Header.Get("Authorization")
 | 
				
			||||||
		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
							apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
				
			||||||
		var err error
 | 
							var err error
 | 
				
			||||||
		if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
 | 
							if apiKey, err = baidu.GetAccessToken(apiKey); err != nil {
 | 
				
			||||||
			return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		fullRequestURL += "?access_token=" + apiKey
 | 
							fullRequestURL += "?access_token=" + apiKey
 | 
				
			||||||
	case APITypePaLM:
 | 
						case APITypePaLM:
 | 
				
			||||||
@@ -202,7 +194,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
 | 
							fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
 | 
				
			||||||
	case APITypeAli:
 | 
						case APITypeAli:
 | 
				
			||||||
		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
 | 
							fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
 | 
				
			||||||
		if relayMode == RelayModeEmbeddings {
 | 
							if relayMode == constant.RelayModeEmbeddings {
 | 
				
			||||||
			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
 | 
								fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	case APITypeTencent:
 | 
						case APITypeTencent:
 | 
				
			||||||
@@ -213,12 +205,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
	var promptTokens int
 | 
						var promptTokens int
 | 
				
			||||||
	var completionTokens int
 | 
						var completionTokens int
 | 
				
			||||||
	switch relayMode {
 | 
						switch relayMode {
 | 
				
			||||||
	case RelayModeChatCompletions:
 | 
						case constant.RelayModeChatCompletions:
 | 
				
			||||||
		promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
 | 
							promptTokens = openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
 | 
				
			||||||
	case RelayModeCompletions:
 | 
						case constant.RelayModeCompletions:
 | 
				
			||||||
		promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
 | 
							promptTokens = openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
 | 
				
			||||||
	case RelayModeModerations:
 | 
						case constant.RelayModeModerations:
 | 
				
			||||||
		promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
 | 
							promptTokens = openai.CountTokenInput(textRequest.Input, textRequest.Model)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	preConsumedTokens := common.PreConsumedQuota
 | 
						preConsumedTokens := common.PreConsumedQuota
 | 
				
			||||||
	if textRequest.MaxTokens != 0 {
 | 
						if textRequest.MaxTokens != 0 {
 | 
				
			||||||
@@ -230,14 +222,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
 | 
						preConsumedQuota := int(float64(preConsumedTokens) * ratio)
 | 
				
			||||||
	userQuota, err := model.CacheGetUserQuota(userId)
 | 
						userQuota, err := model.CacheGetUserQuota(userId)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if userQuota-preConsumedQuota < 0 {
 | 
						if userQuota-preConsumedQuota < 0 {
 | 
				
			||||||
		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 | 
							return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
 | 
						err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if userQuota > 100*preConsumedQuota {
 | 
						if userQuota > 100*preConsumedQuota {
 | 
				
			||||||
		// in this case, we do not pre-consume quota
 | 
							// in this case, we do not pre-consume quota
 | 
				
			||||||
@@ -248,14 +240,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
	if preConsumedQuota > 0 {
 | 
						if preConsumedQuota > 0 {
 | 
				
			||||||
		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 | 
							err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 | 
								return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var requestBody io.Reader
 | 
						var requestBody io.Reader
 | 
				
			||||||
	if isModelMapped {
 | 
						if isModelMapped {
 | 
				
			||||||
		jsonStr, err := json.Marshal(textRequest)
 | 
							jsonStr, err := json.Marshal(textRequest)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
							requestBody = bytes.NewBuffer(jsonStr)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
@@ -263,86 +255,86 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	switch apiType {
 | 
						switch apiType {
 | 
				
			||||||
	case APITypeClaude:
 | 
						case APITypeClaude:
 | 
				
			||||||
		claudeRequest := requestOpenAI2Claude(textRequest)
 | 
							claudeRequest := anthropic.ConvertRequest(textRequest)
 | 
				
			||||||
		jsonStr, err := json.Marshal(claudeRequest)
 | 
							jsonStr, err := json.Marshal(claudeRequest)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
							requestBody = bytes.NewBuffer(jsonStr)
 | 
				
			||||||
	case APITypeBaidu:
 | 
						case APITypeBaidu:
 | 
				
			||||||
		var jsonData []byte
 | 
							var jsonData []byte
 | 
				
			||||||
		var err error
 | 
							var err error
 | 
				
			||||||
		switch relayMode {
 | 
							switch relayMode {
 | 
				
			||||||
		case RelayModeEmbeddings:
 | 
							case constant.RelayModeEmbeddings:
 | 
				
			||||||
			baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
 | 
								baiduEmbeddingRequest := baidu.ConvertEmbeddingRequest(textRequest)
 | 
				
			||||||
			jsonData, err = json.Marshal(baiduEmbeddingRequest)
 | 
								jsonData, err = json.Marshal(baiduEmbeddingRequest)
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
			baiduRequest := requestOpenAI2Baidu(textRequest)
 | 
								baiduRequest := baidu.ConvertRequest(textRequest)
 | 
				
			||||||
			jsonData, err = json.Marshal(baiduRequest)
 | 
								jsonData, err = json.Marshal(baiduRequest)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		requestBody = bytes.NewBuffer(jsonData)
 | 
							requestBody = bytes.NewBuffer(jsonData)
 | 
				
			||||||
	case APITypePaLM:
 | 
						case APITypePaLM:
 | 
				
			||||||
		palmRequest := requestOpenAI2PaLM(textRequest)
 | 
							palmRequest := google.ConvertPaLMRequest(textRequest)
 | 
				
			||||||
		jsonStr, err := json.Marshal(palmRequest)
 | 
							jsonStr, err := json.Marshal(palmRequest)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
							requestBody = bytes.NewBuffer(jsonStr)
 | 
				
			||||||
	case APITypeGemini:
 | 
						case APITypeGemini:
 | 
				
			||||||
		geminiChatRequest := requestOpenAI2Gemini(textRequest)
 | 
							geminiChatRequest := google.ConvertGeminiRequest(textRequest)
 | 
				
			||||||
		jsonStr, err := json.Marshal(geminiChatRequest)
 | 
							jsonStr, err := json.Marshal(geminiChatRequest)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
							requestBody = bytes.NewBuffer(jsonStr)
 | 
				
			||||||
	case APITypeZhipu:
 | 
						case APITypeZhipu:
 | 
				
			||||||
		zhipuRequest := requestOpenAI2Zhipu(textRequest)
 | 
							zhipuRequest := zhipu.ConvertRequest(textRequest)
 | 
				
			||||||
		jsonStr, err := json.Marshal(zhipuRequest)
 | 
							jsonStr, err := json.Marshal(zhipuRequest)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
							requestBody = bytes.NewBuffer(jsonStr)
 | 
				
			||||||
	case APITypeAli:
 | 
						case APITypeAli:
 | 
				
			||||||
		var jsonStr []byte
 | 
							var jsonStr []byte
 | 
				
			||||||
		var err error
 | 
							var err error
 | 
				
			||||||
		switch relayMode {
 | 
							switch relayMode {
 | 
				
			||||||
		case RelayModeEmbeddings:
 | 
							case constant.RelayModeEmbeddings:
 | 
				
			||||||
			aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
 | 
								aliEmbeddingRequest := ali.ConvertEmbeddingRequest(textRequest)
 | 
				
			||||||
			jsonStr, err = json.Marshal(aliEmbeddingRequest)
 | 
								jsonStr, err = json.Marshal(aliEmbeddingRequest)
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
			aliRequest := requestOpenAI2Ali(textRequest)
 | 
								aliRequest := ali.ConvertRequest(textRequest)
 | 
				
			||||||
			jsonStr, err = json.Marshal(aliRequest)
 | 
								jsonStr, err = json.Marshal(aliRequest)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
							requestBody = bytes.NewBuffer(jsonStr)
 | 
				
			||||||
	case APITypeTencent:
 | 
						case APITypeTencent:
 | 
				
			||||||
		apiKey := c.Request.Header.Get("Authorization")
 | 
							apiKey := c.Request.Header.Get("Authorization")
 | 
				
			||||||
		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
							apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
				
			||||||
		appId, secretId, secretKey, err := parseTencentConfig(apiKey)
 | 
							appId, secretId, secretKey, err := tencent.ParseConfig(apiKey)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		tencentRequest := requestOpenAI2Tencent(textRequest)
 | 
							tencentRequest := tencent.ConvertRequest(textRequest)
 | 
				
			||||||
		tencentRequest.AppId = appId
 | 
							tencentRequest.AppId = appId
 | 
				
			||||||
		tencentRequest.SecretId = secretId
 | 
							tencentRequest.SecretId = secretId
 | 
				
			||||||
		jsonStr, err := json.Marshal(tencentRequest)
 | 
							jsonStr, err := json.Marshal(tencentRequest)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		sign := getTencentSign(*tencentRequest, secretKey)
 | 
							sign := tencent.GetSign(*tencentRequest, secretKey)
 | 
				
			||||||
		c.Request.Header.Set("Authorization", sign)
 | 
							c.Request.Header.Set("Authorization", sign)
 | 
				
			||||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
							requestBody = bytes.NewBuffer(jsonStr)
 | 
				
			||||||
	case APITypeAIProxyLibrary:
 | 
						case APITypeAIProxyLibrary:
 | 
				
			||||||
		aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
 | 
							aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest)
 | 
				
			||||||
		aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
 | 
							aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
 | 
				
			||||||
		jsonStr, err := json.Marshal(aiProxyLibraryRequest)
 | 
							jsonStr, err := json.Marshal(aiProxyLibraryRequest)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
							requestBody = bytes.NewBuffer(jsonStr)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -354,7 +346,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
	if apiType != APITypeXunfei { // cause xunfei use websocket
 | 
						if apiType != APITypeXunfei { // cause xunfei use websocket
 | 
				
			||||||
		req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
							req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		apiKey := c.Request.Header.Get("Authorization")
 | 
							apiKey := c.Request.Header.Get("Authorization")
 | 
				
			||||||
		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
							apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
				
			||||||
@@ -377,7 +369,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
			req.Header.Set("anthropic-version", anthropicVersion)
 | 
								req.Header.Set("anthropic-version", anthropicVersion)
 | 
				
			||||||
		case APITypeZhipu:
 | 
							case APITypeZhipu:
 | 
				
			||||||
			token := getZhipuToken(apiKey)
 | 
								token := zhipu.GetToken(apiKey)
 | 
				
			||||||
			req.Header.Set("Authorization", token)
 | 
								req.Header.Set("Authorization", token)
 | 
				
			||||||
		case APITypeAli:
 | 
							case APITypeAli:
 | 
				
			||||||
			req.Header.Set("Authorization", "Bearer "+apiKey)
 | 
								req.Header.Set("Authorization", "Bearer "+apiKey)
 | 
				
			||||||
@@ -402,17 +394,17 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
			req.Header.Set("Accept", "text/event-stream")
 | 
								req.Header.Set("Accept", "text/event-stream")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
 | 
							//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
 | 
				
			||||||
		resp, err = httpClient.Do(req)
 | 
							resp, err = util.HTTPClient.Do(req)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		err = req.Body.Close()
 | 
							err = req.Body.Close()
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		err = c.Request.Body.Close()
 | 
							err = c.Request.Body.Close()
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
								return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 | 
							isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -426,11 +418,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
					}
 | 
										}
 | 
				
			||||||
				}(c.Request.Context())
 | 
									}(c.Request.Context())
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			return relayErrorHandler(resp)
 | 
								return util.RelayErrorHandler(resp)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var textResponse TextResponse
 | 
						var textResponse openai.SlimTextResponse
 | 
				
			||||||
	tokenName := c.GetString("token_name")
 | 
						tokenName := c.GetString("token_name")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	defer func(ctx context.Context) {
 | 
						defer func(ctx context.Context) {
 | 
				
			||||||
@@ -471,15 +463,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
	switch apiType {
 | 
						switch apiType {
 | 
				
			||||||
	case APITypeOpenAI:
 | 
						case APITypeOpenAI:
 | 
				
			||||||
		if isStream {
 | 
							if isStream {
 | 
				
			||||||
			err, responseText := openaiStreamHandler(c, resp, relayMode)
 | 
								err, responseText := openai.StreamHandler(c, resp, relayMode)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			textResponse.Usage.PromptTokens = promptTokens
 | 
								textResponse.Usage.PromptTokens = promptTokens
 | 
				
			||||||
			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
 | 
								textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
 | 
				
			||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
 | 
								err, usage := openai.Handler(c, resp, promptTokens, textRequest.Model)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -490,15 +482,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	case APITypeClaude:
 | 
						case APITypeClaude:
 | 
				
			||||||
		if isStream {
 | 
							if isStream {
 | 
				
			||||||
			err, responseText := claudeStreamHandler(c, resp)
 | 
								err, responseText := anthropic.StreamHandler(c, resp)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			textResponse.Usage.PromptTokens = promptTokens
 | 
								textResponse.Usage.PromptTokens = promptTokens
 | 
				
			||||||
			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
 | 
								textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
 | 
				
			||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
 | 
								err, usage := anthropic.Handler(c, resp, promptTokens, textRequest.Model)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -509,7 +501,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	case APITypeBaidu:
 | 
						case APITypeBaidu:
 | 
				
			||||||
		if isStream {
 | 
							if isStream {
 | 
				
			||||||
			err, usage := baiduStreamHandler(c, resp)
 | 
								err, usage := baidu.StreamHandler(c, resp)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -518,13 +510,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			var err *OpenAIErrorWithStatusCode
 | 
								var err *openai.ErrorWithStatusCode
 | 
				
			||||||
			var usage *Usage
 | 
								var usage *openai.Usage
 | 
				
			||||||
			switch relayMode {
 | 
								switch relayMode {
 | 
				
			||||||
			case RelayModeEmbeddings:
 | 
								case constant.RelayModeEmbeddings:
 | 
				
			||||||
				err, usage = baiduEmbeddingHandler(c, resp)
 | 
									err, usage = baidu.EmbeddingHandler(c, resp)
 | 
				
			||||||
			default:
 | 
								default:
 | 
				
			||||||
				err, usage = baiduHandler(c, resp)
 | 
									err, usage = baidu.Handler(c, resp)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
@@ -536,15 +528,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	case APITypePaLM:
 | 
						case APITypePaLM:
 | 
				
			||||||
		if textRequest.Stream { // PaLM2 API does not support stream
 | 
							if textRequest.Stream { // PaLM2 API does not support stream
 | 
				
			||||||
			err, responseText := palmStreamHandler(c, resp)
 | 
								err, responseText := google.PaLMStreamHandler(c, resp)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			textResponse.Usage.PromptTokens = promptTokens
 | 
								textResponse.Usage.PromptTokens = promptTokens
 | 
				
			||||||
			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
 | 
								textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
 | 
				
			||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
 | 
								err, usage := google.PaLMHandler(c, resp, promptTokens, textRequest.Model)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -555,15 +547,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	case APITypeGemini:
 | 
						case APITypeGemini:
 | 
				
			||||||
		if textRequest.Stream {
 | 
							if textRequest.Stream {
 | 
				
			||||||
			err, responseText := geminiChatStreamHandler(c, resp)
 | 
								err, responseText := google.StreamHandler(c, resp)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			textResponse.Usage.PromptTokens = promptTokens
 | 
								textResponse.Usage.PromptTokens = promptTokens
 | 
				
			||||||
			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
 | 
								textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
 | 
				
			||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model)
 | 
								err, usage := google.GeminiHandler(c, resp, promptTokens, textRequest.Model)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -574,7 +566,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	case APITypeZhipu:
 | 
						case APITypeZhipu:
 | 
				
			||||||
		if isStream {
 | 
							if isStream {
 | 
				
			||||||
			err, usage := zhipuStreamHandler(c, resp)
 | 
								err, usage := zhipu.StreamHandler(c, resp)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -585,7 +577,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
 | 
								textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
 | 
				
			||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			err, usage := zhipuHandler(c, resp)
 | 
								err, usage := zhipu.Handler(c, resp)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -598,7 +590,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	case APITypeAli:
 | 
						case APITypeAli:
 | 
				
			||||||
		if isStream {
 | 
							if isStream {
 | 
				
			||||||
			err, usage := aliStreamHandler(c, resp)
 | 
								err, usage := ali.StreamHandler(c, resp)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -607,13 +599,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			var err *OpenAIErrorWithStatusCode
 | 
								var err *openai.ErrorWithStatusCode
 | 
				
			||||||
			var usage *Usage
 | 
								var usage *openai.Usage
 | 
				
			||||||
			switch relayMode {
 | 
								switch relayMode {
 | 
				
			||||||
			case RelayModeEmbeddings:
 | 
								case constant.RelayModeEmbeddings:
 | 
				
			||||||
				err, usage = aliEmbeddingHandler(c, resp)
 | 
									err, usage = ali.EmbeddingHandler(c, resp)
 | 
				
			||||||
			default:
 | 
								default:
 | 
				
			||||||
				err, usage = aliHandler(c, resp)
 | 
									err, usage = ali.Handler(c, resp)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
@@ -628,14 +620,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		auth = strings.TrimPrefix(auth, "Bearer ")
 | 
							auth = strings.TrimPrefix(auth, "Bearer ")
 | 
				
			||||||
		splits := strings.Split(auth, "|")
 | 
							splits := strings.Split(auth, "|")
 | 
				
			||||||
		if len(splits) != 3 {
 | 
							if len(splits) != 3 {
 | 
				
			||||||
			return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
 | 
								return openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		var err *OpenAIErrorWithStatusCode
 | 
							var err *openai.ErrorWithStatusCode
 | 
				
			||||||
		var usage *Usage
 | 
							var usage *openai.Usage
 | 
				
			||||||
		if isStream {
 | 
							if isStream {
 | 
				
			||||||
			err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
 | 
								err, usage = xunfei.StreamHandler(c, textRequest, splits[0], splits[1], splits[2])
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
 | 
								err, usage = xunfei.Handler(c, textRequest, splits[0], splits[1], splits[2])
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return err
 | 
								return err
 | 
				
			||||||
@@ -646,7 +638,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	case APITypeAIProxyLibrary:
 | 
						case APITypeAIProxyLibrary:
 | 
				
			||||||
		if isStream {
 | 
							if isStream {
 | 
				
			||||||
			err, usage := aiProxyLibraryStreamHandler(c, resp)
 | 
								err, usage := aiproxy.StreamHandler(c, resp)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -655,7 +647,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			err, usage := aiProxyLibraryHandler(c, resp)
 | 
								err, usage := aiproxy.Handler(c, resp)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -666,15 +658,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	case APITypeTencent:
 | 
						case APITypeTencent:
 | 
				
			||||||
		if isStream {
 | 
							if isStream {
 | 
				
			||||||
			err, responseText := tencentStreamHandler(c, resp)
 | 
								err, responseText := tencent.StreamHandler(c, resp)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			textResponse.Usage.PromptTokens = promptTokens
 | 
								textResponse.Usage.PromptTokens = promptTokens
 | 
				
			||||||
			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
 | 
								textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
 | 
				
			||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			err, usage := tencentHandler(c, resp)
 | 
								err, usage := tencent.Handler(c, resp)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -684,6 +676,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
 | 
							return openai.ErrorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
							
								
								
									
										166
									
								
								relay/util/common.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										166
									
								
								relay/util/common.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,166 @@
 | 
				
			|||||||
 | 
					package util
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"one-api/common"
 | 
				
			||||||
 | 
						"one-api/model"
 | 
				
			||||||
 | 
						"one-api/relay/channel/openai"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func ShouldDisableChannel(err *openai.Error, statusCode int) bool {
 | 
				
			||||||
 | 
						if !common.AutomaticDisableChannelEnabled {
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if statusCode == http.StatusUnauthorized {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func ShouldEnableChannel(err error, openAIErr *openai.Error) bool {
 | 
				
			||||||
 | 
						if !common.AutomaticEnableChannelEnabled {
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if openAIErr != nil {
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type GeneralErrorResponse struct {
 | 
				
			||||||
 | 
						Error    openai.Error `json:"error"`
 | 
				
			||||||
 | 
						Message  string       `json:"message"`
 | 
				
			||||||
 | 
						Msg      string       `json:"msg"`
 | 
				
			||||||
 | 
						Err      string       `json:"err"`
 | 
				
			||||||
 | 
						ErrorMsg string       `json:"error_msg"`
 | 
				
			||||||
 | 
						Header   struct {
 | 
				
			||||||
 | 
							Message string `json:"message"`
 | 
				
			||||||
 | 
						} `json:"header"`
 | 
				
			||||||
 | 
						Response struct {
 | 
				
			||||||
 | 
							Error struct {
 | 
				
			||||||
 | 
								Message string `json:"message"`
 | 
				
			||||||
 | 
							} `json:"error"`
 | 
				
			||||||
 | 
						} `json:"response"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (e GeneralErrorResponse) ToMessage() string {
 | 
				
			||||||
 | 
						if e.Error.Message != "" {
 | 
				
			||||||
 | 
							return e.Error.Message
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if e.Message != "" {
 | 
				
			||||||
 | 
							return e.Message
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if e.Msg != "" {
 | 
				
			||||||
 | 
							return e.Msg
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if e.Err != "" {
 | 
				
			||||||
 | 
							return e.Err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if e.ErrorMsg != "" {
 | 
				
			||||||
 | 
							return e.ErrorMsg
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if e.Header.Message != "" {
 | 
				
			||||||
 | 
							return e.Header.Message
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if e.Response.Error.Message != "" {
 | 
				
			||||||
 | 
							return e.Response.Error.Message
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ""
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *openai.ErrorWithStatusCode) {
 | 
				
			||||||
 | 
						ErrorWithStatusCode = &openai.ErrorWithStatusCode{
 | 
				
			||||||
 | 
							StatusCode: resp.StatusCode,
 | 
				
			||||||
 | 
							Error: openai.Error{
 | 
				
			||||||
 | 
								Message: "",
 | 
				
			||||||
 | 
								Type:    "upstream_error",
 | 
				
			||||||
 | 
								Code:    "bad_response_status_code",
 | 
				
			||||||
 | 
								Param:   strconv.Itoa(resp.StatusCode),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						responseBody, err := io.ReadAll(resp.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = resp.Body.Close()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var errResponse GeneralErrorResponse
 | 
				
			||||||
 | 
						err = json.Unmarshal(responseBody, &errResponse)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if errResponse.Error.Message != "" {
 | 
				
			||||||
 | 
							// OpenAI format error, so we override the default one
 | 
				
			||||||
 | 
							ErrorWithStatusCode.Error = errResponse.Error
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							ErrorWithStatusCode.Error.Message = errResponse.ToMessage()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if ErrorWithStatusCode.Error.Message == "" {
 | 
				
			||||||
 | 
							ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
 | 
				
			||||||
 | 
						fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
 | 
				
			||||||
 | 
							switch channelType {
 | 
				
			||||||
 | 
							case common.ChannelTypeOpenAI:
 | 
				
			||||||
 | 
								fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
 | 
				
			||||||
 | 
							case common.ChannelTypeAzure:
 | 
				
			||||||
 | 
								fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return fullRequestURL
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
 | 
				
			||||||
 | 
						// quotaDelta is remaining quota to be consumed
 | 
				
			||||||
 | 
						err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							common.SysError("error consuming token remain quota: " + err.Error())
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = model.CacheUpdateUserQuota(userId)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							common.SysError("error update user quota cache: " + err.Error())
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// totalQuota is total quota consumed
 | 
				
			||||||
 | 
						if totalQuota != 0 {
 | 
				
			||||||
 | 
							logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
				
			||||||
 | 
							model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
 | 
				
			||||||
 | 
							model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
 | 
				
			||||||
 | 
							model.UpdateChannelUsedQuota(channelId, totalQuota)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if totalQuota <= 0 {
 | 
				
			||||||
 | 
							common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func GetAPIVersion(c *gin.Context) string {
 | 
				
			||||||
 | 
						query := c.Request.URL.Query()
 | 
				
			||||||
 | 
						apiVersion := query.Get("api-version")
 | 
				
			||||||
 | 
						if apiVersion == "" {
 | 
				
			||||||
 | 
							apiVersion = c.GetString("api_version")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return apiVersion
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										24
									
								
								relay/util/init.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								relay/util/init.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,24 @@
 | 
				
			|||||||
 | 
					package util
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"one-api/common"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var HTTPClient *http.Client
 | 
				
			||||||
 | 
					var ImpatientHTTPClient *http.Client
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						if common.RelayTimeout == 0 {
 | 
				
			||||||
 | 
							HTTPClient = &http.Client{}
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							HTTPClient = &http.Client{
 | 
				
			||||||
 | 
								Timeout: time.Duration(common.RelayTimeout) * time.Second,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ImpatientHTTPClient = &http.Client{
 | 
				
			||||||
 | 
							Timeout: 5 * time.Second,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -8,7 +8,9 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
1. 在 `web` 文件夹下新建一个文件夹,文件夹名为主题名。
 | 
					1. 在 `web` 文件夹下新建一个文件夹,文件夹名为主题名。
 | 
				
			||||||
2. 把你的主题文件放到这个文件夹下。
 | 
					2. 把你的主题文件放到这个文件夹下。
 | 
				
			||||||
3. 修改 `package.json` 文件,把 `build` 命令改为:`"build": "react-scripts build && mv -f build ../build/default"`,其中 `default` 为你的主题名。
 | 
					3. 修改你的 `package.json` 文件,把 `build` 命令改为:`"build": "react-scripts build && mv -f build ../build/default"`,其中 `default` 为你的主题名。
 | 
				
			||||||
 | 
					4. 修改 `common/constants.go` 中的 `ValidThemes`,把你的主题名称注册进去。
 | 
				
			||||||
 | 
					5. 修改 `web/THEMES` 文件,这里也需要同步修改。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## 主题列表
 | 
					## 主题列表
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -26,7 +26,7 @@ const MinimalLayout = () => {
 | 
				
			|||||||
          <Header />
 | 
					          <Header />
 | 
				
			||||||
        </Toolbar>
 | 
					        </Toolbar>
 | 
				
			||||||
      </AppBar>
 | 
					      </AppBar>
 | 
				
			||||||
      <Box sx={{ flex: '1 1 auto', overflow: 'auto' }} paddingTop={'64px'}>
 | 
					      <Box sx={{ flex: '1 1 auto', overflow: 'auto' }} marginTop={'80px'}>
 | 
				
			||||||
        <Outlet />
 | 
					        <Outlet />
 | 
				
			||||||
      </Box>
 | 
					      </Box>
 | 
				
			||||||
      <Box sx={{ flex: 'none' }}>
 | 
					      <Box sx={{ flex: 'none' }}>
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -87,7 +87,7 @@ const panel = {
 | 
				
			|||||||
      url: '/panel/profile',
 | 
					      url: '/panel/profile',
 | 
				
			||||||
      icon: icons.IconUserScan,
 | 
					      icon: icons.IconUserScan,
 | 
				
			||||||
      breadcrumbs: false,
 | 
					      breadcrumbs: false,
 | 
				
			||||||
      isAdmin: true
 | 
					      isAdmin: false
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
      id: 'setting',
 | 
					      id: 'setting',
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -15,7 +15,7 @@ import { useSelector } from 'react-redux';
 | 
				
			|||||||
const Logo = () => {
 | 
					const Logo = () => {
 | 
				
			||||||
  const siteInfo = useSelector((state) => state.siteInfo);
 | 
					  const siteInfo = useSelector((state) => state.siteInfo);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return <img src={siteInfo.logo || logo} alt={siteInfo.system_name} width="80" />;
 | 
					  return <img src={siteInfo.logo || logo} alt={siteInfo.system_name} height="50" />;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
export default Logo;
 | 
					export default Logo;
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -14,7 +14,7 @@ API.interceptors.response.use(
 | 
				
			|||||||
    if (error.response?.status === 401) {
 | 
					    if (error.response?.status === 401) {
 | 
				
			||||||
      localStorage.removeItem('user');
 | 
					      localStorage.removeItem('user');
 | 
				
			||||||
      store.dispatch({ type: LOGIN, payload: null });
 | 
					      store.dispatch({ type: LOGIN, payload: null });
 | 
				
			||||||
      window.location.href = config.basename + '/login';
 | 
					      window.location.href = config.basename + 'login';
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (error.response?.data?.message) {
 | 
					    if (error.response?.data?.message) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -92,7 +92,7 @@ const LoginForm = ({ ...others }) => {
 | 
				
			|||||||
                  <Box sx={{ mr: { xs: 1, sm: 2, width: 20 }, display: 'flex', alignItems: 'center' }}>
 | 
					                  <Box sx={{ mr: { xs: 1, sm: 2, width: 20 }, display: 'flex', alignItems: 'center' }}>
 | 
				
			||||||
                    <img src={Github} alt="github" width={25} height={25} style={{ marginRight: matchDownSM ? 8 : 16 }} />
 | 
					                    <img src={Github} alt="github" width={25} height={25} style={{ marginRight: matchDownSM ? 8 : 16 }} />
 | 
				
			||||||
                  </Box>
 | 
					                  </Box>
 | 
				
			||||||
                  使用 Github 登录
 | 
					                  使用 GitHub 登录
 | 
				
			||||||
                </Button>
 | 
					                </Button>
 | 
				
			||||||
              </AnimateButton>
 | 
					              </AnimateButton>
 | 
				
			||||||
            </Grid>
 | 
					            </Grid>
 | 
				
			||||||
@@ -115,7 +115,7 @@ const LoginForm = ({ ...others }) => {
 | 
				
			|||||||
                  <Box sx={{ mr: { xs: 1, sm: 2, width: 20 }, display: 'flex', alignItems: 'center' }}>
 | 
					                  <Box sx={{ mr: { xs: 1, sm: 2, width: 20 }, display: 'flex', alignItems: 'center' }}>
 | 
				
			||||||
                    <img src={Wechat} alt="Wechat" width={25} height={25} style={{ marginRight: matchDownSM ? 8 : 16 }} />
 | 
					                    <img src={Wechat} alt="Wechat" width={25} height={25} style={{ marginRight: matchDownSM ? 8 : 16 }} />
 | 
				
			||||||
                  </Box>
 | 
					                  </Box>
 | 
				
			||||||
                  使用 Wechat 登录
 | 
					                  使用微信登录
 | 
				
			||||||
                </Button>
 | 
					                </Button>
 | 
				
			||||||
              </AnimateButton>
 | 
					              </AnimateButton>
 | 
				
			||||||
              <WechatModal open={openWechat} handleClose={handleWechatClose} wechatLogin={wechatLogin} qrCode={siteInfo.wechat_qrcode} />
 | 
					              <WechatModal open={openWechat} handleClose={handleWechatClose} wechatLogin={wechatLogin} qrCode={siteInfo.wechat_qrcode} />
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -21,12 +21,18 @@ import {
 | 
				
			|||||||
  Container,
 | 
					  Container,
 | 
				
			||||||
  Autocomplete,
 | 
					  Autocomplete,
 | 
				
			||||||
  FormHelperText,
 | 
					  FormHelperText,
 | 
				
			||||||
 | 
					  Checkbox
 | 
				
			||||||
} from "@mui/material";
 | 
					} from "@mui/material";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import { Formik } from "formik";
 | 
					import { Formik } from "formik";
 | 
				
			||||||
import * as Yup from "yup";
 | 
					import * as Yup from "yup";
 | 
				
			||||||
import { defaultConfig, typeConfig } from "../type/Config"; //typeConfig
 | 
					import { defaultConfig, typeConfig } from "../type/Config"; //typeConfig
 | 
				
			||||||
import { createFilterOptions } from "@mui/material/Autocomplete";
 | 
					import { createFilterOptions } from "@mui/material/Autocomplete";
 | 
				
			||||||
 | 
					import CheckBoxOutlineBlankIcon from '@mui/icons-material/CheckBoxOutlineBlank';
 | 
				
			||||||
 | 
					import CheckBoxIcon from '@mui/icons-material/CheckBox';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const icon = <CheckBoxOutlineBlankIcon fontSize="small" />;
 | 
				
			||||||
 | 
					const checkedIcon = <CheckBoxIcon fontSize="small" />;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const filter = createFilterOptions();
 | 
					const filter = createFilterOptions();
 | 
				
			||||||
const validationSchema = Yup.object().shape({
 | 
					const validationSchema = Yup.object().shape({
 | 
				
			||||||
@@ -38,12 +44,10 @@ const validationSchema = Yup.object().shape({
 | 
				
			|||||||
    then: Yup.string().required("密钥 不能为空"),
 | 
					    then: Yup.string().required("密钥 不能为空"),
 | 
				
			||||||
  }),
 | 
					  }),
 | 
				
			||||||
  other: Yup.string(),
 | 
					  other: Yup.string(),
 | 
				
			||||||
  proxy: Yup.string(),
 | 
					 | 
				
			||||||
  test_model: Yup.string(),
 | 
					 | 
				
			||||||
  models: Yup.array().min(1, "模型 不能为空"),
 | 
					  models: Yup.array().min(1, "模型 不能为空"),
 | 
				
			||||||
  groups: Yup.array().min(1, "用户组 不能为空"),
 | 
					  groups: Yup.array().min(1, "用户组 不能为空"),
 | 
				
			||||||
  base_url: Yup.string().when("type", {
 | 
					  base_url: Yup.string().when("type", {
 | 
				
			||||||
    is: (value) => [3, 24, 8].includes(value),
 | 
					    is: (value) => [3, 8].includes(value),
 | 
				
			||||||
    then: Yup.string().required("渠道API地址 不能为空"), // base_url 是必需的
 | 
					    then: Yup.string().required("渠道API地址 不能为空"), // base_url 是必需的
 | 
				
			||||||
    otherwise: Yup.string(), // 在其他情况下,base_url 可以是任意字符串
 | 
					    otherwise: Yup.string(), // 在其他情况下,base_url 可以是任意字符串
 | 
				
			||||||
  }),
 | 
					  }),
 | 
				
			||||||
@@ -146,8 +150,23 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
 | 
				
			|||||||
  const fetchModels = async () => {
 | 
					  const fetchModels = async () => {
 | 
				
			||||||
    try {
 | 
					    try {
 | 
				
			||||||
      let res = await API.get(`/api/channel/models`);
 | 
					      let res = await API.get(`/api/channel/models`);
 | 
				
			||||||
 | 
					      const { data } = res.data;
 | 
				
			||||||
 | 
					      data.forEach(item => {
 | 
				
			||||||
 | 
					        if (!item.owned_by) {
 | 
				
			||||||
 | 
					          item.owned_by = "未知";
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      });
 | 
				
			||||||
 | 
					      // 先对data排序
 | 
				
			||||||
 | 
					      data.sort((a, b) => {
 | 
				
			||||||
 | 
					        const ownedByComparison = a.owned_by.localeCompare(b.owned_by);
 | 
				
			||||||
 | 
					        if (ownedByComparison === 0) {
 | 
				
			||||||
 | 
					          return a.id.localeCompare(b.id);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        return ownedByComparison;
 | 
				
			||||||
 | 
					      });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      setModelOptions(
 | 
					      setModelOptions(
 | 
				
			||||||
        res.data.data.map((model) => {
 | 
					         data.map((model) => {
 | 
				
			||||||
          return {
 | 
					          return {
 | 
				
			||||||
            id: model.id,
 | 
					            id: model.id,
 | 
				
			||||||
            group: model.owned_by,
 | 
					            group: model.owned_by,
 | 
				
			||||||
@@ -239,6 +258,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
 | 
				
			|||||||
          2
 | 
					          2
 | 
				
			||||||
        );
 | 
					        );
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					      data.base_url = data.base_url ?? '';
 | 
				
			||||||
      data.is_edit = true;
 | 
					      data.is_edit = true;
 | 
				
			||||||
      initChannel(data.type);
 | 
					      initChannel(data.type);
 | 
				
			||||||
      setInitialInput(data);
 | 
					      setInitialInput(data);
 | 
				
			||||||
@@ -250,12 +270,16 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
 | 
				
			|||||||
  useEffect(() => {
 | 
					  useEffect(() => {
 | 
				
			||||||
    fetchGroups().then();
 | 
					    fetchGroups().then();
 | 
				
			||||||
    fetchModels().then();
 | 
					    fetchModels().then();
 | 
				
			||||||
 | 
					  }, []);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  useEffect(() => {
 | 
				
			||||||
    if (channelId) {
 | 
					    if (channelId) {
 | 
				
			||||||
      loadChannel().then();
 | 
					      loadChannel().then();
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      initChannel(1);
 | 
					      initChannel(1);
 | 
				
			||||||
      setInitialInput({ ...defaultConfig.input, is_edit: false });
 | 
					      setInitialInput({ ...defaultConfig.input, is_edit: false });
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					    // eslint-disable-next-line react-hooks/exhaustive-deps
 | 
				
			||||||
  }, [channelId]);
 | 
					  }, [channelId]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return (
 | 
					  return (
 | 
				
			||||||
@@ -491,7 +515,8 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
 | 
				
			|||||||
                    handleChange(event);
 | 
					                    handleChange(event);
 | 
				
			||||||
                  }}
 | 
					                  }}
 | 
				
			||||||
                  onBlur={handleBlur}
 | 
					                  onBlur={handleBlur}
 | 
				
			||||||
                  filterSelectedOptions
 | 
					                  // filterSelectedOptions
 | 
				
			||||||
 | 
					                  disableCloseOnSelect
 | 
				
			||||||
                  renderInput={(params) => (
 | 
					                  renderInput={(params) => (
 | 
				
			||||||
                    <TextField
 | 
					                    <TextField
 | 
				
			||||||
                      {...params}
 | 
					                      {...params}
 | 
				
			||||||
@@ -524,6 +549,12 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
 | 
				
			|||||||
                    }
 | 
					                    }
 | 
				
			||||||
                    return filtered;
 | 
					                    return filtered;
 | 
				
			||||||
                  }}
 | 
					                  }}
 | 
				
			||||||
 | 
					                  renderOption={(props, option, { selected }) => (
 | 
				
			||||||
 | 
					                    <li {...props}>
 | 
				
			||||||
 | 
					                      <Checkbox icon={icon} checkedIcon={checkedIcon} style={{ marginRight: 8 }} checked={selected} />
 | 
				
			||||||
 | 
					                      {option.id}
 | 
				
			||||||
 | 
					                    </li>
 | 
				
			||||||
 | 
					                  )}
 | 
				
			||||||
                />
 | 
					                />
 | 
				
			||||||
                {errors.models ? (
 | 
					                {errors.models ? (
 | 
				
			||||||
                  <FormHelperText error id="helper-tex-channel-models-label">
 | 
					                  <FormHelperText error id="helper-tex-channel-models-label">
 | 
				
			||||||
@@ -623,71 +654,6 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
 | 
				
			|||||||
                  </FormHelperText>
 | 
					                  </FormHelperText>
 | 
				
			||||||
                )}
 | 
					                )}
 | 
				
			||||||
              </FormControl>
 | 
					              </FormControl>
 | 
				
			||||||
              <FormControl
 | 
					 | 
				
			||||||
                fullWidth
 | 
					 | 
				
			||||||
                error={Boolean(touched.proxy && errors.proxy)}
 | 
					 | 
				
			||||||
                sx={{ ...theme.typography.otherInput }}
 | 
					 | 
				
			||||||
              >
 | 
					 | 
				
			||||||
                <InputLabel htmlFor="channel-proxy-label">
 | 
					 | 
				
			||||||
                  {inputLabel.proxy}
 | 
					 | 
				
			||||||
                </InputLabel>
 | 
					 | 
				
			||||||
                <OutlinedInput
 | 
					 | 
				
			||||||
                  id="channel-proxy-label"
 | 
					 | 
				
			||||||
                  label={inputLabel.proxy}
 | 
					 | 
				
			||||||
                  type="text"
 | 
					 | 
				
			||||||
                  value={values.proxy}
 | 
					 | 
				
			||||||
                  name="proxy"
 | 
					 | 
				
			||||||
                  onBlur={handleBlur}
 | 
					 | 
				
			||||||
                  onChange={handleChange}
 | 
					 | 
				
			||||||
                  inputProps={{}}
 | 
					 | 
				
			||||||
                  aria-describedby="helper-text-channel-proxy-label"
 | 
					 | 
				
			||||||
                />
 | 
					 | 
				
			||||||
                {touched.proxy && errors.proxy ? (
 | 
					 | 
				
			||||||
                  <FormHelperText error id="helper-tex-channel-proxy-label">
 | 
					 | 
				
			||||||
                    {errors.proxy}
 | 
					 | 
				
			||||||
                  </FormHelperText>
 | 
					 | 
				
			||||||
                ) : (
 | 
					 | 
				
			||||||
                  <FormHelperText id="helper-tex-channel-proxy-label">
 | 
					 | 
				
			||||||
                    {" "}
 | 
					 | 
				
			||||||
                    {inputPrompt.proxy}{" "}
 | 
					 | 
				
			||||||
                  </FormHelperText>
 | 
					 | 
				
			||||||
                )}
 | 
					 | 
				
			||||||
              </FormControl>
 | 
					 | 
				
			||||||
              {inputPrompt.test_model && (
 | 
					 | 
				
			||||||
                <FormControl
 | 
					 | 
				
			||||||
                  fullWidth
 | 
					 | 
				
			||||||
                  error={Boolean(touched.test_model && errors.test_model)}
 | 
					 | 
				
			||||||
                  sx={{ ...theme.typography.otherInput }}
 | 
					 | 
				
			||||||
                >
 | 
					 | 
				
			||||||
                  <InputLabel htmlFor="channel-test_model-label">
 | 
					 | 
				
			||||||
                    {inputLabel.test_model}
 | 
					 | 
				
			||||||
                  </InputLabel>
 | 
					 | 
				
			||||||
                  <OutlinedInput
 | 
					 | 
				
			||||||
                    id="channel-test_model-label"
 | 
					 | 
				
			||||||
                    label={inputLabel.test_model}
 | 
					 | 
				
			||||||
                    type="text"
 | 
					 | 
				
			||||||
                    value={values.test_model}
 | 
					 | 
				
			||||||
                    name="test_model"
 | 
					 | 
				
			||||||
                    onBlur={handleBlur}
 | 
					 | 
				
			||||||
                    onChange={handleChange}
 | 
					 | 
				
			||||||
                    inputProps={{}}
 | 
					 | 
				
			||||||
                    aria-describedby="helper-text-channel-test_model-label"
 | 
					 | 
				
			||||||
                  />
 | 
					 | 
				
			||||||
                  {touched.test_model && errors.test_model ? (
 | 
					 | 
				
			||||||
                    <FormHelperText
 | 
					 | 
				
			||||||
                      error
 | 
					 | 
				
			||||||
                      id="helper-tex-channel-test_model-label"
 | 
					 | 
				
			||||||
                    >
 | 
					 | 
				
			||||||
                      {errors.test_model}
 | 
					 | 
				
			||||||
                    </FormHelperText>
 | 
					 | 
				
			||||||
                  ) : (
 | 
					 | 
				
			||||||
                    <FormHelperText id="helper-tex-channel-test_model-label">
 | 
					 | 
				
			||||||
                      {" "}
 | 
					 | 
				
			||||||
                      {inputPrompt.test_model}{" "}
 | 
					 | 
				
			||||||
                    </FormHelperText>
 | 
					 | 
				
			||||||
                  )}
 | 
					 | 
				
			||||||
                </FormControl>
 | 
					 | 
				
			||||||
              )}
 | 
					 | 
				
			||||||
              <DialogActions>
 | 
					              <DialogActions>
 | 
				
			||||||
                <Button onClick={onCancel}>取消</Button>
 | 
					                <Button onClick={onCancel}>取消</Button>
 | 
				
			||||||
                <Button
 | 
					                <Button
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,7 +12,7 @@ import {
 | 
				
			|||||||
    DialogTitle,
 | 
					    DialogTitle,
 | 
				
			||||||
    DialogActions,
 | 
					    DialogActions,
 | 
				
			||||||
    DialogContent,
 | 
					    DialogContent,
 | 
				
			||||||
  Divider
 | 
					    Divider, Link
 | 
				
			||||||
} from '@mui/material';
 | 
					} from '@mui/material';
 | 
				
			||||||
import Grid from '@mui/material/Unstable_Grid2';
 | 
					import Grid from '@mui/material/Unstable_Grid2';
 | 
				
			||||||
import { showError, showSuccess } from 'utils/common'; //,
 | 
					import { showError, showSuccess } from 'utils/common'; //,
 | 
				
			||||||
@@ -26,7 +26,8 @@ const OtherSetting = () => {
 | 
				
			|||||||
    About: '',
 | 
					    About: '',
 | 
				
			||||||
    SystemName: '',
 | 
					    SystemName: '',
 | 
				
			||||||
    Logo: '',
 | 
					    Logo: '',
 | 
				
			||||||
    HomePageContent: ''
 | 
					    HomePageContent: '',
 | 
				
			||||||
 | 
					    Theme: '',
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
  let [loading, setLoading] = useState(false);
 | 
					  let [loading, setLoading] = useState(false);
 | 
				
			||||||
  const [showUpdateModal, setShowUpdateModal] = useState(false);
 | 
					  const [showUpdateModal, setShowUpdateModal] = useState(false);
 | 
				
			||||||
@@ -88,6 +89,10 @@ const OtherSetting = () => {
 | 
				
			|||||||
    await updateOption('SystemName', inputs.SystemName);
 | 
					    await updateOption('SystemName', inputs.SystemName);
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const submitTheme = async () => {
 | 
				
			||||||
 | 
					    await updateOption('Theme', inputs.Theme);
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const submitLogo = async () => {
 | 
					  const submitLogo = async () => {
 | 
				
			||||||
    await updateOption('Logo', inputs.Logo);
 | 
					    await updateOption('Logo', inputs.Logo);
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
@@ -171,6 +176,25 @@ const OtherSetting = () => {
 | 
				
			|||||||
                设置系统名称
 | 
					                设置系统名称
 | 
				
			||||||
              </Button>
 | 
					              </Button>
 | 
				
			||||||
            </Grid>
 | 
					            </Grid>
 | 
				
			||||||
 | 
					            <Grid xs={12}>
 | 
				
			||||||
 | 
					              <FormControl fullWidth>
 | 
				
			||||||
 | 
					                <InputLabel htmlFor="Theme">主题名称</InputLabel>
 | 
				
			||||||
 | 
					                <OutlinedInput
 | 
				
			||||||
 | 
					                    id="Theme"
 | 
				
			||||||
 | 
					                    name="Theme"
 | 
				
			||||||
 | 
					                    value={inputs.Theme || ''}
 | 
				
			||||||
 | 
					                    onChange={handleInputChange}
 | 
				
			||||||
 | 
					                    label="主题名称"
 | 
				
			||||||
 | 
					                    placeholder="请输入主题名称"
 | 
				
			||||||
 | 
					                    disabled={loading}
 | 
				
			||||||
 | 
					                />
 | 
				
			||||||
 | 
					              </FormControl>
 | 
				
			||||||
 | 
					            </Grid>
 | 
				
			||||||
 | 
					            <Grid xs={12}>
 | 
				
			||||||
 | 
					              <Button variant="contained" onClick={submitTheme}>
 | 
				
			||||||
 | 
					                设置主题(重启生效)
 | 
				
			||||||
 | 
					              </Button>
 | 
				
			||||||
 | 
					            </Grid>
 | 
				
			||||||
            <Grid xs={12}>
 | 
					            <Grid xs={12}>
 | 
				
			||||||
              <FormControl fullWidth>
 | 
					              <FormControl fullWidth>
 | 
				
			||||||
                <InputLabel htmlFor="Logo">Logo 图片地址</InputLabel>
 | 
					                <InputLabel htmlFor="Logo">Logo 图片地址</InputLabel>
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -192,7 +192,7 @@ export default function TokensTableRow({ item, manageToken, handleOpenModal, set
 | 
				
			|||||||
              id={`switch-${item.id}`}
 | 
					              id={`switch-${item.id}`}
 | 
				
			||||||
              checked={statusSwitch === 1}
 | 
					              checked={statusSwitch === 1}
 | 
				
			||||||
              onChange={handleStatus}
 | 
					              onChange={handleStatus}
 | 
				
			||||||
              disabled={statusSwitch !== 1 && statusSwitch !== 2}
 | 
					              // disabled={statusSwitch !== 1 && statusSwitch !== 2}
 | 
				
			||||||
            />
 | 
					            />
 | 
				
			||||||
          </Tooltip>
 | 
					          </Tooltip>
 | 
				
			||||||
        </TableCell>
 | 
					        </TableCell>
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,6 +2,7 @@ import React, { useEffect, useState } from 'react';
 | 
				
			|||||||
import { Button, Divider, Form, Grid, Header, Message, Modal } from 'semantic-ui-react';
 | 
					import { Button, Divider, Form, Grid, Header, Message, Modal } from 'semantic-ui-react';
 | 
				
			||||||
import { API, showError, showSuccess } from '../helpers';
 | 
					import { API, showError, showSuccess } from '../helpers';
 | 
				
			||||||
import { marked } from 'marked';
 | 
					import { marked } from 'marked';
 | 
				
			||||||
 | 
					import { Link } from 'react-router-dom';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const OtherSetting = () => {
 | 
					const OtherSetting = () => {
 | 
				
			||||||
  let [inputs, setInputs] = useState({
 | 
					  let [inputs, setInputs] = useState({
 | 
				
			||||||
@@ -10,7 +11,8 @@ const OtherSetting = () => {
 | 
				
			|||||||
    About: '',
 | 
					    About: '',
 | 
				
			||||||
    SystemName: '',
 | 
					    SystemName: '',
 | 
				
			||||||
    Logo: '',
 | 
					    Logo: '',
 | 
				
			||||||
    HomePageContent: ''
 | 
					    HomePageContent: '',
 | 
				
			||||||
 | 
					    Theme: ''
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
  let [loading, setLoading] = useState(false);
 | 
					  let [loading, setLoading] = useState(false);
 | 
				
			||||||
  const [showUpdateModal, setShowUpdateModal] = useState(false);
 | 
					  const [showUpdateModal, setShowUpdateModal] = useState(false);
 | 
				
			||||||
@@ -70,6 +72,10 @@ const OtherSetting = () => {
 | 
				
			|||||||
    await updateOption('SystemName', inputs.SystemName);
 | 
					    await updateOption('SystemName', inputs.SystemName);
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const submitTheme = async () => {
 | 
				
			||||||
 | 
					    await updateOption('Theme', inputs.Theme);
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const submitLogo = async () => {
 | 
					  const submitLogo = async () => {
 | 
				
			||||||
    await updateOption('Logo', inputs.Logo);
 | 
					    await updateOption('Logo', inputs.Logo);
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
@@ -132,6 +138,17 @@ const OtherSetting = () => {
 | 
				
			|||||||
            />
 | 
					            />
 | 
				
			||||||
          </Form.Group>
 | 
					          </Form.Group>
 | 
				
			||||||
          <Form.Button onClick={submitSystemName}>设置系统名称</Form.Button>
 | 
					          <Form.Button onClick={submitSystemName}>设置系统名称</Form.Button>
 | 
				
			||||||
 | 
					          <Form.Group widths='equal'>
 | 
				
			||||||
 | 
					            <Form.Input
 | 
				
			||||||
 | 
					              label={<label>主题名称(<Link
 | 
				
			||||||
 | 
					                to='https://github.com/songquanpeng/one-api/blob/main/web/README.md'>当前可用主题</Link>)</label>}
 | 
				
			||||||
 | 
					              placeholder='请输入主题名称'
 | 
				
			||||||
 | 
					              value={inputs.Theme}
 | 
				
			||||||
 | 
					              name='Theme'
 | 
				
			||||||
 | 
					              onChange={handleInputChange}
 | 
				
			||||||
 | 
					            />
 | 
				
			||||||
 | 
					          </Form.Group>
 | 
				
			||||||
 | 
					          <Form.Button onClick={submitTheme}>设置主题(重启生效)</Form.Button>
 | 
				
			||||||
          <Form.Group widths='equal'>
 | 
					          <Form.Group widths='equal'>
 | 
				
			||||||
            <Form.Input
 | 
					            <Form.Input
 | 
				
			||||||
              label='Logo 图片地址'
 | 
					              label='Logo 图片地址'
 | 
				
			||||||
@@ -165,7 +182,8 @@ const OtherSetting = () => {
 | 
				
			|||||||
            />
 | 
					            />
 | 
				
			||||||
          </Form.Group>
 | 
					          </Form.Group>
 | 
				
			||||||
          <Form.Button onClick={submitAbout}>保存关于</Form.Button>
 | 
					          <Form.Button onClick={submitAbout}>保存关于</Form.Button>
 | 
				
			||||||
          <Message>移除 One API 的版权标识必须首先获得授权,项目维护需要花费大量精力,如果本项目对你有意义,请主动支持本项目。</Message>
 | 
					          <Message>移除 One API
 | 
				
			||||||
 | 
					            的版权标识必须首先获得授权,项目维护需要花费大量精力,如果本项目对你有意义,请主动支持本项目。</Message>
 | 
				
			||||||
          <Form.Group widths='equal'>
 | 
					          <Form.Group widths='equal'>
 | 
				
			||||||
            <Form.Input
 | 
					            <Form.Input
 | 
				
			||||||
              label='页脚'
 | 
					              label='页脚'
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user