mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	fix: handel error response from server correctly (close #90)
This commit is contained in:
		@@ -265,14 +265,14 @@ var testAllChannelsLock sync.Mutex
 | 
			
		||||
var testAllChannelsRunning bool = false
 | 
			
		||||
 | 
			
		||||
// disable & notify
 | 
			
		||||
func disableChannel(channelId int, channelName string, err error) {
 | 
			
		||||
func disableChannel(channelId int, channelName string, reason string) {
 | 
			
		||||
	if common.RootUserEmail == "" {
 | 
			
		||||
		common.RootUserEmail = model.GetRootUserEmail()
 | 
			
		||||
	}
 | 
			
		||||
	model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
 | 
			
		||||
	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
 | 
			
		||||
	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, err.Error())
 | 
			
		||||
	err = common.SendEmail(subject, common.RootUserEmail, content)
 | 
			
		||||
	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
 | 
			
		||||
	err := common.SendEmail(subject, common.RootUserEmail, content)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
 | 
			
		||||
	}
 | 
			
		||||
@@ -312,7 +312,7 @@ func testAllChannels(c *gin.Context) error {
 | 
			
		||||
				if milliseconds > disableThreshold {
 | 
			
		||||
					err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
 | 
			
		||||
				}
 | 
			
		||||
				disableChannel(channel.Id, channel.Name, err)
 | 
			
		||||
				disableChannel(channel.Id, channel.Name, err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			channel.UpdateResponseTime(milliseconds)
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,7 +4,6 @@ import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkoukk/tiktoken-go"
 | 
			
		||||
@@ -47,6 +46,11 @@ type OpenAIError struct {
 | 
			
		||||
	Code    string `json:"code"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAIErrorWithStatusCode struct {
 | 
			
		||||
	OpenAIError
 | 
			
		||||
	StatusCode int `json:"status_code"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TextResponse struct {
 | 
			
		||||
	Usage `json:"usage"`
 | 
			
		||||
	Error OpenAIError `json:"error"`
 | 
			
		||||
@@ -71,23 +75,33 @@ func countToken(text string) int {
 | 
			
		||||
func Relay(c *gin.Context) {
 | 
			
		||||
	err := relayHelper(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"error": gin.H{
 | 
			
		||||
				"message": err.Error(),
 | 
			
		||||
				"type":    "one_api_error",
 | 
			
		||||
			},
 | 
			
		||||
		c.JSON(err.StatusCode, gin.H{
 | 
			
		||||
			"error": err.OpenAIError,
 | 
			
		||||
		})
 | 
			
		||||
		channelId := c.GetInt("channel_id")
 | 
			
		||||
		common.SysError(fmt.Sprintf("Relay error: %s, channel id: %d", err.Error(), channelId))
 | 
			
		||||
		if common.AutomaticDisableChannelEnabled {
 | 
			
		||||
		common.SysError(fmt.Sprintf("Relay error (channel #%d): %s", channelId, err.Message))
 | 
			
		||||
		if err.Type != "invalid_request_error" && err.StatusCode != http.StatusTooManyRequests &&
 | 
			
		||||
			common.AutomaticDisableChannelEnabled {
 | 
			
		||||
			channelId := c.GetInt("channel_id")
 | 
			
		||||
			channelName := c.GetString("channel_name")
 | 
			
		||||
			disableChannel(channelId, channelName, err)
 | 
			
		||||
			disableChannel(channelId, channelName, err.Message)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func relayHelper(c *gin.Context) error {
 | 
			
		||||
func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	openAIError := OpenAIError{
 | 
			
		||||
		Message: err.Error(),
 | 
			
		||||
		Type:    "one_api_error",
 | 
			
		||||
		Code:    code,
 | 
			
		||||
	}
 | 
			
		||||
	return &OpenAIErrorWithStatusCode{
 | 
			
		||||
		OpenAIError: openAIError,
 | 
			
		||||
		StatusCode:  statusCode,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	consumeQuota := c.GetBool("consume_quota")
 | 
			
		||||
@@ -95,15 +109,15 @@ func relayHelper(c *gin.Context) error {
 | 
			
		||||
	if consumeQuota || channelType == common.ChannelTypeAzure {
 | 
			
		||||
		requestBody, err := io.ReadAll(c.Request.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
		err = c.Request.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "close_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
		err = json.Unmarshal(requestBody, &textRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "unmarshal_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
		// Reset request body
 | 
			
		||||
		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
@@ -146,12 +160,12 @@ func relayHelper(c *gin.Context) error {
 | 
			
		||||
	if consumeQuota {
 | 
			
		||||
		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return errorWrapper(err, "new_request_failed", http.StatusOK)
 | 
			
		||||
	}
 | 
			
		||||
	if channelType == common.ChannelTypeAzure {
 | 
			
		||||
		key := c.Request.Header.Get("Authorization")
 | 
			
		||||
@@ -166,15 +180,15 @@ func relayHelper(c *gin.Context) error {
 | 
			
		||||
	client := &http.Client{}
 | 
			
		||||
	resp, err := client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return errorWrapper(err, "do_request_failed", http.StatusOK)
 | 
			
		||||
	}
 | 
			
		||||
	err = req.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return errorWrapper(err, "close_request_body_failed", http.StatusOK)
 | 
			
		||||
	}
 | 
			
		||||
	err = c.Request.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return errorWrapper(err, "close_request_body_failed", http.StatusOK)
 | 
			
		||||
	}
 | 
			
		||||
	var textResponse TextResponse
 | 
			
		||||
	isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 | 
			
		||||
@@ -259,50 +273,60 @@ func relayHelper(c *gin.Context) error {
 | 
			
		||||
		})
 | 
			
		||||
		err = resp.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "close_response_body_failed", http.StatusOK)
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
	} else {
 | 
			
		||||
		for k, v := range resp.Header {
 | 
			
		||||
			c.Writer.Header().Set(k, v[0])
 | 
			
		||||
		}
 | 
			
		||||
		if consumeQuota {
 | 
			
		||||
			responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
				return errorWrapper(err, "read_response_body_failed", http.StatusOK)
 | 
			
		||||
			}
 | 
			
		||||
			err = resp.Body.Close()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
				return errorWrapper(err, "close_response_body_failed", http.StatusOK)
 | 
			
		||||
			}
 | 
			
		||||
			err = json.Unmarshal(responseBody, &textResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
				return errorWrapper(err, "unmarshal_response_body_failed", http.StatusOK)
 | 
			
		||||
			}
 | 
			
		||||
			if textResponse.Error.Type != "" {
 | 
			
		||||
				return errors.New(fmt.Sprintf("type %s, code %s, message %s",
 | 
			
		||||
					textResponse.Error.Type, textResponse.Error.Code, textResponse.Error.Message))
 | 
			
		||||
				return &OpenAIErrorWithStatusCode{
 | 
			
		||||
					OpenAIError: textResponse.Error,
 | 
			
		||||
					StatusCode:  resp.StatusCode,
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			// Reset response body
 | 
			
		||||
			resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
		}
 | 
			
		||||
		// We shouldn't set the header before we parse the response body, because the parse part may fail.
 | 
			
		||||
		// And then we will have to send an error response, but in this case, the header has already been set.
 | 
			
		||||
		// So the client will be confused by the response.
 | 
			
		||||
		// For example, Postman will report error, and we cannot check the response at all.
 | 
			
		||||
		for k, v := range resp.Header {
 | 
			
		||||
			c.Writer.Header().Set(k, v[0])
 | 
			
		||||
		}
 | 
			
		||||
		c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
		_, err = io.Copy(c.Writer, resp.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "copy_response_body_failed", http.StatusOK)
 | 
			
		||||
		}
 | 
			
		||||
		err = resp.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "close_response_body_failed", http.StatusOK)
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RelayNotImplemented(c *gin.Context) {
 | 
			
		||||
	err := OpenAIError{
 | 
			
		||||
		Message: "API not implemented",
 | 
			
		||||
		Type:    "one_api_error",
 | 
			
		||||
		Param:   "",
 | 
			
		||||
		Code:    "api_not_implemented",
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"error": gin.H{
 | 
			
		||||
			"message": "Not Implemented",
 | 
			
		||||
			"type":    "one_api_error",
 | 
			
		||||
		},
 | 
			
		||||
		"error": err,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user