Compare commits

..

4 Commits

Author SHA1 Message Date
JustSong
4463224f04 feat: support automatic channel testing & balance updates (close #11, close #59) 2023-06-22 22:01:03 +08:00
JustSong
ad1049b0cf feat: support search channels by key (close #185) 2023-06-22 21:19:43 +08:00
JustSong
d0c454c78e chore: able to clear all models now 2023-06-22 20:53:21 +08:00
JustSong
fe135fd508 chore: update base url setting 2023-06-22 20:49:55 +08:00
8 changed files with 81 additions and 46 deletions

View File

@@ -250,6 +250,12 @@ graph LR
+ 例子:`SYNC_FREQUENCY=60` + 例子:`SYNC_FREQUENCY=60`
6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
+ 例子:`NODE_TYPE=slave` + 例子:`NODE_TYPE=slave`
7. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
8. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
+ 例子:`CHANNEL_TEST_FREQUENCY=1440`
9. `REQUEST_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ 例子:`POLLING_INTERVAL=5`
### 命令行参数 ### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。

View File

@@ -2,6 +2,7 @@ package common
import ( import (
"os" "os"
"strconv"
"sync" "sync"
"time" "time"
@@ -70,6 +71,9 @@ var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("REQUEST_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
const ( const (
RoleGuestUser = 0 RoleGuestUser = 0
RoleCommonUser = 1 RoleCommonUser = 1

View File

@@ -257,6 +257,7 @@ func updateAllChannelsBalance() error {
disableChannel(channel.Id, channel.Name, "余额不足") disableChannel(channel.Id, channel.Name, "余额不足")
} }
} }
time.Sleep(common.RequestInterval)
} }
return nil return nil
} }
@@ -277,3 +278,12 @@ func UpdateAllChannelsBalance(c *gin.Context) {
}) })
return return
} }
func AutomaticallyUpdateChannels(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Minute)
common.SysLog("updating all channels")
_ = updateAllChannelsBalance()
common.SysLog("channels update done")
}
}

View File

@@ -62,10 +62,9 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
return nil return nil
} }
func buildTestRequest(c *gin.Context) *ChatRequest { func buildTestRequest() *ChatRequest {
model_ := c.Query("model")
testRequest := &ChatRequest{ testRequest := &ChatRequest{
Model: model_, Model: "", // this will be set later
MaxTokens: 1, MaxTokens: 1,
} }
testMessage := Message{ testMessage := Message{
@@ -93,7 +92,7 @@ func TestChannel(c *gin.Context) {
}) })
return return
} }
testRequest := buildTestRequest(c) testRequest := buildTestRequest()
tik := time.Now() tik := time.Now()
err = testChannel(channel, *testRequest) err = testChannel(channel, *testRequest)
tok := time.Now() tok := time.Now()
@@ -133,7 +132,7 @@ func disableChannel(channelId int, channelName string, reason string) {
} }
} }
func testAllChannels(c *gin.Context) error { func testAllChannels(notify bool) error {
if common.RootUserEmail == "" { if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail() common.RootUserEmail = model.GetRootUserEmail()
} }
@@ -146,13 +145,9 @@ func testAllChannels(c *gin.Context) error {
testAllChannelsLock.Unlock() testAllChannelsLock.Unlock()
channels, err := model.GetAllChannels(0, 0, true) channels, err := model.GetAllChannels(0, 0, true)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return err return err
} }
testRequest := buildTestRequest(c) testRequest := buildTestRequest()
var disableThreshold = int64(common.ChannelDisableThreshold * 1000) var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 { if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value disableThreshold = 10000000 // a impossible value
@@ -173,20 +168,23 @@ func testAllChannels(c *gin.Context) error {
disableChannel(channel.Id, channel.Name, err.Error()) disableChannel(channel.Id, channel.Name, err.Error())
} }
channel.UpdateResponseTime(milliseconds) channel.UpdateResponseTime(milliseconds)
} time.Sleep(common.RequestInterval)
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
} }
testAllChannelsLock.Lock() testAllChannelsLock.Lock()
testAllChannelsRunning = false testAllChannelsRunning = false
testAllChannelsLock.Unlock() testAllChannelsLock.Unlock()
if notify {
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
}() }()
return nil return nil
} }
func TestAllChannels(c *gin.Context) { func TestAllChannels(c *gin.Context) {
err := testAllChannels(c) err := testAllChannels(true)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -200,3 +198,12 @@ func TestAllChannels(c *gin.Context) {
}) })
return return
} }
func AutomaticallyTestChannels(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Minute)
common.SysLog("testing all channels")
_ = testAllChannels(false)
common.SysLog("channel test finished")
}
}

15
main.go
View File

@@ -7,6 +7,7 @@ import (
"github.com/gin-contrib/sessions/redis" "github.com/gin-contrib/sessions/redis"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
"one-api/controller"
"one-api/middleware" "one-api/middleware"
"one-api/model" "one-api/model"
"one-api/router" "one-api/router"
@@ -59,6 +60,20 @@ func main() {
go model.SyncChannelCache(frequency) go model.SyncChannelCache(frequency)
} }
} }
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
if err != nil {
common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
}
go controller.AutomaticallyUpdateChannels(frequency)
}
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
if err != nil {
common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
}
go controller.AutomaticallyTestChannels(frequency)
}
// Initialize HTTP server // Initialize HTTP server
server := gin.Default() server := gin.Default()

View File

@@ -8,7 +8,7 @@ import (
type Channel struct { type Channel struct {
Id int `json:"id"` Id int `json:"id"`
Type int `json:"type" gorm:"default:0"` Type int `json:"type" gorm:"default:0"`
Key string `json:"key" gorm:"not null"` Key string `json:"key" gorm:"not null;index"`
Status int `json:"status" gorm:"default:1"` Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"` Name string `json:"name" gorm:"index"`
Weight int `json:"weight"` Weight int `json:"weight"`
@@ -36,7 +36,7 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
} }
func SearchChannels(keyword string) (channels []*Channel, err error) { func SearchChannels(keyword string) (channels []*Channel, err error) {
err = DB.Omit("key").Where("id = ? or name LIKE ?", keyword, keyword+"%").Find(&channels).Error err = DB.Omit("key").Where("id = ? or name LIKE ? or key = ?", keyword, keyword+"%", keyword).Find(&channels).Error
return channels, err return channels, err
} }

View File

@@ -263,7 +263,7 @@ const ChannelsTable = () => {
icon='search' icon='search'
fluid fluid
iconPosition='left' iconPosition='left'
placeholder='搜索渠道的 ID 和名称 ...' placeholder='搜索渠道的 ID,名称和密钥 ...'
value={searchKeyword} value={searchKeyword}
loading={searching} loading={searching}
onChange={handleKeywordChange} onChange={handleKeywordChange}

View File

@@ -32,15 +32,15 @@ const EditChannel = () => {
let res = await API.get(`/api/channel/${channelId}`); let res = await API.get(`/api/channel/${channelId}`);
const { success, message, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
if (data.models === "") { if (data.models === '') {
data.models = [] data.models = [];
} else { } else {
data.models = data.models.split(",") data.models = data.models.split(',');
} }
if (data.group === "") { if (data.group === '') {
data.groups = [] data.groups = [];
} else { } else {
data.groups = data.group.split(",") data.groups = data.group.split(',');
} }
setInputs(data); setInputs(data);
} else { } else {
@@ -55,10 +55,10 @@ const EditChannel = () => {
setModelOptions(res.data.data.map((model) => ({ setModelOptions(res.data.data.map((model) => ({
key: model.id, key: model.id,
text: model.id, text: model.id,
value: model.id, value: model.id
}))); })));
setFullModels(res.data.data.map((model) => model.id)); setFullModels(res.data.data.map((model) => model.id));
setBasicModels(res.data.data.filter((model) => !model.id.startsWith("gpt-4")).map((model) => model.id)); setBasicModels(res.data.data.filter((model) => !model.id.startsWith('gpt-4')).map((model) => model.id));
} catch (error) { } catch (error) {
showError(error.message); showError(error.message);
} }
@@ -70,7 +70,7 @@ const EditChannel = () => {
setGroupOptions(res.data.data.map((group) => ({ setGroupOptions(res.data.data.map((group) => ({
key: group, key: group,
text: group, text: group,
value: group, value: group
}))); })));
} catch (error) { } catch (error) {
showError(error.message); showError(error.message);
@@ -90,6 +90,10 @@ const EditChannel = () => {
showInfo('请填写渠道名称和渠道密钥!'); showInfo('请填写渠道名称和渠道密钥!');
return; return;
} }
if (inputs.models.length === 0) {
showInfo('请至少选择一个模型!');
return;
}
let localInputs = inputs; let localInputs = inputs;
if (localInputs.base_url.endsWith('/')) { if (localInputs.base_url.endsWith('/')) {
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
@@ -98,8 +102,8 @@ const EditChannel = () => {
localInputs.other = '2023-03-15-preview'; localInputs.other = '2023-03-15-preview';
} }
let res; let res;
localInputs.models = localInputs.models.join(",") localInputs.models = localInputs.models.join(',');
localInputs.group = localInputs.groups.join(",") localInputs.group = localInputs.groups.join(',');
if (isEdit) { if (isEdit) {
res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) }); res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) });
} else { } else {
@@ -181,9 +185,9 @@ const EditChannel = () => {
inputs.type !== 3 && inputs.type !== 8 && ( inputs.type !== 3 && inputs.type !== 8 && (
<Form.Field> <Form.Field>
<Form.Input <Form.Input
label='Base URL' label='镜像'
name='base_url' name='base_url'
placeholder={'请输入自定义 Base URL格式为https://domain.com可不填不填使用渠道默认值'} placeholder={'请输入镜像站地址格式为https://domain.com可不填不填使用渠道默认值'}
onChange={handleInputChange} onChange={handleInputChange}
value={inputs.base_url} value={inputs.base_url}
autoComplete='new-password' autoComplete='new-password'
@@ -231,28 +235,17 @@ const EditChannel = () => {
options={modelOptions} options={modelOptions}
/> />
</Form.Field> </Form.Field>
<div style={{ lineHeight: '40px', marginBottom: '12px'}}> <div style={{ lineHeight: '40px', marginBottom: '12px' }}>
<Button type={'button'} onClick={() => { <Button type={'button'} onClick={() => {
handleInputChange(null, { name: 'models', value: basicModels }); handleInputChange(null, { name: 'models', value: basicModels });
}}>填入基础模型</Button> }}>填入基础模型</Button>
<Button type={'button'} onClick={() => { <Button type={'button'} onClick={() => {
handleInputChange(null, { name: 'models', value: fullModels }); handleInputChange(null, { name: 'models', value: fullModels });
}}>填入所有模型</Button> }}>填入所有模型</Button>
<Button type={'button'} onClick={() => {
handleInputChange(null, { name: 'models', value: [] });
}}>清除所有模型</Button>
</div> </div>
{
inputs.type === 1 && (
<Form.Field>
<Form.Input
label='代理'
name='base_url'
placeholder={'请输入 OpenAI API 代理地址如果不需要请留空格式为https://api.openai.com'}
onChange={handleInputChange}
value={inputs.base_url}
autoComplete='new-password'
/>
</Form.Field>
)
}
{ {
batch ? <Form.Field> batch ? <Form.Field>
<Form.TextArea <Form.TextArea