Compare commits

...

59 Commits

Author SHA1 Message Date
CalciumIon
5d0d268c97 fix: epay 2024-08-04 00:18:32 +08:00
CalciumIon
0b4ef42d86 fix: channel typ error 2024-08-03 22:41:47 +08:00
CalciumIon
0123ad4d61 fix: 重试后request id不一致 2024-08-03 17:46:13 +08:00
CalciumIon
5acf074541 chore: 优化自动禁用代码 2024-08-03 17:32:28 +08:00
Calcium-Ion
8af0d9f22f Merge pull request #409 from utopeadia/main
修改readme错误
2024-08-03 17:31:08 +08:00
HowieWu
afd328efcf 修改readme错误 2024-08-03 17:19:44 +08:00
CalciumIon
dd12a0052f chore: 优化relay代码 2024-08-03 17:12:16 +08:00
CalciumIon
fbe6cd75b1 chore: 优化relay代码 2024-08-03 17:07:14 +08:00
CalciumIon
8a9ff36fbf chore: 优化relay代码 2024-08-03 16:55:29 +08:00
CalciumIon
88ba8a840e feat: 优化充值订单号 2024-08-03 01:28:18 +08:00
CalciumIon
e504665f68 feat: 优化Gemini模型版本获取逻辑 2024-08-02 17:23:59 +08:00
Calcium-Ion
54657ec27b Merge pull request #405 from utopeadia/main
Modify the GEMINI version acquisition logic and add support for more gemini1.5pro/flash interfaces
2024-08-02 17:13:46 +08:00
Calcium-Ion
ae6b4e0be2 Merge pull request #399 from kakingone/main
add-mjp-discord-upload
2024-08-02 17:10:35 +08:00
HowieWu
fc0db4505c Update README.md
增加Gemini版本变量说明
2024-08-02 11:25:41 +08:00
HowieWu
22a98c5879 修改Gemini版本获取逻辑
使用GEMINI_MODEL_API环境变量覆盖默认版本映射,使用","分隔不同模型和版本
-e GEMINI_MODEL_API="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta,gemini-1.5-pro:v1beta,gemini-1.5-flash-latest:v1beta,gemini-1.5-flash-001:v1beta,gemini-1.5-flash:v1beta,gemini-ultra:v1beta,gemini-1.5-pro-exp-0801:v1beta"
2024-08-02 11:20:26 +08:00
CalciumIon
f8f15bd1d0 fix: rpm模糊查询 2024-08-01 18:14:10 +08:00
CalciumIon
b7690fe17d fix: 日志模糊查询 2024-08-01 18:06:25 +08:00
CalciumIon
58b4c237a4 feat: 优化rpm查询 2024-08-01 17:39:18 +08:00
CalciumIon
54f6e660f1 feat: 优化日志查初始时间 2024-08-01 17:36:26 +08:00
CalciumIon
3b1745c712 feat: 优化日志查询条件 2024-08-01 16:33:59 +08:00
CalciumIon
c92ab3b569 feat: 日志新增rpm和tpm数据。(close #384) 2024-08-01 16:13:08 +08:00
CalciumIon
1501ccb919 fix: error channel name on notify. #338 2024-07-31 18:20:13 +08:00
Calcium-Ion
7f2a2a7de0 Merge pull request #400 from OswinWu/feat-gitignore-web-dist
feat: ignore npm build dir
2024-07-31 17:14:50 +08:00
Calcium-Ion
cce7d0258f Merge pull request #401 from HynoR/main
Support cloudflare llama3.1-8b
2024-07-31 17:14:30 +08:00
TAKO
c5e8d7ec20 Support cloudflare llama3.1-8b 2024-07-31 17:11:25 +08:00
OswinWu
fe16d51fe4 feat: ignore npm build dir 2024-07-31 16:50:19 +08:00
kakingone
2100d8ee0c addupload 2024-07-31 15:48:51 +08:00
CalciumIon
fbce36238e feat: support dify agent 2024-07-30 17:30:40 +08:00
CalciumIon
a6b6bcfe00 chore: remove useless code 2024-07-28 01:12:26 +08:00
CalciumIon
07e55cc999 chore: update token page 2024-07-28 00:05:53 +08:00
CalciumIon
b16e6bf423 fix: panic when get model ratio (close #392) 2024-07-27 18:09:09 +08:00
CalciumIon
b7bc205b73 feat: print user id when error 2024-07-27 17:55:36 +08:00
CalciumIon
88cc88c5d0 feat: support ollama tools 2024-07-27 17:51:05 +08:00
CalciumIon
ab1d61d910 feat: print user id when error 2024-07-27 17:47:30 +08:00
Calcium-Ion
d4a5df7373 Merge pull request #391 from OswinWu/fix-outlook-smtp
[fix] fix send email error using outlook smtp
2024-07-26 20:24:08 +08:00
CalciumIon
9e610c9429 fix: image quota (close #382) 2024-07-26 18:51:34 +08:00
Oswin
da490db6d3 [fix] fix send email error using outlook smtp 2024-07-26 17:47:36 +08:00
1808837298@qq.com
b8291dcd13 fix: gemini 2024-07-23 18:34:16 +08:00
Calcium-Ion
b0d9756c14 Merge pull request #380 from crabkun/main
fix: 修复aws claude渠道panic的问题
2024-07-23 18:22:27 +08:00
Calcium-Ion
9dc07a8585 Merge pull request #383 from Yan-Zero/main
fix: the base64 format image_url for gemini
2024-07-23 18:22:06 +08:00
1808837298@qq.com
caaecb8d54 fix: first login error (close #385) 2024-07-23 18:25:43 +08:00
Yan Tau
b9454c3f14 fix: the base64 format image_url for gemini 2024-07-22 21:20:23 +08:00
crabkun
96bdf97194 fix: 修复aws claude渠道panic的问题 2024-07-21 01:27:29 +08:00
CalciumIon
3875b141c6 fix: gemini stream finish reason (close #378) 2024-07-19 17:16:20 +08:00
CalciumIon
12da7f64cd feat: update log search 2024-07-19 16:04:56 +08:00
CalciumIon
9ef3212e6c feat: update stream_options again 2024-07-19 15:06:07 +08:00
CalciumIon
20da8228df feat: update stream_options 2024-07-19 14:46:25 +08:00
CalciumIon
436d08b48f feat: update stream_options 2024-07-19 14:06:10 +08:00
CalciumIon
ce815a98d0 fix: 修复nginx缓存导致串用户问题 2024-07-19 13:39:05 +08:00
CalciumIon
e2cf6b1e14 feat: support gpt-4o-mini image tokens 2024-07-19 12:59:37 +08:00
CalciumIon
733b374596 Update README.md 2024-07-19 03:06:20 +08:00
CalciumIon
56afe47aa8 feat: update model ratio 2024-07-19 01:34:00 +08:00
CalciumIon
67b74ada00 feat: update model ratio 2024-07-19 01:29:08 +08:00
CalciumIon
e84300f4ae chore: gopool 2024-07-19 01:07:37 +08:00
CalciumIon
c9100b219f feat: support ali image 2024-07-19 00:45:52 +08:00
CalciumIon
f96291a25a feat: support gemini tool calling (close #368) 2024-07-18 20:28:47 +08:00
CalciumIon
14bf865034 feat: add UPDATE_TASK env 2024-07-18 17:26:21 +08:00
CalciumIon
70491ea1bb fix: image relay quota 2024-07-18 17:12:28 +08:00
CalciumIon
ae00a99cf5 feat: 媒体请求计费选项 2024-07-18 17:04:19 +08:00
59 changed files with 1035 additions and 531 deletions

3
.gitignore vendored
View File

@@ -5,4 +5,5 @@ upload
*.db
build
*.db-journal
logs
logs
web/dist

View File

@@ -2,15 +2,21 @@
# New API
> [!NOTE]
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。
> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发
> 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。
> [!IMPORTANT]
> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> 本项目仅供个人学习使用,不保证稳定性,且不提供任何技术支持。
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
> [!NOTE]
> 最新版Docker镜像 calciumion/new-api:latest
> 更新指令 docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
> [!TIP]
> 最新版Docker镜像`calciumion/new-api:latest`
> 默认账号root 密码123456
> 更新指令:
> ```
> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
> ```
## 主要变更
此分叉版本的主要变更如下:
@@ -18,9 +24,9 @@
1. 全新的UI界面部分界面还待更新
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md)
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
+ [x] 易支付
+ [x] 易支付
4. 支持用key查询使用额度:
+ 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用
+ 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用
5. 渠道显示已使用额度,支持指定组织访问
6. 分页支持选择每页显示数量
7. 兼容原版One API的数据库可直接使用原版数据库one-api.db
@@ -51,29 +57,15 @@
您可以在渠道中添加自定义模型gpt-4-gizmo-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。
## 渠道重试
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
### 缓存设置方法
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true``false`,未设置则默认为 `false`
+ 例子:`MEMORY_CACHE_ENABLED=true`
### 为什么有的时候没有重试
这些错误码不会重试400504524
### 我想让400也重试
`渠道->编辑`中,将`状态码复写`改为
```json
{
"400": "500"
}
```
可以实现400错误转为500错误从而重试
## 比原版One API多出的配置
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`
- `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数请求上游返回流模式usage默认为 `true`
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`
- `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数请求上游返回流模式usage默认为 `true`建议开启不影响客户端传入stream_options参数返回结果。
- `GET_MEDIA_TOKEN`是统计图片token默认为 `true`关闭后将不再在本地计算图片token可能会导致和上游计费不同此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。
- `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置
## 部署
### 部署要求
- 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机)
@@ -96,8 +88,25 @@ docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(宝塔的服务器地址:宝塔数据库端口)/宝塔数据库名称" -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest
# 注意数据库要开启远程访问并且只允许服务器IP访问
```
### 默认账号密码
默认账号root 密码123456
## 渠道重试
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
### 缓存设置方法
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true``false`,未设置则默认为 `false`
+ 例子:`MEMORY_CACHE_ENABLED=true`
### 为什么有的时候没有重试
这些错误码不会重试400504524
### 我想让400也重试
`渠道->编辑`中,将`状态码复写`改为
```json
{
"400": "500"
}
```
可以实现400错误转为500错误从而重试
## Midjourney接口设置文档
[对接文档](Midjourney.md)

View File

@@ -0,0 +1,32 @@
package common
import (
"errors"
"net/smtp"
)
type outlookAuth struct {
username, password string
}
func LoginAuth(username, password string) smtp.Auth {
return &outlookAuth{username, password}
}
func (a *outlookAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) {
return "LOGIN", []byte{}, nil
}
func (a *outlookAuth) Next(fromServer []byte, more bool) ([]byte, error) {
if more {
switch string(fromServer) {
case "Username:":
return []byte(a.username), nil
case "Password:":
return []byte(a.password), nil
default:
return nil, errors.New("unknown fromServer")
}
}
return nil, nil
}

View File

@@ -62,6 +62,9 @@ func SendEmail(subject string, receiver string, content string) error {
if err != nil {
return err
}
} else if strings.HasSuffix(SMTPAccount, "outlook.com") {
auth = LoginAuth(SMTPAccount, SMTPToken)
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
} else {
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
}

View File

@@ -3,6 +3,7 @@ package common
import (
"encoding/json"
"strings"
"sync"
)
// from songquanpeng/one-api
@@ -31,13 +32,15 @@ var defaultModelRatio = map[string]float64{
"gpt-4-32k": 30,
//"gpt-4-32k-0314": 30, //deprecated
"gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4o": 2.5, // $0.01 / 1K tokens
"gpt-4o-2024-05-13": 2.5, // $0.01 / 1K tokens
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4o": 2.5, // $0.01 / 1K tokens
"gpt-4o-2024-05-13": 2.5, // $0.01 / 1K tokens
"gpt-4o-mini": 0.075,
"gpt-4o-mini-2024-07-18": 0.075,
"gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.25, // $0.0015 / 1K tokens
@@ -178,10 +181,17 @@ var defaultModelPrice = map[string]float64{
"mj_describe": 0.05,
"mj_upscale": 0.05,
"swap_face": 0.05,
"mj_upload": 0.05,
}
var modelPrice map[string]float64 = nil
var modelRatio map[string]float64 = nil
var (
modelPriceMap = make(map[string]float64)
modelPriceMapMutex = sync.RWMutex{}
)
var (
modelRatioMap map[string]float64 = nil
modelRatioMapMutex = sync.RWMutex{}
)
var CompletionRatio map[string]float64 = nil
var defaultCompletionRatio = map[string]float64{
@@ -189,11 +199,18 @@ var defaultCompletionRatio = map[string]float64{
"gpt-4-all": 2,
}
func ModelPrice2JSONString() string {
if modelPrice == nil {
modelPrice = defaultModelPrice
func GetModelPriceMap() map[string]float64 {
modelPriceMapMutex.Lock()
defer modelPriceMapMutex.Unlock()
if modelPriceMap == nil {
modelPriceMap = defaultModelPrice
}
jsonBytes, err := json.Marshal(modelPrice)
return modelPriceMap
}
func ModelPrice2JSONString() string {
GetModelPriceMap()
jsonBytes, err := json.Marshal(modelPriceMap)
if err != nil {
SysError("error marshalling model price: " + err.Error())
}
@@ -201,19 +218,19 @@ func ModelPrice2JSONString() string {
}
func UpdateModelPriceByJSONString(jsonStr string) error {
modelPrice = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelPrice)
modelPriceMapMutex.Lock()
defer modelPriceMapMutex.Unlock()
modelPriceMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelPriceMap)
}
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1false
func GetModelPrice(name string, printErr bool) (float64, bool) {
if modelPrice == nil {
modelPrice = defaultModelPrice
}
GetModelPriceMap()
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
price, ok := modelPrice[name]
price, ok := modelPriceMap[name]
if !ok {
if printErr {
SysError("model price not found: " + name)
@@ -223,18 +240,18 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
return price, true
}
func GetModelPriceMap() map[string]float64 {
if modelPrice == nil {
modelPrice = defaultModelPrice
func GetModelRatioMap() map[string]float64 {
modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock()
if modelRatioMap == nil {
modelRatioMap = defaultModelRatio
}
return modelPrice
return modelRatioMap
}
func ModelRatio2JSONString() string {
if modelRatio == nil {
modelRatio = defaultModelRatio
}
jsonBytes, err := json.Marshal(modelRatio)
GetModelRatioMap()
jsonBytes, err := json.Marshal(modelRatioMap)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
}
@@ -242,18 +259,18 @@ func ModelRatio2JSONString() string {
}
func UpdateModelRatioByJSONString(jsonStr string) error {
modelRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelRatio)
modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock()
modelRatioMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelRatioMap)
}
func GetModelRatio(name string) float64 {
if modelRatio == nil {
modelRatio = defaultModelRatio
}
GetModelRatioMap()
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
ratio, ok := modelRatio[name]
ratio, ok := modelRatioMap[name]
if !ok {
SysError("model ratio not found: " + name)
return 30
@@ -305,7 +322,13 @@ func GetCompletionRatio(name string) float64 {
return 4.0 / 3.0
}
if strings.HasPrefix(name, "gpt-4") && !strings.HasSuffix(name, "-all") && !strings.HasSuffix(name, "-gizmo-*") {
if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4o") {
if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") {
return 3
}
if strings.HasPrefix(name, "gpt-4o") {
if strings.Contains(name, "mini") {
return 4
}
return 3
}
return 2

View File

@@ -1,7 +1,10 @@
package constant
import (
"fmt"
"one-api/common"
"os"
"strings"
)
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
@@ -9,3 +12,35 @@ var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
// ForceStreamOption 覆盖请求参数强制返回usage信息
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var GeminiModelMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-pro-001": "v1beta",
"gemini-1.5-pro": "v1beta",
"gemini-1.5-pro-exp-0801": "v1beta",
"gemini-1.5-flash-latest": "v1beta",
"gemini-1.5-flash-001": "v1beta",
"gemini-1.5-flash": "v1beta",
"gemini-ultra": "v1beta",
}
func InitEnv() {
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
if modelVersionMapStr == "" {
return
}
for _, pair := range strings.Split(modelVersionMapStr, ",") {
parts := strings.Split(pair, ":")
if len(parts) == 2 {
GeminiModelMap[parts[0]] = parts[1]
} else {
common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
}
}
}

View File

@@ -27,6 +27,7 @@ const (
MjActionLowVariation = "LOW_VARIATION"
MjActionPan = "PAN"
MjActionSwapFace = "SWAP_FACE"
MjActionUpload = "UPLOAD"
)
var MidjourneyModel2Action = map[string]string{
@@ -45,4 +46,5 @@ var MidjourneyModel2Action = map[string]string{
"mj_low_variation": MjActionLowVariation,
"mj_pan": MjActionPan,
"swap_face": MjActionSwapFace,
"mj_upload": MjActionUpload,
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"io"
"math"
"net/http"
@@ -217,7 +218,7 @@ func testAllChannels(notify bool) error {
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
go func() {
gopool.Go(func() {
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
@@ -239,7 +240,7 @@ func testAllChannels(notify bool) error {
}
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
if !channel.GetAutoBan() {
ban = false
}
@@ -265,7 +266,7 @@ func testAllChannels(notify bool) error {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
}()
})
return nil
}

View File

@@ -2,6 +2,7 @@ package controller
import (
"bytes"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
@@ -39,43 +40,35 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
func Relay(c *gin.Context) {
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
retryTimes := common.RetryTimes
requestId := c.GetString(common.RequestIdKey)
channelId := c.GetInt("channel_id")
channelType := c.GetInt("channel_type")
group := c.GetString("group")
originalModel := c.GetString("original_model")
openaiErr := relayHandler(c, relayMode)
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
if openaiErr != nil {
go processChannelError(c, channelId, channelType, openaiErr)
} else {
retryTimes = 0
}
for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
var openaiErr *dto.OpenAIErrorWithStatusCode
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
common.LogError(c, err.Error())
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break
}
channelId = channel.Id
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
openaiErr = relayHandler(c, relayMode)
if openaiErr != nil {
go processChannelError(c, channelId, channel.Type, openaiErr)
openaiErr = relayRequest(c, relayMode, channel)
if openaiErr == nil {
return // 成功处理请求,直接返回
}
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c.Request.Context(), retryLogStr)
common.LogInfo(c, retryLogStr)
}
if openaiErr != nil {
@@ -89,7 +82,42 @@ func Relay(c *gin.Context) {
}
}
func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relayHandler(c, relayMode)
}
func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
}
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
if retryCount == 0 {
autoBan := c.GetBool("auto_ban")
autoBanInt := 1
if !autoBan {
autoBanInt = 0
}
return &model.Channel{
Id: c.GetInt("channel_id"),
Type: c.GetInt("channel_type"),
Name: c.GetString("channel_name"),
AutoBan: &autoBanInt,
}, nil
}
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
}
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
return channel, nil
}
func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
if openaiErr == nil {
return false
}
@@ -113,6 +141,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
return true
}
if openaiErr.StatusCode == http.StatusBadRequest {
channelType := c.GetInt("channel_type")
if channelType == common.ChannelTypeAnthropic {
return true
}
return false
}
if openaiErr.StatusCode == 408 {
@@ -128,11 +160,11 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
return true
}
func processChannelError(c *gin.Context, channelId int, channelType int, err *dto.OpenAIErrorWithStatusCode) {
autoBan := c.GetBool("auto_ban")
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
// 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
if service.ShouldDisableChannel(channelType, err) && autoBan {
channelName := c.GetString("channel_name")
service.DisableChannel(channelId, channelName, err.Error.Message)
}
}
@@ -208,14 +240,14 @@ func RelayTask(c *gin.Context) {
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
if err != nil {
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break
}
channelId = channel.Id
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
@@ -225,7 +257,7 @@ func RelayTask(c *gin.Context) {
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c.Request.Context(), retryLogStr)
common.LogInfo(c, retryLogStr)
}
if taskErr != nil {
if taskErr.StatusCode == http.StatusTooManyRequests {

View File

@@ -94,6 +94,7 @@ func RequestEpay(c *gin.Context) {
returnUrl, _ := url.Parse(constant.ServerAddress + "/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
client := GetEpayClient()
if client == nil {
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
@@ -101,8 +102,8 @@ func RequestEpay(c *gin.Context) {
}
uri, params, err := client.Purchase(&epay.PurchaseArgs{
Type: payType,
ServiceTradeNo: "A" + tradeNo,
Name: "B" + tradeNo,
ServiceTradeNo: tradeNo,
Name: fmt.Sprintf("TUC%d", req.Amount),
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
Device: epay.PC,
NotifyUrl: notifyUrl,
@@ -120,7 +121,7 @@ func RequestEpay(c *gin.Context) {
UserId: id,
Amount: amount,
Money: payMoney,
TradeNo: "A" + tradeNo,
TradeNo: tradeNo,
CreateTime: time.Now().Unix(),
Status: "pending",
}

View File

@@ -791,11 +791,11 @@ type topUpRequest struct {
Key string `json:"key"`
}
var lock = sync.Mutex{}
var topUpLock = sync.Mutex{}
func TopUp(c *gin.Context) {
lock.Lock()
defer lock.Unlock()
topUpLock.Lock()
defer topUpLock.Unlock()
req := topUpRequest{}
err := c.ShouldBindJSON(&req)
if err != nil {

View File

@@ -12,9 +12,11 @@ type ImageRequest struct {
}
type ImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
B64Json string `json:"b64_json"`
}
Data []ImageData `json:"data"`
Created int64 `json:"created"`
}
type ImageData struct {
Url string `json:"url"`
B64Json string `json:"b64_json"`
RevisedPrompt string `json:"revised_prompt"`
}

View File

@@ -33,6 +33,12 @@ type MidjourneyResponse struct {
Result string `json:"result"`
}
type MidjourneyUploadResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Result []string `json:"result"`
}
type MidjourneyResponseWithStatusCode struct {
StatusCode int `json:"statusCode"`
Response MidjourneyResponse

View File

@@ -148,7 +148,7 @@ func (m Message) ParseContent() []MediaMessage {
if ok {
subObj["detail"] = detail.(string)
} else {
subObj["detail"] = "auto"
subObj["detail"] = "high"
}
contentList = append(contentList, MediaMessage{
Type: ContentTypeImageURL,
@@ -157,7 +157,16 @@ func (m Message) ParseContent() []MediaMessage {
Detail: subObj["detail"].(string),
},
})
} else if url, ok := contentMap["image_url"].(string); ok {
contentList = append(contentList, MediaMessage{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: url,
Detail: "high",
},
})
}
}
}
return contentList

1
go.mod
View File

@@ -38,6 +38,7 @@ require (
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
github.com/aws/smithy-go v1.20.2 // indirect
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect

4
go.sum
View File

@@ -18,6 +18,8 @@ github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76w
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
@@ -198,6 +200,7 @@ golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -206,6 +209,7 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

10
main.go
View File

@@ -3,12 +3,14 @@ package main
import (
"embed"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"log"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/controller"
"one-api/middleware"
"one-api/model"
@@ -53,6 +55,8 @@ func main() {
common.FatalLog("failed to initialize Redis: " + err.Error())
}
// Initialize constants
constant.InitEnv()
// Initialize options
model.InitOptionMap()
if common.RedisEnabled {
@@ -89,11 +93,11 @@ func main() {
}
go controller.AutomaticallyTestChannels(frequency)
}
if common.IsMasterNode {
common.SafeGoroutine(func() {
if common.IsMasterNode && constant.UpdateTask {
gopool.Go(func() {
controller.UpdateMidjourneyTaskBulk()
})
common.SafeGoroutine(func() {
gopool.Go(func() {
controller.UpdateTaskBulk()
})
}

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
)
@@ -15,6 +16,7 @@ func authHelper(c *gin.Context, minRole int) {
role := session.Get("role")
id := session.Get("id")
status := session.Get("status")
useAccessToken := false
if username == nil {
// Check access token
accessToken := c.Request.Header.Get("Authorization")
@@ -33,6 +35,7 @@ func authHelper(c *gin.Context, minRole int) {
role = user.Role
id = user.Id
status = user.Status
useAccessToken = true
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -42,6 +45,36 @@ func authHelper(c *gin.Context, minRole int) {
return
}
}
if !useAccessToken {
// get header New-Api-User
apiUserIdStr := c.Request.Header.Get("New-Api-User")
if apiUserIdStr == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,请刷新页面或清空缓存后重试",
})
c.Abort()
return
}
apiUserId, err := strconv.Atoi(apiUserIdStr)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,登录信息无效,请重新登录",
})
c.Abort()
return
}
if id != apiUserId {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,与登录用户不匹配,请重新登录",
})
c.Abort()
return
}
}
if status.(int) == common.UserStatusDisabled {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -184,19 +184,13 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
if channel == nil {
return
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("channel_type", channel.Type)
ban := true
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization)
}
c.Set("auto_ban", ban)
c.Set("auto_ban", channel.GetAutoBan())
c.Set("model_mapping", channel.GetModelMapping())
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))

View File

@@ -1,11 +1,13 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
)
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
userId := c.GetInt("id")
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
@@ -13,7 +15,7 @@ func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
},
})
c.Abort()
common.LogError(c.Request.Context(), message)
common.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message))
}
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {

View File

@@ -61,6 +61,13 @@ func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
channel.OtherInfo = string(otherInfoBytes)
}
func (channel *Channel) GetAutoBan() bool {
if channel.AutoBan == nil {
return false
}
return *channel.AutoBan == 1
}
func (channel *Channel) Save() error {
return DB.Save(channel).Error
}
@@ -100,8 +107,8 @@ func SearchChannels(keyword string, group string, model string) ([]*Channel, err
var whereClause string
var args []interface{}
if group != "" {
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " LIKE ? AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+group+"%", "%"+model+"%")
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " = ? AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, group, "%"+model+"%")
} else {
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%")

View File

@@ -3,9 +3,11 @@ package model
import (
"context"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common"
"strings"
"time"
)
type Log struct {
@@ -87,7 +89,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
common.LogError(ctx, "failed to record log: "+err.Error())
}
if common.DataExportEnabled {
common.SafeGoroutine(func() {
gopool.Go(func() {
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
})
}
@@ -101,7 +103,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
tx = DB.Where("type = ?", logType)
}
if modelName != "" {
tx = tx.Where("model_name = ?", modelName)
tx = tx.Where("model_name like ?", modelName)
}
if username != "" {
tx = tx.Where("username = ?", username)
@@ -130,7 +132,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
tx = DB.Where("user_id = ? and type = ?", userId, logType)
}
if modelName != "" {
tx = tx.Where("model_name = ?", modelName)
tx = tx.Where("model_name like ?", modelName)
}
if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName)
@@ -171,12 +173,18 @@ type Stat struct {
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
tx := DB.Table("logs").Select("sum(quota) quota, count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
tx := DB.Table("logs").Select("sum(quota) quota")
// 为rpm和tpm创建单独的查询
rpmTpmQuery := DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
if username != "" {
tx = tx.Where("username = ?", username)
rpmTpmQuery = rpmTpmQuery.Where("username = ?", username)
}
if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName)
rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName)
}
if startTimestamp != 0 {
tx = tx.Where("created_at >= ?", startTimestamp)
@@ -186,11 +194,23 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
}
if modelName != "" {
tx = tx.Where("model_name = ?", modelName)
rpmTpmQuery = rpmTpmQuery.Where("model_name = ?", modelName)
}
if channel != 0 {
tx = tx.Where("channel_id = ?", channel)
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
}
tx.Where("type = ?", LogTypeConsume).Scan(&stat)
tx = tx.Where("type = ?", LogTypeConsume)
rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume)
// 只统计最近60秒的rpm和tpm
rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix())
// 执行查询
tx.Scan(&stat)
rpmTpmQuery.Scan(&stat)
return stat
}

View File

@@ -2,6 +2,7 @@ package model
import (
"errors"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common"
"sync"
@@ -28,12 +29,12 @@ func init() {
}
func InitBatchUpdater() {
go func() {
gopool.Go(func() {
for {
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
batchUpdate()
}
}()
})
}
func addNewRecord(type_ int, id int, value int) {

View File

@@ -8,6 +8,7 @@ import (
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
@@ -15,23 +16,18 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", info.BaseUrl)
if info.RelayMode == constant.RelayModeEmbeddings {
var fullRequestURL string
switch info.RelayMode {
case constant.RelayModeEmbeddings:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
case constant.RelayModeImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
default:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
}
return fullRequestURL, nil
}
@@ -57,13 +53,23 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
return baiduEmbeddingRequest, nil
default:
baiduRequest := requestOpenAI2Ali(*request)
return baiduRequest, nil
aliReq := requestOpenAI2Ali(*request)
return aliReq, nil
}
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
aliRequest := oaiImage2Ali(request)
return aliRequest, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
@@ -71,14 +77,16 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = aliStreamHandler(c, resp)
} else {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
default:
err, usage = aliHandler(c, resp)
switch info.RelayMode {
case constant.RelayModeImagesGenerations:
err, usage = aliImageHandler(c, resp, info)
case constant.RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
default:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
}
return

View File

@@ -60,13 +60,40 @@ type AliUsage struct {
TotalTokens int `json:"total_tokens"`
}
type AliOutput struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
type TaskResult struct {
B64Image string `json:"b64_image,omitempty"`
Url string `json:"url,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
type AliChatResponse struct {
type AliOutput struct {
TaskId string `json:"task_id,omitempty"`
TaskStatus string `json:"task_status,omitempty"`
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
Message string `json:"message,omitempty"`
Code string `json:"code,omitempty"`
Results []TaskResult `json:"results,omitempty"`
}
type AliResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
type AliImageRequest struct {
Model string `json:"model"`
Input struct {
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt,omitempty"`
} `json:"input"`
Parameters struct {
Size string `json:"size,omitempty"`
N int `json:"n,omitempty"`
Steps string `json:"steps,omitempty"`
Scale string `json:"scale,omitempty"`
} `json:"parameters,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}

177
relay/channel/ali/image.go Normal file
View File

@@ -0,0 +1,177 @@
package ali
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
)
func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
var imageRequest AliImageRequest
imageRequest.Input.Prompt = request.Prompt
imageRequest.Model = request.Model
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
imageRequest.Parameters.N = request.N
imageRequest.ResponseFormat = request.ResponseFormat
return &imageRequest
}
func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, error, []byte) {
url := fmt.Sprintf("/api/v1/tasks/%s", taskID)
var aliResponse AliResponse
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return &aliResponse, err, nil
}
req.Header.Set("Authorization", "Bearer "+key)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
common.SysError("updateTask client.Do err: " + err.Error())
return &aliResponse, err, nil
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
var response AliResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
common.SysError("updateTask NewDecoder err: " + err.Error())
return &aliResponse, err, nil
}
return &response, nil, responseBody
}
func asyncTaskWait(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, []byte, error) {
waitSeconds := 3
step := 0
maxStep := 20
var taskResponse AliResponse
var responseBody []byte
for {
step++
rsp, err, body := updateTask(info, taskID, key)
responseBody = body
if err != nil {
return &taskResponse, responseBody, err
}
if rsp.Output.TaskStatus == "" {
return &taskResponse, responseBody, nil
}
switch rsp.Output.TaskStatus {
case "FAILED":
fallthrough
case "CANCELED":
fallthrough
case "SUCCEEDED":
fallthrough
case "UNKNOWN":
return rsp, responseBody, nil
}
if step >= maxStep {
break
}
time.Sleep(time.Duration(waitSeconds) * time.Second)
}
return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
}
func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
imageResponse := dto.ImageResponse{
Created: info.StartTime.Unix(),
}
for _, data := range response.Output.Results {
var b64Json string
if responseFormat == "b64_json" {
_, b64, err := service.GetImageFromUrl(data.Url)
if err != nil {
common.LogError(c, "get_image_data_failed: "+err.Error())
continue
}
b64Json = b64
} else {
b64Json = data.B64Image
}
imageResponse.Data = append(imageResponse.Data, dto.ImageData{
Url: data.Url,
B64Json: b64Json,
RevisedPrompt: "",
})
}
return &imageResponse
}
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
responseFormat := c.GetString("response_format")
var aliTaskResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliTaskResponse.Message != "" {
common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
}
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId, apiKey)
if err != nil {
return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
}
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: aliResponse.Output.Message,
Type: "ali_error",
Param: "",
Code: aliResponse.Output.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, nil
}

View File

@@ -16,34 +16,13 @@ import (
const EnableSearchModelSuffix = "-internet"
func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
//prompt := ""
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
messages = append(messages, AliMessage{
Content: message.StringContent(),
Role: strings.ToLower(message.Role),
})
}
enableSearch := false
aliModel := request.Model
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
}
return &AliChatRequest{
Model: request.Model,
Input: AliInput{
//Prompt: prompt,
Messages: messages,
},
Parameters: AliParameters{
IncrementalOutput: request.Stream,
Seed: uint64(request.Seed),
EnableSearch: enableSearch,
},
func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
if request.TopP >= 1 {
request.TopP = 0.999
} else if request.TopP <= 0 {
request.TopP = 0.001
}
return &request
}
func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
@@ -110,7 +89,7 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbe
return &openAIEmbeddingResponse
}
func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
content, _ := json.Marshal(response.Output.Text)
choice := dto.OpenAITextResponseChoice{
Index: 0,
@@ -134,7 +113,7 @@ func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
return &fullTextResponse
}
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse {
func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(aliResponse.Output.Text)
if aliResponse.Output.FinishReason != "null" {
@@ -154,18 +133,7 @@ func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletions
func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage dto.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
scanner.Split(bufio.ScanLines)
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
@@ -187,7 +155,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse AliChatResponse
var aliResponse AliResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
@@ -221,7 +189,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
}
func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var aliResponse AliChatResponse
var aliResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -222,9 +222,11 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
}
service.Done(c)
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
if resp != nil {
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
}
return nil, &usage
}

View File

@@ -8,12 +8,10 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
)
func stopReasonClaude2OpenAI(reason string) string {
@@ -332,91 +330,59 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
responseText := ""
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
info.SetFirstResponseTime()
if len(data) < 6 || !strings.HasPrefix(data, "data:") {
continue
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
data = strings.TrimPrefix(data, "data:")
data = strings.TrimSpace(data)
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue
}
if atEOF {
return len(data), data, nil
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if response == nil {
continue
}
return 0, nil, nil
})
dataChan := make(chan string, 5)
stopChan := make(chan bool, 2)
go func() {
for scanner.Scan() {
data := scanner.Text()
if !strings.HasPrefix(data, "data: ") {
if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion
responseId = response.Id
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
responseId = claudeResponse.Message.Id
info.UpstreamModelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else if claudeResponse.Type == "content_block_start" {
} else {
continue
}
data = strings.TrimPrefix(data, "data: ")
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}
}
stopChan <- true
}()
isFirst := true
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
//response.Id = responseId
response.Id = responseId
response.Created = createdTime
response.Model = info.UpstreamModelName
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if response == nil {
return true
}
if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion
responseId = response.Id
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
responseId = claudeResponse.Message.Id
info.UpstreamModelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else if claudeResponse.Type == "content_block_start" {
} else {
return true
}
}
//response.Id = responseId
response.Id = responseId
response.Created = createdTime
response.Model = info.UpstreamModelName
err = service.ObjectData(c, response)
if err != nil {
common.SysError(err.Error())
}
return true
case <-stopChan:
return false
err = service.ObjectData(c, response)
if err != nil {
common.LogError(c, "send_stream_response_failed: "+err.Error())
}
})
}
if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
@@ -435,10 +401,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
}
service.Done(c)
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body.Close()
return nil, usage
}

View File

@@ -1,6 +1,7 @@
package cloudflare
var ModelList = []string{
"@cf/meta/llama-3.1-8b-instruct",
"@cf/meta/llama-2-7b-chat-fp16",
"@cf/meta/llama-2-7b-chat-int8",
"@cf/mistral/mistral-7b-instruct-v0.1",

View File

@@ -53,7 +53,7 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
} else if constant.DifyDebug && difyResponse.Event == "node_started" {
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
} else if difyResponse.Event == "message" {
} else if difyResponse.Event == "message" || difyResponse.Event == "agent_message" {
choice.Delta.SetContentString(difyResponse.Answer)
}
response.Choices = append(response.Choices, choice)

View File

@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
@@ -25,18 +26,12 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
// 定义一个映射,存储模型名称和对应的版本
var modelVersionMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-flash-latest": "v1beta",
"gemini-ultra": "v1beta",
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
version, beta := modelVersionMap[info.UpstreamModelName]
version, beta := constant.GeminiModelMap[info.UpstreamModelName]
if !beta {
if info.ApiVersion != "" {
version = info.ApiVersion
@@ -47,7 +42,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent"
action = "streamGenerateContent?alt=sse"
}
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
}

View File

@@ -12,9 +12,15 @@ type GeminiInlineData struct {
Data string `json:"data"`
}
type FunctionCall struct {
FunctionName string `json:"name"`
Arguments any `json:"args"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
}
type GeminiChatContent struct {

View File

@@ -4,18 +4,14 @@ import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"log"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
"github.com/gin-gonic/gin"
)
// Setting safety to the lowest possible values since Gemini is already powerless enough
@@ -46,7 +42,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
MaxOutputTokens: textRequest.MaxTokens,
},
}
if textRequest.Functions != nil {
if textRequest.Tools != nil {
functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {
functions = append(functions, tool.Function)
}
geminiRequest.Tools = []GeminiChatTools{
{
FunctionDeclarations: functions,
},
}
} else if textRequest.Functions != nil {
geminiRequest.Tools = []GeminiChatTools{
{
FunctionDeclarations: textRequest.Functions,
@@ -77,13 +83,28 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
// 判断是否是url
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
// 是url获取图片的类型和base64编码的数据
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
} else {
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil {
continue
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: "image/" + format,
Data: base64String,
},
})
}
}
}
content.Parts = parts
@@ -126,6 +147,30 @@ func (g *GeminiChatResponse) GetResponseText() string {
return ""
}
func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
var toolCalls []dto.ToolCall
item := candidate.Content.Parts[0]
if item.FunctionCall == nil {
return toolCalls
}
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
if err != nil {
//common.SysError("getToolCalls failed: " + err.Error())
return toolCalls
}
toolCall := dto.ToolCall{
ID: fmt.Sprintf("call_%s", common.GetUUID()),
Type: "function",
Function: dto.FunctionCall{
Arguments: string(argsBytes),
Name: item.FunctionCall.FunctionName,
},
}
toolCalls = append(toolCalls, toolCall)
return toolCalls
}
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
@@ -144,8 +189,11 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
FinishReason: relaycommon.StopFinishReason,
}
if len(candidate.Content.Parts) > 0 {
content, _ = json.Marshal(candidate.Content.Parts[0].Text)
choice.Message.Content = content
if candidate.Content.Parts[0].FunctionCall != nil {
choice.Message.ToolCalls = getToolCalls(&candidate)
} else {
choice.Message.SetStringContent(candidate.Content.Parts[0].Text)
}
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
@@ -154,8 +202,17 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(geminiResponse.GetResponseText())
choice.FinishReason = &relaycommon.StopFinishReason
//choice.Delta.SetContentString(geminiResponse.GetResponseText())
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
respFirst := geminiResponse.Candidates[0].Content.Parts[0]
if respFirst.FunctionCall != nil {
// function response
choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0])
} else {
// text response
choice.Delta.SetContentString(respFirst.Text)
}
}
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "gemini"
@@ -165,104 +222,60 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseText := ""
responseJson := ""
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp()
var usage = &dto.Usage{}
dataChan := make(chan string, 5)
stopChan := make(chan bool, 2)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
go func() {
for scanner.Scan() {
data := scanner.Text()
responseJson += data
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "\"text\": \"") {
continue
}
data = strings.TrimPrefix(data, "\"text\": \"")
data = strings.TrimSuffix(data, "\"")
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}
}
stopChan <- true
}()
isFirst := true
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
// this is used to prevent annoying \ related format bug
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
type dummyStruct struct {
Content string `json:"content"`
}
var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(dummy.Content)
response := dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: createAt,
Model: info.UpstreamModelName,
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
return false
for scanner.Scan() {
data := scanner.Text()
info.SetFirstResponseTime()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "data: ") {
continue
}
})
var geminiChatResponses []GeminiChatResponse
err := json.Unmarshal([]byte(responseJson), &geminiChatResponses)
if err != nil {
log.Printf("cannot get gemini usage: %s", err.Error())
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
for _, response := range geminiChatResponses {
usage.PromptTokens = response.UsageMetadata.PromptTokenCount
usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount
data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\"")
var geminiResponse GeminiChatResponse
err := json.Unmarshal([]byte(data), &geminiResponse)
if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error())
continue
}
response := streamResponseGeminiChat2OpenAI(&geminiResponse)
if response == nil {
continue
}
response.Id = id
response.Created = createAt
responseText += response.Choices[0].Delta.GetContentString()
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
}
err = service.ObjectData(c, response)
if err != nil {
common.LogError(c, err.Error())
}
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, relaycommon.StopFinishReason)
service.ObjectData(c, response)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage
}
resp.Body.Close()
return nil, usage
}

View File

@@ -64,7 +64,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OpenaiStreamHandler(c, resp, info)
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
if info.RelayMode == relayconstant.RelayModeEmbeddings {
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)

View File

@@ -3,14 +3,18 @@ package ollama
import "one-api/dto"
type OllamaRequest struct {
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"`
ResponseFormat *dto.ResponseFormat `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
}
type OllamaEmbeddingRequest struct {
@@ -21,6 +25,3 @@ type OllamaEmbeddingRequest struct {
type OllamaEmbeddingResponse struct {
Embedding []float64 `json:"embedding,omitempty"`
}
//type OllamaOptions struct {
//}

View File

@@ -28,14 +28,18 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
Stop, _ = request.Stop.([]string)
}
return &OllamaRequest{
Model: request.Model,
Messages: messages,
Stream: request.Stream,
Temperature: request.Temperature,
Seed: request.Seed,
Topp: request.TopP,
TopK: request.TopK,
Stop: Stop,
Model: request.Model,
Messages: messages,
Stream: request.Stream,
Temperature: request.Temperature,
Seed: request.Seed,
Topp: request.TopP,
TopK: request.TopK,
Stop: Stop,
Tools: request.Tools,
ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
}
}

View File

@@ -145,7 +145,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, usage = OpenaiTTSHandler(c, resp, info)
default:
if info.IsStream {
err, usage = OpenaiStreamHandler(c, resp, info)
err, usage = OaiStreamHandler(c, resp, info)
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -9,6 +9,7 @@ var ModelList = []string{
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4-vision-preview",
"gpt-4o", "gpt-4o-2024-05-13",
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001",
"text-moderation-latest", "text-moderation-stable",

View File

@@ -5,6 +5,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"io"
"net/http"
@@ -15,12 +16,13 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"strings"
"sync"
"time"
)
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
hasStreamUsage := false
responseId := ""
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
containStreamUsage := false
var responseId string
var createAt int64 = 0
var systemFingerprint string
model := info.UpstreamModelName
@@ -40,8 +42,11 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
stopChan := make(chan bool)
defer close(stopChan)
go func() {
var (
lastStreamData string
mu sync.Mutex
)
gopool.Go(func() {
for scanner.Scan() {
info.SetFirstResponseTime()
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
@@ -52,17 +57,22 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
mu.Lock()
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
err := service.StringData(c, data)
if err != nil {
common.LogError(c, "streaming error: "+err.Error())
if lastStreamData != "" {
err := service.StringData(c, lastStreamData)
if err != nil {
common.LogError(c, "streaming error: "+err.Error())
}
}
lastStreamData = data
streamItems = append(streamItems, data)
}
mu.Unlock()
}
common.SafeSendBool(stopChan, true)
}()
})
select {
case <-ticker.C:
@@ -72,6 +82,26 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
// 正常结束
}
shouldSendLastResp := true
var lastStreamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
if err == nil {
responseId = lastStreamResponse.Id
createAt = lastStreamResponse.Created
systemFingerprint = lastStreamResponse.GetSystemFingerprint()
model = lastStreamResponse.Model
if service.ValidUsage(lastStreamResponse.Usage) {
containStreamUsage = true
usage = lastStreamResponse.Usage
if !info.ShouldIncludeUsage {
shouldSendLastResp = false
}
}
}
if shouldSendLastResp {
service.StringData(c, lastStreamData)
}
// 计算token
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch info.RelayMode {
@@ -85,14 +115,9 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
var streamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
responseId = streamResponse.Id
createAt = streamResponse.Created
systemFingerprint = streamResponse.GetSystemFingerprint()
model = streamResponse.Model
if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage
hasStreamUsage = true
}
//if service.ValidUsage(streamResponse.Usage) {
// usage = streamResponse.Usage
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
@@ -109,14 +134,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
} else {
for _, streamResponse := range streamResponses {
responseId = streamResponse.Id
createAt = streamResponse.Created
systemFingerprint = streamResponse.GetSystemFingerprint()
model = streamResponse.Model
if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage
hasStreamUsage = true
}
//if service.ValidUsage(streamResponse.Usage) {
// usage = streamResponse.Usage
// containStreamUsage = true
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
@@ -155,12 +176,12 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
}
if !hasStreamUsage {
if !containStreamUsage {
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
if info.ShouldIncludeUsage && !hasStreamUsage {
if info.ShouldIncludeUsage && !containStreamUsage {
response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
service.ObjectData(c, response)

View File

@@ -58,7 +58,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OpenaiStreamHandler(c, resp, info)
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -153,18 +153,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dt
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage *dto.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
return i + 2, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
scanner.Split(bufio.ScanLines)
dataChan := make(chan string)
metaChan := make(chan string)
stopChan := make(chan bool)

View File

@@ -59,7 +59,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OpenaiStreamHandler(c, resp, info)
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -33,7 +33,7 @@ type RelayInfo struct {
}
func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel")
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
@@ -112,7 +112,7 @@ type TaskRelayInfo struct {
}
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
channelType := c.GetInt("channel")
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")

View File

@@ -27,6 +27,7 @@ const (
RelayModeMidjourneyModal
RelayModeMidjourneyShorten
RelayModeSwapFace
RelayModeMidjourneyUpload
RelayModeAudioSpeech // tts
RelayModeAudioTranscription // whisper
@@ -81,6 +82,9 @@ func Path2RelayModeMidjourney(path string) int {
} else if strings.HasSuffix(path, "/mj/insight-face/swap") {
// midjourney plus
relayMode = RelayModeSwapFace
} else if strings.HasSuffix(path, "/submit/upload-discord-images") {
// midjourney plus
relayMode = RelayModeMidjourneyUpload
} else if strings.HasSuffix(path, "/mj/submit/imagine") {
relayMode = RelayModeMidjourneyImagine
} else if strings.HasSuffix(path, "/mj/submit/blend") {

View File

@@ -121,7 +121,8 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
}
}
quota := int(modelPrice*groupRatio*common.QuotaPerUnit*sizeRatio*qualityRatio) * imageRequest.N
imageRatio := modelPrice * sizeRatio * qualityRatio * float64(imageRequest.N)
quota := int(imageRatio * groupRatio * common.QuotaPerUnit)
if userQuota-quota < 0 {
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
@@ -170,8 +171,8 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
}
usage := &dto.Usage{
PromptTokens: relayInfo.PromptTokens,
TotalTokens: relayInfo.PromptTokens,
PromptTokens: imageRequest.N,
TotalTokens: imageRequest.N,
}
quality := "standard"
@@ -180,7 +181,7 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
}
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, modelPrice, true, logContent)
postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, imageRatio, true, logContent)
return nil
}

View File

@@ -382,6 +382,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
midjRequest.Action = constant.MjActionShorten
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = constant.MjActionBlend
} else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
midjRequest.Action = constant.MjActionUpload
} else if midjRequest.TaskId != "" { //放大、变换任务此类任务如果重复且已有结果远端api会直接返回最终结果
mjId := ""
if relayMode == relayconstant.RelayModeMidjourneyChange {
@@ -547,7 +549,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if err != nil {
common.SysError("get_channel_null: " + err.Error())
}
if channel.AutoBan != nil && *channel.AutoBan == 1 && common.AutomaticDisableChannelEnabled {
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
}
}
@@ -580,7 +582,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
responseBody = []byte(newBody)
}
}
if midjResponse.Code == 1 && midjRequest.Action == "UPLOAD" {
midjourneyTask.Progress = "100%"
midjourneyTask.Status = "SUCCESS"
}
err = midjourneyTask.Insert()
if err != nil {
return &dto.MidjourneyResponse{
@@ -594,7 +599,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
responseBody = []byte(newBody)
}
//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))

View File

@@ -130,6 +130,12 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return openaiErr
}
includeUsage := false
// 判断用户是否需要返回使用情况
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
includeUsage = true
}
// 如果不支持StreamOptions将StreamOptions设置为nil
if !relayInfo.SupportStreamOptions || !textRequest.Stream {
textRequest.StreamOptions = nil
@@ -142,8 +148,8 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
}
}
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
relayInfo.ShouldIncludeUsage = textRequest.StreamOptions.IncludeUsage
if includeUsage {
relayInfo.ShouldIncludeUsage = true
}
adaptor := GetAdaptor(relayInfo.ApiType)
@@ -247,13 +253,13 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if tokenQuota > 100*preConsumedQuota {
// 令牌额度充足,信任令牌
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
}
} else {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
}
}
if preConsumedQuota > 0 {
@@ -280,7 +286,14 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
modelPrice float64, usePrice bool, extraContent string) {
if usage == nil {
usage = &dto.Usage{
PromptTokens: relayInfo.PromptTokens,
CompletionTokens: 0,
TotalTokens: relayInfo.PromptTokens,
}
extraContent += " ,(可能是请求出错)"
}
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens

View File

@@ -79,5 +79,6 @@ func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
relayMjRouter.POST("/submit/upload-discord-images", controller.RelayMidjourney)
}
}

View File

@@ -49,6 +49,8 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
action = constant.MjActionModal
case relayconstant.RelayModeSwapFace:
action = constant.MjActionSwapFace
case relayconstant.RelayModeMidjourneyUpload:
action = constant.MjActionUpload
case relayconstant.RelayModeMidjourneySimpleChange:
params := ConvertSimpleChangeParams(midjRequest.Content)
if params == nil {
@@ -220,7 +222,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
}
var midjResponse dto.MidjourneyResponse
var midjourneyUploadsResponse dto.MidjourneyUploadResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
@@ -230,13 +232,16 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
}
respStr := string(responseBody)
log.Printf("responseBody: %s", respStr)
log.Printf("respStr: %s", respStr)
if respStr == "" {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
} else {
err = json.Unmarshal(responseBody, &midjResponse)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
err2 := json.Unmarshal(responseBody, &midjourneyUploadsResponse)
if err2 != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
}
}
}
//log.Printf("midjResponse: %v", midjResponse)

View File

@@ -7,6 +7,7 @@ import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/dto"
)
func SetEventStreamHeaders(c *gin.Context) {
@@ -45,3 +46,30 @@ func GetResponseID(c *gin.Context) string {
logID := c.GetString("X-Oneapi-Request-Id")
return fmt.Sprintf("chatcmpl-%s", logID)
}
func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: createAt,
Model: model,
SystemFingerprint: nil,
Choices: []dto.ChatCompletionsStreamResponseChoice{
{
FinishReason: &finishReason,
},
},
}
}
func GenerateFinalUsageResponse(id string, createAt int64, model string, usage dto.Usage) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: createAt,
Model: model,
SystemFingerprint: nil,
Choices: make([]dto.ChatCompletionsStreamResponseChoice, 0),
Usage: &usage,
}
}

View File

@@ -9,6 +9,7 @@ import (
"log"
"math"
"one-api/common"
"one-api/constant"
"one-api/dto"
"strings"
"unicode/utf8"
@@ -81,17 +82,31 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
}
func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
// TODO: 非流模式下不计算图片token数量
baseTokens := 85
if model == "glm-4v" {
return 1047, nil
}
if imageUrl.Detail == "low" {
return 85, nil
return baseTokens, nil
}
// TODO: 非流模式下不计算图片token数量
if !constant.GetMediaTokenNotStream && !stream {
return 1000, nil
}
// 是否统计图片token
if !constant.GetMediaToken {
return 1000, nil
}
// 同步One API的图片计费逻辑
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
imageUrl.Detail = "high"
}
tileTokens := 170
if strings.HasPrefix(model, "gpt-4o-mini") {
tileTokens = 5667
baseTokens = 2833
}
var config image.Config
var err error
var format string
@@ -138,7 +153,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
// 计算图片的token数量(边的长度除以512向上取整)
tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512)
log.Printf("tiles: %d", tiles)
return tiles*170 + 85, nil
return tiles*tileTokens + baseTokens, nil
}
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, error) {

View File

@@ -25,18 +25,6 @@ func ResponseText2Usage(responseText string, modeName string, promptTokens int)
return usage, err
}
func GenerateFinalUsageResponse(id string, createAt int64, model string, usage dto.Usage) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: createAt,
Model: model,
SystemFingerprint: nil,
Choices: make([]dto.ChatCompletionsStreamResponseChoice, 0),
Usage: &usage,
}
}
func ValidUsage(usage *dto.Usage) bool {
return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0)
}

View File

@@ -1,7 +1,7 @@
import React, { useContext, useEffect, useState } from 'react';
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
import { UserContext } from '../context/User';
import { API, getLogo, showError, showInfo, showSuccess } from '../helpers';
import { API, getLogo, showError, showInfo, showSuccess, updateAPI } from '../helpers';
import { onGitHubOAuthClicked } from './utils';
import Turnstile from 'react-turnstile';
import {
@@ -101,6 +101,7 @@ const LoginForm = () => {
if (success) {
userDispatch({ type: 'login', payload: data });
setUserData(data);
updateAPI()
showSuccess('登录成功!');
if (username === 'root' && password === '123456') {
Modal.error({

View File

@@ -1,11 +1,11 @@
import React, { useEffect, useState } from 'react';
import {
API,
copy,
copy, getTodayStartTimestamp,
isAdmin,
showError,
showSuccess,
timestamp2string,
timestamp2string
} from '../helpers';
import {
@@ -419,12 +419,12 @@ const LogsTable = () => {
const [logType, setLogType] = useState(0);
const isAdminUser = isAdmin();
let now = new Date();
// 初始化start_timestamp为前一天
// 初始化start_timestamp为今天0点
const [inputs, setInputs] = useState({
username: '',
token_name: '',
model_name: '',
start_timestamp: timestamp2string(now.getTime() / 1000 - 86400),
start_timestamp: timestamp2string(getTodayStartTimestamp()),
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
channel: '',
});
@@ -449,8 +449,10 @@ const LogsTable = () => {
const getLogSelfStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let url = `/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
url = encodeURI(url);
let res = await API.get(
`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`,
url,
);
const { success, message, data } = res.data;
if (success) {
@@ -463,8 +465,10 @@ const LogsTable = () => {
const getLogStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let url = `/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
url = encodeURI(url);
let res = await API.get(
`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`,
url,
);
const { success, message, data } = res.data;
if (success) {
@@ -475,6 +479,9 @@ const LogsTable = () => {
};
const handleEyeClick = async () => {
if (loadingStat) {
return;
}
setLoadingStat(true);
if (isAdminUser) {
await getLogStat();
@@ -531,6 +538,7 @@ const LogsTable = () => {
} else {
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
}
url = encodeURI(url);
const res = await API.get(url);
const { success, message, data } = res.data;
if (success) {
@@ -574,6 +582,7 @@ const LogsTable = () => {
const refresh = async () => {
// setLoading(true);
setActivePage(1);
handleEyeClick();
await loadLogs(0, pageSize, logType);
};
@@ -596,6 +605,7 @@ const LogsTable = () => {
.catch((reason) => {
showError(reason);
});
handleEyeClick();
}, []);
const searchLogs = async () => {
@@ -622,19 +632,17 @@ const LogsTable = () => {
<Layout>
<Header>
<Spin spinning={loadingStat}>
<h3>
使用明细总消耗额度
<span
onClick={handleEyeClick}
style={{
cursor: 'pointer',
color: 'gray',
}}
>
{showStat ? renderQuota(stat.quota) : '点击查看'}
</span>
</h3>
<Space>
<Tag color='green' size='large' style={{ padding: 15 }}>
总消耗额度: {renderQuota(stat.quota)}
</Tag>
<Tag color='blue' size='large' style={{ padding: 15 }}>
RPM: {stat.rpm}
</Tag>
<Tag color='purple' size='large' style={{ padding: 15 }}>
TPM: {stat.tpm}
</Tag>
</Space>
</Spin>
</Header>
<Form layout='horizontal' style={{ marginTop: 10 }}>
@@ -700,17 +708,19 @@ const LogsTable = () => {
/>
</>
)}
<Button
label='查询'
type='primary'
htmlType='submit'
className='btn-margin-right'
onClick={refresh}
loading={loading}
style={{ marginTop: 24 }}
>
查询
</Button>
<Form.Section>
<Button
label='查询'
type='primary'
htmlType='submit'
className='btn-margin-right'
onClick={refresh}
loading={loading}
>
查询
</Button>
</Form.Section>
</>
</Form>

View File

@@ -90,6 +90,12 @@ function renderType(type) {
图混合
</Tag>
);
case 'UPLOAD':
return (
<Tag color='blue' size='large'>
上传文件
</Tag>
);
case 'SHORTEN':
return (
<Tag color='pink' size='large'>
@@ -239,7 +245,7 @@ const renderTimestamp = (timestampInSeconds) => {
// 修改renderDuration函数以包含颜色逻辑
function renderDuration(submit_time, finishTime) {
// 确保startTime和finishTime都是有效的时间戳
if (!submit_time || !finishTime) return 'N/A';
if (!submit_time || !finishTime) return 'N/A';
// 将时间戳转换为Date对象
const start = new Date(submit_time);

View File

@@ -1,12 +1,26 @@
import { showError } from './utils';
import { getUserIdFromLocalStorage, showError } from './utils';
import axios from 'axios';
export const API = axios.create({
export let API = axios.create({
baseURL: import.meta.env.VITE_REACT_APP_SERVER_URL
? import.meta.env.VITE_REACT_APP_SERVER_URL
: '',
headers: {
'New-API-User': getUserIdFromLocalStorage()
}
});
export function updateAPI() {
API = axios.create({
baseURL: import.meta.env.VITE_REACT_APP_SERVER_URL
? import.meta.env.VITE_REACT_APP_SERVER_URL
: '',
headers: {
'New-API-User': getUserIdFromLocalStorage()
}
});
}
API.interceptors.response.use(
(response) => response,
(error) => {

View File

@@ -33,6 +33,13 @@ export function getLogo() {
return logo;
}
export function getUserIdFromLocalStorage() {
let user = localStorage.getItem('user');
if (!user) return -1;
user = JSON.parse(user);
return user.id;
}
export function getFooterHTML() {
return localStorage.getItem('footer_html');
}
@@ -133,6 +140,12 @@ export function removeTrailingSlash(url) {
}
}
export function getTodayStartTimestamp() {
var now = new Date();
now.setHours(0, 0, 0, 0);
return Math.floor(now.getTime() / 1000);
}
export function timestamp2string(timestamp) {
let date = new Date(timestamp * 1000);
let year = date.getFullYear().toString();

View File

@@ -1,11 +1,14 @@
import React from 'react';
import TokensTable from '../../components/TokensTable';
import { Layout } from '@douyinfe/semi-ui';
import { Banner, Layout } from '@douyinfe/semi-ui';
const Token = () => (
<>
<Layout>
<Layout.Header>
<h3>我的令牌</h3>
<Banner
type='warning'
description='令牌无法精确控制使用额度,请勿直接将令牌分发给用户。'
/>
</Layout.Header>
<Layout.Content>
<TokensTable />