Compare commits

...

24 Commits

Author SHA1 Message Date
JustSong
fa79e8b7a3 fix: use gpt-3.5's encoder if not found (close #110) 2023-05-21 11:11:19 +08:00
JustSong
1cc7c20183 chore: prompt user if redemption code not input 2023-05-21 10:32:47 +08:00
JustSong
2eee97e9b6 style: add comma to quota stat 2023-05-21 10:15:30 +08:00
JustSong
a3a1b612b0 chore: set initial quota for root user 2023-05-21 10:05:34 +08:00
JustSong
61e682ca47 feat: able to manage user's quota now 2023-05-21 10:01:02 +08:00
JustSong
b383983106 docs: update README 2023-05-21 09:18:23 +08:00
JustSong
cfd587117e feat: support channel AI Proxy now 2023-05-20 17:24:56 +08:00
JustSong
ef9dca28f5 chore: set default value for Azure's api version if not set 2023-05-19 22:13:29 +08:00
JustSong
741c0b9c18 docs: update README (#103) 2023-05-19 15:58:01 +08:00
JustSong
3711f4a741 feat: support channel ai.ls now (close #99) 2023-05-19 11:07:17 +08:00
quzard
7c6bf3e97b fix: make the token number calculation more accurate (#101)
* Make token calculation more accurate.

* fix: make the token number calculation more accurate

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-05-19 09:41:26 +08:00
JustSong
481ba41fbd docs: update README 2023-05-18 18:13:57 +08:00
JustSong
2779d6629c fix: add X-Accel-Buffering header on SSE response 2023-05-18 17:16:34 +08:00
JustSong
e509899daf docs: update README (close #97) 2023-05-18 16:18:45 +08:00
JustSong
b53cdbaf05 docs: update README 2023-05-18 15:57:40 +08:00
JustSong
ced89398a5 chore: rewrite 429 prompt text (close #96) 2023-05-18 15:27:15 +08:00
JustSong
09c2e3bcec docs: fix typo 2023-05-18 12:50:47 +08:00
JustSong
5cba800fa6 docs: fix typo 2023-05-18 12:50:19 +08:00
JustSong
2d39a135f2 feat: now slave server can sync options with master server (close #88) 2023-05-18 12:48:20 +08:00
JustSong
3c6834a79c feat: support redirecting frontend url now (close #89) 2023-05-18 12:26:18 +08:00
JustSong
6da3410823 fix: fix channel test error checking 2023-05-18 11:41:03 +08:00
JustSong
ceb289cb4d fix: handel error response from server correctly (close #90) 2023-05-18 11:11:15 +08:00
JustSong
6f8cc712b0 docs: update README 2023-05-17 23:26:30 +08:00
JustSong
ad01e1f3b3 fix: fix error log not recorded (close #83) 2023-05-17 20:20:48 +08:00
15 changed files with 250 additions and 78 deletions

View File

@@ -38,6 +38,8 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
<a href="https://github.com/songquanpeng/one-api#截图展示">截图展示</a> <a href="https://github.com/songquanpeng/one-api#截图展示">截图展示</a>
· ·
<a href="https://openai.justsong.cn/">在线演示</a> <a href="https://openai.justsong.cn/">在线演示</a>
·
<a href="https://github.com/songquanpeng/one-api#常见问题">常见问题</a>
</p> </p>
> **Warning**:从 `v0.2` 版本升级到 `v0.3` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.2-v0.3.sql)。 > **Warning**:从 `v0.2` 版本升级到 `v0.3` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.2-v0.3.sql)。
@@ -48,26 +50,29 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
+ [x] OpenAI 官方通道 + [x] OpenAI 官方通道
+ [x] **Azure OpenAI API** + [x] **Azure OpenAI API**
+ [x] [API2D](https://api2d.com/r/197971) + [x] [API2D](https://api2d.com/r/197971)
+ [x] [CloseAI](https://console.openai-asia.com) + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
+ [x] [OpenAI-SB](https://openai-sb.com) + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`
+ [x] [AI.LS](https://ai.ls)
+ [x] [OpenAI Max](https://openaimax.com) + [x] [OpenAI Max](https://openaimax.com)
+ [x] [OhMyGPT](https://www.ohmygpt.com) + [x] [OpenAI-SB](https://openai-sb.com)
+ [x] [CloseAI](https://console.openai-asia.com/r/2412)
+ [x] 自定义渠道:例如使用自行搭建的 OpenAI 代理 + [x] 自定义渠道:例如使用自行搭建的 OpenAI 代理
2. 支持通过**负载均衡**的方式访问多个渠道。 2. 支持通过**负载均衡**的方式访问多个渠道。
3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
4. 支持**令牌管理**,设置令牌的过期时间和使用次数 4. 支持**多机部署**[详见此处](#多机部署)
5. 支持**兑换码管理**支持批量生成和导出兑换码,可使用兑换码为令牌进行充值 5. 支持**令牌管理**设置令牌的过期时间和使用次数
6. 支持**通道管理**批量创建通道 6. 支持**兑换码管理**支持批量生成和导出兑换码,可使用兑换码为账户进行充值
7. 支持发布公告,设置充值链接,设置新用户初始额度 7. 支持**通道管理**,批量创建通道
8. 支持丰富的**自定义**设置, 8. 支持发布公告,设置充值链接,设置新用户初始额度。
9. 支持丰富的**自定义**设置,
1. 支持自定义系统名称logo 以及页脚。 1. 支持自定义系统名称logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
9. 支持通过系统访问令牌访问管理 API。 10. 支持通过系统访问令牌访问管理 API。
10. 支持用户管理,支持**多种用户登录注册方式** 11. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册以及通过邮箱进行密码重置。 + 邮箱登录注册以及通过邮箱进行密码重置。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。 + [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
11. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。 12. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
## 部署 ## 部署
### 基于 Docker 进行部署 ### 基于 Docker 进行部署
@@ -90,13 +95,10 @@ server{
proxy_set_header X-Forwarded-For $remote_addr; proxy_set_header X-Forwarded-For $remote_addr;
proxy_cache_bypass $http_upgrade; proxy_cache_bypass $http_upgrade;
proxy_set_header Accept-Encoding gzip; proxy_set_header Accept-Encoding gzip;
proxy_buffering off; # 重要:关闭代理缓冲
} }
} }
``` ```
注意,为了 SSE 正常工作,需要关闭 Nginx 的代理缓冲。
之后使用 Let's Encrypt 的 certbot 配置 HTTPS 之后使用 Let's Encrypt 的 certbot 配置 HTTPS
```bash ```bash
# Ubuntu 安装 certbot # Ubuntu 安装 certbot
@@ -133,6 +135,14 @@ sudo service nginx restart
更加详细的部署教程[参见此处](https://iamazing.cn/page/how-to-deploy-a-website)。 更加详细的部署教程[参见此处](https://iamazing.cn/page/how-to-deploy-a-website)。
### 多机部署
1. 所有服务器 `SESSION_SECRET` 设置一样的值。
2. 必须设置 `SQL_DSN`,使用 MySQL 数据库而非 SQLite请自行配置主备数据库同步。
3. 所有从服务器必须设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。
4. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。
环境变量的具体使用方法详见[此处](#环境变量)。
## 配置 ## 配置
系统本身开箱即用。 系统本身开箱即用。
@@ -157,6 +167,10 @@ sudo service nginx restart
+ 例子:`SESSION_SECRET=random_string` + 例子:`SESSION_SECRET=random_string`
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite。 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite。
+ 例子:`SQL_DSN=root:123456@tcp(localhost:3306)/one-api` + 例子:`SQL_DSN=root:123456@tcp(localhost:3306)/one-api`
4. `FRONTEND_BASE_URL`:设置之后将使用指定的前端地址,而非后端地址。
+ 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步。
+ 例子:`SYNC_FREQUENCY=60`
### 命令行参数 ### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
@@ -174,3 +188,10 @@ https://openai.justsong.cn
### 截图展示 ### 截图展示
![channel](https://user-images.githubusercontent.com/39998050/233837954-ae6683aa-5c4f-429f-a949-6645a83c9490.png) ![channel](https://user-images.githubusercontent.com/39998050/233837954-ae6683aa-5c4f-429f-a949-6645a83c9490.png)
![token](https://user-images.githubusercontent.com/39998050/233837971-dab488b7-6d96-43af-b640-a168e8d1c9bf.png) ![token](https://user-images.githubusercontent.com/39998050/233837971-dab488b7-6d96-43af-b640-a168e8d1c9bf.png)
## 常见问题
1. 账户额度足够为什么提示额度不足?
+ 请检查你的令牌额度是否足够,这个和账户额度是分开的。
+ 令牌额度仅供用户设置最大使用量,用户可自由设置。
2. 宝塔部署后访问出现空白页面?
+ 自动配置的问题,详见[#97](https://github.com/songquanpeng/one-api/issues/97)。

View File

@@ -127,6 +127,8 @@ const (
ChannelTypeOpenAIMax = 6 ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7 ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8 ChannelTypeCustom = 8
ChannelTypeAILS = 9
ChannelTypeAIProxy = 10
) )
var ChannelBaseURLs = []string{ var ChannelBaseURLs = []string{
@@ -139,4 +141,6 @@ var ChannelBaseURLs = []string{
"https://api.openaimax.com", // 6 "https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7 "https://api.ohmygpt.com", // 7
"", // 8 "", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
} }

View File

@@ -201,7 +201,7 @@ func testChannel(channel *model.Channel, request *ChatRequest) error {
if err != nil { if err != nil {
return err return err
} }
if response.Error.Type != "" { if response.Error.Message != "" {
return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message)) return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
} }
return nil return nil
@@ -265,14 +265,14 @@ var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false var testAllChannelsRunning bool = false
// disable & notify // disable & notify
func disableChannel(channelId int, channelName string, err error) { func disableChannel(channelId int, channelName string, reason string) {
if common.RootUserEmail == "" { if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail() common.RootUserEmail = model.GetRootUserEmail()
} }
model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled) model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId) subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, err.Error()) content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
err = common.SendEmail(subject, common.RootUserEmail, content) err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error())) common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
} }
@@ -312,7 +312,7 @@ func testAllChannels(c *gin.Context) error {
if milliseconds > disableThreshold { if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
} }
disableChannel(channel.Id, channel.Name, err) disableChannel(channel.Id, channel.Name, err.Error())
} }
channel.UpdateResponseTime(milliseconds) channel.UpdateResponseTime(milliseconds)
} }

65
controller/relay-utils.go Normal file
View File

@@ -0,0 +1,65 @@
package controller
import (
"fmt"
"github.com/pkoukk/tiktoken-go"
"one-api/common"
"strings"
)
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
func getTokenEncoder(model string) *tiktoken.Tiktoken {
if tokenEncoder, ok := tokenEncoderMap[model]; ok {
return tokenEncoder
}
tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil {
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
}
}
tokenEncoderMap[model] = tokenEncoder
return tokenEncoder
}
func countTokenMessages(messages []Message, model string) int {
tokenEncoder := getTokenEncoder(model)
// Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// https://github.com/pkoukk/tiktoken-go/issues/6
//
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
var tokensPerMessage int
var tokensPerName int
if strings.HasPrefix(model, "gpt-3.5") {
tokensPerMessage = 4
tokensPerName = -1 // If there's a name, the role is omitted
} else if strings.HasPrefix(model, "gpt-4") {
tokensPerMessage = 3
tokensPerName = 1
} else {
tokensPerMessage = 3
tokensPerName = 1
}
tokenNum := 0
for _, message := range messages {
tokenNum += tokensPerMessage
tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil))
tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil))
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += len(tokenEncoder.Encode(*message.Name, nil, nil))
}
}
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
return tokenNum
}
func countTokenText(text string, model string) int {
tokenEncoder := getTokenEncoder(model)
token := tokenEncoder.Encode(text, nil, nil)
return len(token)
}

View File

@@ -4,10 +4,8 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
@@ -16,8 +14,9 @@ import (
) )
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
Name *string `json:"name,omitempty"`
} }
type ChatRequest struct { type ChatRequest struct {
@@ -47,6 +46,11 @@ type OpenAIError struct {
Code string `json:"code"` Code string `json:"code"`
} }
type OpenAIErrorWithStatusCode struct {
OpenAIError
StatusCode int `json:"status_code"`
}
type TextResponse struct { type TextResponse struct {
Usage `json:"usage"` Usage `json:"usage"`
Error OpenAIError `json:"error"` Error OpenAIError `json:"error"`
@@ -61,31 +65,39 @@ type StreamResponse struct {
} `json:"choices"` } `json:"choices"`
} }
var tokenEncoder, _ = tiktoken.GetEncoding("cl100k_base")
func countToken(text string) int {
token := tokenEncoder.Encode(text, nil, nil)
return len(token)
}
func Relay(c *gin.Context) { func Relay(c *gin.Context) {
err := relayHelper(c) err := relayHelper(c)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ if err.StatusCode == http.StatusTooManyRequests {
"error": gin.H{ err.OpenAIError.Message = "负载已满,请稍后再试,或升级账户以提升服务质量。"
"message": err.Error(), }
"type": "one_api_error", c.JSON(err.StatusCode, gin.H{
}, "error": err.OpenAIError,
}) })
if common.AutomaticDisableChannelEnabled { channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("Relay error (channel #%d): %s", channelId, err.Message))
if err.Type != "invalid_request_error" && err.StatusCode != http.StatusTooManyRequests &&
common.AutomaticDisableChannelEnabled {
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name") channelName := c.GetString("channel_name")
disableChannel(channelId, channelName, err) disableChannel(channelId, channelName, err.Message)
} }
} }
} }
func relayHelper(c *gin.Context) error { func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
openAIError := OpenAIError{
Message: err.Error(),
Type: "one_api_error",
Code: code,
}
return &OpenAIErrorWithStatusCode{
OpenAIError: openAIError,
StatusCode: statusCode,
}
}
func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
consumeQuota := c.GetBool("consume_quota") consumeQuota := c.GetBool("consume_quota")
@@ -93,15 +105,15 @@ func relayHelper(c *gin.Context) error {
if consumeQuota || channelType == common.ChannelTypeAzure { if consumeQuota || channelType == common.ChannelTypeAzure {
requestBody, err := io.ReadAll(c.Request.Body) requestBody, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
return err return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest)
} }
err = c.Request.Body.Close() err = c.Request.Body.Close()
if err != nil { if err != nil {
return err return errorWrapper(err, "close_request_body_failed", http.StatusBadRequest)
} }
err = json.Unmarshal(requestBody, &textRequest) err = json.Unmarshal(requestBody, &textRequest)
if err != nil { if err != nil {
return err return errorWrapper(err, "unmarshal_request_body_failed", http.StatusBadRequest)
} }
// Reset request body // Reset request body
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
@@ -130,11 +142,8 @@ func relayHelper(c *gin.Context) error {
model_ = strings.TrimSuffix(model_, "-0314") model_ = strings.TrimSuffix(model_, "-0314")
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
} }
var promptText string
for _, message := range textRequest.Messages { promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)
promptText += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
}
promptTokens := countToken(promptText) + 3
preConsumedTokens := common.PreConsumedQuota preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 { if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens preConsumedTokens = promptTokens + textRequest.MaxTokens
@@ -144,12 +153,12 @@ func relayHelper(c *gin.Context) error {
if consumeQuota { if consumeQuota {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil { if err != nil {
return err return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK)
} }
} }
req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body) req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
if err != nil { if err != nil {
return err return errorWrapper(err, "new_request_failed", http.StatusOK)
} }
if channelType == common.ChannelTypeAzure { if channelType == common.ChannelTypeAzure {
key := c.Request.Header.Get("Authorization") key := c.Request.Header.Get("Authorization")
@@ -164,15 +173,15 @@ func relayHelper(c *gin.Context) error {
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return err return errorWrapper(err, "do_request_failed", http.StatusOK)
} }
err = req.Body.Close() err = req.Body.Close()
if err != nil { if err != nil {
return err return errorWrapper(err, "close_request_body_failed", http.StatusOK)
} }
err = c.Request.Body.Close() err = c.Request.Body.Close()
if err != nil { if err != nil {
return err return errorWrapper(err, "close_request_body_failed", http.StatusOK)
} }
var textResponse TextResponse var textResponse TextResponse
isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
@@ -187,8 +196,8 @@ func relayHelper(c *gin.Context) error {
completionRatio = 2 completionRatio = 2
} }
if isStream { if isStream {
completionText := fmt.Sprintf("%s: %s\n", "assistant", streamResponseText) responseTokens := countTokenText(streamResponseText, textRequest.Model)
quota = promptTokens + countToken(completionText)*completionRatio quota = promptTokens + responseTokens*completionRatio
} else { } else {
quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
} }
@@ -223,6 +232,10 @@ func relayHelper(c *gin.Context) error {
go func() { go func() {
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
if len(data) < 6 { // must be something wrong!
common.SysError("Invalid stream response: " + data)
continue
}
dataChan <- data dataChan <- data
data = data[6:] data = data[6:]
if !strings.HasPrefix(data, "[DONE]") { if !strings.HasPrefix(data, "[DONE]") {
@@ -243,6 +256,7 @@ func relayHelper(c *gin.Context) error {
c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive") c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked") c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
@@ -257,50 +271,60 @@ func relayHelper(c *gin.Context) error {
}) })
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return err return errorWrapper(err, "close_response_body_failed", http.StatusOK)
} }
return nil return nil
} else { } else {
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
if consumeQuota { if consumeQuota {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return err return errorWrapper(err, "read_response_body_failed", http.StatusOK)
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return err return errorWrapper(err, "close_response_body_failed", http.StatusOK)
} }
err = json.Unmarshal(responseBody, &textResponse) err = json.Unmarshal(responseBody, &textResponse)
if err != nil { if err != nil {
return err return errorWrapper(err, "unmarshal_response_body_failed", http.StatusOK)
} }
if textResponse.Error.Type != "" { if textResponse.Error.Type != "" {
return errors.New(fmt.Sprintf("type %s, code %s, message %s", return &OpenAIErrorWithStatusCode{
textResponse.Error.Type, textResponse.Error.Code, textResponse.Error.Message)) OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
}
} }
// Reset response body // Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
} }
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the client will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body) _, err = io.Copy(c.Writer, resp.Body)
if err != nil { if err != nil {
return err return errorWrapper(err, "copy_response_body_failed", http.StatusOK)
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return err return errorWrapper(err, "close_response_body_failed", http.StatusOK)
} }
return nil return nil
} }
} }
func RelayNotImplemented(c *gin.Context) { func RelayNotImplemented(c *gin.Context) {
err := OpenAIError{
Message: "API not implemented",
Type: "one_api_error",
Param: "",
Code: "api_not_implemented",
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"error": gin.H{ "error": err,
"message": "Not Implemented",
"type": "one_api_error",
},
}) })
} }

View File

@@ -47,6 +47,13 @@ func main() {
// Initialize options // Initialize options
model.InitOptionMap() model.InitOptionMap()
if os.Getenv("SYNC_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY"))
if err != nil {
common.FatalLog(err)
}
go model.SyncOptions(frequency)
}
// Initialize HTTP server // Initialize HTTP server
server := gin.Default() server := gin.Default()

View File

@@ -26,6 +26,7 @@ func createRootAccountIfNeed() error {
Status: common.UserStatusEnabled, Status: common.UserStatusEnabled,
DisplayName: "Root User", DisplayName: "Root User",
AccessToken: common.GetUUID(), AccessToken: common.GetUUID(),
Quota: 100000000,
} }
DB.Create(&rootUser) DB.Create(&rootUser)
} }

View File

@@ -4,6 +4,7 @@ import (
"one-api/common" "one-api/common"
"strconv" "strconv"
"strings" "strings"
"time"
) )
type Option struct { type Option struct {
@@ -59,6 +60,10 @@ func InitOptionMap() {
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMapRWMutex.Unlock() common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
func loadOptionsFromDatabase() {
options, _ := AllOption() options, _ := AllOption()
for _, option := range options { for _, option := range options {
err := updateOptionMap(option.Key, option.Value) err := updateOptionMap(option.Key, option.Value)
@@ -68,6 +73,14 @@ func InitOptionMap() {
} }
} }
func SyncOptions(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
common.SysLog("Syncing options from database")
loadOptionsFromDatabase()
}
}
func UpdateOption(key string, value string) error { func UpdateOption(key string, value string) error {
// Save to database first // Save to database first
option := Option{ option := Option{

View File

@@ -19,8 +19,7 @@ type User struct {
Email string `json:"email" gorm:"index" validate:"max=50"` Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"` GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
Balance int `json:"balance" gorm:"type:int;default:0"`
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"` Quota int `json:"quota" gorm:"type:int;default:0"`
} }

View File

@@ -2,12 +2,24 @@ package router
import ( import (
"embed" "embed"
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http"
"os"
"strings"
) )
func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
SetApiRouter(router) SetApiRouter(router)
SetDashboardRouter(router) SetDashboardRouter(router)
SetRelayRouter(router) SetRelayRouter(router)
setWebRouter(router, buildFS, indexPage) frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
if frontendBaseUrl == "" {
SetWebRouter(router, buildFS, indexPage)
} else {
frontendBaseUrl = strings.TrimSuffix(frontendBaseUrl, "/")
router.NoRoute(func(c *gin.Context) {
c.Redirect(http.StatusMovedPermanently, fmt.Sprintf("%s%s", frontendBaseUrl, c.Request.RequestURI))
})
}
} }

View File

@@ -10,7 +10,7 @@ import (
"one-api/middleware" "one-api/middleware"
) )
func setWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
router.Use(gzip.Gzip(gzip.DefaultCompression)) router.Use(gzip.Gzip(gzip.DefaultCompression))
router.Use(middleware.GlobalWebRateLimit()) router.Use(middleware.GlobalWebRateLimit())
router.Use(middleware.Cache()) router.Use(middleware.Cache())

View File

@@ -1,10 +1,12 @@
export const CHANNEL_OPTIONS = [ export const CHANNEL_OPTIONS = [
{ key: 1, text: 'OpenAI', value: 1, color: 'green' }, { key: 1, text: 'OpenAI', value: 1, color: 'green' },
{ key: 2, text: 'API2D', value: 2, color: 'blue' }, { key: 8, text: '自定义', value: 8, color: 'pink' },
{ key: 3, text: 'Azure', value: 3, color: 'olive' }, { key: 3, text: 'Azure', value: 3, color: 'olive' },
{ key: 2, text: 'API2D', value: 2, color: 'blue' },
{ key: 4, text: 'CloseAI', value: 4, color: 'teal' }, { key: 4, text: 'CloseAI', value: 4, color: 'teal' },
{ key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' }, { key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' },
{ key: 6, text: 'OpenAI Max', value: 6, color: 'violet' }, { key: 6, text: 'OpenAI Max', value: 6, color: 'violet' },
{ key: 7, text: 'OhMyGPT', value: 7, color: 'purple' }, { key: 7, text: 'OhMyGPT', value: 7, color: 'purple' },
{ key: 8, text: '自定义', value: 8, color: 'pink' } { key: 9, text: 'AI.LS', value: 9, color: 'yellow' },
{ key: 10, text: 'AI Proxy', value: 10, color: 'purple' }
]; ];

View File

@@ -46,6 +46,9 @@ const EditChannel = () => {
if (localInputs.base_url.endsWith('/')) { if (localInputs.base_url.endsWith('/')) {
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
} }
if (localInputs.type === 3 && localInputs.other === '') {
localInputs.other = '2023-03-15-preview';
}
let res; let res;
if (isEdit) { if (isEdit) {
res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) }); res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) });

View File

@@ -1,6 +1,6 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { Button, Form, Grid, Header, Segment, Statistic } from 'semantic-ui-react'; import { Button, Form, Grid, Header, Segment, Statistic } from 'semantic-ui-react';
import { API, showError, showSuccess } from '../../helpers'; import { API, showError, showInfo, showSuccess } from '../../helpers';
const TopUp = () => { const TopUp = () => {
const [redemptionCode, setRedemptionCode] = useState(''); const [redemptionCode, setRedemptionCode] = useState('');
@@ -9,6 +9,7 @@ const TopUp = () => {
const topUp = async () => { const topUp = async () => {
if (redemptionCode === '') { if (redemptionCode === '') {
showInfo('请输入充值码!')
return; return;
} }
const res = await API.post('/api/user/topup', { const res = await API.post('/api/user/topup', {
@@ -80,7 +81,7 @@ const TopUp = () => {
<Grid.Column> <Grid.Column>
<Statistic.Group widths='one'> <Statistic.Group widths='one'>
<Statistic> <Statistic>
<Statistic.Value>{userQuota}</Statistic.Value> <Statistic.Value>{userQuota.toLocaleString()}</Statistic.Value>
<Statistic.Label>剩余额度</Statistic.Label> <Statistic.Label>剩余额度</Statistic.Label>
</Statistic> </Statistic>
</Statistic.Group> </Statistic.Group>

View File

@@ -14,8 +14,9 @@ const EditUser = () => {
github_id: '', github_id: '',
wechat_id: '', wechat_id: '',
email: '', email: '',
quota: 0,
}); });
const { username, display_name, password, github_id, wechat_id, email } = const { username, display_name, password, github_id, wechat_id, email, quota } =
inputs; inputs;
const handleInputChange = (e, { name, value }) => { const handleInputChange = (e, { name, value }) => {
setInputs((inputs) => ({ ...inputs, [name]: value })); setInputs((inputs) => ({ ...inputs, [name]: value }));
@@ -44,7 +45,11 @@ const EditUser = () => {
const submit = async () => { const submit = async () => {
let res = undefined; let res = undefined;
if (userId) { if (userId) {
res = await API.put(`/api/user/`, { ...inputs, id: parseInt(userId) }); let data = { ...inputs, id: parseInt(userId) };
if (typeof data.quota === 'string') {
data.quota = parseInt(data.quota);
}
res = await API.put(`/api/user/`, data);
} else { } else {
res = await API.put(`/api/user/self`, inputs); res = await API.put(`/api/user/self`, inputs);
} }
@@ -92,6 +97,21 @@ const EditUser = () => {
autoComplete='new-password' autoComplete='new-password'
/> />
</Form.Field> </Form.Field>
{
userId && (
<Form.Field>
<Form.Input
label='剩余额度'
name='quota'
placeholder={'请输入新的剩余额度'}
onChange={handleInputChange}
value={quota}
type={'number'}
autoComplete='new-password'
/>
</Form.Field>
)
}
<Form.Field> <Form.Field>
<Form.Input <Form.Input
label='已绑定的 GitHub 账户' label='已绑定的 GitHub 账户'