merge upstream

Signed-off-by: wozulong <>
This commit is contained in:
wozulong
2024-07-19 10:58:21 +08:00
72 changed files with 1989 additions and 1193 deletions

View File

@@ -74,14 +74,14 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
return false
}
func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError, status int) bool {
func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool {
if !common.AutomaticEnableChannelEnabled {
return false
}
if err != nil {
return false
}
if openAIErr != nil {
if openaiWithStatusErr != nil {
return false
}
if status != common.ChannelStatusAutoDisabled {

View File

@@ -56,10 +56,9 @@ func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorW
errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
Error: dto.OpenAIError{
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)

View File

@@ -2,10 +2,11 @@ package service
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"strings"
)
func SetEventStreamHeaders(c *gin.Context) {
@@ -16,11 +17,16 @@ func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("X-Accel-Buffering", "no")
}
func StringData(c *gin.Context, str string) {
str = strings.TrimPrefix(str, "data: ")
str = strings.TrimSuffix(str, "\r")
func StringData(c *gin.Context, str string) error {
//str = strings.TrimPrefix(str, "data: ")
//str = strings.TrimSuffix(str, "\r")
c.Render(-1, common.CustomEvent{Data: "data: " + str})
c.Writer.Flush()
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
return nil
}
func ObjectData(c *gin.Context, object interface{}) error {
@@ -28,10 +34,14 @@ func ObjectData(c *gin.Context, object interface{}) error {
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)
}
StringData(c, string(jsonData))
return nil
return StringData(c, string(jsonData))
}
func Done(c *gin.Context) {
StringData(c, "[DONE]")
_ = StringData(c, "[DONE]")
}
func GetResponseID(c *gin.Context) string {
logID := c.GetString("X-Oneapi-Request-Id")
return fmt.Sprintf("chatcmpl-%s", logID)
}

View File

@@ -9,6 +9,7 @@ import (
"log"
"math"
"one-api/common"
"one-api/constant"
"one-api/dto"
"strings"
"unicode/utf8"
@@ -71,13 +72,20 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
}
func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
// TODO: 非流模式下不计算图片token数量
if model == "glm-4v" {
return 1047, nil
}
if imageUrl.Detail == "low" {
return 85, nil
}
// TODO: 非流模式下不计算图片token数量
if !constant.GetMediaTokenNotStream && !stream {
return 1000, nil
}
// 是否统计图片token
if !constant.GetMediaToken {
return 1000, nil
}
// 同步One API的图片计费逻辑
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
imageUrl.Detail = "high"

View File

@@ -36,3 +36,7 @@ func GenerateFinalUsageResponse(id string, createAt int64, model string, usage d
Usage: &usage,
}
}
func ValidUsage(usage *dto.Usage) bool {
return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0)
}