Compare commits

...

10 Commits

Author SHA1 Message Date
JustSong
a1f61384c5 feat: automatically disable channel when error occurred (#59) 2023-05-15 17:34:09 +08:00
JustSong
44ebae1559 feat: add refresh button 2023-05-15 16:20:01 +08:00
JustSong
aae92683d7 fix: fix lock is not working 2023-05-15 16:19:39 +08:00
JustSong
cc3072c4df fix: remove version suffix for Azure (close #67) 2023-05-15 15:48:18 +08:00
JustSong
bffee4e91d fix: fix /v1/models not working (close #66) 2023-05-15 15:33:34 +08:00
JustSong
79dc53ff0d ci: build arm version 2023-05-15 15:14:33 +08:00
JustSong
68e53d3e10 chore: only show two digits 2023-05-15 12:56:28 +08:00
JustSong
d267211ee7 feat: able to test all enabled channels (#59) 2023-05-15 12:36:55 +08:00
JustSong
570b3bc71c ci: remove arm64 image builder 2023-05-15 11:36:50 +08:00
JustSong
225176aae9 feat: save response time & test time (#59) 2023-05-15 11:35:38 +08:00
12 changed files with 244 additions and 28 deletions

View File

@@ -52,6 +52,11 @@ var TurnstileSecretKey = ""
var QuotaForNewUser = 100
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
var RootUserEmail = ""
const (
RoleGuestUser = 0
RoleCommonUser = 1

View File

@@ -11,6 +11,7 @@ import (
"one-api/model"
"strconv"
"strings"
"sync"
"time"
)
@@ -19,7 +20,7 @@ func GetAllChannels(c *gin.Context) {
if p < 0 {
p = 0
}
channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage)
channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -89,7 +90,6 @@ func AddChannel(c *gin.Context) {
return
}
channel.CreatedTime = common.GetTimestamp()
channel.AccessedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n")
channels := make([]model.Channel, 0)
for _, key := range keys {
@@ -207,6 +207,19 @@ func testChannel(channel *model.Channel, request *ChatRequest) error {
return nil
}
func buildTestRequest(c *gin.Context) *ChatRequest {
model_ := c.Query("model")
testRequest := &ChatRequest{
Model: model_,
}
testMessage := Message{
Role: "user",
Content: "echo hi",
}
testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest
}
func TestChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
@@ -224,19 +237,13 @@ func TestChannel(c *gin.Context) {
})
return
}
model_ := c.Query("model")
chatRequest := &ChatRequest{
Model: model_,
}
testMessage := Message{
Role: "user",
Content: "echo hi",
}
chatRequest.Messages = append(chatRequest.Messages, testMessage)
testRequest := buildTestRequest(c)
tik := time.Now()
err = testChannel(channel, chatRequest)
err = testChannel(channel, testRequest)
tok := time.Now()
consumedTime := float64(tok.Sub(tik).Milliseconds()) / 1000.0
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -252,3 +259,85 @@ func TestChannel(c *gin.Context) {
})
return
}
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
// disable & notify
func disableChannel(channelId int, channelName string, err error) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, err.Error())
err = common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
}
func testAllChannels(c *gin.Context) error {
testAllChannelsLock.Lock()
if testAllChannelsRunning {
testAllChannelsLock.Unlock()
return errors.New("测试已在运行中")
}
testAllChannelsRunning = true
testAllChannelsLock.Unlock()
channels, err := model.GetAllChannels(0, 0, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return err
}
testRequest := buildTestRequest(c)
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
go func() {
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
continue
}
tik := time.Now()
err := testChannel(channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if err != nil || milliseconds > disableThreshold {
if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
}
disableChannel(channel.Id, channel.Name, err)
}
channel.UpdateResponseTime(milliseconds)
}
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
testAllChannelsLock.Lock()
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
}()
return nil
}
func TestAllChannels(c *gin.Context) {
err := testAllChannels(c)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}

View File

@@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
@@ -74,6 +75,11 @@ func Relay(c *gin.Context) {
"type": "one_api_error",
},
})
if common.AutomaticDisableChannelEnabled {
channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name")
disableChannel(channelId, channelName, err)
}
}
}
@@ -117,6 +123,9 @@ func relayHelper(c *gin.Context) error {
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := textRequest.Model
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314")
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
@@ -253,6 +262,10 @@ func relayHelper(c *gin.Context) error {
if err != nil {
return err
}
if textResponse.Error.Type != "" {
return errors.New(fmt.Sprintf("type %s, code %s, message %s",
textResponse.Error.Type, textResponse.Error.Code, textResponse.Error.Message))
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}

View File

@@ -112,7 +112,7 @@ func TokenAuth() func(c *gin.Context) {
c.Set("token_id", token.Id)
requestURL := c.Request.URL.String()
consumeQuota := !token.UnlimitedQuota
if strings.HasPrefix(requestURL, "/models") {
if strings.HasPrefix(requestURL, "/v1/models") {
consumeQuota = false
}
c.Set("consume_quota", consumeQuota)

View File

@@ -62,6 +62,8 @@ func Distribute() func(c *gin.Context) {
}
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
if channel.Type == common.ChannelTypeCustom || channel.Type == common.ChannelTypeAzure {
c.Set("base_url", channel.BaseURL)

View File

@@ -13,15 +13,20 @@ type Channel struct {
Name string `json:"name" gorm:"index"`
Weight int `json:"weight"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds
BaseURL string `json:"base_url" gorm:"column:base_url"`
Other string `json:"other"`
}
func GetAllChannels(startIdx int, num int) ([]*Channel, error) {
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
var channels []*Channel
var err error
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
if selectAll {
err = DB.Order("id desc").Find(&channels).Error
} else {
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
}
return channels, err
}
@@ -71,8 +76,25 @@ func (channel *Channel) Update() error {
return err
}
func (channel *Channel) UpdateResponseTime(responseTime int64) {
err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
TestTime: common.GetTimestamp(),
ResponseTime: int(responseTime),
}).Error
if err != nil {
common.SysError("failed to update response time: " + err.Error())
}
}
func (channel *Channel) Delete() error {
var err error
err = DB.Delete(channel).Error
return err
}
func UpdateChannelStatusById(id int, status int) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
if err != nil {
common.SysError("failed to update channel status: " + err.Error())
}
}

View File

@@ -32,6 +32,8 @@ func InitOptionMap() {
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["SMTPServer"] = ""
common.OptionMap["SMTPFrom"] = ""
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
@@ -114,6 +116,8 @@ func updateOptionMap(key string, value string) (err error) {
common.TurnstileCheckEnabled = boolValue
case "RegisterEnabled":
common.RegisterEnabled = boolValue
case "AutomaticDisableChannelEnabled":
common.AutomaticDisableChannelEnabled = boolValue
}
}
switch key {
@@ -156,6 +160,8 @@ func updateOptionMap(key string, value string) (err error) {
err = common.UpdateModelRatioByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
case "ChannelDisableThreshold":
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
}
return err
}

View File

@@ -234,3 +234,8 @@ func DecreaseUserQuota(id int, quota int) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
return err
}
func GetRootUserEmail() (email string) {
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
return email
}

View File

@@ -63,6 +63,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.GET("/", controller.GetAllChannels)
channelRoute.GET("/search", controller.SearchChannels)
channelRoute.GET("/:id", controller.GetChannel)
channelRoute.GET("/test", controller.TestAllChannels)
channelRoute.GET("/test/:id", controller.TestChannel)
channelRoute.POST("/", controller.AddChannel)
channelRoute.PUT("/", controller.UpdateChannel)

View File

@@ -1,7 +1,7 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
import { Link } from 'react-router-dom';
import { API, copy, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
@@ -60,6 +60,11 @@ const ChannelsTable = () => {
})();
};
const refresh = async () => {
setLoading(true);
await loadChannels(0);
}
useEffect(() => {
loadChannels(0)
.then()
@@ -120,6 +125,22 @@ const ChannelsTable = () => {
}
};
const renderResponseTime = (responseTime) => {
let time = responseTime / 1000;
time = time.toFixed(2) + " 秒";
if (responseTime === 0) {
return <Label basic color='grey'>未测试</Label>;
} else if (responseTime <= 1000) {
return <Label basic color='green'>{time}</Label>;
} else if (responseTime <= 3000) {
return <Label basic color='olive'>{time}</Label>;
} else if (responseTime <= 5000) {
return <Label basic color='yellow'>{time}</Label>;
} else {
return <Label basic color='red'>{time}</Label>;
}
};
const searchChannels = async () => {
if (searchKeyword === '') {
// if keyword is blank, load files instead.
@@ -139,11 +160,26 @@ const ChannelsTable = () => {
setSearching(false);
};
const testChannel = async (id, name) => {
const testChannel = async (id, name, idx) => {
const res = await API.get(`/api/channel/test/${id}/`);
const { success, message, time } = res.data;
if (success) {
showInfo(`通道 ${name} 测试成功,耗时 ${time} 秒。`);
let newChannels = [...channels];
let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
newChannels[realIdx].response_time = time * 1000;
newChannels[realIdx].test_time = Date.now() / 1000;
setChannels(newChannels);
showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
} else {
showError(message);
}
};
const testAllChannels = async () => {
const res = await API.get(`/api/channel/test`);
const { success, message } = res.data;
if (success) {
showInfo("已成功开始测试所有已启用通道,请刷新页面查看结果。");
} else {
showError(message);
}
@@ -219,18 +255,18 @@ const ChannelsTable = () => {
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortChannel('created_time');
sortChannel('response_time');
}}
>
创建时间
响应时间
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortChannel('accessed_time');
sortChannel('test_time');
}}
>
访问时间
测试时间
</Table.HeaderCell>
<Table.HeaderCell>操作</Table.HeaderCell>
</Table.Row>
@@ -250,15 +286,15 @@ const ChannelsTable = () => {
<Table.Cell>{channel.name ? channel.name : '无'}</Table.Cell>
<Table.Cell>{renderType(channel.type)}</Table.Cell>
<Table.Cell>{renderStatus(channel.status)}</Table.Cell>
<Table.Cell>{renderTimestamp(channel.created_time)}</Table.Cell>
<Table.Cell>{renderTimestamp(channel.accessed_time)}</Table.Cell>
<Table.Cell>{renderResponseTime(channel.response_time)}</Table.Cell>
<Table.Cell>{channel.test_time ? renderTimestamp(channel.test_time) : "未测试"}</Table.Cell>
<Table.Cell>
<div>
<Button
size={'small'}
positive
onClick={() => {
testChannel(channel.id, channel.name);
testChannel(channel.id, channel.name, idx);
}}
>
测试
@@ -314,6 +350,9 @@ const ChannelsTable = () => {
<Button size='small' as={Link} to='/channel/add' loading={loading}>
添加新的渠道
</Button>
<Button size='small' loading={loading} onClick={testAllChannels}>
测试所有已启用通道
</Button>
<Pagination
floated='right'
activePage={activePage}
@@ -325,6 +364,7 @@ const ChannelsTable = () => {
(channels.length % ITEMS_PER_PAGE === 0 ? 1 : 0)
}
/>
<Button size='small' onClick={refresh} loading={loading}>刷新</Button>
</Table.HeaderCell>
</Table.Row>
</Table.Footer>

View File

@@ -28,7 +28,9 @@ const SystemSetting = () => {
RegisterEnabled: '',
QuotaForNewUser: 0,
ModelRatio: '',
TopUpLink: ''
TopUpLink: '',
AutomaticDisableChannelEnabled: '',
ChannelDisableThreshold: 0,
});
let originInputs = {};
let [loading, setLoading] = useState(false);
@@ -62,6 +64,7 @@ const SystemSetting = () => {
case 'WeChatAuthEnabled':
case 'TurnstileCheckEnabled':
case 'RegisterEnabled':
case 'AutomaticDisableChannelEnabled':
value = inputs[key] === 'true' ? 'false' : 'true';
break;
default:
@@ -298,6 +301,30 @@ const SystemSetting = () => {
</Form.Group>
<Form.Button onClick={submitOperationConfig}>保存运营设置</Form.Button>
<Divider />
<Header as='h3'>
监控设置
</Header>
<Form.Group widths={3}>
<Form.Input
label='最长回应时间'
name='ChannelDisableThreshold'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.ChannelDisableThreshold}
type='number'
min='0'
placeholder='单位秒,当运行通道全部测试时,超过此时间将自动禁用通道'
/>
</Form.Group>
<Form.Group inline>
<Form.Checkbox
checked={inputs.AutomaticDisableChannelEnabled === 'true'}
label='失败时自动禁用通道'
name='AutomaticDisableChannelEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Divider />
<Header as='h3'>
配置 SMTP
<Header.Subheader>用以支持系统的邮件发送</Header.Subheader>

View File

@@ -66,6 +66,11 @@ const TokensTable = () => {
})();
};
const refresh = async () => {
setLoading(true);
await loadTokens(0);
}
useEffect(() => {
loadTokens(0)
.then()
@@ -334,6 +339,7 @@ const TokensTable = () => {
<Button size='small' as={Link} to='/token/add' loading={loading}>
添加新的令牌
</Button>
<Button size='small' onClick={refresh} loading={loading}>刷新</Button>
<Pagination
floated='right'
activePage={activePage}