mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-30 21:33:41 +08:00 
			
		
		
		
	Compare commits
	
		
			9 Commits
		
	
	
		
			v0.3.0
			...
			v0.3.2-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 3c6834a79c | ||
|  | 6da3410823 | ||
|  | ceb289cb4d | ||
|  | 6f8cc712b0 | ||
|  | ad01e1f3b3 | ||
|  | cc1ef2ffd5 | ||
|  | 7201bd1c97 | ||
|  | 73d5e0f283 | ||
|  | efc744ca35 | 
| @@ -48,10 +48,10 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用 | |||||||
|    + [x] OpenAI 官方通道 |    + [x] OpenAI 官方通道 | ||||||
|    + [x] **Azure OpenAI API** |    + [x] **Azure OpenAI API** | ||||||
|    + [x] [API2D](https://api2d.com/r/197971) |    + [x] [API2D](https://api2d.com/r/197971) | ||||||
|  |    + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) | ||||||
|    + [x] [CloseAI](https://console.openai-asia.com) |    + [x] [CloseAI](https://console.openai-asia.com) | ||||||
|    + [x] [OpenAI-SB](https://openai-sb.com) |    + [x] [OpenAI-SB](https://openai-sb.com) | ||||||
|    + [x] [OpenAI Max](https://openaimax.com) |    + [x] [OpenAI Max](https://openaimax.com) | ||||||
|    + [x] [OhMyGPT](https://www.ohmygpt.com) |  | ||||||
|    + [x] 自定义渠道:例如使用自行搭建的 OpenAI 代理 |    + [x] 自定义渠道:例如使用自行搭建的 OpenAI 代理 | ||||||
| 2. 支持通过**负载均衡**的方式访问多个渠道。 | 2. 支持通过**负载均衡**的方式访问多个渠道。 | ||||||
| 3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | 3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||||
| @@ -157,6 +157,8 @@ sudo service nginx restart | |||||||
|    + 例子:`SESSION_SECRET=random_string` |    + 例子:`SESSION_SECRET=random_string` | ||||||
| 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite。 | 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite。 | ||||||
|    + 例子:`SQL_DSN=root:123456@tcp(localhost:3306)/one-api` |    + 例子:`SQL_DSN=root:123456@tcp(localhost:3306)/one-api` | ||||||
|  | 4. `FRONTEND_BASE_URL`:设置之后将使用指定的前端地址,而非后端地址。 | ||||||
|  |    + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||||
|  |  | ||||||
| ### 命令行参数 | ### 命令行参数 | ||||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||||
|   | |||||||
| @@ -132,7 +132,7 @@ const ( | |||||||
| var ChannelBaseURLs = []string{ | var ChannelBaseURLs = []string{ | ||||||
| 	"",                            // 0 | 	"",                            // 0 | ||||||
| 	"https://api.openai.com",      // 1 | 	"https://api.openai.com",      // 1 | ||||||
| 	"https://openai.api2d.net",    // 2 | 	"https://oa.api2d.net",        // 2 | ||||||
| 	"",                            // 3 | 	"",                            // 3 | ||||||
| 	"https://api.openai-asia.com", // 4 | 	"https://api.openai-asia.com", // 4 | ||||||
| 	"https://api.openai-sb.com",   // 5 | 	"https://api.openai-sb.com",   // 5 | ||||||
|   | |||||||
| @@ -201,7 +201,7 @@ func testChannel(channel *model.Channel, request *ChatRequest) error { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	if response.Error.Type != "" { | 	if response.Error.Message != "" { | ||||||
| 		return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message)) | 		return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message)) | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| @@ -265,14 +265,14 @@ var testAllChannelsLock sync.Mutex | |||||||
| var testAllChannelsRunning bool = false | var testAllChannelsRunning bool = false | ||||||
|  |  | ||||||
| // disable & notify | // disable & notify | ||||||
| func disableChannel(channelId int, channelName string, err error) { | func disableChannel(channelId int, channelName string, reason string) { | ||||||
| 	if common.RootUserEmail == "" { | 	if common.RootUserEmail == "" { | ||||||
| 		common.RootUserEmail = model.GetRootUserEmail() | 		common.RootUserEmail = model.GetRootUserEmail() | ||||||
| 	} | 	} | ||||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled) | 	model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled) | ||||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) | 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) | ||||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, err.Error()) | 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) | ||||||
| 	err = common.SendEmail(subject, common.RootUserEmail, content) | 	err := common.SendEmail(subject, common.RootUserEmail, content) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error())) | 		common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error())) | ||||||
| 	} | 	} | ||||||
| @@ -312,7 +312,7 @@ func testAllChannels(c *gin.Context) error { | |||||||
| 				if milliseconds > disableThreshold { | 				if milliseconds > disableThreshold { | ||||||
| 					err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | 					err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||||
| 				} | 				} | ||||||
| 				disableChannel(channel.Id, channel.Name, err) | 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||||
| 			} | 			} | ||||||
| 			channel.UpdateResponseTime(milliseconds) | 			channel.UpdateResponseTime(milliseconds) | ||||||
| 		} | 		} | ||||||
|   | |||||||
							
								
								
									
										153
									
								
								controller/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										153
									
								
								controller/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,153 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // https://platform.openai.com/docs/api-reference/models/list | ||||||
|  |  | ||||||
|  | type OpenAIModelPermission struct { | ||||||
|  | 	Id                 string  `json:"id"` | ||||||
|  | 	Object             string  `json:"object"` | ||||||
|  | 	Created            int     `json:"created"` | ||||||
|  | 	AllowCreateEngine  bool    `json:"allow_create_engine"` | ||||||
|  | 	AllowSampling      bool    `json:"allow_sampling"` | ||||||
|  | 	AllowLogprobs      bool    `json:"allow_logprobs"` | ||||||
|  | 	AllowSearchIndices bool    `json:"allow_search_indices"` | ||||||
|  | 	AllowView          bool    `json:"allow_view"` | ||||||
|  | 	AllowFineTuning    bool    `json:"allow_fine_tuning"` | ||||||
|  | 	Organization       string  `json:"organization"` | ||||||
|  | 	Group              *string `json:"group"` | ||||||
|  | 	IsBlocking         bool    `json:"is_blocking"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type OpenAIModels struct { | ||||||
|  | 	Id         string                `json:"id"` | ||||||
|  | 	Object     string                `json:"object"` | ||||||
|  | 	Created    int                   `json:"created"` | ||||||
|  | 	OwnedBy    string                `json:"owned_by"` | ||||||
|  | 	Permission OpenAIModelPermission `json:"permission"` | ||||||
|  | 	Root       string                `json:"root"` | ||||||
|  | 	Parent     *string               `json:"parent"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var openAIModels []OpenAIModels | ||||||
|  | var openAIModelsMap map[string]OpenAIModels | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	permission := OpenAIModelPermission{ | ||||||
|  | 		Id:                 "modelperm-LwHkVFn8AcMItP432fKKDIKJ", | ||||||
|  | 		Object:             "model_permission", | ||||||
|  | 		Created:            1626777600, | ||||||
|  | 		AllowCreateEngine:  true, | ||||||
|  | 		AllowSampling:      true, | ||||||
|  | 		AllowLogprobs:      true, | ||||||
|  | 		AllowSearchIndices: false, | ||||||
|  | 		AllowView:          true, | ||||||
|  | 		AllowFineTuning:    false, | ||||||
|  | 		Organization:       "*", | ||||||
|  | 		Group:              nil, | ||||||
|  | 		IsBlocking:         false, | ||||||
|  | 	} | ||||||
|  | 	// https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||||
|  | 	openAIModels = []OpenAIModels{ | ||||||
|  | 		{ | ||||||
|  | 			Id:         "gpt-3.5-turbo", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "openai", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "gpt-3.5-turbo", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "gpt-3.5-turbo-0301", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "openai", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "gpt-3.5-turbo-0301", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "gpt-4", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "openai", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "gpt-4", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "gpt-4-0314", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "openai", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "gpt-4-0314", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "gpt-4-32k", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "openai", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "gpt-4-32k", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "gpt-4-32k-0314", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "openai", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "gpt-4-32k-0314", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "gpt-3.5-turbo", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "openai", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "gpt-3.5-turbo", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "text-embedding-ada-002", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "openai", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "text-embedding-ada-002", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	openAIModelsMap = make(map[string]OpenAIModels) | ||||||
|  | 	for _, model := range openAIModels { | ||||||
|  | 		openAIModelsMap[model.Id] = model | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ListModels(c *gin.Context) { | ||||||
|  | 	c.JSON(200, openAIModels) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func RetrieveModel(c *gin.Context) { | ||||||
|  | 	modelId := c.Param("model") | ||||||
|  | 	if model, ok := openAIModelsMap[modelId]; ok { | ||||||
|  | 		c.JSON(200, model) | ||||||
|  | 	} else { | ||||||
|  | 		openAIError := OpenAIError{ | ||||||
|  | 			Message: fmt.Sprintf("The model '%s' does not exist", modelId), | ||||||
|  | 			Type:    "invalid_request_error", | ||||||
|  | 			Param:   "model", | ||||||
|  | 			Code:    "model_not_found", | ||||||
|  | 		} | ||||||
|  | 		c.JSON(200, gin.H{ | ||||||
|  | 			"error": openAIError, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -4,7 +4,6 @@ import ( | |||||||
| 	"bufio" | 	"bufio" | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" |  | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/pkoukk/tiktoken-go" | 	"github.com/pkoukk/tiktoken-go" | ||||||
| @@ -47,6 +46,11 @@ type OpenAIError struct { | |||||||
| 	Code    string `json:"code"` | 	Code    string `json:"code"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type OpenAIErrorWithStatusCode struct { | ||||||
|  | 	OpenAIError | ||||||
|  | 	StatusCode int `json:"status_code"` | ||||||
|  | } | ||||||
|  |  | ||||||
| type TextResponse struct { | type TextResponse struct { | ||||||
| 	Usage `json:"usage"` | 	Usage `json:"usage"` | ||||||
| 	Error OpenAIError `json:"error"` | 	Error OpenAIError `json:"error"` | ||||||
| @@ -71,21 +75,33 @@ func countToken(text string) int { | |||||||
| func Relay(c *gin.Context) { | func Relay(c *gin.Context) { | ||||||
| 	err := relayHelper(c) | 	err := relayHelper(c) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(err.StatusCode, gin.H{ | ||||||
| 			"error": gin.H{ | 			"error": err.OpenAIError, | ||||||
| 				"message": err.Error(), |  | ||||||
| 				"type":    "one_api_error", |  | ||||||
| 			}, |  | ||||||
| 		}) | 		}) | ||||||
| 		if common.AutomaticDisableChannelEnabled { | 		channelId := c.GetInt("channel_id") | ||||||
|  | 		common.SysError(fmt.Sprintf("Relay error (channel #%d): %s", channelId, err.Message)) | ||||||
|  | 		if err.Type != "invalid_request_error" && err.StatusCode != http.StatusTooManyRequests && | ||||||
|  | 			common.AutomaticDisableChannelEnabled { | ||||||
| 			channelId := c.GetInt("channel_id") | 			channelId := c.GetInt("channel_id") | ||||||
| 			channelName := c.GetString("channel_name") | 			channelName := c.GetString("channel_name") | ||||||
| 			disableChannel(channelId, channelName, err) | 			disableChannel(channelId, channelName, err.Message) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func relayHelper(c *gin.Context) error { | func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode { | ||||||
|  | 	openAIError := OpenAIError{ | ||||||
|  | 		Message: err.Error(), | ||||||
|  | 		Type:    "one_api_error", | ||||||
|  | 		Code:    code, | ||||||
|  | 	} | ||||||
|  | 	return &OpenAIErrorWithStatusCode{ | ||||||
|  | 		OpenAIError: openAIError, | ||||||
|  | 		StatusCode:  statusCode, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode { | ||||||
| 	channelType := c.GetInt("channel") | 	channelType := c.GetInt("channel") | ||||||
| 	tokenId := c.GetInt("token_id") | 	tokenId := c.GetInt("token_id") | ||||||
| 	consumeQuota := c.GetBool("consume_quota") | 	consumeQuota := c.GetBool("consume_quota") | ||||||
| @@ -93,15 +109,15 @@ func relayHelper(c *gin.Context) error { | |||||||
| 	if consumeQuota || channelType == common.ChannelTypeAzure { | 	if consumeQuota || channelType == common.ChannelTypeAzure { | ||||||
| 		requestBody, err := io.ReadAll(c.Request.Body) | 		requestBody, err := io.ReadAll(c.Request.Body) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest) | ||||||
| 		} | 		} | ||||||
| 		err = c.Request.Body.Close() | 		err = c.Request.Body.Close() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return errorWrapper(err, "close_request_body_failed", http.StatusBadRequest) | ||||||
| 		} | 		} | ||||||
| 		err = json.Unmarshal(requestBody, &textRequest) | 		err = json.Unmarshal(requestBody, &textRequest) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return errorWrapper(err, "unmarshal_request_body_failed", http.StatusBadRequest) | ||||||
| 		} | 		} | ||||||
| 		// Reset request body | 		// Reset request body | ||||||
| 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||||
| @@ -144,12 +160,12 @@ func relayHelper(c *gin.Context) error { | |||||||
| 	if consumeQuota { | 	if consumeQuota { | ||||||
| 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body) | 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return errorWrapper(err, "new_request_failed", http.StatusOK) | ||||||
| 	} | 	} | ||||||
| 	if channelType == common.ChannelTypeAzure { | 	if channelType == common.ChannelTypeAzure { | ||||||
| 		key := c.Request.Header.Get("Authorization") | 		key := c.Request.Header.Get("Authorization") | ||||||
| @@ -164,18 +180,18 @@ func relayHelper(c *gin.Context) error { | |||||||
| 	client := &http.Client{} | 	client := &http.Client{} | ||||||
| 	resp, err := client.Do(req) | 	resp, err := client.Do(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return errorWrapper(err, "do_request_failed", http.StatusOK) | ||||||
| 	} | 	} | ||||||
| 	err = req.Body.Close() | 	err = req.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return errorWrapper(err, "close_request_body_failed", http.StatusOK) | ||||||
| 	} | 	} | ||||||
| 	err = c.Request.Body.Close() | 	err = c.Request.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return errorWrapper(err, "close_request_body_failed", http.StatusOK) | ||||||
| 	} | 	} | ||||||
| 	var textResponse TextResponse | 	var textResponse TextResponse | ||||||
| 	isStream := resp.Header.Get("Content-Type") == "text/event-stream" | 	isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||||
| 	var streamResponseText string | 	var streamResponseText string | ||||||
|  |  | ||||||
| 	defer func() { | 	defer func() { | ||||||
| @@ -257,50 +273,60 @@ func relayHelper(c *gin.Context) error { | |||||||
| 		}) | 		}) | ||||||
| 		err = resp.Body.Close() | 		err = resp.Body.Close() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return errorWrapper(err, "close_response_body_failed", http.StatusOK) | ||||||
| 		} | 		} | ||||||
| 		return nil | 		return nil | ||||||
| 	} else { | 	} else { | ||||||
| 		for k, v := range resp.Header { |  | ||||||
| 			c.Writer.Header().Set(k, v[0]) |  | ||||||
| 		} |  | ||||||
| 		if consumeQuota { | 		if consumeQuota { | ||||||
| 			responseBody, err := io.ReadAll(resp.Body) | 			responseBody, err := io.ReadAll(resp.Body) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return errorWrapper(err, "read_response_body_failed", http.StatusOK) | ||||||
| 			} | 			} | ||||||
| 			err = resp.Body.Close() | 			err = resp.Body.Close() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return errorWrapper(err, "close_response_body_failed", http.StatusOK) | ||||||
| 			} | 			} | ||||||
| 			err = json.Unmarshal(responseBody, &textResponse) | 			err = json.Unmarshal(responseBody, &textResponse) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return errorWrapper(err, "unmarshal_response_body_failed", http.StatusOK) | ||||||
| 			} | 			} | ||||||
| 			if textResponse.Error.Type != "" { | 			if textResponse.Error.Type != "" { | ||||||
| 				return errors.New(fmt.Sprintf("type %s, code %s, message %s", | 				return &OpenAIErrorWithStatusCode{ | ||||||
| 					textResponse.Error.Type, textResponse.Error.Code, textResponse.Error.Message)) | 					OpenAIError: textResponse.Error, | ||||||
|  | 					StatusCode:  resp.StatusCode, | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 			// Reset response body | 			// Reset response body | ||||||
| 			resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | 			resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||||
| 		} | 		} | ||||||
|  | 		// We shouldn't set the header before we parse the response body, because the parse part may fail. | ||||||
|  | 		// And then we will have to send an error response, but in this case, the header has already been set. | ||||||
|  | 		// So the client will be confused by the response. | ||||||
|  | 		// For example, Postman will report error, and we cannot check the response at all. | ||||||
|  | 		for k, v := range resp.Header { | ||||||
|  | 			c.Writer.Header().Set(k, v[0]) | ||||||
|  | 		} | ||||||
|  | 		c.Writer.WriteHeader(resp.StatusCode) | ||||||
| 		_, err = io.Copy(c.Writer, resp.Body) | 		_, err = io.Copy(c.Writer, resp.Body) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return errorWrapper(err, "copy_response_body_failed", http.StatusOK) | ||||||
| 		} | 		} | ||||||
| 		err = resp.Body.Close() | 		err = resp.Body.Close() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return errorWrapper(err, "close_response_body_failed", http.StatusOK) | ||||||
| 		} | 		} | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func RelayNotImplemented(c *gin.Context) { | func RelayNotImplemented(c *gin.Context) { | ||||||
|  | 	err := OpenAIError{ | ||||||
|  | 		Message: "API not implemented", | ||||||
|  | 		Type:    "one_api_error", | ||||||
|  | 		Param:   "", | ||||||
|  | 		Code:    "api_not_implemented", | ||||||
|  | 	} | ||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"error": gin.H{ | 		"error": err, | ||||||
| 			"message": "Not Implemented", |  | ||||||
| 			"type":    "one_api_error", |  | ||||||
| 		}, |  | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -85,6 +85,8 @@ func RootAuth() func(c *gin.Context) { | |||||||
| func TokenAuth() func(c *gin.Context) { | func TokenAuth() func(c *gin.Context) { | ||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
| 		key := c.Request.Header.Get("Authorization") | 		key := c.Request.Header.Get("Authorization") | ||||||
|  | 		key = strings.TrimPrefix(key, "Bearer ") | ||||||
|  | 		key = strings.TrimPrefix(key, "sk-") | ||||||
| 		parts := strings.Split(key, "-") | 		parts := strings.Split(key, "-") | ||||||
| 		key = parts[0] | 		key = parts[0] | ||||||
| 		token, err := model.ValidateUserToken(key) | 		token, err := model.ValidateUserToken(key) | ||||||
|   | |||||||
| @@ -6,7 +6,6 @@ import ( | |||||||
| 	_ "gorm.io/driver/sqlite" | 	_ "gorm.io/driver/sqlite" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"strings" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type Token struct { | type Token struct { | ||||||
| @@ -38,7 +37,6 @@ func ValidateUserToken(key string) (token *Token, err error) { | |||||||
| 	if key == "" { | 	if key == "" { | ||||||
| 		return nil, errors.New("未提供 token") | 		return nil, errors.New("未提供 token") | ||||||
| 	} | 	} | ||||||
| 	key = strings.Replace(key, "Bearer ", "", 1) |  | ||||||
| 	token = &Token{} | 	token = &Token{} | ||||||
| 	err = DB.Where("`key` = ?", key).First(token).Error | 	err = DB.Where("`key` = ?", key).First(token).Error | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
|   | |||||||
| @@ -2,12 +2,24 @@ package router | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"embed" | 	"embed" | ||||||
|  | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"net/http" | ||||||
|  | 	"os" | ||||||
|  | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { | func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { | ||||||
| 	SetApiRouter(router) | 	SetApiRouter(router) | ||||||
| 	SetDashboardRouter(router) | 	SetDashboardRouter(router) | ||||||
| 	SetRelayRouter(router) | 	SetRelayRouter(router) | ||||||
| 	setWebRouter(router, buildFS, indexPage) | 	frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") | ||||||
|  | 	if frontendBaseUrl == "" { | ||||||
|  | 		SetWebRouter(router, buildFS, indexPage) | ||||||
|  | 	} else { | ||||||
|  | 		frontendBaseUrl = strings.TrimSuffix(frontendBaseUrl, "/") | ||||||
|  | 		router.NoRoute(func(c *gin.Context) { | ||||||
|  | 			c.Redirect(http.StatusMovedPermanently, fmt.Sprintf("%s%s", frontendBaseUrl, c.Request.RequestURI)) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -11,8 +11,8 @@ func SetRelayRouter(router *gin.Engine) { | |||||||
| 	relayV1Router := router.Group("/v1") | 	relayV1Router := router.Group("/v1") | ||||||
| 	relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) | 	relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) | ||||||
| 	{ | 	{ | ||||||
| 		relayV1Router.GET("/models", controller.Relay) | 		relayV1Router.GET("/models", controller.ListModels) | ||||||
| 		relayV1Router.GET("/models/:model", controller.Relay) | 		relayV1Router.GET("/models/:model", controller.RetrieveModel) | ||||||
| 		relayV1Router.POST("/completions", controller.RelayNotImplemented) | 		relayV1Router.POST("/completions", controller.RelayNotImplemented) | ||||||
| 		relayV1Router.POST("/chat/completions", controller.Relay) | 		relayV1Router.POST("/chat/completions", controller.Relay) | ||||||
| 		relayV1Router.POST("/edits", controller.RelayNotImplemented) | 		relayV1Router.POST("/edits", controller.RelayNotImplemented) | ||||||
|   | |||||||
| @@ -10,7 +10,7 @@ import ( | |||||||
| 	"one-api/middleware" | 	"one-api/middleware" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func setWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { | func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { | ||||||
| 	router.Use(gzip.Gzip(gzip.DefaultCompression)) | 	router.Use(gzip.Gzip(gzip.DefaultCompression)) | ||||||
| 	router.Use(middleware.GlobalWebRateLimit()) | 	router.Use(middleware.GlobalWebRateLimit()) | ||||||
| 	router.Use(middleware.Cache()) | 	router.Use(middleware.Cache()) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user