merge upstream

Signed-off-by: wozulong <>
This commit is contained in:
wozulong 2024-07-01 15:14:22 +08:00
commit 895ee09b33
41 changed files with 2140 additions and 1671 deletions

View File

@ -5,8 +5,6 @@
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。 > 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。
> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 > 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> [!WARNING]
> 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。 > 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
@ -85,8 +83,13 @@
``` ```
可以实现400错误转为500错误从而重试 可以实现400错误转为500错误从而重试
## 比原版One API多出的配置
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
## 部署 ## 部署
### 部署要求
- 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机)
- 远程数据库MySQL 版本 >= 5.7.8PgSQL 版本 >= 9.6
### 基于 Docker 进行部署 ### 基于 Docker 进行部署
```shell ```shell
# 使用 SQLite 的部署命令: # 使用 SQLite 的部署命令:

View File

@ -120,14 +120,14 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 60) // unit is second var SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60) // unit is second
var BatchUpdateEnabled = false var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) var BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
const ( const (
RequestIdKey = "X-Oneapi-Request-Id" RequestIdKey = "X-Oneapi-Request-Id"
@ -150,10 +150,10 @@ var (
// All duration's unit is seconds // All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration // Shouldn't larger then RateLimitKeyExpirationDuration
var ( var (
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180) GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60 GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60 GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10 UploadRateLimitNum = 10

26
common/env.go Normal file
View File

@ -0,0 +1,26 @@
package common
import (
"fmt"
"os"
"strconv"
)
func GetEnvOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetEnvOrDefaultString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}

View File

@ -3,6 +3,7 @@ package common
import ( import (
"fmt" "fmt"
"runtime/debug" "runtime/debug"
"time"
) )
func SafeGoroutine(f func()) { func SafeGoroutine(f func()) {
@ -45,3 +46,21 @@ func SafeSendString(ch chan string, value string) (closed bool) {
// If the code reaches here, then the channel was not closed. // If the code reaches here, then the channel was not closed.
return false return false
} }
// SafeSendStringTimeout send, return true, else return false
func SafeSendStringTimeout(ch chan string, value string, timeout int) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
if recover() != nil {
closed = false
}
}()
// This will panic if the channel is closed.
select {
case ch <- value:
return true
case <-time.After(time.Duration(timeout) * time.Second):
return false
}
}

View File

@ -1,6 +1,8 @@
package common package common
import "encoding/json" import (
"encoding/json"
)
var GroupRatio = map[string]float64{ var GroupRatio = map[string]float64{
"default": 1, "default": 1,

View File

@ -114,6 +114,7 @@ var defaultModelRatio = map[string]float64{
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858,
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens "360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens "360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens

View File

@ -1,6 +1,8 @@
package common package common
import "encoding/json" import (
"encoding/json"
)
var TopupGroupRatio = map[string]float64{ var TopupGroupRatio = map[string]float64{
"default": 1, "default": 1,

View File

@ -13,7 +13,6 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os"
"os/exec" "os/exec"
"runtime" "runtime"
"strconv" "strconv"
@ -196,25 +195,6 @@ func Max(a int, b int) int {
} }
} }
func GetOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetOrDefaultString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
func MessageWithRequestId(message string, id string) string { func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id) return fmt.Sprintf("%s (request id: %s)", message, id)
} }

7
constant/env.go Normal file
View File

@ -0,0 +1,7 @@
package constant
import (
"one-api/common"
)
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)

View File

@ -222,16 +222,18 @@ func testAllChannels(notify bool) error {
if channel.AutoBan != nil && *channel.AutoBan == 0 { if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false ban = false
} }
openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{ if openaiErr != nil {
StatusCode: -1, openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{
Error: *openaiErr, StatusCode: -1,
LocalError: false, Error: *openaiErr,
} LocalError: false,
if isChannelEnabled && service.ShouldDisableChannel(&openAiErrWithStatus) && ban { }
service.DisableChannel(channel.Id, channel.Name, err.Error()) if isChannelEnabled && service.ShouldDisableChannel(&openAiErrWithStatus) && ban {
} service.DisableChannel(channel.Id, channel.Name, err.Error())
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) { }
service.EnableChannel(channel.Id, channel.Name) if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
service.EnableChannel(channel.Id, channel.Name)
}
} }
channel.UpdateResponseTime(milliseconds) channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval) time.Sleep(common.RequestInterval)

View File

@ -19,9 +19,6 @@ import (
) )
func UpdateMidjourneyTaskBulk() { func UpdateMidjourneyTaskBulk() {
if !common.IsMasterNode {
return
}
//imageModel := "midjourney" //imageModel := "midjourney"
ctx := context.TODO() ctx := context.TODO()
for { for {

View File

@ -24,14 +24,3 @@ type OpenAIModels struct {
Root string `json:"root"` Root string `json:"root"`
Parent *string `json:"parent"` Parent *string `json:"parent"`
} }
type ModelPricing struct {
Available bool `json:"available"`
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_group,omitempty"`
}

14
main.go
View File

@ -89,12 +89,14 @@ func main() {
} }
go controller.AutomaticallyTestChannels(frequency) go controller.AutomaticallyTestChannels(frequency)
} }
common.SafeGoroutine(func() { if common.IsMasterNode {
controller.UpdateMidjourneyTaskBulk() common.SafeGoroutine(func() {
}) controller.UpdateMidjourneyTaskBulk()
common.SafeGoroutine(func() { })
controller.UpdateTaskBulk() common.SafeGoroutine(func() {
}) controller.UpdateTaskBulk()
})
}
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true common.BatchUpdateEnabled = true
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")

View File

@ -56,6 +56,11 @@ func getPriority(group string, model string, retry int) (int, error) {
return 0, err return 0, err
} }
if len(priorities) == 0 {
// 如果没有查询到优先级,则返回错误
return 0, errors.New("数据库一致性被破坏")
}
// 确定要使用的优先级 // 确定要使用的优先级
var priorityToUse int var priorityToUse int
if retry >= len(priorities) { if retry >= len(priorities) {
@ -199,7 +204,7 @@ func FixAbility() (int, error) {
// Use channelIds to find channel not in abilities table // Use channelIds to find channel not in abilities table
var abilityChannelIds []int var abilityChannelIds []int
err = DB.Model(&Ability{}).Pluck("channel_id", &abilityChannelIds).Error err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error())) common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
return 0, err return 0, err

View File

@ -86,9 +86,9 @@ func InitDB() (err error) {
if err != nil { if err != nil {
return err return err
} }
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100)) sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60))) sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode { if !common.IsMasterNode {
return nil return nil

View File

@ -2,18 +2,28 @@ package model
import ( import (
"one-api/common" "one-api/common"
"one-api/dto"
"sync" "sync"
"time" "time"
) )
type Pricing struct {
Available bool `json:"available"`
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_group,omitempty"`
}
var ( var (
pricingMap []dto.ModelPricing pricingMap []Pricing
lastGetPricingTime time.Time lastGetPricingTime time.Time
updatePricingLock sync.Mutex updatePricingLock sync.Mutex
) )
func GetPricing(group string) []dto.ModelPricing { func GetPricing(group string) []Pricing {
updatePricingLock.Lock() updatePricingLock.Lock()
defer updatePricingLock.Unlock() defer updatePricingLock.Unlock()
@ -21,7 +31,7 @@ func GetPricing(group string) []dto.ModelPricing {
updatePricing() updatePricing()
} }
if group != "" { if group != "" {
userPricingMap := make([]dto.ModelPricing, 0) userPricingMap := make([]Pricing, 0)
models := GetGroupModels(group) models := GetGroupModels(group)
for _, pricing := range pricingMap { for _, pricing := range pricingMap {
if !common.StringsContains(models, pricing.ModelName) { if !common.StringsContains(models, pricing.ModelName) {
@ -42,9 +52,9 @@ func updatePricing() {
allModels[model] = i allModels[model] = i
} }
pricingMap = make([]dto.ModelPricing, 0) pricingMap = make([]Pricing, 0)
for model, _ := range allModels { for model, _ := range allModels {
pricing := dto.ModelPricing{ pricing := Pricing{
Available: true, Available: true,
ModelName: model, ModelName: model,
} }

View File

@ -14,6 +14,7 @@ import (
"one-api/relay/channel/claude" "one-api/relay/channel/claude"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"strings" "strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/credentials"
@ -156,6 +157,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
var usage relaymodel.Usage var usage relaymodel.Usage
var id string var id string
var model string var model string
isFirst := true
createdTime := common.GetTimestamp() createdTime := common.GetTimestamp()
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events() event, ok := <-stream.Events()
@ -166,6 +168,10 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
switch v := event.(type) { switch v := event.(type) {
case *types.ResponseStreamMemberChunk: case *types.ResponseStreamMemberChunk:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
claudeResp := new(claude.ClaudeResponse) claudeResp := new(claude.ClaudeResponse)
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp) err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
if err != nil { if err != nil {

View File

@ -65,7 +65,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp) err, usage = claudeStreamHandler(c, resp, info, a.RequestMode)
} else { } else {
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
} }

View File

@ -8,6 +8,7 @@ var ModelList = []string{
"claude-3-sonnet-20240229", "claude-3-sonnet-20240229",
"claude-3-opus-20240229", "claude-3-opus-20240229",
"claude-3-haiku-20240307", "claude-3-haiku-20240307",
"claude-3-5-sonnet-20240620",
} }
var ChannelName = "claude" var ChannelName = "claude"

View File

@ -8,9 +8,12 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"strings" "strings"
"time"
) )
func stopReasonClaude2OpenAI(reason string) string { func stopReasonClaude2OpenAI(reason string) string {
@ -246,7 +249,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
return &fullTextResponse return &fullTextResponse
} }
func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
var usage *dto.Usage var usage *dto.Usage
usage = &dto.Usage{} usage = &dto.Usage{}
@ -265,8 +268,8 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
} }
return 0, nil, nil return 0, nil, nil
}) })
dataChan := make(chan string) dataChan := make(chan string, 5)
stopChan := make(chan bool) stopChan := make(chan bool, 2)
go func() { go func() {
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
@ -274,14 +277,23 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
continue continue
} }
data = strings.TrimPrefix(data, "data: ") data = strings.TrimPrefix(data, "data: ")
dataChan <- data if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}
} }
stopChan <- true stopChan <- true
}() }()
isFirst := true
service.SetEventStreamHeaders(c) service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
// some implementations may add \r at the end of data // some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r") data = strings.TrimSuffix(data, "\r")
var claudeResponse ClaudeResponse var claudeResponse ClaudeResponse
@ -302,7 +314,7 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
if claudeResponse.Type == "message_start" { if claudeResponse.Type == "message_start" {
// message_start, 获取usage // message_start, 获取usage
responseId = claudeResponse.Message.Id responseId = claudeResponse.Message.Id
modelName = claudeResponse.Message.Model info.UpstreamModelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens usage.PromptTokens = claudeUsage.InputTokens
} else if claudeResponse.Type == "content_block_delta" { } else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text responseText += claudeResponse.Delta.Text
@ -316,7 +328,7 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
//response.Id = responseId //response.Id = responseId
response.Id = responseId response.Id = responseId
response.Created = createdTime response.Created = createdTime
response.Model = modelName response.Model = info.UpstreamModelName
jsonStr, err := json.Marshal(response) jsonStr, err := json.Marshal(response)
if err != nil { if err != nil {
@ -335,13 +347,13 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
if requestMode == RequestModeCompletion { if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
if usage.PromptTokens == 0 { if usage.PromptTokens == 0 {
usage.PromptTokens = promptTokens usage.PromptTokens = info.PromptTokens
} }
if usage.CompletionTokens == 0 { if usage.CompletionTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, modelName, usage.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
} }
} }
return nil, usage return nil, usage

View File

@ -36,7 +36,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = cohereStreamHandler(c, resp, info.UpstreamModelName, info.PromptTokens) err, usage = cohereStreamHandler(c, resp, info)
} else { } else {
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens) err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
} }

View File

@ -9,8 +9,10 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"strings" "strings"
"time"
) )
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest { func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
@ -56,7 +58,7 @@ func stopReasonCohere2OpenAI(reason string) string {
} }
} }
func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp() createdTime := common.GetTimestamp()
usage := &dto.Usage{} usage := &dto.Usage{}
@ -84,9 +86,14 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
stopChan <- true stopChan <- true
}() }()
service.SetEventStreamHeaders(c) service.SetEventStreamHeaders(c)
isFirst := true
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
data = strings.TrimSuffix(data, "\r") data = strings.TrimSuffix(data, "\r")
var cohereResp CohereResponse var cohereResp CohereResponse
err := json.Unmarshal([]byte(data), &cohereResp) err := json.Unmarshal([]byte(data), &cohereResp)
@ -98,7 +105,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
openaiResp.Id = responseId openaiResp.Id = responseId
openaiResp.Created = createdTime openaiResp.Created = createdTime
openaiResp.Object = "chat.completion.chunk" openaiResp.Object = "chat.completion.chunk"
openaiResp.Model = modelName openaiResp.Model = info.UpstreamModelName
if cohereResp.IsFinished { if cohereResp.IsFinished {
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason) finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{ openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
@ -137,7 +144,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
} }
}) })
if usage.PromptTokens == 0 { if usage.PromptTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} }
return nil, usage return nil, usage
} }

View File

@ -20,27 +20,27 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIReq
// 定义一个映射,存储模型名称和对应的版本 // 定义一个映射,存储模型名称和对应的版本
var modelVersionMap = map[string]string{ var modelVersionMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta", "gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-flash-latest": "v1beta", "gemini-1.5-flash-latest": "v1beta",
"gemini-ultra": "v1beta", "gemini-ultra": "v1beta",
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1" // 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
version, beta := modelVersionMap[info.UpstreamModelName] version, beta := modelVersionMap[info.UpstreamModelName]
if !beta { if !beta {
if info.ApiVersion != "" { if info.ApiVersion != "" {
version = info.ApiVersion version = info.ApiVersion
} else { } else {
version = "v1" version = "v1"
} }
} }
action := "generateContent" action := "generateContent"
if info.IsStream { if info.IsStream {
action = "streamGenerateContent" action = "streamGenerateContent"
} }
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
@ -63,7 +63,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = geminiChatStreamHandler(c, resp) err, responseText = geminiChatStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)

View File

@ -7,10 +7,12 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"strings" "strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@ -160,10 +162,10 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
return &response return &response
} }
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) { func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) {
responseText := "" responseText := ""
dataChan := make(chan string) dataChan := make(chan string, 5)
stopChan := make(chan bool) stopChan := make(chan bool, 2)
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
@ -186,14 +188,23 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIEr
} }
data = strings.TrimPrefix(data, "\"text\": \"") data = strings.TrimPrefix(data, "\"text\": \"")
data = strings.TrimSuffix(data, "\"") data = strings.TrimSuffix(data, "\"")
dataChan <- data if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}
} }
stopChan <- true stopChan <- true
}() }()
isFirst := true
service.SetEventStreamHeaders(c) service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
// this is used to prevent annoying \ related format bug // this is used to prevent annoying \ related format bug
data = fmt.Sprintf("{\"content\": \"%s\"}", data) data = fmt.Sprintf("{\"content\": \"%s\"}", data)
type dummyStruct struct { type dummyStruct struct {

View File

@ -52,7 +52,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
if info.RelayMode == relayconstant.RelayModeEmbeddings { if info.RelayMode == relayconstant.RelayModeEmbeddings {

View File

@ -82,7 +82,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream { if info.IsStream {
var responseText string var responseText string
var toolCount int var toolCount int
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7 usage.CompletionTokens += toolCount * 7
} else { } else {

View File

@ -8,7 +8,9 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"strings" "strings"
@ -16,7 +18,7 @@ import (
"time" "time"
) )
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string, int) { func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string, int) {
//checkSensitive := constant.ShouldCheckCompletionSensitive() //checkSensitive := constant.ShouldCheckCompletionSensitive()
var responseTextBuilder strings.Builder var responseTextBuilder strings.Builder
toolCount := 0 toolCount := 0
@ -50,14 +52,18 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
if data[:6] != "data: " && data[:6] != "[DONE]" { if data[:6] != "data: " && data[:6] != "[DONE]" {
continue continue
} }
common.SafeSendString(dataChan, data) if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}
data = data[6:] data = data[6:]
if !strings.HasPrefix(data, "[DONE]") { if !strings.HasPrefix(data, "[DONE]") {
streamItems = append(streamItems, data) streamItems = append(streamItems, data)
} }
} }
streamResp := "[" + strings.Join(streamItems, ",") + "]" streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch relayMode { switch info.RelayMode {
case relayconstant.RelayModeChatCompletions: case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponseSimple var streamResponses []dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
@ -126,9 +132,14 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
common.SafeSendBool(stopChan, true) common.SafeSendBool(stopChan, true)
}() }()
service.SetEventStreamHeaders(c) service.SetEventStreamHeaders(c)
isFirst := true
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
if strings.HasPrefix(data, "data: [DONE]") { if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12] data = data[:12]
} }
@ -187,7 +198,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
if simpleResponse.Usage.TotalTokens == 0 { if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0 completionTokens := 0
for _, choice := range simpleResponse.Choices { for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model) ctkm, _ := service.CountTokenText(string(choice.Message.Content), model)

View File

@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)

View File

@ -6,6 +6,7 @@ var ModelList = []string{
"SparkDesk-v2.1", "SparkDesk-v2.1",
"SparkDesk-v3.1", "SparkDesk-v3.1",
"SparkDesk-v3.5", "SparkDesk-v3.5",
"SparkDesk-v4.0",
} }
var ChannelName = "xunfei" var ChannelName = "xunfei"

View File

@ -252,6 +252,8 @@ func apiVersion2domain(apiVersion string) string {
return "generalv3" return "generalv3"
case "v3.5": case "v3.5":
return "generalv3.5" return "generalv3.5"
case "v4.0":
return "4.0Ultra"
} }
return "general" + apiVersion return "general" + apiVersion
} }

View File

@ -48,7 +48,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream { if info.IsStream {
var responseText string var responseText string
var toolCount int var toolCount int
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7 usage.CompletionTokens += toolCount * 7
} else { } else {

View File

@ -16,6 +16,7 @@ type RelayInfo struct {
Group string Group string
TokenUnlimited bool TokenUnlimited bool
StartTime time.Time StartTime time.Time
FirstResponseTime time.Time
ApiType int ApiType int
IsStream bool IsStream bool
RelayMode int RelayMode int
@ -37,24 +38,26 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
group := c.GetString("group") group := c.GetString("group")
tokenUnlimited := c.GetBool("token_unlimited_quota") tokenUnlimited := c.GetBool("token_unlimited_quota")
startTime := time.Now() startTime := time.Now()
// firstResponseTime = time.Now() - 1 second
apiType, _ := constant.ChannelType2APIType(channelType) apiType, _ := constant.ChannelType2APIType(channelType)
info := &RelayInfo{ info := &RelayInfo{
RelayMode: constant.Path2RelayMode(c.Request.URL.Path), RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"), BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(), RequestURLPath: c.Request.URL.String(),
ChannelType: channelType, ChannelType: channelType,
ChannelId: channelId, ChannelId: channelId,
TokenId: tokenId, TokenId: tokenId,
UserId: userId, UserId: userId,
Group: group, Group: group,
TokenUnlimited: tokenUnlimited, TokenUnlimited: tokenUnlimited,
StartTime: startTime, StartTime: startTime,
ApiType: apiType, FirstResponseTime: startTime.Add(-time.Second),
ApiVersion: c.GetString("api_version"), ApiType: apiType,
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), ApiVersion: c.GetString("api_version"),
Organization: c.GetString("channel_organization"), ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
} }
if info.BaseUrl == "" { if info.BaseUrl == "" {
info.BaseUrl = common.ChannelBaseURLs[channelType] info.BaseUrl = common.ChannelBaseURLs[channelType]

View File

@ -538,7 +538,16 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
ChannelId: c.GetInt("channel_id"), ChannelId: c.GetInt("channel_id"),
Quota: quota, Quota: quota,
} }
if midjResponse.Code == 3 {
//无实例账号自动禁用渠道No available account instance
channel, err := model.GetChannelById(midjourneyTask.ChannelId, true)
if err != nil {
common.SysError("get_channel_null: " + err.Error())
}
if channel.AutoBan != nil && *channel.AutoBan == 1 {
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
}
}
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 { if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
//非1-提交成功,21-任务已存在和22-排队中,则记录错误原因 //非1-提交成功,21-任务已存在和22-排队中,则记录错误原因
midjourneyTask.FailReason = midjResponse.Description midjourneyTask.FailReason = midjResponse.Description

View File

@ -344,14 +344,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
logModel = "g-*" logModel = "g-*"
logContent += fmt.Sprintf(",模型 %s", textRequest.Model) logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
} }
other := make(map[string]interface{}) other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
other["model_ratio"] = modelRatio
other["group_ratio"] = groupRatio
other["completion_ratio"] = completionRatio
other["model_price"] = modelPrice
adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
other["admin_info"] = adminInfo
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
//if quota != 0 { //if quota != 0 {

19
service/log.go Normal file
View File

@ -0,0 +1,19 @@
package service
import (
"github.com/gin-gonic/gin"
relaycommon "one-api/relay/common"
)
func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio, modelPrice float64) map[string]interface{} {
other := make(map[string]interface{})
other["model_ratio"] = modelRatio
other["group_ratio"] = groupRatio
other["completion_ratio"] = completionRatio
other["model_price"] = modelPrice
other["frt"] = float64(relayInfo.FirstResponseTime.UnixMilli() - relayInfo.StartTime.UnixMilli())
adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
other["admin_info"] = adminInfo
return other
}

View File

@ -3,7 +3,6 @@ package service
import ( import (
"errors" "errors"
"fmt" "fmt"
"one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"strings" "strings"
@ -62,7 +61,7 @@ func SensitiveWordContains(text string) (bool, []string) {
} }
checkText := strings.ToLower(text) checkText := strings.ToLower(text)
// 构建一个AC自动机 // 构建一个AC自动机
m := common.InitAc() m := InitAc()
hits := m.MultiPatternSearch([]rune(checkText), false) hits := m.MultiPatternSearch([]rune(checkText), false)
if len(hits) > 0 { if len(hits) > 0 {
words := make([]string, 0) words := make([]string, 0)
@ -80,7 +79,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
return false, nil, text return false, nil, text
} }
checkText := strings.ToLower(text) checkText := strings.ToLower(text)
m := common.InitAc() m := InitAc()
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately) hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
if len(hits) > 0 { if len(hits) > 0 {
words := make([]string, 0) words := make([]string, 0)

View File

@ -1,4 +1,4 @@
package common package service
import ( import (
"bytes" "bytes"

File diff suppressed because it is too large Load Diff

View File

@ -29,6 +29,7 @@ import {
stringToColor, stringToColor,
} from '../helpers/render'; } from '../helpers/render';
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph'; import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
import { getLogOther } from '../helpers/other.js';
const { Header } = Layout; const { Header } = Layout;
@ -141,6 +142,33 @@ function renderUseTime(type) {
} }
} }
function renderFirstUseTime(type) {
let time = parseFloat(type) / 1000.0;
time = time.toFixed(1);
if (time < 3) {
return (
<Tag color='green' size='large'>
{' '}
{time} s{' '}
</Tag>
);
} else if (time < 10) {
return (
<Tag color='orange' size='large'>
{' '}
{time} s{' '}
</Tag>
);
} else {
return (
<Tag color='red' size='large'>
{' '}
{time} s{' '}
</Tag>
);
}
}
const LogsTable = () => { const LogsTable = () => {
const columns = [ const columns = [
{ {
@ -247,17 +275,30 @@ const LogsTable = () => {
}, },
}, },
{ {
title: '用时', title: '用时/首字',
dataIndex: 'use_time', dataIndex: 'use_time',
render: (text, record, index) => { render: (text, record, index) => {
return ( if (record.is_stream) {
<div> let other = getLogOther(record.other);
<Space> return (
{renderUseTime(text)} <div>
{renderIsStream(record.is_stream)} <Space>
</Space> {renderUseTime(text)}
</div> {renderFirstUseTime(other.frt)}
); {renderIsStream(record.is_stream)}
</Space>
</div>
);
} else {
return (
<div>
<Space>
{renderUseTime(text)}
{renderIsStream(record.is_stream)}
</Space>
</div>
);
}
}, },
}, },
{ {
@ -325,10 +366,7 @@ const LogsTable = () => {
title: '详情', title: '详情',
dataIndex: 'content', dataIndex: 'content',
render: (text, record, index) => { render: (text, record, index) => {
if (record.other === '') { let other = getLogOther(record.other);
record.other = '{}';
}
let other = JSON.parse(record.other);
if (other == null) { if (other == null) {
return ( return (
<Paragraph <Paragraph

7
web/src/helpers/other.js Normal file
View File

@ -0,0 +1,7 @@
export function getLogOther(otherStr) {
if (otherStr === undefined || otherStr === '') {
otherStr = '{}';
}
let other = JSON.parse(otherStr);
return other;
}

View File

@ -144,28 +144,42 @@ export function renderModelPrice(
) { ) {
// 1 ratio = $0.002 / 1K tokens // 1 ratio = $0.002 / 1K tokens
if (modelPrice !== -1) { if (modelPrice !== -1) {
return '模型价格:$' + modelPrice * groupRatio; return (
'模型价格:$' +
modelPrice +
' * 分组倍率:' +
groupRatio +
' = $' +
modelPrice * groupRatio
);
} else { } else {
if (completionRatio === undefined) { if (completionRatio === undefined) {
completionRatio = 0; completionRatio = 0;
} }
// 这里的 *2 是因为 1倍率=0.002刀,请勿删除 // 这里的 *2 是因为 1倍率=0.002刀,请勿删除
let inputRatioPrice = modelRatio * 2.0 * groupRatio; let inputRatioPrice = modelRatio * 2.0;
let completionRatioPrice = modelRatio * 2.0 * completionRatio * groupRatio; let completionRatioPrice = modelRatio * 2.0 * completionRatio;
let price = let price =
(inputTokens / 1000000) * inputRatioPrice + (inputTokens / 1000000) * inputRatioPrice +
(completionTokens / 1000000) * completionRatioPrice; (completionTokens / 1000000) * completionRatioPrice;
return ( return (
<> <>
<article> <article>
<p>提示 ${inputRatioPrice} / 1M tokens</p> <p>
<p>补全 ${completionRatioPrice} / 1M tokens</p> 提示${inputRatioPrice} * {groupRatio} = $
{inputRatioPrice * groupRatio} / 1M tokens
</p>
<p>
补全${completionRatioPrice} * {groupRatio} = $
{completionRatioPrice * groupRatio} / 1M tokens
</p>
<p></p> <p></p>
<p> <p>
提示 {inputTokens} tokens / 1M tokens * ${inputRatioPrice} + 补全{' '} 提示 {inputTokens} tokens / 1M tokens * ${inputRatioPrice} + 补全{' '}
{completionTokens} tokens / 1M tokens * ${completionRatioPrice} = $ {completionTokens} tokens / 1M tokens * ${completionRatioPrice} *
{price.toFixed(6)} 分组 {groupRatio} = ${price.toFixed(6)}
</p> </p>
<p>仅供参考以实际扣费为准</p>
</article> </article>
</> </>
); );