mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 07:56:38 +08:00
merge upstream
Signed-off-by: wozulong <>
This commit is contained in:
commit
895ee09b33
@ -5,8 +5,6 @@
|
|||||||
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。
|
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。
|
||||||
> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
|
> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
|
||||||
|
|
||||||
|
|
||||||
> [!WARNING]
|
|
||||||
> 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。
|
> 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。
|
||||||
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
||||||
|
|
||||||
@ -85,8 +83,13 @@
|
|||||||
```
|
```
|
||||||
可以实现400错误转为500错误,从而重试
|
可以实现400错误转为500错误,从而重试
|
||||||
|
|
||||||
|
## 比原版One API多出的配置
|
||||||
|
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
|
||||||
|
|
||||||
## 部署
|
## 部署
|
||||||
|
### 部署要求
|
||||||
|
- 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机)
|
||||||
|
- 远程数据库:MySQL 版本 >= 5.7.8,PgSQL 版本 >= 9.6
|
||||||
### 基于 Docker 进行部署
|
### 基于 Docker 进行部署
|
||||||
```shell
|
```shell
|
||||||
# 使用 SQLite 的部署命令:
|
# 使用 SQLite 的部署命令:
|
||||||
|
@ -120,14 +120,14 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
|||||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||||
var RequestInterval = time.Duration(requestInterval) * time.Second
|
var RequestInterval = time.Duration(requestInterval) * time.Second
|
||||||
|
|
||||||
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 60) // unit is second
|
var SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60) // unit is second
|
||||||
|
|
||||||
var BatchUpdateEnabled = false
|
var BatchUpdateEnabled = false
|
||||||
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
var BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||||
|
|
||||||
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
|
var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second
|
||||||
|
|
||||||
var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RequestIdKey = "X-Oneapi-Request-Id"
|
RequestIdKey = "X-Oneapi-Request-Id"
|
||||||
@ -150,10 +150,10 @@ var (
|
|||||||
// All duration's unit is seconds
|
// All duration's unit is seconds
|
||||||
// Shouldn't larger then RateLimitKeyExpirationDuration
|
// Shouldn't larger then RateLimitKeyExpirationDuration
|
||||||
var (
|
var (
|
||||||
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
|
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
|
||||||
GlobalApiRateLimitDuration int64 = 3 * 60
|
GlobalApiRateLimitDuration int64 = 3 * 60
|
||||||
|
|
||||||
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||||
GlobalWebRateLimitDuration int64 = 3 * 60
|
GlobalWebRateLimitDuration int64 = 3 * 60
|
||||||
|
|
||||||
UploadRateLimitNum = 10
|
UploadRateLimitNum = 10
|
||||||
|
26
common/env.go
Normal file
26
common/env.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetEnvOrDefault(env string, defaultValue int) int {
|
||||||
|
if env == "" || os.Getenv(env) == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
num, err := strconv.Atoi(os.Getenv(env))
|
||||||
|
if err != nil {
|
||||||
|
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return num
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetEnvOrDefaultString(env string, defaultValue string) string {
|
||||||
|
if env == "" || os.Getenv(env) == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return os.Getenv(env)
|
||||||
|
}
|
@ -3,6 +3,7 @@ package common
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SafeGoroutine(f func()) {
|
func SafeGoroutine(f func()) {
|
||||||
@ -45,3 +46,21 @@ func SafeSendString(ch chan string, value string) (closed bool) {
|
|||||||
// If the code reaches here, then the channel was not closed.
|
// If the code reaches here, then the channel was not closed.
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SafeSendStringTimeout send, return true, else return false
|
||||||
|
func SafeSendStringTimeout(ch chan string, value string, timeout int) (closed bool) {
|
||||||
|
defer func() {
|
||||||
|
// Recover from panic if one occured. A panic would mean the channel was closed.
|
||||||
|
if recover() != nil {
|
||||||
|
closed = false
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// This will panic if the channel is closed.
|
||||||
|
select {
|
||||||
|
case ch <- value:
|
||||||
|
return true
|
||||||
|
case <-time.After(time.Duration(timeout) * time.Second):
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import "encoding/json"
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
var GroupRatio = map[string]float64{
|
var GroupRatio = map[string]float64{
|
||||||
"default": 1,
|
"default": 1,
|
||||||
|
@ -114,6 +114,7 @@ var defaultModelRatio = map[string]float64{
|
|||||||
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
|
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
|
||||||
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
|
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
|
||||||
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
|
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
|
||||||
|
"SparkDesk-v4.0": 1.2858,
|
||||||
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
||||||
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
|
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
|
||||||
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens
|
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import "encoding/json"
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
var TopupGroupRatio = map[string]float64{
|
var TopupGroupRatio = map[string]float64{
|
||||||
"default": 1,
|
"default": 1,
|
||||||
|
@ -13,7 +13,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -196,25 +195,6 @@ func Max(a int, b int) int {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetOrDefault(env string, defaultValue int) int {
|
|
||||||
if env == "" || os.Getenv(env) == "" {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
num, err := strconv.Atoi(os.Getenv(env))
|
|
||||||
if err != nil {
|
|
||||||
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return num
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetOrDefaultString(env string, defaultValue string) string {
|
|
||||||
if env == "" || os.Getenv(env) == "" {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return os.Getenv(env)
|
|
||||||
}
|
|
||||||
|
|
||||||
func MessageWithRequestId(message string, id string) string {
|
func MessageWithRequestId(message string, id string) string {
|
||||||
return fmt.Sprintf("%s (request id: %s)", message, id)
|
return fmt.Sprintf("%s (request id: %s)", message, id)
|
||||||
}
|
}
|
||||||
|
7
constant/env.go
Normal file
7
constant/env.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
|
@ -222,16 +222,18 @@ func testAllChannels(notify bool) error {
|
|||||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||||
ban = false
|
ban = false
|
||||||
}
|
}
|
||||||
openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{
|
if openaiErr != nil {
|
||||||
StatusCode: -1,
|
openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{
|
||||||
Error: *openaiErr,
|
StatusCode: -1,
|
||||||
LocalError: false,
|
Error: *openaiErr,
|
||||||
}
|
LocalError: false,
|
||||||
if isChannelEnabled && service.ShouldDisableChannel(&openAiErrWithStatus) && ban {
|
}
|
||||||
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
if isChannelEnabled && service.ShouldDisableChannel(&openAiErrWithStatus) && ban {
|
||||||
}
|
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||||
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
|
}
|
||||||
service.EnableChannel(channel.Id, channel.Name)
|
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
|
||||||
|
service.EnableChannel(channel.Id, channel.Name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
channel.UpdateResponseTime(milliseconds)
|
channel.UpdateResponseTime(milliseconds)
|
||||||
time.Sleep(common.RequestInterval)
|
time.Sleep(common.RequestInterval)
|
||||||
|
@ -19,9 +19,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func UpdateMidjourneyTaskBulk() {
|
func UpdateMidjourneyTaskBulk() {
|
||||||
if !common.IsMasterNode {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
//imageModel := "midjourney"
|
//imageModel := "midjourney"
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
for {
|
for {
|
||||||
|
@ -24,14 +24,3 @@ type OpenAIModels struct {
|
|||||||
Root string `json:"root"`
|
Root string `json:"root"`
|
||||||
Parent *string `json:"parent"`
|
Parent *string `json:"parent"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModelPricing struct {
|
|
||||||
Available bool `json:"available"`
|
|
||||||
ModelName string `json:"model_name"`
|
|
||||||
QuotaType int `json:"quota_type"`
|
|
||||||
ModelRatio float64 `json:"model_ratio"`
|
|
||||||
ModelPrice float64 `json:"model_price"`
|
|
||||||
OwnerBy string `json:"owner_by"`
|
|
||||||
CompletionRatio float64 `json:"completion_ratio"`
|
|
||||||
EnableGroup []string `json:"enable_group,omitempty"`
|
|
||||||
}
|
|
||||||
|
14
main.go
14
main.go
@ -89,12 +89,14 @@ func main() {
|
|||||||
}
|
}
|
||||||
go controller.AutomaticallyTestChannels(frequency)
|
go controller.AutomaticallyTestChannels(frequency)
|
||||||
}
|
}
|
||||||
common.SafeGoroutine(func() {
|
if common.IsMasterNode {
|
||||||
controller.UpdateMidjourneyTaskBulk()
|
common.SafeGoroutine(func() {
|
||||||
})
|
controller.UpdateMidjourneyTaskBulk()
|
||||||
common.SafeGoroutine(func() {
|
})
|
||||||
controller.UpdateTaskBulk()
|
common.SafeGoroutine(func() {
|
||||||
})
|
controller.UpdateTaskBulk()
|
||||||
|
})
|
||||||
|
}
|
||||||
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
|
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
|
||||||
common.BatchUpdateEnabled = true
|
common.BatchUpdateEnabled = true
|
||||||
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
||||||
|
@ -56,6 +56,11 @@ func getPriority(group string, model string, retry int) (int, error) {
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(priorities) == 0 {
|
||||||
|
// 如果没有查询到优先级,则返回错误
|
||||||
|
return 0, errors.New("数据库一致性被破坏")
|
||||||
|
}
|
||||||
|
|
||||||
// 确定要使用的优先级
|
// 确定要使用的优先级
|
||||||
var priorityToUse int
|
var priorityToUse int
|
||||||
if retry >= len(priorities) {
|
if retry >= len(priorities) {
|
||||||
@ -199,7 +204,7 @@ func FixAbility() (int, error) {
|
|||||||
|
|
||||||
// Use channelIds to find channel not in abilities table
|
// Use channelIds to find channel not in abilities table
|
||||||
var abilityChannelIds []int
|
var abilityChannelIds []int
|
||||||
err = DB.Model(&Ability{}).Pluck("channel_id", &abilityChannelIds).Error
|
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
|
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
|
||||||
return 0, err
|
return 0, err
|
||||||
|
@ -86,9 +86,9 @@ func InitDB() (err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
|
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
|
||||||
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
|
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
|
||||||
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
|
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
|
||||||
|
|
||||||
if !common.IsMasterNode {
|
if !common.IsMasterNode {
|
||||||
return nil
|
return nil
|
||||||
|
@ -2,18 +2,28 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Pricing struct {
|
||||||
|
Available bool `json:"available"`
|
||||||
|
ModelName string `json:"model_name"`
|
||||||
|
QuotaType int `json:"quota_type"`
|
||||||
|
ModelRatio float64 `json:"model_ratio"`
|
||||||
|
ModelPrice float64 `json:"model_price"`
|
||||||
|
OwnerBy string `json:"owner_by"`
|
||||||
|
CompletionRatio float64 `json:"completion_ratio"`
|
||||||
|
EnableGroup []string `json:"enable_group,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
pricingMap []dto.ModelPricing
|
pricingMap []Pricing
|
||||||
lastGetPricingTime time.Time
|
lastGetPricingTime time.Time
|
||||||
updatePricingLock sync.Mutex
|
updatePricingLock sync.Mutex
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetPricing(group string) []dto.ModelPricing {
|
func GetPricing(group string) []Pricing {
|
||||||
updatePricingLock.Lock()
|
updatePricingLock.Lock()
|
||||||
defer updatePricingLock.Unlock()
|
defer updatePricingLock.Unlock()
|
||||||
|
|
||||||
@ -21,7 +31,7 @@ func GetPricing(group string) []dto.ModelPricing {
|
|||||||
updatePricing()
|
updatePricing()
|
||||||
}
|
}
|
||||||
if group != "" {
|
if group != "" {
|
||||||
userPricingMap := make([]dto.ModelPricing, 0)
|
userPricingMap := make([]Pricing, 0)
|
||||||
models := GetGroupModels(group)
|
models := GetGroupModels(group)
|
||||||
for _, pricing := range pricingMap {
|
for _, pricing := range pricingMap {
|
||||||
if !common.StringsContains(models, pricing.ModelName) {
|
if !common.StringsContains(models, pricing.ModelName) {
|
||||||
@ -42,9 +52,9 @@ func updatePricing() {
|
|||||||
allModels[model] = i
|
allModels[model] = i
|
||||||
}
|
}
|
||||||
|
|
||||||
pricingMap = make([]dto.ModelPricing, 0)
|
pricingMap = make([]Pricing, 0)
|
||||||
for model, _ := range allModels {
|
for model, _ := range allModels {
|
||||||
pricing := dto.ModelPricing{
|
pricing := Pricing{
|
||||||
Available: true,
|
Available: true,
|
||||||
ModelName: model,
|
ModelName: model,
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
"one-api/relay/channel/claude"
|
"one-api/relay/channel/claude"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go-v2/aws"
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||||
@ -156,6 +157,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
|
|||||||
var usage relaymodel.Usage
|
var usage relaymodel.Usage
|
||||||
var id string
|
var id string
|
||||||
var model string
|
var model string
|
||||||
|
isFirst := true
|
||||||
createdTime := common.GetTimestamp()
|
createdTime := common.GetTimestamp()
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
event, ok := <-stream.Events()
|
event, ok := <-stream.Events()
|
||||||
@ -166,6 +168,10 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
|
|||||||
|
|
||||||
switch v := event.(type) {
|
switch v := event.(type) {
|
||||||
case *types.ResponseStreamMemberChunk:
|
case *types.ResponseStreamMemberChunk:
|
||||||
|
if isFirst {
|
||||||
|
isFirst = false
|
||||||
|
info.FirstResponseTime = time.Now()
|
||||||
|
}
|
||||||
claudeResp := new(claude.ClaudeResponse)
|
claudeResp := new(claude.ClaudeResponse)
|
||||||
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
|
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -65,7 +65,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
|
err, usage = claudeStreamHandler(c, resp, info, a.RequestMode)
|
||||||
} else {
|
} else {
|
||||||
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@ var ModelList = []string{
|
|||||||
"claude-3-sonnet-20240229",
|
"claude-3-sonnet-20240229",
|
||||||
"claude-3-opus-20240229",
|
"claude-3-opus-20240229",
|
||||||
"claude-3-haiku-20240307",
|
"claude-3-haiku-20240307",
|
||||||
|
"claude-3-5-sonnet-20240620",
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "claude"
|
var ChannelName = "claude"
|
||||||
|
@ -8,9 +8,12 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func stopReasonClaude2OpenAI(reason string) string {
|
func stopReasonClaude2OpenAI(reason string) string {
|
||||||
@ -246,7 +249,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
|||||||
return &fullTextResponse
|
return &fullTextResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||||
var usage *dto.Usage
|
var usage *dto.Usage
|
||||||
usage = &dto.Usage{}
|
usage = &dto.Usage{}
|
||||||
@ -265,8 +268,8 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
|
|||||||
}
|
}
|
||||||
return 0, nil, nil
|
return 0, nil, nil
|
||||||
})
|
})
|
||||||
dataChan := make(chan string)
|
dataChan := make(chan string, 5)
|
||||||
stopChan := make(chan bool)
|
stopChan := make(chan bool, 2)
|
||||||
go func() {
|
go func() {
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
@ -274,14 +277,23 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
data = strings.TrimPrefix(data, "data: ")
|
data = strings.TrimPrefix(data, "data: ")
|
||||||
dataChan <- data
|
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
|
||||||
|
// send data timeout, stop the stream
|
||||||
|
common.LogError(c, "send data timeout, stop the stream")
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
|
isFirst := true
|
||||||
service.SetEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
|
if isFirst {
|
||||||
|
isFirst = false
|
||||||
|
info.FirstResponseTime = time.Now()
|
||||||
|
}
|
||||||
// some implementations may add \r at the end of data
|
// some implementations may add \r at the end of data
|
||||||
data = strings.TrimSuffix(data, "\r")
|
data = strings.TrimSuffix(data, "\r")
|
||||||
var claudeResponse ClaudeResponse
|
var claudeResponse ClaudeResponse
|
||||||
@ -302,7 +314,7 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
|
|||||||
if claudeResponse.Type == "message_start" {
|
if claudeResponse.Type == "message_start" {
|
||||||
// message_start, 获取usage
|
// message_start, 获取usage
|
||||||
responseId = claudeResponse.Message.Id
|
responseId = claudeResponse.Message.Id
|
||||||
modelName = claudeResponse.Message.Model
|
info.UpstreamModelName = claudeResponse.Message.Model
|
||||||
usage.PromptTokens = claudeUsage.InputTokens
|
usage.PromptTokens = claudeUsage.InputTokens
|
||||||
} else if claudeResponse.Type == "content_block_delta" {
|
} else if claudeResponse.Type == "content_block_delta" {
|
||||||
responseText += claudeResponse.Delta.Text
|
responseText += claudeResponse.Delta.Text
|
||||||
@ -316,7 +328,7 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
|
|||||||
//response.Id = responseId
|
//response.Id = responseId
|
||||||
response.Id = responseId
|
response.Id = responseId
|
||||||
response.Created = createdTime
|
response.Created = createdTime
|
||||||
response.Model = modelName
|
response.Model = info.UpstreamModelName
|
||||||
|
|
||||||
jsonStr, err := json.Marshal(response)
|
jsonStr, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -335,13 +347,13 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
|
|||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
if usage.PromptTokens == 0 {
|
if usage.PromptTokens == 0 {
|
||||||
usage.PromptTokens = promptTokens
|
usage.PromptTokens = info.PromptTokens
|
||||||
}
|
}
|
||||||
if usage.CompletionTokens == 0 {
|
if usage.CompletionTokens == 0 {
|
||||||
usage, _ = service.ResponseText2Usage(responseText, modelName, usage.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, usage
|
return nil, usage
|
||||||
|
@ -36,7 +36,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = cohereStreamHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
err, usage = cohereStreamHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||||
}
|
}
|
||||||
|
@ -9,8 +9,10 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
||||||
@ -56,7 +58,7 @@ func stopReasonCohere2OpenAI(reason string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||||
createdTime := common.GetTimestamp()
|
createdTime := common.GetTimestamp()
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
@ -84,9 +86,14 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
|
|||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
service.SetEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
|
isFirst := true
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
|
if isFirst {
|
||||||
|
isFirst = false
|
||||||
|
info.FirstResponseTime = time.Now()
|
||||||
|
}
|
||||||
data = strings.TrimSuffix(data, "\r")
|
data = strings.TrimSuffix(data, "\r")
|
||||||
var cohereResp CohereResponse
|
var cohereResp CohereResponse
|
||||||
err := json.Unmarshal([]byte(data), &cohereResp)
|
err := json.Unmarshal([]byte(data), &cohereResp)
|
||||||
@ -98,7 +105,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
|
|||||||
openaiResp.Id = responseId
|
openaiResp.Id = responseId
|
||||||
openaiResp.Created = createdTime
|
openaiResp.Created = createdTime
|
||||||
openaiResp.Object = "chat.completion.chunk"
|
openaiResp.Object = "chat.completion.chunk"
|
||||||
openaiResp.Model = modelName
|
openaiResp.Model = info.UpstreamModelName
|
||||||
if cohereResp.IsFinished {
|
if cohereResp.IsFinished {
|
||||||
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
|
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
|
||||||
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
|
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
|
||||||
@ -137,7 +144,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
if usage.PromptTokens == 0 {
|
if usage.PromptTokens == 0 {
|
||||||
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
}
|
}
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
@ -20,27 +20,27 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIReq
|
|||||||
|
|
||||||
// 定义一个映射,存储模型名称和对应的版本
|
// 定义一个映射,存储模型名称和对应的版本
|
||||||
var modelVersionMap = map[string]string{
|
var modelVersionMap = map[string]string{
|
||||||
"gemini-1.5-pro-latest": "v1beta",
|
"gemini-1.5-pro-latest": "v1beta",
|
||||||
"gemini-1.5-flash-latest": "v1beta",
|
"gemini-1.5-flash-latest": "v1beta",
|
||||||
"gemini-ultra": "v1beta",
|
"gemini-ultra": "v1beta",
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
|
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
|
||||||
version, beta := modelVersionMap[info.UpstreamModelName]
|
version, beta := modelVersionMap[info.UpstreamModelName]
|
||||||
if !beta {
|
if !beta {
|
||||||
if info.ApiVersion != "" {
|
if info.ApiVersion != "" {
|
||||||
version = info.ApiVersion
|
version = info.ApiVersion
|
||||||
} else {
|
} else {
|
||||||
version = "v1"
|
version = "v1"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
action = "streamGenerateContent"
|
action = "streamGenerateContent"
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
@ -63,7 +63,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = geminiChatStreamHandler(c, resp)
|
err, responseText = geminiChatStreamHandler(c, resp, info)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
|
@ -7,10 +7,12 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@ -160,10 +162,10 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
|
|||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||||
responseText := ""
|
responseText := ""
|
||||||
dataChan := make(chan string)
|
dataChan := make(chan string, 5)
|
||||||
stopChan := make(chan bool)
|
stopChan := make(chan bool, 2)
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
if atEOF && len(data) == 0 {
|
if atEOF && len(data) == 0 {
|
||||||
@ -186,14 +188,23 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIEr
|
|||||||
}
|
}
|
||||||
data = strings.TrimPrefix(data, "\"text\": \"")
|
data = strings.TrimPrefix(data, "\"text\": \"")
|
||||||
data = strings.TrimSuffix(data, "\"")
|
data = strings.TrimSuffix(data, "\"")
|
||||||
dataChan <- data
|
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
|
||||||
|
// send data timeout, stop the stream
|
||||||
|
common.LogError(c, "send data timeout, stop the stream")
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
|
isFirst := true
|
||||||
service.SetEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
|
if isFirst {
|
||||||
|
isFirst = false
|
||||||
|
info.FirstResponseTime = time.Now()
|
||||||
|
}
|
||||||
// this is used to prevent annoying \ related format bug
|
// this is used to prevent annoying \ related format bug
|
||||||
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
||||||
type dummyStruct struct {
|
type dummyStruct struct {
|
||||||
|
@ -52,7 +52,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||||
|
@ -82,7 +82,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
var toolCount int
|
var toolCount int
|
||||||
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info.RelayMode)
|
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
usage.CompletionTokens += toolCount * 7
|
usage.CompletionTokens += toolCount * 7
|
||||||
} else {
|
} else {
|
||||||
|
@ -8,7 +8,9 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
@ -16,7 +18,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string, int) {
|
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string, int) {
|
||||||
//checkSensitive := constant.ShouldCheckCompletionSensitive()
|
//checkSensitive := constant.ShouldCheckCompletionSensitive()
|
||||||
var responseTextBuilder strings.Builder
|
var responseTextBuilder strings.Builder
|
||||||
toolCount := 0
|
toolCount := 0
|
||||||
@ -50,14 +52,18 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
|||||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
common.SafeSendString(dataChan, data)
|
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
|
||||||
|
// send data timeout, stop the stream
|
||||||
|
common.LogError(c, "send data timeout, stop the stream")
|
||||||
|
break
|
||||||
|
}
|
||||||
data = data[6:]
|
data = data[6:]
|
||||||
if !strings.HasPrefix(data, "[DONE]") {
|
if !strings.HasPrefix(data, "[DONE]") {
|
||||||
streamItems = append(streamItems, data)
|
streamItems = append(streamItems, data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||||
switch relayMode {
|
switch info.RelayMode {
|
||||||
case relayconstant.RelayModeChatCompletions:
|
case relayconstant.RelayModeChatCompletions:
|
||||||
var streamResponses []dto.ChatCompletionsStreamResponseSimple
|
var streamResponses []dto.ChatCompletionsStreamResponseSimple
|
||||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||||
@ -126,9 +132,14 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
|||||||
common.SafeSendBool(stopChan, true)
|
common.SafeSendBool(stopChan, true)
|
||||||
}()
|
}()
|
||||||
service.SetEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
|
isFirst := true
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
|
if isFirst {
|
||||||
|
isFirst = false
|
||||||
|
info.FirstResponseTime = time.Now()
|
||||||
|
}
|
||||||
if strings.HasPrefix(data, "data: [DONE]") {
|
if strings.HasPrefix(data, "data: [DONE]") {
|
||||||
data = data[:12]
|
data = data[:12]
|
||||||
}
|
}
|
||||||
@ -187,7 +198,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
|||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if simpleResponse.Usage.TotalTokens == 0 {
|
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||||
completionTokens := 0
|
completionTokens := 0
|
||||||
for _, choice := range simpleResponse.Choices {
|
for _, choice := range simpleResponse.Choices {
|
||||||
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model)
|
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model)
|
||||||
|
@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
|
@ -6,6 +6,7 @@ var ModelList = []string{
|
|||||||
"SparkDesk-v2.1",
|
"SparkDesk-v2.1",
|
||||||
"SparkDesk-v3.1",
|
"SparkDesk-v3.1",
|
||||||
"SparkDesk-v3.5",
|
"SparkDesk-v3.5",
|
||||||
|
"SparkDesk-v4.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "xunfei"
|
var ChannelName = "xunfei"
|
||||||
|
@ -252,6 +252,8 @@ func apiVersion2domain(apiVersion string) string {
|
|||||||
return "generalv3"
|
return "generalv3"
|
||||||
case "v3.5":
|
case "v3.5":
|
||||||
return "generalv3.5"
|
return "generalv3.5"
|
||||||
|
case "v4.0":
|
||||||
|
return "4.0Ultra"
|
||||||
}
|
}
|
||||||
return "general" + apiVersion
|
return "general" + apiVersion
|
||||||
}
|
}
|
||||||
|
@ -48,7 +48,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
var toolCount int
|
var toolCount int
|
||||||
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
usage.CompletionTokens += toolCount * 7
|
usage.CompletionTokens += toolCount * 7
|
||||||
} else {
|
} else {
|
||||||
|
@ -16,6 +16,7 @@ type RelayInfo struct {
|
|||||||
Group string
|
Group string
|
||||||
TokenUnlimited bool
|
TokenUnlimited bool
|
||||||
StartTime time.Time
|
StartTime time.Time
|
||||||
|
FirstResponseTime time.Time
|
||||||
ApiType int
|
ApiType int
|
||||||
IsStream bool
|
IsStream bool
|
||||||
RelayMode int
|
RelayMode int
|
||||||
@ -37,24 +38,26 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
tokenUnlimited := c.GetBool("token_unlimited_quota")
|
tokenUnlimited := c.GetBool("token_unlimited_quota")
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
// firstResponseTime = time.Now() - 1 second
|
||||||
|
|
||||||
apiType, _ := constant.ChannelType2APIType(channelType)
|
apiType, _ := constant.ChannelType2APIType(channelType)
|
||||||
|
|
||||||
info := &RelayInfo{
|
info := &RelayInfo{
|
||||||
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
|
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
|
||||||
BaseUrl: c.GetString("base_url"),
|
BaseUrl: c.GetString("base_url"),
|
||||||
RequestURLPath: c.Request.URL.String(),
|
RequestURLPath: c.Request.URL.String(),
|
||||||
ChannelType: channelType,
|
ChannelType: channelType,
|
||||||
ChannelId: channelId,
|
ChannelId: channelId,
|
||||||
TokenId: tokenId,
|
TokenId: tokenId,
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
Group: group,
|
Group: group,
|
||||||
TokenUnlimited: tokenUnlimited,
|
TokenUnlimited: tokenUnlimited,
|
||||||
StartTime: startTime,
|
StartTime: startTime,
|
||||||
ApiType: apiType,
|
FirstResponseTime: startTime.Add(-time.Second),
|
||||||
ApiVersion: c.GetString("api_version"),
|
ApiType: apiType,
|
||||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
ApiVersion: c.GetString("api_version"),
|
||||||
Organization: c.GetString("channel_organization"),
|
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||||
|
Organization: c.GetString("channel_organization"),
|
||||||
}
|
}
|
||||||
if info.BaseUrl == "" {
|
if info.BaseUrl == "" {
|
||||||
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
||||||
|
@ -538,7 +538,16 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
ChannelId: c.GetInt("channel_id"),
|
ChannelId: c.GetInt("channel_id"),
|
||||||
Quota: quota,
|
Quota: quota,
|
||||||
}
|
}
|
||||||
|
if midjResponse.Code == 3 {
|
||||||
|
//无实例账号自动禁用渠道(No available account instance)
|
||||||
|
channel, err := model.GetChannelById(midjourneyTask.ChannelId, true)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("get_channel_null: " + err.Error())
|
||||||
|
}
|
||||||
|
if channel.AutoBan != nil && *channel.AutoBan == 1 {
|
||||||
|
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
|
||||||
|
}
|
||||||
|
}
|
||||||
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
|
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
|
||||||
//非1-提交成功,21-任务已存在和22-排队中,则记录错误原因
|
//非1-提交成功,21-任务已存在和22-排队中,则记录错误原因
|
||||||
midjourneyTask.FailReason = midjResponse.Description
|
midjourneyTask.FailReason = midjResponse.Description
|
||||||
|
@ -344,14 +344,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
|||||||
logModel = "g-*"
|
logModel = "g-*"
|
||||||
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
|
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
|
||||||
}
|
}
|
||||||
other := make(map[string]interface{})
|
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
|
||||||
other["model_ratio"] = modelRatio
|
|
||||||
other["group_ratio"] = groupRatio
|
|
||||||
other["completion_ratio"] = completionRatio
|
|
||||||
other["model_price"] = modelPrice
|
|
||||||
adminInfo := make(map[string]interface{})
|
|
||||||
adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
|
|
||||||
other["admin_info"] = adminInfo
|
|
||||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||||
|
|
||||||
//if quota != 0 {
|
//if quota != 0 {
|
||||||
|
19
service/log.go
Normal file
19
service/log.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio, modelPrice float64) map[string]interface{} {
|
||||||
|
other := make(map[string]interface{})
|
||||||
|
other["model_ratio"] = modelRatio
|
||||||
|
other["group_ratio"] = groupRatio
|
||||||
|
other["completion_ratio"] = completionRatio
|
||||||
|
other["model_price"] = modelPrice
|
||||||
|
other["frt"] = float64(relayInfo.FirstResponseTime.UnixMilli() - relayInfo.StartTime.UnixMilli())
|
||||||
|
adminInfo := make(map[string]interface{})
|
||||||
|
adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
|
||||||
|
other["admin_info"] = adminInfo
|
||||||
|
return other
|
||||||
|
}
|
@ -3,7 +3,6 @@ package service
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"strings"
|
"strings"
|
||||||
@ -62,7 +61,7 @@ func SensitiveWordContains(text string) (bool, []string) {
|
|||||||
}
|
}
|
||||||
checkText := strings.ToLower(text)
|
checkText := strings.ToLower(text)
|
||||||
// 构建一个AC自动机
|
// 构建一个AC自动机
|
||||||
m := common.InitAc()
|
m := InitAc()
|
||||||
hits := m.MultiPatternSearch([]rune(checkText), false)
|
hits := m.MultiPatternSearch([]rune(checkText), false)
|
||||||
if len(hits) > 0 {
|
if len(hits) > 0 {
|
||||||
words := make([]string, 0)
|
words := make([]string, 0)
|
||||||
@ -80,7 +79,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
|
|||||||
return false, nil, text
|
return false, nil, text
|
||||||
}
|
}
|
||||||
checkText := strings.ToLower(text)
|
checkText := strings.ToLower(text)
|
||||||
m := common.InitAc()
|
m := InitAc()
|
||||||
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
|
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
|
||||||
if len(hits) > 0 {
|
if len(hits) > 0 {
|
||||||
words := make([]string, 0)
|
words := make([]string, 0)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package common
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
3311
web/pnpm-lock.yaml
3311
web/pnpm-lock.yaml
File diff suppressed because it is too large
Load Diff
@ -29,6 +29,7 @@ import {
|
|||||||
stringToColor,
|
stringToColor,
|
||||||
} from '../helpers/render';
|
} from '../helpers/render';
|
||||||
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
|
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
|
||||||
|
import { getLogOther } from '../helpers/other.js';
|
||||||
|
|
||||||
const { Header } = Layout;
|
const { Header } = Layout;
|
||||||
|
|
||||||
@ -141,6 +142,33 @@ function renderUseTime(type) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function renderFirstUseTime(type) {
|
||||||
|
let time = parseFloat(type) / 1000.0;
|
||||||
|
time = time.toFixed(1);
|
||||||
|
if (time < 3) {
|
||||||
|
return (
|
||||||
|
<Tag color='green' size='large'>
|
||||||
|
{' '}
|
||||||
|
{time} s{' '}
|
||||||
|
</Tag>
|
||||||
|
);
|
||||||
|
} else if (time < 10) {
|
||||||
|
return (
|
||||||
|
<Tag color='orange' size='large'>
|
||||||
|
{' '}
|
||||||
|
{time} s{' '}
|
||||||
|
</Tag>
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
return (
|
||||||
|
<Tag color='red' size='large'>
|
||||||
|
{' '}
|
||||||
|
{time} s{' '}
|
||||||
|
</Tag>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const LogsTable = () => {
|
const LogsTable = () => {
|
||||||
const columns = [
|
const columns = [
|
||||||
{
|
{
|
||||||
@ -247,17 +275,30 @@ const LogsTable = () => {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
title: '用时',
|
title: '用时/首字',
|
||||||
dataIndex: 'use_time',
|
dataIndex: 'use_time',
|
||||||
render: (text, record, index) => {
|
render: (text, record, index) => {
|
||||||
return (
|
if (record.is_stream) {
|
||||||
<div>
|
let other = getLogOther(record.other);
|
||||||
<Space>
|
return (
|
||||||
{renderUseTime(text)}
|
<div>
|
||||||
{renderIsStream(record.is_stream)}
|
<Space>
|
||||||
</Space>
|
{renderUseTime(text)}
|
||||||
</div>
|
{renderFirstUseTime(other.frt)}
|
||||||
);
|
{renderIsStream(record.is_stream)}
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<Space>
|
||||||
|
{renderUseTime(text)}
|
||||||
|
{renderIsStream(record.is_stream)}
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -325,10 +366,7 @@ const LogsTable = () => {
|
|||||||
title: '详情',
|
title: '详情',
|
||||||
dataIndex: 'content',
|
dataIndex: 'content',
|
||||||
render: (text, record, index) => {
|
render: (text, record, index) => {
|
||||||
if (record.other === '') {
|
let other = getLogOther(record.other);
|
||||||
record.other = '{}';
|
|
||||||
}
|
|
||||||
let other = JSON.parse(record.other);
|
|
||||||
if (other == null) {
|
if (other == null) {
|
||||||
return (
|
return (
|
||||||
<Paragraph
|
<Paragraph
|
||||||
|
7
web/src/helpers/other.js
Normal file
7
web/src/helpers/other.js
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
export function getLogOther(otherStr) {
|
||||||
|
if (otherStr === undefined || otherStr === '') {
|
||||||
|
otherStr = '{}';
|
||||||
|
}
|
||||||
|
let other = JSON.parse(otherStr);
|
||||||
|
return other;
|
||||||
|
}
|
@ -144,28 +144,42 @@ export function renderModelPrice(
|
|||||||
) {
|
) {
|
||||||
// 1 ratio = $0.002 / 1K tokens
|
// 1 ratio = $0.002 / 1K tokens
|
||||||
if (modelPrice !== -1) {
|
if (modelPrice !== -1) {
|
||||||
return '模型价格:$' + modelPrice * groupRatio;
|
return (
|
||||||
|
'模型价格:$' +
|
||||||
|
modelPrice +
|
||||||
|
' * 分组倍率:' +
|
||||||
|
groupRatio +
|
||||||
|
' = $' +
|
||||||
|
modelPrice * groupRatio
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
if (completionRatio === undefined) {
|
if (completionRatio === undefined) {
|
||||||
completionRatio = 0;
|
completionRatio = 0;
|
||||||
}
|
}
|
||||||
// 这里的 *2 是因为 1倍率=0.002刀,请勿删除
|
// 这里的 *2 是因为 1倍率=0.002刀,请勿删除
|
||||||
let inputRatioPrice = modelRatio * 2.0 * groupRatio;
|
let inputRatioPrice = modelRatio * 2.0;
|
||||||
let completionRatioPrice = modelRatio * 2.0 * completionRatio * groupRatio;
|
let completionRatioPrice = modelRatio * 2.0 * completionRatio;
|
||||||
let price =
|
let price =
|
||||||
(inputTokens / 1000000) * inputRatioPrice +
|
(inputTokens / 1000000) * inputRatioPrice +
|
||||||
(completionTokens / 1000000) * completionRatioPrice;
|
(completionTokens / 1000000) * completionRatioPrice;
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<article>
|
<article>
|
||||||
<p>提示 ${inputRatioPrice} / 1M tokens</p>
|
<p>
|
||||||
<p>补全 ${completionRatioPrice} / 1M tokens</p>
|
提示:${inputRatioPrice} * {groupRatio} = $
|
||||||
|
{inputRatioPrice * groupRatio} / 1M tokens
|
||||||
|
</p>
|
||||||
|
<p>
|
||||||
|
补全:${completionRatioPrice} * {groupRatio} = $
|
||||||
|
{completionRatioPrice * groupRatio} / 1M tokens
|
||||||
|
</p>
|
||||||
<p></p>
|
<p></p>
|
||||||
<p>
|
<p>
|
||||||
提示 {inputTokens} tokens / 1M tokens * ${inputRatioPrice} + 补全{' '}
|
提示 {inputTokens} tokens / 1M tokens * ${inputRatioPrice} + 补全{' '}
|
||||||
{completionTokens} tokens / 1M tokens * ${completionRatioPrice} = $
|
{completionTokens} tokens / 1M tokens * ${completionRatioPrice} *
|
||||||
{price.toFixed(6)}
|
分组 {groupRatio} = ${price.toFixed(6)}
|
||||||
</p>
|
</p>
|
||||||
|
<p>仅供参考,以实际扣费为准</p>
|
||||||
</article>
|
</article>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
|
Loading…
Reference in New Issue
Block a user