Compare commits

..

11 Commits

Author SHA1 Message Date
JustSong
89d458b9cf feat: able to set RELAY_TIMEOUT 2023-10-22 20:39:49 +08:00
JustSong
63fafba112 feat: support ERNIE-Bot-4 (close #608) 2023-10-22 18:48:35 +08:00
Bryan
a398f35968 fix: fix postgresql support (#606)
* fix postgresql support

fixes #517

* fix: fix pg support

* chore: delete useless code

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-10-22 18:38:29 +08:00
yiGmMk
57aa637c77 fix: set Accept header if not given (#615)
* fix: fastgpt调用通义千问问答失败

* refactor: Dockerfile

* Revert "refactor: Dockerfile"

This reverts commit a538c4f28e.

* chore: update implementation

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-10-22 17:56:20 +08:00
vc
3b483639a4 feat: add cloudflare ai gateway support for image & audio (#607)
* Update channel-test.go

* Update relay-audio.go

* Update relay-image.go

* chore: using a util function

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-10-22 17:50:52 +08:00
subnew
22980b4c44 docs: add description for TIKTOKEN_CACHE_DIR (#612)
* Update README.md

* Update README.md
2023-10-22 17:31:27 +08:00
Pluto
64cdb7eafb fix: docker compose healthcheck failed (#593) 2023-10-14 21:55:16 -05:00
JustSong
824444244b feat: able to delete all disabled channels 2023-10-14 17:25:48 +08:00
JustSong
fbe9985f57 chore: show prompt to let the user know 2023-10-14 16:32:01 +08:00
JustSong
a27a5bcc06 fix: fix array index not checked (close #588) 2023-10-14 16:11:15 +08:00
JustSong
e28d4b1741 feat: support cloudflare AI gateway now (close #565, #598) 2023-10-14 15:26:28 +08:00
24 changed files with 145 additions and 55 deletions

View File

@@ -95,12 +95,13 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
15. 支持模型映射,重定向用户的请求模型。
16. 支持失败自动重试。
17. 支持绘图接口。
18. 支持丰富的**自定义**设置,
18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。
19. 支持丰富的**自定义**设置,
1. 支持自定义系统名称logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
19. 支持通过系统访问令牌访问管理 API。
20. 支持 Cloudflare Turnstile 用户校验。
21. 支持用户管理,支持**多种用户登录注册方式**
20. 支持通过系统访问令牌访问管理 API。
21. 支持 Cloudflare Turnstile 用户校验。
22. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
@@ -351,6 +352,10 @@ graph LR
13. 请求频率限制:
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
14. 编码器缓存设置:
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。

View File

@@ -21,12 +21,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
var UsingSQLite = false
// Any options with "Secret", "Token" in its key won't be return by GetOptions
var SessionSecret = uuid.New().String()
var SQLitePath = "one-api.db"
var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex
@@ -98,6 +95,8 @@ var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
const (
RequestIdKey = "X-Oneapi-Request-Id"
)

6
common/database.go Normal file
View File

@@ -0,0 +1,6 @@
package common
var UsingSQLite = false
var UsingPostgreSQL = false
var SQLitePath = "one-api.db"

View File

@@ -46,6 +46,7 @@ var ModelRatio = map[string]float64{
"claude-2": 5.51, // $11.02 / 1M tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens

View File

@@ -199,3 +199,11 @@ func GetOrDefault(env string, defaultValue int) int {
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}

View File

@@ -5,13 +5,14 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"sync"
"time"
"github.com/gin-gonic/gin"
)
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {

View File

@@ -127,8 +127,8 @@ func DeleteChannel(c *gin.Context) {
return
}
func DeleteManuallyDisabledChannel(c *gin.Context) {
rows, err := model.DeleteChannelByStatus(common.ChannelStatusManuallyDisabled)
func DeleteDisabledChannel(c *gin.Context) {
rows, err := model.DeleteDisabledChannel()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -306,6 +306,15 @@ func init() {
Root: "ERNIE-Bot-turbo",
Parent: nil,
},
{
Id: "ERNIE-Bot-4",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot-4",
Parent: nil,
},
{
Id: "Embedding-V1",
Object: "model",

View File

@@ -6,12 +6,11 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"github.com/gin-gonic/gin"
)
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@@ -66,12 +65,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
requestBody := c.Request.Body
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)

View File

@@ -6,12 +6,11 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"github.com/gin-gonic/gin"
)
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@@ -61,16 +60,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
isModelMapped = true
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(imageRequest)

View File

@@ -6,13 +6,14 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"time"
"github.com/gin-gonic/gin"
)
const (
@@ -31,7 +32,14 @@ var httpClient *http.Client
var impatientHTTPClient *http.Client
func init() {
httpClient = &http.Client{}
if common.RelayTimeout == 0 {
httpClient = &http.Client{}
} else {
httpClient = &http.Client{
Timeout: time.Duration(common.RelayTimeout) * time.Second,
}
}
impatientHTTPClient = &http.Client{
Timeout: 5 * time.Second,
}
@@ -118,7 +126,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
@@ -151,6 +159,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "ERNIE-Bot-4":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1":
@@ -361,6 +371,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if isStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
resp, err = httpClient.Do(req)
if err != nil {

View File

@@ -176,3 +176,13 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
return
}
func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if channelType == common.ChannelTypeOpenAI {
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
}
}
return fullRequestURL
}

View File

@@ -220,6 +220,9 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin
for !stop {
select {
case xunfeiResponse = <-dataChan:
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
continue
}
content += xunfeiResponse.Payload.Choices.Text[0].Content
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens

View File

@@ -23,7 +23,7 @@ services:
depends_on:
- redis
healthcheck:
test: [ "CMD-SHELL", "curl -s http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk '{print $2}' | grep 'true'" ]
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
interval: 30s
timeout: 10s
retries: 3

View File

@@ -15,10 +15,17 @@ type Ability struct {
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{}
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
var err error = nil
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model)
channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite {
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("RANDOM()").First(&ability).Error
} else {
err = channelQuery.Order("RAND()").First(&ability).Error

View File

@@ -21,14 +21,18 @@ var (
)
func CacheGetTokenByKey(key string) (*Token, error) {
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
var token Token
if !common.RedisEnabled {
err := DB.Where("`key` = ?", key).First(&token).Error
err := DB.Where(keyCol+" = ?", key).First(&token).Error
return &token, err
}
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil {
err := DB.Where("`key` = ?", key).First(&token).Error
err := DB.Where(keyCol+" = ?", key).First(&token).Error
if err != nil {
return nil, err
}

View File

@@ -38,7 +38,11 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
}
func SearchChannels(keyword string) (channels []*Channel, err error) {
err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
return channels, err
}
@@ -53,17 +57,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
return &channel, err
}
func GetRandomChannel() (*Channel, error) {
channel := Channel{}
var err error = nil
if common.UsingSQLite {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
} else {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
}
return &channel, err
}
func BatchInsertChannels(channels []Channel) error {
var err error
err = DB.Create(&channels).Error
@@ -181,3 +174,8 @@ func DeleteChannelByStatus(status int64) (int64, error) {
result := DB.Where("status = ?", status).Delete(&Channel{})
return result.RowsAffected, result.Error
}
func DeleteDisabledChannel() (int64, error) {
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
return result.RowsAffected, result.Error
}

View File

@@ -42,6 +42,7 @@ func chooseDB() (*gorm.DB, error) {
if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage

View File

@@ -50,8 +50,13 @@ func Redeem(key string, userId int) (quota int, err error) {
}
redemption := &Redemption{}
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
if err != nil {
return errors.New("无效的兑换码")
}

View File

@@ -266,7 +266,12 @@ func GetUserEmail(id int) (email string, err error) {
}
func GetUserGroup(id int) (group string, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
groupCol := "`group`"
if common.UsingPostgreSQL {
groupCol = `"group"`
}
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
return group, err
}

View File

@@ -74,7 +74,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
channelRoute.POST("/", controller.AddChannel)
channelRoute.PUT("/", controller.UpdateChannel)
channelRoute.DELETE("/manually_disabled", controller.DeleteManuallyDisabledChannel)
channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel)
channelRoute.DELETE("/:id", controller.DeleteChannel)
}
tokenRoute := apiRouter.Group("/token")

View File

@@ -1,7 +1,7 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Input, Label, Pagination, Popup, Table } from 'semantic-ui-react';
import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react';
import { Link } from 'react-router-dom';
import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers';
import { API, setPromptShown, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
import { renderGroup, renderNumber } from '../helpers/render';
@@ -55,6 +55,7 @@ const ChannelsTable = () => {
const [searchKeyword, setSearchKeyword] = useState('');
const [searching, setSearching] = useState(false);
const [updatingBalance, setUpdatingBalance] = useState(false);
const [showPrompt, setShowPrompt] = useState(shouldShowPrompt("channel-test"));
const loadChannels = async (startIdx) => {
const res = await API.get(`/api/channel/?p=${startIdx}`);
@@ -226,7 +227,6 @@ const ChannelsTable = () => {
showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
} else {
showError(message);
showNotice('当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。');
}
};
@@ -240,11 +240,11 @@ const ChannelsTable = () => {
}
};
const deleteAllManuallyDisabledChannels = async () => {
const res = await API.delete(`/api/channel/manually_disabled`);
const deleteAllDisabledChannels = async () => {
const res = await API.delete(`/api/channel/disabled`);
const { success, message, data } = res.data;
if (success) {
showSuccess(`已删除所有手动禁用渠道,共计 ${data}`);
showSuccess(`已删除所有禁用渠道,共计 ${data}`);
await refresh();
} else {
showError(message);
@@ -317,7 +317,19 @@ const ChannelsTable = () => {
onChange={handleKeywordChange}
/>
</Form>
{
showPrompt && (
<Message onDismiss={() => {
setShowPrompt(false);
setPromptShown("channel-test");
}}>
当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo
模型进行非流式请求实现的因此测试报错并不一定代表通道不可用该功能后续会修复
另外OpenAI 渠道已经不再支持通过 key 获取余额因此余额显示为 0对于支持的渠道类型请点击余额进行刷新
</Message>
)
}
<Table basic compact size='small'>
<Table.Header>
<Table.Row>
@@ -519,14 +531,14 @@ const ChannelsTable = () => {
<Popup
trigger={
<Button size='small' loading={loading}>
删除所有手动禁用渠道
删除禁用渠道
</Button>
}
on='click'
flowing
hoverable
>
<Button size='small' loading={loading} negative onClick={deleteAllManuallyDisabledChannels}>
<Button size='small' loading={loading} negative onClick={deleteAllDisabledChannels}>
确认删除
</Button>
</Popup>

View File

@@ -186,4 +186,14 @@ export const verifyJSON = (str) => {
return false;
}
return true;
};
};
export function shouldShowPrompt(id) {
let prompt = localStorage.getItem(`prompt-${id}`);
return !prompt;
}
export function setPromptShown(id) {
localStorage.setItem(`prompt-${id}`, 'true');
}

View File

@@ -66,7 +66,7 @@ const EditChannel = () => {
localModels = ['PaLM-2'];
break;
case 15:
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1'];
break;
case 17:
localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1'];