merge upstream

Signed-off-by: wozulong <>
This commit is contained in:
wozulong 2024-07-22 10:51:49 +08:00
commit f0008e95fa
11 changed files with 128 additions and 30 deletions

View File

@ -58,12 +58,12 @@
您可以在渠道中添加自定义模型gpt-4-gizmo-*或g-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。 您可以在渠道中添加自定义模型gpt-4-gizmo-*或g-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。
## 比原版One API多出的配置 ## 比原版One API多出的配置
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒 - `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true` - `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`
- `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数请求上游返回流模式usage默认为 `true` - `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数请求上游返回流模式usage默认为 `true`建议开启不影响客户端传入stream_options参数返回结果。
- `GET_MEDIA_TOKEN`是统计图片token默认为 `true`关闭后将不再在本地计算图片token可能会导致和上游计费不同此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用 - `GET_MEDIA_TOKEN`是统计图片token默认为 `true`关闭后将不再在本地计算图片token可能会导致和上游计费不同此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true` - `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度 - `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度
## 部署 ## 部署
### 部署要求 ### 部署要求

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv"
"strings" "strings"
) )
@ -16,6 +17,7 @@ func authHelper(c *gin.Context, minRole int) {
id := session.Get("id") id := session.Get("id")
status := session.Get("status") status := session.Get("status")
linuxDoEnable := session.Get("linuxdo_enable") linuxDoEnable := session.Get("linuxdo_enable")
useAccessToken := false
if username == nil { if username == nil {
// Check access token // Check access token
accessToken := c.Request.Header.Get("Authorization") accessToken := c.Request.Header.Get("Authorization")
@ -35,6 +37,7 @@ func authHelper(c *gin.Context, minRole int) {
id = user.Id id = user.Id
status = user.Status status = user.Status
linuxDoEnable = user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel linuxDoEnable = user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel
useAccessToken = true
} else { } else {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -44,6 +47,36 @@ func authHelper(c *gin.Context, minRole int) {
return return
} }
} }
if !useAccessToken {
// get header New-Api-User
apiUserIdStr := c.Request.Header.Get("New-Api-User")
if apiUserIdStr == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,请刷新页面或清空缓存后重试",
})
c.Abort()
return
}
apiUserId, err := strconv.Atoi(apiUserIdStr)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,登录信息无效,请重新登录",
})
c.Abort()
return
}
if id != apiUserId {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,与登录用户不匹配,请重新登录",
})
c.Abort()
return
}
}
if status.(int) == common.UserStatusDisabled { if status.(int) == common.UserStatusDisabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,

View File

@ -102,7 +102,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
tx = DB.Where("type = ?", logType) tx = DB.Where("type = ?", logType)
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name = ?", modelName) tx = tx.Where("model_name like ?", "%"+modelName+"%")
} }
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)

View File

@ -198,7 +198,6 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
choice.Delta.SetContentString(respFirst.Text) choice.Delta.SetContentString(respFirst.Text)
} }
} }
choice.FinishReason = &relaycommon.StopFinishReason
var response dto.ChatCompletionsStreamResponse var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Model = "gemini" response.Model = "gemini"
@ -247,10 +246,14 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
common.LogError(c, err.Error()) common.LogError(c, err.Error())
} }
} }
response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, relaycommon.StopFinishReason)
service.ObjectData(c, response)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
if info.ShouldIncludeUsage { if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response) err := service.ObjectData(c, response)
if err != nil { if err != nil {
common.SysError("send final response failed: " + err.Error()) common.SysError("send final response failed: " + err.Error())

View File

@ -16,6 +16,7 @@ import (
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"strings" "strings"
"sync"
"time" "time"
) )
@ -41,7 +42,10 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
stopChan := make(chan bool) stopChan := make(chan bool)
defer close(stopChan) defer close(stopChan)
var (
lastStreamData string
mu sync.Mutex
)
gopool.Go(func() { gopool.Go(func() {
for scanner.Scan() { for scanner.Scan() {
info.SetFirstResponseTime() info.SetFirstResponseTime()
@ -53,14 +57,19 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
if data[:6] != "data: " && data[:6] != "[DONE]" { if data[:6] != "data: " && data[:6] != "[DONE]" {
continue continue
} }
mu.Lock()
data = data[6:] data = data[6:]
if !strings.HasPrefix(data, "[DONE]") { if !strings.HasPrefix(data, "[DONE]") {
err := service.StringData(c, data) if lastStreamData != "" {
err := service.StringData(c, lastStreamData)
if err != nil { if err != nil {
common.LogError(c, "streaming error: "+err.Error()) common.LogError(c, "streaming error: "+err.Error())
} }
}
lastStreamData = data
streamItems = append(streamItems, data) streamItems = append(streamItems, data)
} }
mu.Unlock()
} }
common.SafeSendBool(stopChan, true) common.SafeSendBool(stopChan, true)
}) })
@ -73,6 +82,20 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
// 正常结束 // 正常结束
} }
shouldSendLastResp := true
var lastStreamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
if err == nil {
if lastStreamResponse.Usage != nil && service.ValidUsage(lastStreamResponse.Usage) {
if !info.ShouldIncludeUsage {
shouldSendLastResp = false
}
}
}
if shouldSendLastResp {
service.StringData(c, lastStreamData)
}
// 计算token // 计算token
streamResp := "[" + strings.Join(streamItems, ",") + "]" streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch info.RelayMode { switch info.RelayMode {

View File

@ -130,6 +130,12 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return openaiErr return openaiErr
} }
includeUsage := false
// 判断用户是否需要返回使用情况
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
includeUsage = true
}
// 如果不支持StreamOptions将StreamOptions设置为nil // 如果不支持StreamOptions将StreamOptions设置为nil
if !relayInfo.SupportStreamOptions || !textRequest.Stream { if !relayInfo.SupportStreamOptions || !textRequest.Stream {
textRequest.StreamOptions = nil textRequest.StreamOptions = nil
@ -142,8 +148,8 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
} }
} }
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage { if includeUsage {
relayInfo.ShouldIncludeUsage = textRequest.StreamOptions.IncludeUsage relayInfo.ShouldIncludeUsage = true
} }
adaptor := GetAdaptor(relayInfo.ApiType) adaptor := GetAdaptor(relayInfo.ApiType)

View File

@ -7,6 +7,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto"
) )
func SetEventStreamHeaders(c *gin.Context) { func SetEventStreamHeaders(c *gin.Context) {
@ -45,3 +46,30 @@ func GetResponseID(c *gin.Context) string {
logID := c.GetString("X-Oneapi-Request-Id") logID := c.GetString("X-Oneapi-Request-Id")
return fmt.Sprintf("chatcmpl-%s", logID) return fmt.Sprintf("chatcmpl-%s", logID)
} }
func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: createAt,
Model: model,
SystemFingerprint: nil,
Choices: []dto.ChatCompletionsStreamResponseChoice{
{
FinishReason: &finishReason,
},
},
}
}
func GenerateFinalUsageResponse(id string, createAt int64, model string, usage dto.Usage) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: createAt,
Model: model,
SystemFingerprint: nil,
Choices: make([]dto.ChatCompletionsStreamResponseChoice, 0),
Usage: &usage,
}
}

View File

@ -72,11 +72,12 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
} }
func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) { func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
baseTokens := 85
if model == "glm-4v" { if model == "glm-4v" {
return 1047, nil return 1047, nil
} }
if imageUrl.Detail == "low" { if imageUrl.Detail == "low" {
return 85, nil return baseTokens, nil
} }
// TODO: 非流模式下不计算图片token数量 // TODO: 非流模式下不计算图片token数量
if !constant.GetMediaTokenNotStream && !stream { if !constant.GetMediaTokenNotStream && !stream {
@ -90,6 +91,12 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
if imageUrl.Detail == "auto" || imageUrl.Detail == "" { if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
imageUrl.Detail = "high" imageUrl.Detail = "high"
} }
tileTokens := 170
if strings.HasPrefix(model, "gpt-4o-mini") {
tileTokens = 5667
baseTokens = 2833
}
var config image.Config var config image.Config
var err error var err error
var format string var format string
@ -137,7 +144,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
// 计算图片的token数量(边的长度除以512向上取整) // 计算图片的token数量(边的长度除以512向上取整)
tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512) tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512)
log.Printf("tiles: %d", tiles) log.Printf("tiles: %d", tiles)
return tiles*170 + 85, nil return tiles*tileTokens + baseTokens, nil
} }
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, error) { func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, error) {

View File

@ -25,18 +25,6 @@ func ResponseText2Usage(responseText string, modeName string, promptTokens int)
return usage, err return usage, err
} }
func GenerateFinalUsageResponse(id string, createAt int64, model string, usage dto.Usage) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: createAt,
Model: model,
SystemFingerprint: nil,
Choices: make([]dto.ChatCompletionsStreamResponseChoice, 0),
Usage: &usage,
}
}
func ValidUsage(usage *dto.Usage) bool { func ValidUsage(usage *dto.Usage) bool {
return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0) return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0)
} }

View File

@ -1,10 +1,13 @@
import { showError } from './utils'; import { getUserIdFromLocalStorage, showError } from './utils';
import axios from 'axios'; import axios from 'axios';
export const API = axios.create({ export const API = axios.create({
baseURL: import.meta.env.VITE_REACT_APP_SERVER_URL baseURL: import.meta.env.VITE_REACT_APP_SERVER_URL
? import.meta.env.VITE_REACT_APP_SERVER_URL ? import.meta.env.VITE_REACT_APP_SERVER_URL
: '', : '',
headers: {
'New-API-User': getUserIdFromLocalStorage(),
},
}); });
API.interceptors.response.use( API.interceptors.response.use(

View File

@ -33,6 +33,13 @@ export function getLogo() {
return logo; return logo;
} }
export function getUserIdFromLocalStorage() {
let user = localStorage.getItem('user');
if (!user) return -1;
user = JSON.parse(user);
return user.id;
}
export function getFooterHTML() { export function getFooterHTML() {
return localStorage.getItem('footer_html'); return localStorage.getItem('footer_html');
} }