merge upstream

Signed-off-by: wozulong <>
This commit is contained in:
wozulong 2024-07-01 15:14:22 +08:00
commit 895ee09b33
41 changed files with 2140 additions and 1671 deletions

View File

@ -5,8 +5,6 @@
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。
> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> [!WARNING]
> 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
@ -85,8 +83,13 @@
```
可以实现400错误转为500错误从而重试
## 比原版One API多出的配置
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
## 部署
### 部署要求
- 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机)
- 远程数据库MySQL 版本 >= 5.7.8PgSQL 版本 >= 9.6
### 基于 Docker 进行部署
```shell
# 使用 SQLite 的部署命令:

View File

@ -120,14 +120,14 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 60) // unit is second
var SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
const (
RequestIdKey = "X-Oneapi-Request-Id"
@ -150,10 +150,10 @@ var (
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10

26
common/env.go Normal file
View File

@ -0,0 +1,26 @@
package common
import (
"fmt"
"os"
"strconv"
)
func GetEnvOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetEnvOrDefaultString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}

View File

@ -3,6 +3,7 @@ package common
import (
"fmt"
"runtime/debug"
"time"
)
func SafeGoroutine(f func()) {
@ -45,3 +46,21 @@ func SafeSendString(ch chan string, value string) (closed bool) {
// If the code reaches here, then the channel was not closed.
return false
}
// SafeSendStringTimeout send, return true, else return false
func SafeSendStringTimeout(ch chan string, value string, timeout int) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
if recover() != nil {
closed = false
}
}()
// This will panic if the channel is closed.
select {
case ch <- value:
return true
case <-time.After(time.Duration(timeout) * time.Second):
return false
}
}

View File

@ -1,6 +1,8 @@
package common
import "encoding/json"
import (
"encoding/json"
)
var GroupRatio = map[string]float64{
"default": 1,

View File

@ -114,6 +114,7 @@ var defaultModelRatio = map[string]float64{
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858,
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens

View File

@ -1,6 +1,8 @@
package common
import "encoding/json"
import (
"encoding/json"
)
var TopupGroupRatio = map[string]float64{
"default": 1,

View File

@ -13,7 +13,6 @@ import (
"net"
"net/http"
"net/url"
"os"
"os/exec"
"runtime"
"strconv"
@ -196,25 +195,6 @@ func Max(a int, b int) int {
}
}
func GetOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetOrDefaultString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}

7
constant/env.go Normal file
View File

@ -0,0 +1,7 @@
package constant
import (
"one-api/common"
)
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)

View File

@ -222,6 +222,7 @@ func testAllChannels(notify bool) error {
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
if openaiErr != nil {
openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{
StatusCode: -1,
Error: *openaiErr,
@ -233,6 +234,7 @@ func testAllChannels(notify bool) error {
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
service.EnableChannel(channel.Id, channel.Name)
}
}
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
}

View File

@ -19,9 +19,6 @@ import (
)
func UpdateMidjourneyTaskBulk() {
if !common.IsMasterNode {
return
}
//imageModel := "midjourney"
ctx := context.TODO()
for {

View File

@ -24,14 +24,3 @@ type OpenAIModels struct {
Root string `json:"root"`
Parent *string `json:"parent"`
}
type ModelPricing struct {
Available bool `json:"available"`
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_group,omitempty"`
}

View File

@ -89,12 +89,14 @@ func main() {
}
go controller.AutomaticallyTestChannels(frequency)
}
if common.IsMasterNode {
common.SafeGoroutine(func() {
controller.UpdateMidjourneyTaskBulk()
})
common.SafeGoroutine(func() {
controller.UpdateTaskBulk()
})
}
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")

View File

@ -56,6 +56,11 @@ func getPriority(group string, model string, retry int) (int, error) {
return 0, err
}
if len(priorities) == 0 {
// 如果没有查询到优先级,则返回错误
return 0, errors.New("数据库一致性被破坏")
}
// 确定要使用的优先级
var priorityToUse int
if retry >= len(priorities) {
@ -199,7 +204,7 @@ func FixAbility() (int, error) {
// Use channelIds to find channel not in abilities table
var abilityChannelIds []int
err = DB.Model(&Ability{}).Pluck("channel_id", &abilityChannelIds).Error
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
return 0, err

View File

@ -86,9 +86,9 @@ func InitDB() (err error) {
if err != nil {
return err
}
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode {
return nil

View File

@ -2,18 +2,28 @@ package model
import (
"one-api/common"
"one-api/dto"
"sync"
"time"
)
type Pricing struct {
Available bool `json:"available"`
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_group,omitempty"`
}
var (
pricingMap []dto.ModelPricing
pricingMap []Pricing
lastGetPricingTime time.Time
updatePricingLock sync.Mutex
)
func GetPricing(group string) []dto.ModelPricing {
func GetPricing(group string) []Pricing {
updatePricingLock.Lock()
defer updatePricingLock.Unlock()
@ -21,7 +31,7 @@ func GetPricing(group string) []dto.ModelPricing {
updatePricing()
}
if group != "" {
userPricingMap := make([]dto.ModelPricing, 0)
userPricingMap := make([]Pricing, 0)
models := GetGroupModels(group)
for _, pricing := range pricingMap {
if !common.StringsContains(models, pricing.ModelName) {
@ -42,9 +52,9 @@ func updatePricing() {
allModels[model] = i
}
pricingMap = make([]dto.ModelPricing, 0)
pricingMap = make([]Pricing, 0)
for model, _ := range allModels {
pricing := dto.ModelPricing{
pricing := Pricing{
Available: true,
ModelName: model,
}

View File

@ -14,6 +14,7 @@ import (
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
@ -156,6 +157,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
var usage relaymodel.Usage
var id string
var model string
isFirst := true
createdTime := common.GetTimestamp()
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
@ -166,6 +168,10 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
switch v := event.(type) {
case *types.ResponseStreamMemberChunk:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
claudeResp := new(claude.ClaudeResponse)
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
if err != nil {

View File

@ -65,7 +65,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
err, usage = claudeStreamHandler(c, resp, info, a.RequestMode)
} else {
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@ -8,6 +8,7 @@ var ModelList = []string{
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-haiku-20240307",
"claude-3-5-sonnet-20240620",
}
var ChannelName = "claude"

View File

@ -8,9 +8,12 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
)
func stopReasonClaude2OpenAI(reason string) string {
@ -246,7 +249,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
return &fullTextResponse
}
func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
var usage *dto.Usage
usage = &dto.Usage{}
@ -265,8 +268,8 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
dataChan := make(chan string, 5)
stopChan := make(chan bool, 2)
go func() {
for scanner.Scan() {
data := scanner.Text()
@ -274,14 +277,23 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
continue
}
data = strings.TrimPrefix(data, "data: ")
dataChan <- data
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}
}
stopChan <- true
}()
isFirst := true
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse ClaudeResponse
@ -302,7 +314,7 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
responseId = claudeResponse.Message.Id
modelName = claudeResponse.Message.Model
info.UpstreamModelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text
@ -316,7 +328,7 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
//response.Id = responseId
response.Id = responseId
response.Created = createdTime
response.Model = modelName
response.Model = info.UpstreamModelName
jsonStr, err := json.Marshal(response)
if err != nil {
@ -335,13 +347,13 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
if usage.PromptTokens == 0 {
usage.PromptTokens = promptTokens
usage.PromptTokens = info.PromptTokens
}
if usage.CompletionTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, modelName, usage.PromptTokens)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
}
}
return nil, usage

View File

@ -36,7 +36,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = cohereStreamHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
err, usage = cohereStreamHandler(c, resp, info)
} else {
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
}

View File

@ -9,8 +9,10 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
)
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
@ -56,7 +58,7 @@ func stopReasonCohere2OpenAI(reason string) string {
}
}
func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
usage := &dto.Usage{}
@ -84,9 +86,14 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
stopChan <- true
}()
service.SetEventStreamHeaders(c)
isFirst := true
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
data = strings.TrimSuffix(data, "\r")
var cohereResp CohereResponse
err := json.Unmarshal([]byte(data), &cohereResp)
@ -98,7 +105,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
openaiResp.Id = responseId
openaiResp.Created = createdTime
openaiResp.Object = "chat.completion.chunk"
openaiResp.Model = modelName
openaiResp.Model = info.UpstreamModelName
if cohereResp.IsFinished {
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
@ -137,7 +144,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
}
})
if usage.PromptTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
return nil, usage
}

View File

@ -63,7 +63,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = geminiChatStreamHandler(c, resp)
err, responseText = geminiChatStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)

View File

@ -7,10 +7,12 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
"github.com/gin-gonic/gin"
)
@ -160,10 +162,10 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
return &response
}
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) {
responseText := ""
dataChan := make(chan string)
stopChan := make(chan bool)
dataChan := make(chan string, 5)
stopChan := make(chan bool, 2)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@ -186,14 +188,23 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIEr
}
data = strings.TrimPrefix(data, "\"text\": \"")
data = strings.TrimSuffix(data, "\"")
dataChan <- data
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}
}
stopChan <- true
}()
isFirst := true
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
// this is used to prevent annoying \ related format bug
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
type dummyStruct struct {

View File

@ -52,7 +52,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
if info.RelayMode == relayconstant.RelayModeEmbeddings {

View File

@ -82,7 +82,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
var toolCount int
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info.RelayMode)
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
} else {

View File

@ -8,7 +8,9 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"strings"
@ -16,7 +18,7 @@ import (
"time"
)
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string, int) {
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string, int) {
//checkSensitive := constant.ShouldCheckCompletionSensitive()
var responseTextBuilder strings.Builder
toolCount := 0
@ -50,14 +52,18 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
common.SafeSendString(dataChan, data)
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
streamItems = append(streamItems, data)
}
}
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch relayMode {
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
@ -126,9 +132,14 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
common.SafeSendBool(stopChan, true)
}()
service.SetEventStreamHeaders(c)
isFirst := true
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
@ -187,7 +198,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if simpleResponse.Usage.TotalTokens == 0 {
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model)

View File

@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)

View File

@ -6,6 +6,7 @@ var ModelList = []string{
"SparkDesk-v2.1",
"SparkDesk-v3.1",
"SparkDesk-v3.5",
"SparkDesk-v4.0",
}
var ChannelName = "xunfei"

View File

@ -252,6 +252,8 @@ func apiVersion2domain(apiVersion string) string {
return "generalv3"
case "v3.5":
return "generalv3.5"
case "v4.0":
return "4.0Ultra"
}
return "general" + apiVersion
}

View File

@ -48,7 +48,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
var toolCount int
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
} else {

View File

@ -16,6 +16,7 @@ type RelayInfo struct {
Group string
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
ApiType int
IsStream bool
RelayMode int
@ -37,6 +38,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
group := c.GetString("group")
tokenUnlimited := c.GetBool("token_unlimited_quota")
startTime := time.Now()
// firstResponseTime = time.Now() - 1 second
apiType, _ := constant.ChannelType2APIType(channelType)
@ -51,6 +53,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
Group: group,
TokenUnlimited: tokenUnlimited,
StartTime: startTime,
FirstResponseTime: startTime.Add(-time.Second),
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),

View File

@ -538,7 +538,16 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
ChannelId: c.GetInt("channel_id"),
Quota: quota,
}
if midjResponse.Code == 3 {
//无实例账号自动禁用渠道No available account instance
channel, err := model.GetChannelById(midjourneyTask.ChannelId, true)
if err != nil {
common.SysError("get_channel_null: " + err.Error())
}
if channel.AutoBan != nil && *channel.AutoBan == 1 {
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
}
}
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
//非1-提交成功,21-任务已存在和22-排队中,则记录错误原因
midjourneyTask.FailReason = midjResponse.Description

View File

@ -344,14 +344,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
logModel = "g-*"
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
}
other := make(map[string]interface{})
other["model_ratio"] = modelRatio
other["group_ratio"] = groupRatio
other["completion_ratio"] = completionRatio
other["model_price"] = modelPrice
adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
other["admin_info"] = adminInfo
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
//if quota != 0 {

19
service/log.go Normal file
View File

@ -0,0 +1,19 @@
package service
import (
"github.com/gin-gonic/gin"
relaycommon "one-api/relay/common"
)
func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio, modelPrice float64) map[string]interface{} {
other := make(map[string]interface{})
other["model_ratio"] = modelRatio
other["group_ratio"] = groupRatio
other["completion_ratio"] = completionRatio
other["model_price"] = modelPrice
other["frt"] = float64(relayInfo.FirstResponseTime.UnixMilli() - relayInfo.StartTime.UnixMilli())
adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
other["admin_info"] = adminInfo
return other
}

View File

@ -3,7 +3,6 @@ package service
import (
"errors"
"fmt"
"one-api/common"
"one-api/constant"
"one-api/dto"
"strings"
@ -62,7 +61,7 @@ func SensitiveWordContains(text string) (bool, []string) {
}
checkText := strings.ToLower(text)
// 构建一个AC自动机
m := common.InitAc()
m := InitAc()
hits := m.MultiPatternSearch([]rune(checkText), false)
if len(hits) > 0 {
words := make([]string, 0)
@ -80,7 +79,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
return false, nil, text
}
checkText := strings.ToLower(text)
m := common.InitAc()
m := InitAc()
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
if len(hits) > 0 {
words := make([]string, 0)

View File

@ -1,4 +1,4 @@
package common
package service
import (
"bytes"

File diff suppressed because it is too large Load Diff

View File

@ -29,6 +29,7 @@ import {
stringToColor,
} from '../helpers/render';
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
import { getLogOther } from '../helpers/other.js';
const { Header } = Layout;
@ -141,6 +142,33 @@ function renderUseTime(type) {
}
}
function renderFirstUseTime(type) {
let time = parseFloat(type) / 1000.0;
time = time.toFixed(1);
if (time < 3) {
return (
<Tag color='green' size='large'>
{' '}
{time} s{' '}
</Tag>
);
} else if (time < 10) {
return (
<Tag color='orange' size='large'>
{' '}
{time} s{' '}
</Tag>
);
} else {
return (
<Tag color='red' size='large'>
{' '}
{time} s{' '}
</Tag>
);
}
}
const LogsTable = () => {
const columns = [
{
@ -247,9 +275,21 @@ const LogsTable = () => {
},
},
{
title: '用时',
title: '用时/首字',
dataIndex: 'use_time',
render: (text, record, index) => {
if (record.is_stream) {
let other = getLogOther(record.other);
return (
<div>
<Space>
{renderUseTime(text)}
{renderFirstUseTime(other.frt)}
{renderIsStream(record.is_stream)}
</Space>
</div>
);
} else {
return (
<div>
<Space>
@ -258,6 +298,7 @@ const LogsTable = () => {
</Space>
</div>
);
}
},
},
{
@ -325,10 +366,7 @@ const LogsTable = () => {
title: '详情',
dataIndex: 'content',
render: (text, record, index) => {
if (record.other === '') {
record.other = '{}';
}
let other = JSON.parse(record.other);
let other = getLogOther(record.other);
if (other == null) {
return (
<Paragraph

7
web/src/helpers/other.js Normal file
View File

@ -0,0 +1,7 @@
export function getLogOther(otherStr) {
if (otherStr === undefined || otherStr === '') {
otherStr = '{}';
}
let other = JSON.parse(otherStr);
return other;
}

View File

@ -144,28 +144,42 @@ export function renderModelPrice(
) {
// 1 ratio = $0.002 / 1K tokens
if (modelPrice !== -1) {
return '模型价格:$' + modelPrice * groupRatio;
return (
'模型价格:$' +
modelPrice +
' * 分组倍率:' +
groupRatio +
' = $' +
modelPrice * groupRatio
);
} else {
if (completionRatio === undefined) {
completionRatio = 0;
}
// 这里的 *2 是因为 1倍率=0.002刀,请勿删除
let inputRatioPrice = modelRatio * 2.0 * groupRatio;
let completionRatioPrice = modelRatio * 2.0 * completionRatio * groupRatio;
let inputRatioPrice = modelRatio * 2.0;
let completionRatioPrice = modelRatio * 2.0 * completionRatio;
let price =
(inputTokens / 1000000) * inputRatioPrice +
(completionTokens / 1000000) * completionRatioPrice;
return (
<>
<article>
<p>提示 ${inputRatioPrice} / 1M tokens</p>
<p>补全 ${completionRatioPrice} / 1M tokens</p>
<p>
提示${inputRatioPrice} * {groupRatio} = $
{inputRatioPrice * groupRatio} / 1M tokens
</p>
<p>
补全${completionRatioPrice} * {groupRatio} = $
{completionRatioPrice * groupRatio} / 1M tokens
</p>
<p></p>
<p>
提示 {inputTokens} tokens / 1M tokens * ${inputRatioPrice} + 补全{' '}
{completionTokens} tokens / 1M tokens * ${completionRatioPrice} = $
{price.toFixed(6)}
{completionTokens} tokens / 1M tokens * ${completionRatioPrice} *
分组 {groupRatio} = ${price.toFixed(6)}
</p>
<p>仅供参考以实际扣费为准</p>
</article>
</>
);