Compare commits

...

9 Commits

Author SHA1 Message Date
CalciumIon
6c5b3b51b0 fix: try to fix tencent hunyuan #336 2024-07-05 20:00:52 +08:00
CalciumIon
d306aea9e5 feat: log mj task id 2024-07-05 17:22:36 +08:00
CalciumIon
d4578e28b3 fix: channel auto ban 2024-07-04 22:46:43 +08:00
CalciumIon
584eefec3e feat: 完善日志扣费计算过程 2024-07-01 00:56:37 +08:00
CalciumIon
a7e3168c17 feat: support cohere first response time 2024-06-28 23:32:02 +08:00
CalciumIon
d767ae04ff chore: 重构 2024-06-27 19:30:17 +08:00
CalciumIon
402a415c79 feat: 支持设置流模式超时时间(gemini, claude) 2024-06-27 17:24:48 +08:00
CalciumIon
55c28b2f98 Merge remote-tracking branch 'origin/main' 2024-06-27 17:17:48 +08:00
CalciumIon
fc6ae6bf34 feat: 支持设置流模式超时时间 2024-06-27 17:17:23 +08:00
32 changed files with 351 additions and 241 deletions

View File

@@ -83,6 +83,8 @@
```
可以实现400错误转为500错误从而重试
## 比原版One API多出的配置
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
## 部署
### 部署要求

View File

@@ -103,14 +103,14 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 60) // unit is second
var SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
const (
RequestIdKey = "X-Oneapi-Request-Id"
@@ -133,10 +133,10 @@ var (
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10

26
common/env.go Normal file
View File

@@ -0,0 +1,26 @@
package common
import (
"fmt"
"os"
"strconv"
)
func GetEnvOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetEnvOrDefaultString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}

View File

@@ -3,6 +3,7 @@ package common
import (
"fmt"
"runtime/debug"
"time"
)
func SafeGoroutine(f func()) {
@@ -45,3 +46,21 @@ func SafeSendString(ch chan string, value string) (closed bool) {
// If the code reaches here, then the channel was not closed.
return false
}
// SafeSendStringTimeout send, return true, else return false
func SafeSendStringTimeout(ch chan string, value string, timeout int) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
if recover() != nil {
closed = false
}
}()
// This will panic if the channel is closed.
select {
case ch <- value:
return true
case <-time.After(time.Duration(timeout) * time.Second):
return false
}
}

View File

@@ -1,6 +1,8 @@
package common
import "encoding/json"
import (
"encoding/json"
)
var GroupRatio = map[string]float64{
"default": 1,

View File

@@ -1,6 +1,8 @@
package common
import "encoding/json"
import (
"encoding/json"
)
var TopupGroupRatio = map[string]float64{
"default": 1,

View File

@@ -8,7 +8,6 @@ import (
"log"
"math/rand"
"net"
"os"
"os/exec"
"runtime"
"strconv"
@@ -191,25 +190,6 @@ func Max(a int, b int) int {
}
}
func GetOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetOrDefaultString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}

7
constant/env.go Normal file
View File

@@ -0,0 +1,7 @@
package constant
import (
"one-api/common"
)
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)

View File

@@ -228,7 +228,7 @@ func testAllChannels(notify bool) error {
Error: *openaiErr,
LocalError: false,
}
if isChannelEnabled && service.ShouldDisableChannel(&openAiErrWithStatus) && ban {
if isChannelEnabled && service.ShouldDisableChannel(channel.Type, &openAiErrWithStatus) && ban {
service.DisableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {

View File

@@ -40,12 +40,13 @@ func Relay(c *gin.Context) {
retryTimes := common.RetryTimes
requestId := c.GetString(common.RequestIdKey)
channelId := c.GetInt("channel_id")
channelType := c.GetInt("channel_type")
group := c.GetString("group")
originalModel := c.GetString("original_model")
openaiErr := relayHandler(c, relayMode)
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
if openaiErr != nil {
go processChannelError(c, channelId, openaiErr)
go processChannelError(c, channelId, channelType, openaiErr)
} else {
retryTimes = 0
}
@@ -66,7 +67,7 @@ func Relay(c *gin.Context) {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
openaiErr = relayHandler(c, relayMode)
if openaiErr != nil {
go processChannelError(c, channelId, openaiErr)
go processChannelError(c, channelId, channel.Type, openaiErr)
}
}
useChannel := c.GetStringSlice("use_channel")
@@ -125,10 +126,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
return true
}
func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) {
func processChannelError(c *gin.Context, channelId int, channelType int, err *dto.OpenAIErrorWithStatusCode) {
autoBan := c.GetBool("auto_ban")
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
if service.ShouldDisableChannel(err) && autoBan {
if service.ShouldDisableChannel(channelType, err) && autoBan {
channelName := c.GetString("channel_name")
service.DisableChannel(channelId, channelName, err.Error.Message)
}

View File

@@ -24,7 +24,7 @@ func UpdateTaskBulk() {
//imageModel := "midjourney"
for {
time.Sleep(time.Duration(15) * time.Second)
common.SysLog("任务进度轮询开始")
common.SysLog(" 任务进度轮询开始")
ctx := context.TODO()
allTasks := model.GetAllUnFinishSyncTasks(500)
platformTask := make(map[constant.TaskPlatform][]*model.Task)

View File

@@ -5,11 +5,10 @@ import (
"github.com/Calcium-Ion/go-epay/epay"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"one-api/constant"
"log"
"net/url"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/service"
"strconv"

View File

@@ -24,14 +24,3 @@ type OpenAIModels struct {
Root string `json:"root"`
Parent *string `json:"parent"`
}
type ModelPricing struct {
Available bool `json:"available"`
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_group,omitempty"`
}

View File

@@ -178,6 +178,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("channel_type", channel.Type)
ban := true
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {

View File

@@ -86,9 +86,9 @@ func InitDB() (err error) {
if err != nil {
return err
}
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode {
return nil

View File

@@ -2,18 +2,28 @@ package model
import (
"one-api/common"
"one-api/dto"
"sync"
"time"
)
type Pricing struct {
Available bool `json:"available"`
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_group,omitempty"`
}
var (
pricingMap []dto.ModelPricing
pricingMap []Pricing
lastGetPricingTime time.Time
updatePricingLock sync.Mutex
)
func GetPricing(group string) []dto.ModelPricing {
func GetPricing(group string) []Pricing {
updatePricingLock.Lock()
defer updatePricingLock.Unlock()
@@ -21,7 +31,7 @@ func GetPricing(group string) []dto.ModelPricing {
updatePricing()
}
if group != "" {
userPricingMap := make([]dto.ModelPricing, 0)
userPricingMap := make([]Pricing, 0)
models := GetGroupModels(group)
for _, pricing := range pricingMap {
if !common.StringsContains(models, pricing.ModelName) {
@@ -42,9 +52,9 @@ func updatePricing() {
allModels[model] = i
}
pricingMap = make([]dto.ModelPricing, 0)
pricingMap = make([]Pricing, 0)
for model, _ := range allModels {
pricing := dto.ModelPricing{
pricing := Pricing{
Available: true,
ModelName: model,
}

View File

@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
@@ -267,8 +268,8 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
dataChan := make(chan string, 5)
stopChan := make(chan bool, 2)
go func() {
for scanner.Scan() {
data := scanner.Text()
@@ -276,7 +277,11 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
continue
}
data = strings.TrimPrefix(data, "data: ")
dataChan <- data
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}
}
stopChan <- true
}()

View File

@@ -36,7 +36,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = cohereStreamHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
err, usage = cohereStreamHandler(c, resp, info)
} else {
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
}

View File

@@ -9,8 +9,10 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
)
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
@@ -56,7 +58,7 @@ func stopReasonCohere2OpenAI(reason string) string {
}
}
func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
usage := &dto.Usage{}
@@ -84,9 +86,14 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
stopChan <- true
}()
service.SetEventStreamHeaders(c)
isFirst := true
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
data = strings.TrimSuffix(data, "\r")
var cohereResp CohereResponse
err := json.Unmarshal([]byte(data), &cohereResp)
@@ -98,7 +105,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
openaiResp.Id = responseId
openaiResp.Created = createdTime
openaiResp.Object = "chat.completion.chunk"
openaiResp.Model = modelName
openaiResp.Model = info.UpstreamModelName
if cohereResp.IsFinished {
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
@@ -137,7 +144,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
}
})
if usage.PromptTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
return nil, usage
}

View File

@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
@@ -163,8 +164,8 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) {
responseText := ""
dataChan := make(chan string)
stopChan := make(chan bool)
dataChan := make(chan string, 5)
stopChan := make(chan bool, 2)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -187,7 +188,11 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
}
data = strings.TrimPrefix(data, "\"text\": \"")
data = strings.TrimSuffix(data, "\"")
dataChan <- data
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}
}
stopChan <- true
}()

View File

@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
@@ -51,7 +52,11 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
common.SafeSendString(dataChan, data)
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
streamItems = append(streamItems, data)

View File

@@ -6,18 +6,26 @@ import (
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"strconv"
"strings"
)
type Adaptor struct {
Sign string
Sign string
Action string
Version string
Timestamp int64
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
a.Action = "ChatCompletions"
a.Version = "2023-09-01"
a.Timestamp = common.GetTimestamp()
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -27,7 +35,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", a.Sign)
req.Header.Set("X-TC-Action", info.UpstreamModelName)
req.Header.Set("X-TC-Action", a.Action)
req.Header.Set("X-TC-Version", a.Version)
req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
return nil
}
@@ -37,15 +47,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
_, secretId, secretKey, err := parseTencentConfig(apiKey)
if err != nil {
return nil, err
}
tencentRequest := requestOpenAI2Tencent(*request)
tencentRequest.AppId = appId
tencentRequest.SecretId = secretId
// we have to calculate the sign here
a.Sign = getTencentSign(*tencentRequest, secretKey)
a.Sign = getTencentSign(*tencentRequest, a, secretId, secretKey)
return tencentRequest, nil
}

View File

@@ -1,62 +1,71 @@
package tencent
import "one-api/dto"
type TencentMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Role string `json:"Role"`
Content string `json:"Content"`
}
type TencentChatRequest struct {
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
SecretId string `json:"secret_id"` // 官网 SecretId
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
// 例如1529223702如果与当前时间相差过大会引起签名过期错误
Timestamp int64 `json:"timestamp"`
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
// 单位为秒Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
Expired int64 `json:"expired"`
QueryID string `json:"query_id"` //请求 Id用于问题排查
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
// 建议该参数和 top_p 只设置1个不要同时更改 top_p
Temperature float64 `json:"temperature"`
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
// 建议该参数和 temperature 只设置1个不要同时更改
TopP float64 `json:"top_p"`
// Stream 0同步1流式 默认协议SSE)
// 同步请求超时60s如果内容较长建议使用流式
Stream int `json:"stream"`
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
// 输入 content 总数最大支持 3000 token。
Messages []TencentMessage `json:"messages"`
Model string `json:"model"` // 模型名称
// 模型名称,可选值包括 hunyuan-lite、hunyuan-standard、hunyuan-standard-256K、hunyuan-pro。
// 各模型介绍请阅读 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 中的说明。
//
// 注意:
// 不同的模型计费不同,请根据 [购买指南](https://cloud.tencent.com/document/product/1729/97731) 按需调用。
Model *string `json:"Model"`
// 聊天上下文信息。
// 说明:
// 1. 长度最多为 40按对话时间从旧到新在数组中排列。
// 2. Message.Role 可选值system、user、assistant。
// 其中system 角色可选如存在则必须位于列表的最开始。user 和 assistant 需交替出现(一问一答),以 user 提问开始和结束,且 Content 不能为空。Role 的顺序示例:[system可选 user assistant user assistant user ...]。
// 3. Messages 中 Content 总长度不能超过模型输入长度上限(可参考 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 文档),超过则会截断最前面的内容,只保留尾部内容。
Messages []*TencentMessage `json:"Messages"`
// 流式调用开关。
// 说明:
// 1. 未传值时默认为非流式调用false
// 2. 流式调用时以 SSE 协议增量返回结果(返回值取 Choices[n].Delta 中的值,需要拼接增量数据才能获得完整结果)。
// 3. 非流式调用时:
// 调用方式与普通 HTTP 请求无异。
// 接口响应耗时较长,**如需更低时延建议设置为 true**。
// 只返回一次最终结果(返回值取 Choices[n].Message 中的值)。
//
// 注意:
// 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。
Stream *bool `json:"Stream"`
// 说明:
// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
// 3. 非必要不建议使用,不合理的取值会影响效果。
TopP *float64 `json:"TopP"`
// 说明:
// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
// 3. 非必要不建议使用,不合理的取值会影响效果。
Temperature *float64 `json:"Temperature"`
}
type TencentError struct {
Code int `json:"code"`
Message string `json:"message"`
Code int `json:"Code"`
Message string `json:"Message"`
}
type TencentUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokens int `json:"PromptTokens"`
CompletionTokens int `json:"CompletionTokens"`
TotalTokens int `json:"TotalTokens"`
}
type TencentResponseChoices struct {
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
FinishReason string `json:"FinishReason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages TencentMessage `json:"Message,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta TencentMessage `json:"Delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
}
type TencentChatResponse struct {
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id
Usage dto.Usage `json:"usage,omitempty"` // token 数量
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
Choices []TencentResponseChoices `json:"Choices,omitempty"` // 结果
Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串
Id string `json:"Id,omitempty"` // 会话 id
Usage TencentUsage `json:"Usage,omitempty"` // token 数量
Error TencentError `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"Note,omitempty"` // 注释
ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}

View File

@@ -3,8 +3,8 @@ package tencent
import (
"bufio"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
@@ -15,46 +15,28 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"sort"
"strconv"
"strings"
"time"
)
// https://cloud.tencent.com/document/product/1729/97732
func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest {
messages := make([]TencentMessage, 0, len(request.Messages))
messages := make([]*TencentMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, TencentMessage{
Role: "user",
Content: message.StringContent(),
})
messages = append(messages, TencentMessage{
Role: "assistant",
Content: "Okay",
})
continue
}
messages = append(messages, TencentMessage{
messages = append(messages, &TencentMessage{
Content: message.StringContent(),
Role: message.Role,
})
}
stream := 0
if request.Stream {
stream = 1
}
return &TencentChatRequest{
Timestamp: common.GetTimestamp(),
Expired: common.GetTimestamp() + 24*60*60,
QueryID: common.GetUUID(),
Temperature: request.Temperature,
TopP: request.TopP,
Stream: stream,
Temperature: &request.Temperature,
TopP: &request.TopP,
Stream: &request.Stream,
Messages: messages,
Model: request.Model,
Model: &request.Model,
}
}
@@ -62,7 +44,11 @@ func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextRespon
fullTextResponse := dto.OpenAITextResponse{
Object: "chat.completion",
Created: common.GetTimestamp(),
Usage: response.Usage,
Usage: dto.Usage{
PromptTokens: response.Usage.PromptTokens,
CompletionTokens: response.Usage.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
},
}
if len(response.Choices) > 0 {
content, _ := json.Marshal(response.Choices[0].Messages.Content)
@@ -99,64 +85,46 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var TencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &TencentResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.GetContentString()
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
continue
}
})
data = strings.TrimPrefix(data, "data:")
var tencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &tencentResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue
}
response := streamResponseTencent2OpenAI(&tencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.GetContentString()
}
err = service.ObjectData(c, response)
if err != nil {
common.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
common.SysError("error reading stream: " + err.Error())
}
service.Done(c)
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
@@ -206,29 +174,62 @@ func parseTencentConfig(config string) (appId int64, secretId string, secretKey
return
}
func getTencentSign(req TencentChatRequest, secretKey string) string {
params := make([]string, 0)
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
params = append(params, "secret_id="+req.SecretId)
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
params = append(params, "query_id="+req.QueryID)
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
params = append(params, "stream="+strconv.Itoa(req.Stream))
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
var messageStr string
for _, msg := range req.Messages {
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
}
messageStr = strings.TrimSuffix(messageStr, ",")
params = append(params, "messages=["+messageStr+"]")
sort.Sort(sort.StringSlice(params))
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
mac := hmac.New(sha1.New, []byte(secretKey))
signURL := url
mac.Write([]byte(signURL))
sign := mac.Sum([]byte(nil))
return base64.StdEncoding.EncodeToString(sign)
func sha256hex(s string) string {
b := sha256.Sum256([]byte(s))
return hex.EncodeToString(b[:])
}
func hmacSha256(s, key string) string {
hashed := hmac.New(sha256.New, []byte(key))
hashed.Write([]byte(s))
return string(hashed.Sum(nil))
}
func getTencentSign(req TencentChatRequest, adaptor *Adaptor, secId, secKey string) string {
// build canonical request string
host := "hunyuan.tencentcloudapi.com"
httpRequestMethod := "POST"
canonicalURI := "/"
canonicalQueryString := ""
canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n",
"application/json", host, strings.ToLower(adaptor.Action))
signedHeaders := "content-type;host;x-tc-action"
payload, _ := json.Marshal(req)
hashedRequestPayload := sha256hex(string(payload))
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
httpRequestMethod,
canonicalURI,
canonicalQueryString,
canonicalHeaders,
signedHeaders,
hashedRequestPayload)
// build string to sign
algorithm := "TC3-HMAC-SHA256"
requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10)
timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64)
t := time.Unix(timestamp, 0).UTC()
// must be the format 2006-01-02, ref to package time for more info
date := t.Format("2006-01-02")
credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan")
hashedCanonicalRequest := sha256hex(canonicalRequest)
string2sign := fmt.Sprintf("%s\n%s\n%s\n%s",
algorithm,
requestTimestamp,
credentialScope,
hashedCanonicalRequest)
// sign string
secretDate := hmacSha256(date, "TC3"+secKey)
secretService := hmacSha256("hunyuan", secretDate)
secretKey := hmacSha256("tc3_request", secretService)
signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey)))
// build authorization
authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
algorithm,
secId,
credentialScope,
signedHeaders,
signature)
return authorization
}

View File

@@ -73,14 +73,14 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota-preConsumedQuota < 0 {
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil {
return service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
@@ -90,7 +90,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}

View File

@@ -147,7 +147,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
quota := int(modelPrice*groupRatio*common.QuotaPerUnit*sizeRatio*qualityRatio) * imageRequest.N
if userQuota-quota < 0 {
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)

View File

@@ -500,7 +500,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %sID %s", modelPrice, groupRatio, midjRequest.Action, midjResponse.Result)
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio

View File

@@ -24,7 +24,7 @@ func EnableChannel(channelId int, channelName string) {
notifyRootUser(subject, content)
}
func ShouldDisableChannel(err *relaymodel.OpenAIErrorWithStatusCode) bool {
func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatusCode) bool {
if !common.AutomaticDisableChannelEnabled {
return false
}
@@ -34,9 +34,15 @@ func ShouldDisableChannel(err *relaymodel.OpenAIErrorWithStatusCode) bool {
if err.LocalError {
return false
}
if err.StatusCode == http.StatusUnauthorized || err.StatusCode == http.StatusForbidden {
if err.StatusCode == http.StatusUnauthorized {
return true
}
if err.StatusCode == http.StatusForbidden {
switch channelType {
case common.ChannelTypeGemini:
return true
}
}
switch err.Error.Code {
case "invalid_api_key":
return true

View File

@@ -3,7 +3,6 @@ package service
import (
"errors"
"fmt"
"one-api/common"
"one-api/constant"
"one-api/dto"
"strings"
@@ -62,7 +61,7 @@ func SensitiveWordContains(text string) (bool, []string) {
}
checkText := strings.ToLower(text)
// 构建一个AC自动机
m := common.InitAc()
m := InitAc()
hits := m.MultiPatternSearch([]rune(checkText), false)
if len(hits) > 0 {
words := make([]string, 0)
@@ -80,7 +79,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
return false, nil, text
}
checkText := strings.ToLower(text)
m := common.InitAc()
m := InitAc()
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
if len(hits) > 0 {
words := make([]string, 0)

View File

@@ -1,6 +1,12 @@
package service
import "github.com/gin-gonic/gin"
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
"strings"
)
func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
@@ -9,3 +15,23 @@ func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}
func StringData(c *gin.Context, str string) {
str = strings.TrimPrefix(str, "data: ")
str = strings.TrimSuffix(str, "\r")
c.Render(-1, common.CustomEvent{Data: "data: " + str})
c.Writer.Flush()
}
func ObjectData(c *gin.Context, object interface{}) error {
jsonData, err := json.Marshal(object)
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)
}
StringData(c, string(jsonData))
return nil
}
func Done(c *gin.Context) {
StringData(c, "[DONE]")
}

View File

@@ -1,4 +1,4 @@
package common
package service
import (
"bytes"

View File

@@ -144,28 +144,29 @@ export function renderModelPrice(
) {
// 1 ratio = $0.002 / 1K tokens
if (modelPrice !== -1) {
return '模型价格:$' + modelPrice * groupRatio;
return '模型价格:$' + modelPrice + ' * 分组倍率:' + groupRatio + ' = $' + modelPrice * groupRatio;
} else {
if (completionRatio === undefined) {
completionRatio = 0;
}
// 这里的 *2 是因为 1倍率=0.002刀,请勿删除
let inputRatioPrice = modelRatio * 2.0 * groupRatio;
let completionRatioPrice = modelRatio * 2.0 * completionRatio * groupRatio;
let inputRatioPrice = modelRatio * 2.0;
let completionRatioPrice = modelRatio * 2.0 * completionRatio;
let price =
(inputTokens / 1000000) * inputRatioPrice +
(completionTokens / 1000000) * completionRatioPrice;
return (
<>
<article>
<p>提示 ${inputRatioPrice} / 1M tokens</p>
<p>补全 ${completionRatioPrice} / 1M tokens</p>
<p>提示${inputRatioPrice} * {groupRatio} = ${inputRatioPrice * groupRatio} / 1M tokens</p>
<p>补全${completionRatioPrice} * {groupRatio} = ${completionRatioPrice * groupRatio} / 1M tokens</p>
<p></p>
<p>
提示 {inputTokens} tokens / 1M tokens * ${inputRatioPrice} + 补全{' '}
{completionTokens} tokens / 1M tokens * ${completionRatioPrice} = $
{price.toFixed(6)}
{completionTokens} tokens / 1M tokens * ${completionRatioPrice} * 分组 {groupRatio} =
${price.toFixed(6)}
</p>
<p>仅供参考以实际扣费为准</p>
</article>
</>
);