Compare commits

...

33 Commits

Author SHA1 Message Date
JustSong
159b9e3369 fix: fix unable to set zero value for base url & model mapping 2023-09-18 22:07:17 +08:00
JustSong
92001986db Merge branch 'main' of https://github.com/songquanpeng/one-api 2023-09-18 21:44:36 +08:00
JustSong
a5647b1ea7 fix: fix priority not updated & random choice not working 2023-09-18 21:43:45 +08:00
Redreamality
215e54fc96 docs: update readme (#502)
* Update README.md

* docs: update README

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-09-17 21:39:21 +08:00
Xyfacai
ecf8a6d875 feat: supprt channel priority now & record channel id in log (#484)
* feat: 支持设置渠道优先级 & 日志中显示使用的渠道ID

* fix: 设置渠道优先级未更新 ability

* chore: update implementation

---------

Co-authored-by: Xiangyuan Liu <xiangyuan.liu@ui.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
2023-09-17 19:18:16 +08:00
igophper
24df3e5f62 feat: support non-stream mode for xunfei (#498)
* feat:xunfei suport none stream

* fix:join content ignore seq

---------

Co-authored-by: igophper <admin@jialilgu.cn>
2023-09-17 18:16:12 +08:00
JustSong
12ef9679a7 fix: fix url not passing when using custom chat_link 2023-09-17 17:19:12 +08:00
JustSong
328aa68255 feat: able to delete logs now (close #486) 2023-09-17 17:09:56 +08:00
JustSong
4335f005a6 feat: create new log file when too many logs recorded 2023-09-17 16:35:30 +08:00
JustSong
fe26a1448d docs: update README 2023-09-17 15:41:01 +08:00
JustSong
42451d9d02 refactor: update logging related logic 2023-09-17 15:39:46 +08:00
JustSong
25c4c111ab fix: only enable cors for relay routers to avoid csrf attack 2023-09-17 11:44:38 +08:00
JustSong
0d50ad4b2b chore: update channel test prompt 2023-09-17 11:34:06 +08:00
JustSong
959bcdef88 chore: update error code 2023-09-17 11:30:20 +08:00
JustSong
39ae8075e4 fix: fix oauth2 state not checking 2023-09-15 00:24:20 +08:00
JustSong
b57a0eca16 docs: update readme 2023-09-13 23:22:53 +08:00
Junyan Qin
1b4cc78890 docs: add QChatGPT (#522)
* doc(README.md): 添加支持One API的项目QChatGPT

* Update README.md

---------

Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
2023-09-13 23:18:53 +08:00
JustSong
420c375140 perf: only return quota when it's not zero 2023-09-13 22:05:10 +08:00
JustSong
01863d3e44 fix: fix quota not return when error occurred (close #518) 2023-09-13 21:50:45 +08:00
igophper
d0a0e871e1 fix: support ali's embedding model (#481, close #469)
* feat:支持阿里的 embedding 模型

* fix: add to model list

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
2023-09-03 22:12:35 +08:00
JustSong
bd6fe1e93c feat: able to config rate limit (close #477) 2023-09-03 21:56:37 +08:00
JustSong
c55bb67818 docs: update README (close #482) 2023-09-03 21:50:00 +08:00
JustSong
0f949c3782 docs: update README (close #482) 2023-09-03 21:49:41 +08:00
JustSong
a721a5b6f9 chore: add error prompt for Azure 2023-09-03 21:46:07 +08:00
JustSong
276163affd fix: press enter to submit custom model name 2023-09-03 21:40:58 +08:00
JustSong
621eb91b46 chore: pass through error out 2023-09-03 21:31:58 +08:00
JustSong
7e575abb95 feat: add channel type FastGPT 2023-09-03 15:50:49 +08:00
JustSong
9db93316c4 docs: update README 2023-09-03 15:12:54 +08:00
JustSong
c3dc315e75 feat: add batch update support (close #414) 2023-09-03 14:58:20 +08:00
JustSong
04acdb1ccb feat: support aiproxy's library 2023-09-03 12:51:59 +08:00
JustSong
f0d5e102a3 fix: fix log table use created_at as key instead of id
Co-authored-by: 13714733197 <13714733197@163.com>
2023-08-30 21:43:01 +08:00
JustSong
abbf2fded0 perf: preallocate array capacity 2023-08-30 21:15:56 +08:00
JustSong
ef2c5abb5b docs: update README 2023-08-30 20:51:37 +08:00
49 changed files with 1320 additions and 377 deletions

3
.gitignore vendored
View File

@@ -4,4 +4,5 @@ upload
*.exe
*.db
build
*.db-journal
*.db-journal
logs

View File

@@ -71,10 +71,10 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [360 智脑](https://ai.360.cn)
2. 支持配置镜像以及众多第三方代理服务:
+ [x] [OpenAI-SB](https://openai-sb.com)
+ [x] [CloseAI](https://console.closeai-asia.com/r/2412)
+ [x] [API2D](https://api2d.com/r/197971)
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`
+ [x] [CloseAI](https://console.closeai-asia.com/r/2412)
+ [x] 自定义渠道:例如各种未收录的第三方代理服务
3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
@@ -109,6 +109,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。
如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。
如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。
@@ -209,6 +211,13 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。
#### QChatGPT - QQ机器人
项目主页https://github.com/RockChinQ/QChatGPT
根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。
可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。
### 部署到第三方平台
<details>
<summary><strong>部署到 Sealos </strong></summary>
@@ -260,6 +269,12 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
注意,具体的 API Base 的格式取决于你所使用的客户端。
例如对于 OpenAI 的官方库:
```bash
OPENAI_API_KEY="sk-xxxxxx"
OPENAI_API_BASE="https://<HOST>:<PORT>/v1"
```
```mermaid
graph LR
A(用户)
@@ -275,8 +290,9 @@ graph LR
不加的话将会使用负载均衡的方式使用多个渠道。
### 环境变量
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为请求频率限制的存储,而非使用内存存储
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
+ 如果数据库访问延迟很低,没有必要启用 Redis启用后反而会出现数据滞后的问题。
2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
+ 例子:`SESSION_SECRET=random_string`
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite请使用 MySQL 或 PostgreSQL。
@@ -303,11 +319,19 @@ graph LR
+ 例子:`CHANNEL_TEST_FREQUENCY=1440`
9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ 例子:`POLLING_INTERVAL=5`
10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
+ 例子:`BATCH_UPDATE_ENABLED=true`
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
+ 例子:`BATCH_UPDATE_INTERVAL=5`
12. 请求频率限制:
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
+ 例子:`--port 3000`
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,日志将不会被保存
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下
+ 例子:`--log-dir ./logs`
3. `--version`: 打印系统版本号并退出。
4. `--help`: 查看命令的使用帮助和参数说明。
@@ -339,6 +363,7 @@ https://openai.justsong.cn
5. ChatGPT Next Web 报错:`Failed to fetch`
+ 部署的时候不要设置 `BASE_URL`。
+ 检查你的接口地址和 API Key 有没有填对。
+ 检查是否启用了 HTTPS浏览器会拦截 HTTPS 域名下的 HTTP 请求。
6. 报错:`当前分组负载已饱和,请稍后再试`
+ 上游通道 429 了。
@@ -352,4 +377,4 @@ https://openai.justsong.cn
同样适用于基于本项目的二开项目。
依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。
依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。

View File

@@ -94,6 +94,13 @@ var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
const (
RoleGuestUser = 0
RoleCommonUser = 1
@@ -111,10 +118,10 @@ var (
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = 180
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = 60
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
@@ -154,49 +161,53 @@ const (
)
const (
ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1
ChannelTypeAPI2D = 2
ChannelTypeAzure = 3
ChannelTypeCloseAI = 4
ChannelTypeOpenAISB = 5
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
ChannelTypeAILS = 9
ChannelTypeAIProxy = 10
ChannelTypePaLM = 11
ChannelTypeAPI2GPT = 12
ChannelTypeAIGC2D = 13
ChannelTypeAnthropic = 14
ChannelTypeBaidu = 15
ChannelTypeZhipu = 16
ChannelTypeAli = 17
ChannelTypeXunfei = 18
ChannelType360 = 19
ChannelTypeOpenRouter = 20
ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1
ChannelTypeAPI2D = 2
ChannelTypeAzure = 3
ChannelTypeCloseAI = 4
ChannelTypeOpenAISB = 5
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
ChannelTypeAILS = 9
ChannelTypeAIProxy = 10
ChannelTypePaLM = 11
ChannelTypeAPI2GPT = 12
ChannelTypeAIGC2D = 13
ChannelTypeAnthropic = 14
ChannelTypeBaidu = 15
ChannelTypeZhipu = 16
ChannelTypeAli = 17
ChannelTypeXunfei = 18
ChannelType360 = 19
ChannelTypeOpenRouter = 20
ChannelTypeAIProxyLibrary = 21
ChannelTypeFastGPT = 22
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
}

View File

@@ -12,7 +12,7 @@ var (
Port = flag.Int("port", 3000, "the listening port")
PrintVersion = flag.Bool("version", false, "print version and exit")
PrintHelp = flag.Bool("help", false, "print help and exit")
LogDir = flag.String("log-dir", "", "specify the log directory")
LogDir = flag.String("log-dir", "./logs", "specify the log directory")
)
func printHelp() {

View File

@@ -1,29 +1,47 @@
package common
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"io"
"log"
"os"
"path/filepath"
"sync"
"time"
)
func SetupGinLog() {
const (
loggerINFO = "INFO"
loggerWarn = "WARN"
loggerError = "ERR"
)
const maxLogCount = 1000000
var logCount int
var setupLogLock sync.Mutex
var setupLogWorking bool
func SetupLogger() {
if *LogDir != "" {
commonLogPath := filepath.Join(*LogDir, "common.log")
errorLogPath := filepath.Join(*LogDir, "error.log")
commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
ok := setupLogLock.TryLock()
if !ok {
log.Println("setup log is already working")
return
}
defer func() {
setupLogLock.Unlock()
setupLogWorking = false
}()
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
}
errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
}
gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd)
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd)
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
}
}
@@ -37,6 +55,36 @@ func SysError(s string) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func LogInfo(ctx context.Context, msg string) {
logHelper(ctx, loggerINFO, msg)
}
func LogWarn(ctx context.Context, msg string) {
logHelper(ctx, loggerWarn, msg)
}
func LogError(ctx context.Context, msg string) {
logHelper(ctx, loggerError, msg)
}
func logHelper(ctx context.Context, level string, msg string) {
writer := gin.DefaultErrorWriter
if level == loggerINFO {
writer = gin.DefaultWriter
}
id := ctx.Value(RequestIdKey)
now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
logCount++ // we don't need accurate count, so no lock here
if logCount > maxLogCount && !setupLogWorking {
logCount = 0
setupLogWorking = true
go func() {
SetupLogger()
}()
}
}
func FatalLog(v ...any) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)

View File

@@ -50,9 +50,10 @@ var ModelRatio = map[string]float64{
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag
"qwen-plus-v1": 0.5715, // Same as above
"SparkDesk": 0.8572, // TBD
"qwen-v1": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus-v1": 1, // ¥0.014 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens

View File

@@ -171,6 +171,11 @@ func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
func Max(a int, b int) int {
if a >= b {
return a
@@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int {
}
return num
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}

View File

@@ -29,7 +29,7 @@ func GetSubscription(c *gin.Context) {
if err != nil {
openAIError := OpenAIError{
Message: err.Error(),
Type: "one_api_error",
Type: "upstream_error",
}
c.JSON(200, gin.H{
"error": openAIError,

View File

@@ -111,7 +111,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
}
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL)
url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
@@ -201,18 +201,18 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type]
if channel.BaseURL == "" {
channel.BaseURL = baseURL
if channel.GetBaseURL() == "" {
channel.BaseURL = &baseURL
}
switch channel.Type {
case common.ChannelTypeOpenAI:
if channel.BaseURL != "" {
baseURL = channel.BaseURL
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
case common.ChannelTypeAzure:
return 0, errors.New("尚未实现")
case common.ChannelTypeCustom:
baseURL = channel.BaseURL
baseURL = channel.GetBaseURL()
case common.ChannelTypeCloseAI:
return updateChannelCloseAIBalance(channel)
case common.ChannelTypeOpenAISB:

View File

@@ -14,7 +14,7 @@ import (
"time"
)
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
switch channel.Type {
case common.ChannelTypePaLM:
fallthrough
@@ -32,15 +32,20 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
case common.ChannelTypeAzure:
request.Model = "gpt-35-turbo"
defer func() {
if err != nil {
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
}
}()
default:
request.Model = "gpt-3.5-turbo"
}
requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
} else {
if channel.BaseURL != "" {
requestURL = channel.BaseURL
if channel.GetBaseURL() != "" {
requestURL = channel.GetBaseURL()
}
requestURL += "/v1/chat/completions"
}

View File

@@ -85,7 +85,7 @@ func AddChannel(c *gin.Context) {
}
channel.CreatedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n")
channels := make([]model.Channel, 0)
channels := make([]model.Channel, 0, len(keys))
for _, key := range keys {
if key == "" {
continue

View File

@@ -79,6 +79,14 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
func GitHubOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
GitHubBind(c)
@@ -205,3 +213,22 @@ func GitHubBind(c *gin.Context) {
})
return
}
func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c)
state := common.GetRandomString(12)
session.Set("oauth_state", state)
err := session.Save()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": state,
})
}

View File

@@ -2,6 +2,7 @@ package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
@@ -18,19 +19,21 @@ func GetAllLogs(c *gin.Context) {
username := c.Query("username")
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
channel, _ := strconv.Atoi(c.Query("channel"))
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
if err != nil {
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": logs,
})
return
}
func GetUserLogs(c *gin.Context) {
@@ -46,34 +49,36 @@ func GetUserLogs(c *gin.Context) {
modelName := c.Query("model_name")
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
if err != nil {
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": logs,
})
return
}
func SearchAllLogs(c *gin.Context) {
keyword := c.Query("keyword")
logs, err := model.SearchAllLogs(keyword)
if err != nil {
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": logs,
})
return
}
func SearchUserLogs(c *gin.Context) {
@@ -81,17 +86,18 @@ func SearchUserLogs(c *gin.Context) {
userId := c.GetInt("id")
logs, err := model.SearchUserLogs(userId, keyword)
if err != nil {
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": logs,
})
return
}
func GetLogsStat(c *gin.Context) {
@@ -101,9 +107,10 @@ func GetLogsStat(c *gin.Context) {
tokenName := c.Query("token_name")
username := c.Query("username")
modelName := c.Query("model_name")
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
channel, _ := strconv.Atoi(c.Query("channel"))
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
@@ -111,6 +118,7 @@ func GetLogsStat(c *gin.Context) {
//"token": tokenNum,
},
})
return
}
func GetLogsSelfStat(c *gin.Context) {
@@ -120,9 +128,10 @@ func GetLogsSelfStat(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
channel, _ := strconv.Atoi(c.Query("channel"))
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
@@ -130,4 +139,30 @@ func GetLogsSelfStat(c *gin.Context) {
//"token": tokenNum,
},
})
return
}
func DeleteHistoryLogs(c *gin.Context) {
targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64)
if targetTimestamp == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "target timestamp is required",
})
return
}
count, err := model.DeleteOldLog(targetTimestamp)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": count,
})
return
}

View File

@@ -360,6 +360,15 @@ func init() {
Root: "qwen-plus-v1",
Parent: nil,
},
{
Id: "text-embedding-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "text-embedding-v1",
Parent: nil,
},
{
Id: "SparkDesk",
Object: "model",

220
controller/relay-aiproxy.go Normal file
View File

@@ -0,0 +1,220 @@
package controller
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strconv"
"strings"
)
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
type AIProxyLibraryRequest struct {
Model string `json:"model"`
Query string `json:"query"`
LibraryId string `json:"libraryId"`
Stream bool `json:"stream"`
}
type AIProxyLibraryError struct {
ErrCode int `json:"errCode"`
Message string `json:"message"`
}
type AIProxyLibraryDocument struct {
Title string `json:"title"`
URL string `json:"url"`
}
type AIProxyLibraryResponse struct {
Success bool `json:"success"`
Answer string `json:"answer"`
Documents []AIProxyLibraryDocument `json:"documents"`
AIProxyLibraryError
}
type AIProxyLibraryStreamResponse struct {
Content string `json:"content"`
Finish bool `json:"finish"`
Model string `json:"model"`
Documents []AIProxyLibraryDocument `json:"documents"`
}
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
query := ""
if len(request.Messages) != 0 {
query = request.Messages[len(request.Messages)-1].Content
}
return &AIProxyLibraryRequest{
Model: request.Model,
Stream: request.Stream,
Query: query,
}
}
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
if len(documents) == 0 {
return ""
}
content := "\n\n参考文档\n"
for i, document := range documents {
content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
}
return content
}
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: content,
},
FinishReason: "stop",
}
fullTextResponse := OpenAITextResponse{
Id: common.GetUUID(),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
}
return &fullTextResponse
}
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
choice.FinishReason = &stopFinishReason
return &ChatCompletionsStreamResponse{
Id: common.GetUUID(),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
}
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content
return &ChatCompletionsStreamResponse{
Id: common.GetUUID(),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: response.Model,
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
}
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
var documents []AIProxyLibraryDocument
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var AIProxyLibraryResponse AIProxyLibraryStreamResponse
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if len(AIProxyLibraryResponse.Documents) != 0 {
documents = AIProxyLibraryResponse.Documents
}
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
response := documentsAIProxyLibrary(documents)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var AIProxyLibraryResponse AIProxyLibraryResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if AIProxyLibraryResponse.ErrCode != 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: AIProxyLibraryResponse.Message,
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
Code: AIProxyLibraryResponse.ErrCode,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -35,6 +35,29 @@ type AliChatRequest struct {
Parameters AliParameters `json:"parameters,omitempty"`
}
type AliEmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type AliEmbedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type AliEmbeddingResponse struct {
Output struct {
Embeddings []AliEmbedding `json:"embeddings"`
} `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
type AliError struct {
Code string `json:"code"`
Message string `json:"message"`
@@ -44,6 +67,7 @@ type AliError struct {
type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type AliOutput struct {
@@ -95,6 +119,70 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
}
}
func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
return &AliEmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
Texts []string `json:"texts"`
}{
Texts: request.ParseInput(),
},
}
}
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var aliResponse AliEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
Object: "list",
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: "text-embedding-v1",
Usage: Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
choice := OpenAITextResponseChoice{
Index: 0,

View File

@@ -2,6 +2,7 @@ package controller
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@@ -17,6 +18,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
group := c.GetString("group")
@@ -91,7 +93,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
}
var audioResponse AudioResponse
defer func() {
defer func(ctx context.Context) {
go func() {
quota := countTokenText(audioResponse.Text, audioModel)
quotaDelta := quota - preConsumedQuota
@@ -106,13 +108,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}()
}()
}(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body)

View File

@@ -144,20 +144,9 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom
}
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
baiduEmbeddingRequest := BaiduEmbeddingRequest{
Input: nil,
return &BaiduEmbeddingRequest{
Input: request.ParseInput(),
}
switch request.Input.(type) {
case string:
baiduEmbeddingRequest.Input = []string{request.Input.(string)}
case []any:
for _, item := range request.Input.([]any) {
if str, ok := item.(string); ok {
baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str)
}
}
}
return &baiduEmbeddingRequest
}
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {

View File

@@ -2,6 +2,7 @@ package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -18,6 +19,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
@@ -124,7 +126,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
}
var textResponse ImageResponse
defer func() {
defer func(ctx context.Context) {
if consumeQuota {
err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil {
@@ -137,13 +139,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}()
}(c.Request.Context())
if consumeQuota {
responseBody, err := io.ReadAll(resp.Body)

View File

@@ -2,6 +2,7 @@ package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -22,6 +23,7 @@ const (
APITypeZhipu
APITypeAli
APITypeXunfei
APITypeAIProxyLibrary
)
var httpClient *http.Client
@@ -36,6 +38,7 @@ func init() {
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
@@ -104,6 +107,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiType = APITypeAli
case common.ChannelTypeXunfei:
apiType = APITypeXunfei
case common.ChannelTypeAIProxyLibrary:
apiType = APITypeAIProxyLibrary
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
@@ -171,6 +176,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
case APITypeAli:
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
if relayMode == RelayModeEmbeddings {
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
}
case APITypeAIProxyLibrary:
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
}
var promptTokens int
var completionTokens int
@@ -202,6 +212,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
if consumeQuota && preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
@@ -257,8 +268,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeAli:
aliRequest := requestOpenAI2Ali(textRequest)
jsonStr, err := json.Marshal(aliRequest)
var jsonStr []byte
var err error
switch relayMode {
case RelayModeEmbeddings:
aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
jsonStr, err = json.Marshal(aliEmbeddingRequest)
default:
aliRequest := requestOpenAI2Ali(textRequest)
jsonStr, err = json.Marshal(aliRequest)
}
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeAIProxyLibrary:
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
@@ -302,6 +329,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if textRequest.Stream {
req.Header.Set("X-DashScope-SSE", "enable")
}
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
@@ -321,15 +350,23 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
return relayErrorHandler(resp)
}
}
var textResponse TextResponse
tokenName := c.GetString("token_name")
channelId := c.GetInt("channel_id")
defer func() {
defer func(ctx context.Context) {
// c.Writer.Flush()
go func() {
if consumeQuota {
@@ -352,22 +389,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}()
}()
}(c.Request.Context())
switch apiType {
case APITypeOpenAI:
if isStream {
@@ -488,7 +524,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
return nil
} else {
err, usage := aliHandler(c, resp)
var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
default:
err, usage = aliHandler(c, resp)
}
if err != nil {
return err
}
@@ -498,14 +541,29 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return nil
}
case APITypeXunfei:
auth := c.Request.Header.Get("Authorization")
auth = strings.TrimPrefix(auth, "Bearer ")
splits := strings.Split(auth, "|")
if len(splits) != 3 {
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
var err *OpenAIErrorWithStatusCode
var usage *Usage
if isStream {
auth := c.Request.Header.Get("Authorization")
auth = strings.TrimPrefix(auth, "Bearer ")
splits := strings.Split(auth, "|")
if len(splits) != 3 {
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
} else {
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
case APITypeAIProxyLibrary:
if isStream {
err, usage := aiProxyLibraryStreamHandler(c, resp)
if err != nil {
return err
}
@@ -514,7 +572,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
return nil
} else {
return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
err, usage := aiProxyLibraryHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
default:
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)

View File

@@ -146,7 +146,7 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
StatusCode: resp.StatusCode,
OpenAIError: OpenAIError{
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
Type: "one_api_error",
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},

View File

@@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
Role: "assistant",
Content: response.Payload.Choices.Text[0].Content,
},
FinishReason: stopFinishReason,
}
fullTextResponse := OpenAITextResponse{
Object: "chat.completion",
@@ -177,33 +178,82 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
}
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
setEventStreamHeaders(c)
var usage Usage
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
c.Stream(func(w io.Writer) bool {
select {
case xunfeiResponse := <-dataChan:
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, &usage
}
func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
if apiVersion == "" {
apiVersion = "v1.1"
common.SysLog("api_version not found, use default: " + apiVersion)
var usage Usage
var content string
var xunfeiResponse XunfeiChatResponse
stop := false
for !stop {
select {
case xunfeiResponse = <-dataChan:
content += xunfeiResponse.Payload.Choices.Text[0].Content
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
case stop = <-stopChan:
}
}
domain := "general"
if apiVersion == "v2.1" {
domain = "generalv2"
xunfeiResponse.Payload.Choices.Text[0].Content = content
response := responseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion)
c.Writer.Header().Set("Content-Type", "application/json")
_, _ = c.Writer.Write(jsonResponse)
return nil, &usage
}
func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
conn, resp, err := d.Dial(authUrl, nil)
if err != nil || resp.StatusCode != 101 {
return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
return nil, nil, err
}
data := requestOpenAI2Xunfei(textRequest, appId, domain)
err = conn.WriteJSON(data)
if err != nil {
return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
return nil, nil, err
}
dataChan := make(chan XunfeiChatResponse)
stopChan := make(chan bool)
go func() {
@@ -230,61 +280,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case xunfeiResponse := <-dataChan:
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, &usage
return dataChan, stopChan, nil
}
func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var xunfeiResponse XunfeiChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
if apiVersion == "" {
apiVersion = "v1.1"
common.SysLog("api_version not found, use default: " + apiVersion)
}
err = json.Unmarshal(responseBody, &xunfeiResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
domain := "general"
if apiVersion == "v2.1" {
domain = "generalv2"
}
if xunfeiResponse.Header.Code != 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: xunfeiResponse.Header.Message,
Type: "xunfei_error",
Param: "",
Code: xunfeiResponse.Header.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
return domain, authUrl
}

View File

@@ -44,6 +44,25 @@ type GeneralOpenAIRequest struct {
Functions any `json:"functions,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
@@ -177,6 +196,7 @@ func Relay(c *gin.Context) {
err = relayTextHelper(c, relayMode)
}
if err != nil {
requestId := c.GetString(common.RequestIdKey)
retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" {
@@ -188,12 +208,13 @@ func Relay(c *gin.Context) {
if err.StatusCode == http.StatusTooManyRequests {
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
}
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
c.JSON(err.StatusCode, gin.H{
"error": err.OpenAIError,
})
}
channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
channelId := c.GetInt("channel_id")

View File

@@ -523,5 +523,6 @@
"按照如下格式输入:": "Enter in the following format:",
"模型版本": "Model version",
"请输入星火大模型版本注意是接口地址中的版本号例如v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
"点击查看": "click to view"
"点击查看": "click to view",
"请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!"
}

14
main.go
View File

@@ -21,7 +21,7 @@ var buildFS embed.FS
var indexPage []byte
func main() {
common.SetupGinLog()
common.SetupLogger()
common.SysLog("One API " + common.Version + " started")
if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode)
@@ -77,14 +77,20 @@ func main() {
}
go controller.AutomaticallyTestChannels(frequency)
}
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
model.InitBatchUpdater()
}
controller.InitTokenEncoders()
// Initialize HTTP server
server := gin.Default()
server := gin.New()
server.Use(gin.Recovery())
// This will cause SSE not to work!!!
//server.Use(gzip.Gzip(gzip.DefaultCompression))
server.Use(middleware.CORS())
server.Use(middleware.RequestId())
middleware.SetUpLogger(server)
// Initialize session store
store := cookie.NewStore([]byte(common.SessionSecret))
server.Use(sessions.Sessions("session", store))

View File

@@ -91,23 +91,16 @@ func TokenAuth() func(c *gin.Context) {
key = parts[0]
token, err := model.ValidateUserToken(key)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
c.Abort()
abortWithMessage(c, http.StatusUnauthorized, err.Error())
return
}
if !model.CacheIsUserEnabled(token.UserId) {
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"message": "用户已被封禁",
"type": "one_api_error",
},
})
c.Abort()
userEnabled, err := model.IsUserEnabled(token.UserId)
if err != nil {
abortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !userEnabled {
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
c.Set("id", token.UserId)
@@ -123,13 +116,7 @@ func TokenAuth() func(c *gin.Context) {
if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1])
} else {
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"message": "普通用户不支持指定渠道",
"type": "one_api_error",
},
})
c.Abort()
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return
}
}

View File

@@ -25,34 +25,16 @@ func Distribute() func(c *gin.Context) {
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "无效的渠道 ID",
"type": "one_api_error",
},
})
c.Abort()
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
return
}
channel, err = model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "无效的渠道 ID",
"type": "one_api_error",
},
})
c.Abort()
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
return
}
if channel.Status != common.ChannelStatusEnabled {
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"message": "该渠道已被禁用",
"type": "one_api_error",
},
})
c.Abort()
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
return
}
} else {
@@ -63,13 +45,7 @@ func Distribute() func(c *gin.Context) {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "无效的请求",
"type": "one_api_error",
},
})
c.Abort()
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
@@ -99,24 +75,23 @@ func Distribute() func(c *gin.Context) {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"message": message,
"type": "one_api_error",
},
})
c.Abort()
abortWithMessage(c, http.StatusServiceUnavailable, message)
return
}
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("model_mapping", channel.ModelMapping)
c.Set("model_mapping", channel.GetModelMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.BaseURL)
if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei {
c.Set("base_url", channel.GetBaseURL())
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other)
}
c.Next()
}

25
middleware/logger.go Normal file
View File

@@ -0,0 +1,25 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
)
func SetUpLogger(server *gin.Engine) {
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var requestID string
if param.Keys != nil {
requestID = param.Keys[common.RequestIdKey].(string)
}
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
requestID,
param.StatusCode,
param.Latency,
param.ClientIP,
param.Method,
param.Path,
)
}))
}

18
middleware/request-id.go Normal file
View File

@@ -0,0 +1,18 @@
package middleware
import (
"context"
"github.com/gin-gonic/gin"
"one-api/common"
)
func RequestId() func(c *gin.Context) {
return func(c *gin.Context) {
id := common.GetTimeString() + common.GetRandomString(8)
c.Set(common.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
c.Request = c.Request.WithContext(ctx)
c.Header(common.RequestIdKey, id)
c.Next()
}
}

17
middleware/utils.go Normal file
View File

@@ -0,0 +1,17 @@
package middleware
import (
"github.com/gin-gonic/gin"
"one-api/common"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
"type": "one_api_error",
},
})
c.Abort()
common.LogError(c.Request.Context(), message)
}

View File

@@ -10,15 +10,16 @@ type Ability struct {
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
}
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{}
var err error = nil
if common.UsingSQLite {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error
} else {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error
}
if err != nil {
return nil, err
@@ -40,6 +41,7 @@ func (channel *Channel) AddAbilities() error {
Model: model,
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority,
}
abilities = append(abilities, ability)
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"math/rand"
"one-api/common"
"sort"
"strconv"
"strings"
"sync"
@@ -103,23 +104,28 @@ func CacheDecreaseUserQuota(id int, quota int) error {
return err
}
func CacheIsUserEnabled(userId int) bool {
func CacheIsUserEnabled(userId int) (bool, error) {
if !common.RedisEnabled {
return IsUserEnabled(userId)
}
enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
if err != nil {
status := common.UserStatusDisabled
if IsUserEnabled(userId) {
status = common.UserStatusEnabled
}
enabled = fmt.Sprintf("%d", status)
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set user enabled error: " + err.Error())
}
if err == nil {
return enabled == "1", nil
}
return enabled == "1"
userEnabled, err := IsUserEnabled(userId)
if err != nil {
return false, err
}
enabled = "0"
if userEnabled {
enabled = "1"
}
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set user enabled error: " + err.Error())
}
return userEnabled, err
}
var group2model2channels map[string]map[string][]*Channel
@@ -154,6 +160,17 @@ func InitChannelCache() {
}
}
}
// sort by priority
for group, model2channels := range newGroup2model2channels {
for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool {
return channels[i].GetPriority() > channels[j].GetPriority()
})
newGroup2model2channels[group][model] = channels
}
}
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
channelSyncLock.Unlock()
@@ -178,6 +195,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
if len(channels) == 0 {
return nil, errors.New("channel not found")
}
idx := rand.Intn(len(channels))
endIdx := len(channels)
// choose by priority
firstChannel := channels[0]
if firstChannel.GetPriority() > 0 {
for i := range channels {
if channels[i].GetPriority() != firstChannel.GetPriority() {
endIdx = i
break
}
}
}
idx := rand.Intn(endIdx)
return channels[idx], nil
}

View File

@@ -15,14 +15,15 @@ type Channel struct {
CreatedTime int64 `json:"created_time" gorm:"bigint"`
TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds
BaseURL string `json:"base_url" gorm:"column:base_url"`
BaseURL *string `json:"base_url" gorm:"column:base_url"`
Other string `json:"other"`
Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"`
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
}
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@@ -78,6 +79,27 @@ func BatchInsertChannels(channels []Channel) error {
return nil
}
func (channel *Channel) GetPriority() int64 {
if channel.Priority == nil {
return 0
}
return *channel.Priority
}
func (channel *Channel) GetBaseURL() string {
if channel.BaseURL == nil {
return ""
}
return *channel.BaseURL
}
func (channel *Channel) GetModelMapping() string {
if channel.ModelMapping == nil {
return ""
}
return *channel.ModelMapping
}
func (channel *Channel) Insert() error {
var err error
err = DB.Create(channel).Error
@@ -141,6 +163,14 @@ func UpdateChannelStatusById(id int, status int) {
}
func UpdateChannelUsedQuota(id int, quota int) {
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
return
}
updateChannelUsedQuota(id, quota)
}
func updateChannelUsedQuota(id int, quota int) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
if err != nil {
common.SysError("failed to update channel used quota: " + err.Error())

View File

@@ -1,6 +1,8 @@
package model
import (
"context"
"fmt"
"gorm.io/gorm"
"one-api/common"
)
@@ -17,6 +19,7 @@ type Log struct {
Quota int `json:"quota" gorm:"default:0"`
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
Channel int `json:"channel" gorm:"default:0"`
}
const (
@@ -44,7 +47,9 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled {
return
}
@@ -59,14 +64,15 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
TokenName: tokenName,
ModelName: modelName,
Quota: quota,
Channel: channelId,
}
err := DB.Create(log).Error
if err != nil {
common.SysError("failed to record log: " + err.Error())
common.LogError(ctx, "failed to record log: "+err.Error())
}
}
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = DB
@@ -88,6 +94,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp)
}
if channel != 0 {
tx = tx.Where("channel = ?", channel)
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
return logs, err
}
@@ -125,7 +134,7 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
return logs, err
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) {
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) {
tx := DB.Table("logs").Select("sum(quota)")
if username != "" {
tx = tx.Where("username = ?", username)
@@ -142,6 +151,9 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
if modelName != "" {
tx = tx.Where("model_name = ?", modelName)
}
if channel != 0 {
tx = tx.Where("channel = ?", channel)
}
tx.Where("type = ?", LogTypeConsume).Scan(&quota)
return quota
}
@@ -166,3 +178,8 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
tx.Where("type = ?", LogTypeConsume).Scan(&token)
return token
}
func DeleteOldLog(targetTimestamp int64) (int64, error) {
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
return result.RowsAffected, result.Error
}

View File

@@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) {
}
token, err = CacheGetTokenByKey(key)
if err == nil {
if token.Status == common.TokenStatusExhausted {
return nil, errors.New("该令牌额度已用尽")
} else if token.Status == common.TokenStatusExpired {
return nil, errors.New("该令牌已过期")
}
if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用")
}
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
token.Status = common.TokenStatusExpired
err := token.SelectUpdate()
if err != nil {
common.SysError("failed to update token status" + err.Error())
if !common.RedisEnabled {
token.Status = common.TokenStatusExpired
err := token.SelectUpdate()
if err != nil {
common.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New("该令牌已过期")
}
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
token.Status = common.TokenStatusExhausted
err := token.SelectUpdate()
if err != nil {
common.SysError("failed to update token status" + err.Error())
if !common.RedisEnabled {
// in this case, we can make sure the token is exhausted
token.Status = common.TokenStatusExhausted
err := token.SelectUpdate()
if err != nil {
common.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New("该令牌额度已用尽")
}
go func() {
token.AccessedTime = common.GetTimestamp()
err := token.SelectUpdate()
if err != nil {
common.SysError("failed to update token" + err.Error())
}
}()
return token, nil
}
return nil, errors.New("无效的令牌")
@@ -131,10 +134,19 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
return nil
}
return increaseTokenQuota(id, quota)
}
func increaseTokenQuota(id int, quota int) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota + ?", quota),
"used_quota": gorm.Expr("used_quota - ?", quota),
"remain_quota": gorm.Expr("remain_quota + ?", quota),
"used_quota": gorm.Expr("used_quota - ?", quota),
"accessed_time": common.GetTimestamp(),
},
).Error
return err
@@ -144,10 +156,19 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
return nil
}
return decreaseTokenQuota(id, quota)
}
func decreaseTokenQuota(id int, quota int) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota - ?", quota),
"used_quota": gorm.Expr("used_quota + ?", quota),
"remain_quota": gorm.Expr("remain_quota - ?", quota),
"used_quota": gorm.Expr("used_quota + ?", quota),
"accessed_time": common.GetTimestamp(),
},
).Error
return err

View File

@@ -226,17 +226,16 @@ func IsAdmin(userId int) bool {
return user.Role >= common.RoleAdminUser
}
func IsUserEnabled(userId int) bool {
func IsUserEnabled(userId int) (bool, error) {
if userId == 0 {
return false
return false, errors.New("user id is empty")
}
var user User
err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
if err != nil {
common.SysError("no such user " + err.Error())
return false
return false, err
}
return user.Status == common.UserStatusEnabled
return user.Status == common.UserStatusEnabled, nil
}
func ValidateAccessToken(token string) (user *User) {
@@ -275,6 +274,14 @@ func IncreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil
}
return increaseUserQuota(id, quota)
}
func increaseUserQuota(id int, quota int) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
return err
}
@@ -283,6 +290,14 @@ func DecreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
return nil
}
return decreaseUserQuota(id, quota)
}
func decreaseUserQuota(id int, quota int) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
return err
}
@@ -293,10 +308,18 @@ func GetRootUserEmail() (email string) {
}
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota)
return
}
updateUserUsedQuotaAndRequestCount(id, quota, 1)
}
func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota),
"request_count": gorm.Expr("request_count + ?", 1),
"request_count": gorm.Expr("request_count + ?", count),
},
).Error
if err != nil {

75
model/utils.go Normal file
View File

@@ -0,0 +1,75 @@
package model
import (
"one-api/common"
"sync"
"time"
)
const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock
const (
BatchUpdateTypeUserQuota = iota
BatchUpdateTypeTokenQuota
BatchUpdateTypeUsedQuotaAndRequestCount
BatchUpdateTypeChannelUsedQuota
)
var batchUpdateStores []map[int]int
var batchUpdateLocks []sync.Mutex
func init() {
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateStores = append(batchUpdateStores, make(map[int]int))
batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
}
}
func InitBatchUpdater() {
go func() {
for {
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
batchUpdate()
}
}()
}
func addNewRecord(type_ int, id int, value int) {
batchUpdateLocks[type_].Lock()
defer batchUpdateLocks[type_].Unlock()
if _, ok := batchUpdateStores[type_][id]; !ok {
batchUpdateStores[type_][id] = value
} else {
batchUpdateStores[type_][id] += value
}
}
func batchUpdate() {
common.SysLog("batch update started")
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
store := batchUpdateStores[i]
batchUpdateStores[i] = make(map[int]int)
batchUpdateLocks[i].Unlock()
for key, value := range store {
switch i {
case BatchUpdateTypeUserQuota:
err := increaseUserQuota(key, value)
if err != nil {
common.SysError("failed to batch update user quota: " + err.Error())
}
case BatchUpdateTypeTokenQuota:
err := increaseTokenQuota(key, value)
if err != nil {
common.SysError("failed to batch update token quota: " + err.Error())
}
case BatchUpdateTypeUsedQuotaAndRequestCount:
updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect
case BatchUpdateTypeChannelUsedQuota:
updateChannelUsedQuota(key, value)
}
}
}
common.SysLog("batch update finished")
}

View File

@@ -21,6 +21,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
@@ -97,6 +98,7 @@ func SetApiRouter(router *gin.Engine) {
}
logRoute := apiRouter.Group("/log")
logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs)
logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat)
logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat)
logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)

View File

@@ -8,6 +8,7 @@ import (
)
func SetRelayRouter(router *gin.Engine) {
router.Use(middleware.CORS())
// https://platform.openai.com/docs/api-reference/introduction
modelsRouter := router.Group("/v1/models")
modelsRouter.Use(middleware.TokenAuth())

View File

@@ -1,7 +1,7 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react';
import { Link } from 'react-router-dom';
import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers';
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
import { renderGroup, renderNumber } from '../helpers/render';
@@ -24,7 +24,7 @@ function renderType(type) {
}
type2label[0] = { value: 0, text: '未知类型', color: 'grey' };
}
return <Label basic color={type2label[type].color}>{type2label[type].text}</Label>;
return <Label basic color={type2label[type]?.color}>{type2label[type]?.text}</Label>;
}
function renderBalance(type, balance) {
@@ -96,7 +96,7 @@ const ChannelsTable = () => {
});
}, []);
const manageChannel = async (id, action, idx) => {
const manageChannel = async (id, action, idx, priority) => {
let data = { id };
let res;
switch (action) {
@@ -111,6 +111,13 @@ const ChannelsTable = () => {
data.status = 2;
res = await API.put('/api/channel/', data);
break;
case 'priority':
if (priority === '') {
return;
}
data.priority = parseInt(priority);
res = await API.put('/api/channel/', data);
break;
}
const { success, message } = res.data;
if (success) {
@@ -195,6 +202,7 @@ const ChannelsTable = () => {
showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
} else {
showError(message);
showNotice("当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。")
}
};
@@ -334,6 +342,14 @@ const ChannelsTable = () => {
>
余额
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortChannel('priority');
}}
>
优先级
</Table.HeaderCell>
<Table.HeaderCell>操作</Table.HeaderCell>
</Table.Row>
</Table.Header>
@@ -372,6 +388,22 @@ const ChannelsTable = () => {
basic
/>
</Table.Cell>
<Table.Cell>
<Popup
trigger={<Input type="number" defaultValue={channel.priority} onBlur={(event) => {
manageChannel(
channel.id,
'priority',
idx,
event.target.value,
);
}}>
<input style={{maxWidth:'60px'}} />
</Input>}
content='渠道选择优先级,越高越优先'
basic
/>
</Table.Cell>
<Table.Cell>
<div>
<Button
@@ -440,7 +472,7 @@ const ChannelsTable = () => {
<Table.Footer>
<Table.Row>
<Table.HeaderCell colSpan='8'>
<Table.HeaderCell colSpan='9'>
<Button size='small' as={Link} to='/channel/add' loading={loading}>
添加新的渠道
</Button>

View File

@@ -13,8 +13,8 @@ const GitHubOAuth = () => {
let navigate = useNavigate();
const sendCode = async (code, count) => {
const res = await API.get(`/api/oauth/github?code=${code}`);
const sendCode = async (code, state, count) => {
const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`);
const { success, message, data } = res.data;
if (success) {
if (message === 'bind') {
@@ -36,13 +36,14 @@ const GitHubOAuth = () => {
count++;
setPrompt(`出现错误,第 ${count} 次重试中...`);
await new Promise((resolve) => setTimeout(resolve, count * 2000));
await sendCode(code, count);
await sendCode(code, state, count);
}
};
useEffect(() => {
let code = searchParams.get('code');
sendCode(code, 0).then();
let state = searchParams.get('state');
sendCode(code, state, 0).then();
}, []);
return (

View File

@@ -3,6 +3,7 @@ import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } f
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
import { UserContext } from '../context/User';
import { API, getLogo, showError, showSuccess } from '../helpers';
import { getOAuthState, onGitHubOAuthClicked } from './utils';
const LoginForm = () => {
const [inputs, setInputs] = useState({
@@ -31,12 +32,6 @@ const LoginForm = () => {
const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false);
const onGitHubOAuthClicked = () => {
window.open(
`https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email`
);
};
const onWeChatLoginClicked = () => {
setShowWeChatLoginModal(true);
};
@@ -131,7 +126,7 @@ const LoginForm = () => {
circular
color='black'
icon='github'
onClick={onGitHubOAuthClicked}
onClick={()=>onGitHubOAuthClicked(status.github_client_id)}
/>
) : (
<></>

View File

@@ -56,9 +56,10 @@ const LogsTable = () => {
token_name: '',
model_name: '',
start_timestamp: timestamp2string(0),
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600)
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
channel: ''
});
const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs;
const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs;
const [stat, setStat] = useState({
quota: 0,
@@ -84,7 +85,7 @@ const LogsTable = () => {
const getLogStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`);
const { success, message, data } = res.data;
if (success) {
setStat(data);
@@ -109,7 +110,7 @@ const LogsTable = () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
if (isAdminUser) {
url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
} else {
url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
}
@@ -205,16 +206,9 @@ const LogsTable = () => {
</Header>
<Form>
<Form.Group>
{
isAdminUser && (
<Form.Input fluid label={'用户名称'} width={2} value={username}
placeholder={'可选值'} name='username'
onChange={handleInputChange} />
)
}
<Form.Input fluid label={'令牌名称'} width={isAdminUser ? 2 : 3} value={token_name}
<Form.Input fluid label={'令牌名称'} width={3} value={token_name}
placeholder={'可选值'} name='token_name' onChange={handleInputChange} />
<Form.Input fluid label='模型名称' width={isAdminUser ? 2 : 3} value={model_name} placeholder='可选值'
<Form.Input fluid label='模型名称' width={3} value={model_name} placeholder='可选值'
name='model_name'
onChange={handleInputChange} />
<Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local'
@@ -225,6 +219,19 @@ const LogsTable = () => {
onChange={handleInputChange} />
<Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button>
</Form.Group>
{
isAdminUser && <>
<Form.Group>
<Form.Input fluid label={'渠道 ID'} width={3} value={channel}
placeholder='可选值' name='channel'
onChange={handleInputChange} />
<Form.Input fluid label={'用户名称'} width={3} value={username}
placeholder={'可选值'} name='username'
onChange={handleInputChange} />
</Form.Group>
</>
}
</Form>
<Table basic compact size='small'>
<Table.Header>
@@ -238,6 +245,17 @@ const LogsTable = () => {
>
时间
</Table.HeaderCell>
{
isAdminUser && <Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('channel');
}}
width={1}
>
渠道
</Table.HeaderCell>
}
{
isAdminUser && <Table.HeaderCell
style={{ cursor: 'pointer' }}
@@ -299,16 +317,16 @@ const LogsTable = () => {
onClick={() => {
sortLog('quota');
}}
width={2}
width={1}
>
消耗额度
额度
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('content');
}}
width={isAdminUser ? 4 : 5}
width={isAdminUser ? 4 : 6}
>
详情
</Table.HeaderCell>
@@ -324,8 +342,13 @@ const LogsTable = () => {
.map((log, idx) => {
if (log.deleted) return <></>;
return (
<Table.Row key={log.created_at}>
<Table.Row key={log.id}>
<Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
{
isAdminUser && (
<Table.Cell>{log.channel ? <Label basic>{log.channel}</Label> : ''}</Table.Cell>
)
}
{
isAdminUser && (
<Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell>
@@ -345,7 +368,7 @@ const LogsTable = () => {
<Table.Footer>
<Table.Row>
<Table.HeaderCell colSpan={'9'}>
<Table.HeaderCell colSpan={'10'}>
<Select
placeholder='选择明细分类'
options={LOG_OPTIONS}

View File

@@ -1,8 +1,9 @@
import React, { useEffect, useState } from 'react';
import { Divider, Form, Grid, Header } from 'semantic-ui-react';
import { API, showError, verifyJSON } from '../helpers';
import { API, showError, showSuccess, timestamp2string, verifyJSON } from '../helpers';
const OperationSetting = () => {
let now = new Date();
let [inputs, setInputs] = useState({
QuotaForNewUser: 0,
QuotaForInviter: 0,
@@ -20,10 +21,11 @@ const OperationSetting = () => {
DisplayInCurrencyEnabled: '',
DisplayTokenStatEnabled: '',
ApproximateTokenEnabled: '',
RetryTimes: 0,
RetryTimes: 0
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);
let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
const getOptions = async () => {
const res = await API.get('/api/option/');
@@ -130,6 +132,17 @@ const OperationSetting = () => {
}
};
const deleteHistoryLogs = async () => {
console.log(inputs);
const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
const { success, message, data } = res.data;
if (success) {
showSuccess(`${data} 条日志已清理!`);
return;
}
showError('日志清理失败:' + message);
};
return (
<Grid columns={1}>
<Grid.Column>
@@ -179,12 +192,6 @@ const OperationSetting = () => {
/>
</Form.Group>
<Form.Group inline>
<Form.Checkbox
checked={inputs.LogConsumeEnabled === 'true'}
label='启用额度消费日志记录'
name='LogConsumeEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.DisplayInCurrencyEnabled === 'true'}
label='以货币形式显示额度'
@@ -208,6 +215,28 @@ const OperationSetting = () => {
submitConfig('general').then();
}}>保存通用设置</Form.Button>
<Divider />
<Header as='h3'>
日志设置
</Header>
<Form.Group inline>
<Form.Checkbox
checked={inputs.LogConsumeEnabled === 'true'}
label='启用额度消费日志记录'
name='LogConsumeEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Group widths={4}>
<Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
name='history_timestamp'
onChange={(e, { name, value }) => {
setHistoryTimestamp(value);
}} />
</Form.Group>
<Form.Button onClick={() => {
deleteHistoryLogs().then();
}}>清理历史日志</Form.Button>
<Divider />
<Header as='h3'>
监控设置
</Header>

View File

@@ -4,6 +4,7 @@ import { Link, useNavigate } from 'react-router-dom';
import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
import Turnstile from 'react-turnstile';
import { UserContext } from '../context/User';
import { onGitHubOAuthClicked } from './utils';
const PersonalSetting = () => {
const [userState, userDispatch] = useContext(UserContext);
@@ -130,12 +131,6 @@ const PersonalSetting = () => {
}
};
const openGitHubOAuth = () => {
window.open(
`https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email`
);
};
const sendVerificationCode = async () => {
setDisableButton(true);
if (inputs.email === '') return;
@@ -249,7 +244,7 @@ const PersonalSetting = () => {
</Modal>
{
status.github_oauth && (
<Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button>
<Button onClick={()=>{onGitHubOAuthClicked(status.github_client_id)}}>绑定 GitHub 账号</Button>
)
}
<Button

View File

@@ -96,7 +96,7 @@ const TokensTable = () => {
let nextUrl;
if (nextLink) {
nextUrl = nextLink + `/#/?settings={"key":"sk-${key}"}`;
nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
} else {
nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
}

View File

@@ -0,0 +1,20 @@
import { API, showError } from '../helpers';
export async function getOAuthState() {
const res = await API.get('/api/oauth/state');
const { success, message, data } = res.data;
if (success) {
return data;
} else {
showError(message);
return '';
}
}
export async function onGitHubOAuthClicked(github_client_id) {
const state = await getOAuthState();
if (!state) return;
window.open(
`https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`
);
}

View File

@@ -9,6 +9,8 @@ export const CHANNEL_OPTIONS = [
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
{ key: 19, text: '360 智脑', value: 19, color: 'blue' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 22, text: '知识库FastGPT', value: 22, color: 'blue' },
{ key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' },
{ key: 20, text: '代理OpenRouter', value: 20, color: 'black' },
{ key: 2, text: '代理API2D', value: 2, color: 'blue' },
{ key: 5, text: '代理OpenAI-SB', value: 5, color: 'brown' },

View File

@@ -10,6 +10,20 @@ const MODEL_MAPPING_EXAMPLE = {
'gpt-4-32k-0314': 'gpt-4-32k'
};
function type2secretPrompt(type) {
// inputs.type === 15 ? '按照如下格式输入APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
switch (type) {
case 15:
return '按照如下格式输入APIKey|SecretKey';
case 18:
return '按照如下格式输入APPID|APISecret|APIKey';
case 22:
return '按照如下格式输入APIKey-AppId例如fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041';
default:
return '请输入渠道对应的鉴权密钥';
}
}
const EditChannel = () => {
const params = useParams();
const navigate = useNavigate();
@@ -53,7 +67,7 @@ const EditChannel = () => {
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
break;
case 17:
localModels = ['qwen-v1', 'qwen-plus-v1'];
localModels = ['qwen-v1', 'qwen-plus-v1', 'text-embedding-v1'];
break;
case 16:
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
@@ -169,9 +183,6 @@ const EditChannel = () => {
if (localInputs.type === 18 && localInputs.other === '') {
localInputs.other = 'v2.1';
}
if (localInputs.model_mapping === '') {
localInputs.model_mapping = '{}';
}
let res;
localInputs.models = localInputs.models.join(',');
localInputs.group = localInputs.groups.join(',');
@@ -193,6 +204,24 @@ const EditChannel = () => {
}
};
const addCustomModel = () => {
if (customModel.trim() === '') return;
if (inputs.models.includes(customModel)) return;
let localModels = [...inputs.models];
localModels.push(customModel);
let localModelOptions = [];
localModelOptions.push({
key: customModel,
text: customModel,
value: customModel
});
setModelOptions(modelOptions => {
return [...modelOptions, ...localModelOptions];
});
setCustomModel('');
handleInputChange(null, { name: 'models', value: localModels });
};
return (
<>
<Segment loading={loading}>
@@ -295,6 +324,20 @@ const EditChannel = () => {
</Form.Field>
)
}
{
inputs.type === 21 && (
<Form.Field>
<Form.Input
label='知识库 ID'
name='other'
placeholder={'请输入知识库 ID例如123456'}
onChange={handleInputChange}
value={inputs.other}
autoComplete='new-password'
/>
</Form.Field>
)
}
<Form.Field>
<Form.Dropdown
label='模型'
@@ -322,29 +365,19 @@ const EditChannel = () => {
}}>清除所有模型</Button>
<Input
action={
<Button type={'button'} onClick={() => {
if (customModel.trim() === '') return;
if (inputs.models.includes(customModel)) return;
let localModels = [...inputs.models];
localModels.push(customModel);
let localModelOptions = [];
localModelOptions.push({
key: customModel,
text: customModel,
value: customModel
});
setModelOptions(modelOptions => {
return [...modelOptions, ...localModelOptions];
});
setCustomModel('');
handleInputChange(null, { name: 'models', value: localModels });
}}>填入</Button>
<Button type={'button'} onClick={addCustomModel}>填入</Button>
}
placeholder='输入自定义模型名称'
value={customModel}
onChange={(e, { value }) => {
setCustomModel(value);
}}
onKeyDown={(e) => {
if (e.key === 'Enter') {
addCustomModel();
e.preventDefault();
}
}}
/>
</div>
<Form.Field>
@@ -375,7 +408,7 @@ const EditChannel = () => {
label='密钥'
name='key'
required
placeholder={inputs.type === 15 ? '按照如下格式输入APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')}
placeholder={type2secretPrompt(inputs.type)}
onChange={handleInputChange}
value={inputs.key}
autoComplete='new-password'
@@ -393,7 +426,7 @@ const EditChannel = () => {
)
}
{
inputs.type !== 3 && inputs.type !== 8 && (
inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
<Form.Field>
<Form.Input
label='代理'
@@ -406,6 +439,20 @@ const EditChannel = () => {
</Form.Field>
)
}
{
inputs.type === 22 && (
<Form.Field>
<Form.Input
label='私有部署地址'
name='base_url'
placeholder={'请输入私有部署地址格式为https://fastgpt.run/api/openapi'}
onChange={handleInputChange}
value={inputs.base_url}
autoComplete='new-password'
/>
</Form.Field>
)
}
<Button onClick={handleCancel}>取消</Button>
<Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button>
</Form>