mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-12-25 17:25:56 +08:00
Compare commits
8 Commits
v0.4.5-alp
...
v0.4.7-alp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8a4cd403fd | ||
|
|
9ac5410d06 | ||
|
|
7edc2b5376 | ||
|
|
d4869dfad2 | ||
|
|
4463224f04 | ||
|
|
ad1049b0cf | ||
|
|
d0c454c78e | ||
|
|
fe135fd508 |
@@ -250,6 +250,12 @@ graph LR
|
|||||||
+ 例子:`SYNC_FREQUENCY=60`
|
+ 例子:`SYNC_FREQUENCY=60`
|
||||||
6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
|
6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
|
||||||
+ 例子:`NODE_TYPE=slave`
|
+ 例子:`NODE_TYPE=slave`
|
||||||
|
7. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
|
||||||
|
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
|
||||||
|
8. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
|
||||||
|
+ 例子:`CHANNEL_TEST_FREQUENCY=1440`
|
||||||
|
9. `REQUEST_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
|
||||||
|
+ 例子:`POLLING_INTERVAL=5`
|
||||||
|
|
||||||
### 命令行参数
|
### 命令行参数
|
||||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -17,7 +18,8 @@ var Logo = ""
|
|||||||
var TopUpLink = ""
|
var TopUpLink = ""
|
||||||
var ChatLink = ""
|
var ChatLink = ""
|
||||||
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
|
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
|
||||||
var DisplayInCurrencyEnabled = false
|
var DisplayInCurrencyEnabled = true
|
||||||
|
var DisplayTokenStatEnabled = true
|
||||||
|
|
||||||
var UsingSQLite = false
|
var UsingSQLite = false
|
||||||
|
|
||||||
@@ -70,6 +72,9 @@ var RootUserEmail = ""
|
|||||||
|
|
||||||
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
||||||
|
|
||||||
|
var requestInterval, _ = strconv.Atoi(os.Getenv("REQUEST_INTERVAL"))
|
||||||
|
var RequestInterval = time.Duration(requestInterval) * time.Second
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RoleGuestUser = 0
|
RoleGuestUser = 0
|
||||||
RoleCommonUser = 1
|
RoleCommonUser = 1
|
||||||
|
|||||||
@@ -7,8 +7,17 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func GetSubscription(c *gin.Context) {
|
func GetSubscription(c *gin.Context) {
|
||||||
userId := c.GetInt("id")
|
var quota int
|
||||||
quota, err := model.GetUserQuota(userId)
|
var err error
|
||||||
|
var token *model.Token
|
||||||
|
if common.DisplayTokenStatEnabled {
|
||||||
|
tokenId := c.GetInt("token_id")
|
||||||
|
token, err = model.GetTokenById(tokenId)
|
||||||
|
quota = token.RemainQuota
|
||||||
|
} else {
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
quota, err = model.GetUserQuota(userId)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openAIError := OpenAIError{
|
openAIError := OpenAIError{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
@@ -35,8 +44,17 @@ func GetSubscription(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetUsage(c *gin.Context) {
|
func GetUsage(c *gin.Context) {
|
||||||
userId := c.GetInt("id")
|
var quota int
|
||||||
quota, err := model.GetUserUsedQuota(userId)
|
var err error
|
||||||
|
var token *model.Token
|
||||||
|
if common.DisplayTokenStatEnabled {
|
||||||
|
tokenId := c.GetInt("token_id")
|
||||||
|
token, err = model.GetTokenById(tokenId)
|
||||||
|
quota = token.UsedQuota
|
||||||
|
} else {
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
quota, err = model.GetUserUsedQuota(userId)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openAIError := OpenAIError{
|
openAIError := OpenAIError{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
|
|||||||
@@ -257,6 +257,7 @@ func updateAllChannelsBalance() error {
|
|||||||
disableChannel(channel.Id, channel.Name, "余额不足")
|
disableChannel(channel.Id, channel.Name, "余额不足")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
time.Sleep(common.RequestInterval)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -277,3 +278,12 @@ func UpdateAllChannelsBalance(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AutomaticallyUpdateChannels(frequency int) {
|
||||||
|
for {
|
||||||
|
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||||
|
common.SysLog("updating all channels")
|
||||||
|
_ = updateAllChannelsBalance()
|
||||||
|
common.SysLog("channels update done")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -62,10 +62,9 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildTestRequest(c *gin.Context) *ChatRequest {
|
func buildTestRequest() *ChatRequest {
|
||||||
model_ := c.Query("model")
|
|
||||||
testRequest := &ChatRequest{
|
testRequest := &ChatRequest{
|
||||||
Model: model_,
|
Model: "", // this will be set later
|
||||||
MaxTokens: 1,
|
MaxTokens: 1,
|
||||||
}
|
}
|
||||||
testMessage := Message{
|
testMessage := Message{
|
||||||
@@ -93,7 +92,7 @@ func TestChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
testRequest := buildTestRequest(c)
|
testRequest := buildTestRequest()
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
err = testChannel(channel, *testRequest)
|
err = testChannel(channel, *testRequest)
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
@@ -133,7 +132,7 @@ func disableChannel(channelId int, channelName string, reason string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testAllChannels(c *gin.Context) error {
|
func testAllChannels(notify bool) error {
|
||||||
if common.RootUserEmail == "" {
|
if common.RootUserEmail == "" {
|
||||||
common.RootUserEmail = model.GetRootUserEmail()
|
common.RootUserEmail = model.GetRootUserEmail()
|
||||||
}
|
}
|
||||||
@@ -146,13 +145,9 @@ func testAllChannels(c *gin.Context) error {
|
|||||||
testAllChannelsLock.Unlock()
|
testAllChannelsLock.Unlock()
|
||||||
channels, err := model.GetAllChannels(0, 0, true)
|
channels, err := model.GetAllChannels(0, 0, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
testRequest := buildTestRequest(c)
|
testRequest := buildTestRequest()
|
||||||
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
||||||
if disableThreshold == 0 {
|
if disableThreshold == 0 {
|
||||||
disableThreshold = 10000000 // a impossible value
|
disableThreshold = 10000000 // a impossible value
|
||||||
@@ -173,20 +168,23 @@ func testAllChannels(c *gin.Context) error {
|
|||||||
disableChannel(channel.Id, channel.Name, err.Error())
|
disableChannel(channel.Id, channel.Name, err.Error())
|
||||||
}
|
}
|
||||||
channel.UpdateResponseTime(milliseconds)
|
channel.UpdateResponseTime(milliseconds)
|
||||||
}
|
time.Sleep(common.RequestInterval)
|
||||||
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
|
|
||||||
if err != nil {
|
|
||||||
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
|
||||||
}
|
}
|
||||||
testAllChannelsLock.Lock()
|
testAllChannelsLock.Lock()
|
||||||
testAllChannelsRunning = false
|
testAllChannelsRunning = false
|
||||||
testAllChannelsLock.Unlock()
|
testAllChannelsLock.Unlock()
|
||||||
|
if notify {
|
||||||
|
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAllChannels(c *gin.Context) {
|
func TestAllChannels(c *gin.Context) {
|
||||||
err := testAllChannels(c)
|
err := testAllChannels(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -200,3 +198,12 @@ func TestAllChannels(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AutomaticallyTestChannels(frequency int) {
|
||||||
|
for {
|
||||||
|
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||||
|
common.SysLog("testing all channels")
|
||||||
|
_ = testAllChannels(false)
|
||||||
|
common.SysLog("channel test finished")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ func GetOptions(c *gin.Context) {
|
|||||||
var options []*model.Option
|
var options []*model.Option
|
||||||
common.OptionMapRWMutex.Lock()
|
common.OptionMapRWMutex.Lock()
|
||||||
for k, v := range common.OptionMap {
|
for k, v := range common.OptionMap {
|
||||||
if strings.Contains(k, "Token") || strings.Contains(k, "Secret") {
|
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
options = append(options, &model.Option{
|
options = append(options, &model.Option{
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
userQuota, err := model.CacheGetUserQuota(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "get_user_quota_failed", http.StatusOK)
|
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if userQuota > 10*preConsumedQuota {
|
if userQuota > 10*preConsumedQuota {
|
||||||
// in this case, we do not pre-consume quota
|
// in this case, we do not pre-consume quota
|
||||||
@@ -86,12 +86,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
if consumeQuota && preConsumedQuota > 0 {
|
if consumeQuota && preConsumedQuota > 0 {
|
||||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK)
|
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "new_request_failed", http.StatusOK)
|
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if channelType == common.ChannelTypeAzure {
|
if channelType == common.ChannelTypeAzure {
|
||||||
key := c.Request.Header.Get("Authorization")
|
key := c.Request.Header.Get("Authorization")
|
||||||
@@ -106,15 +106,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "do_request_failed", http.StatusOK)
|
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
err = req.Body.Close()
|
err = req.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusOK)
|
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
err = c.Request.Body.Close()
|
err = c.Request.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusOK)
|
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
var textResponse TextResponse
|
var textResponse TextResponse
|
||||||
isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
@@ -224,22 +224,22 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
})
|
})
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusOK)
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
if consumeQuota {
|
if consumeQuota {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusOK)
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusOK)
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
err = json.Unmarshal(responseBody, &textResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusOK)
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if textResponse.Error.Type != "" {
|
if textResponse.Error.Type != "" {
|
||||||
return &OpenAIErrorWithStatusCode{
|
return &OpenAIErrorWithStatusCode{
|
||||||
@@ -260,11 +260,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
_, err = io.Copy(c.Writer, resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "copy_response_body_failed", http.StatusOK)
|
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusOK)
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ func RelayNotImplemented(c *gin.Context) {
|
|||||||
Param: "",
|
Param: "",
|
||||||
Code: "api_not_implemented",
|
Code: "api_not_implemented",
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusNotImplemented, gin.H{
|
||||||
"error": err,
|
"error": err,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -147,7 +147,7 @@ func RelayNotFound(c *gin.Context) {
|
|||||||
Param: "",
|
Param: "",
|
||||||
Code: "api_not_found",
|
Code: "api_not_found",
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusNotFound, gin.H{
|
||||||
"error": err,
|
"error": err,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
15
main.go
15
main.go
@@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/gin-contrib/sessions/redis"
|
"github.com/gin-contrib/sessions/redis"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/controller"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/router"
|
"one-api/router"
|
||||||
@@ -59,6 +60,20 @@ func main() {
|
|||||||
go model.SyncChannelCache(frequency)
|
go model.SyncChannelCache(frequency)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
|
||||||
|
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
|
||||||
|
if err != nil {
|
||||||
|
common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
|
||||||
|
}
|
||||||
|
go controller.AutomaticallyUpdateChannels(frequency)
|
||||||
|
}
|
||||||
|
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
|
||||||
|
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
|
||||||
|
if err != nil {
|
||||||
|
common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
|
||||||
|
}
|
||||||
|
go controller.AutomaticallyTestChannels(frequency)
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize HTTP server
|
// Initialize HTTP server
|
||||||
server := gin.Default()
|
server := gin.Default()
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
key = parts[0]
|
key = parts[0]
|
||||||
token, err := model.ValidateUserToken(key)
|
token, err := model.ValidateUserToken(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusUnauthorized, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
"type": "one_api_error",
|
"type": "one_api_error",
|
||||||
@@ -101,7 +101,7 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !model.CacheIsUserEnabled(token.UserId) {
|
if !model.CacheIsUserEnabled(token.UserId) {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": "用户已被封禁",
|
"message": "用户已被封禁",
|
||||||
"type": "one_api_error",
|
"type": "one_api_error",
|
||||||
@@ -123,7 +123,7 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
if model.IsAdmin(token.UserId) {
|
if model.IsAdmin(token.UserId) {
|
||||||
c.Set("channelId", parts[1])
|
c.Set("channelId", parts[1])
|
||||||
} else {
|
} else {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": "普通用户不支持指定渠道",
|
"message": "普通用户不支持指定渠道",
|
||||||
"type": "one_api_error",
|
"type": "one_api_error",
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
if ok {
|
if ok {
|
||||||
id, err := strconv.Atoi(channelId.(string))
|
id, err := strconv.Atoi(channelId.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": "无效的渠道 ID",
|
"message": "无效的渠道 ID",
|
||||||
"type": "one_api_error",
|
"type": "one_api_error",
|
||||||
@@ -35,7 +35,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
channel, err = model.GetChannelById(id, true)
|
channel, err = model.GetChannelById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": "无效的渠道 ID",
|
"message": "无效的渠道 ID",
|
||||||
"type": "one_api_error",
|
"type": "one_api_error",
|
||||||
@@ -45,7 +45,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if channel.Status != common.ChannelStatusEnabled {
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": "该渠道已被禁用",
|
"message": "该渠道已被禁用",
|
||||||
"type": "one_api_error",
|
"type": "one_api_error",
|
||||||
@@ -59,7 +59,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
var modelRequest ModelRequest
|
var modelRequest ModelRequest
|
||||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": "无效的请求",
|
"message": "无效的请求",
|
||||||
"type": "one_api_error",
|
"type": "one_api_error",
|
||||||
@@ -75,7 +75,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": "无可用渠道",
|
"message": "无可用渠道",
|
||||||
"type": "one_api_error",
|
"type": "one_api_error",
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
type Channel struct {
|
type Channel struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
Type int `json:"type" gorm:"default:0"`
|
Type int `json:"type" gorm:"default:0"`
|
||||||
Key string `json:"key" gorm:"not null"`
|
Key string `json:"key" gorm:"not null;index"`
|
||||||
Status int `json:"status" gorm:"default:1"`
|
Status int `json:"status" gorm:"default:1"`
|
||||||
Name string `json:"name" gorm:"index"`
|
Name string `json:"name" gorm:"index"`
|
||||||
Weight int `json:"weight"`
|
Weight int `json:"weight"`
|
||||||
@@ -36,7 +36,7 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SearchChannels(keyword string) (channels []*Channel, err error) {
|
func SearchChannels(keyword string) (channels []*Channel, err error) {
|
||||||
err = DB.Omit("key").Where("id = ? or name LIKE ?", keyword, keyword+"%").Find(&channels).Error
|
err = DB.Omit("key").Where("id = ? or name LIKE ? or key = ?", keyword, keyword+"%", keyword).Find(&channels).Error
|
||||||
return channels, err
|
return channels, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
|
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
|
||||||
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
|
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
|
||||||
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
|
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
|
||||||
|
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
|
||||||
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
|
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
|
||||||
common.OptionMap["SMTPServer"] = ""
|
common.OptionMap["SMTPServer"] = ""
|
||||||
common.OptionMap["SMTPFrom"] = ""
|
common.OptionMap["SMTPFrom"] = ""
|
||||||
@@ -144,6 +145,8 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
common.LogConsumeEnabled = boolValue
|
common.LogConsumeEnabled = boolValue
|
||||||
case "DisplayInCurrencyEnabled":
|
case "DisplayInCurrencyEnabled":
|
||||||
common.DisplayInCurrencyEnabled = boolValue
|
common.DisplayInCurrencyEnabled = boolValue
|
||||||
|
case "DisplayTokenStatEnabled":
|
||||||
|
common.DisplayTokenStatEnabled = boolValue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
switch key {
|
switch key {
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ type Token struct {
|
|||||||
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
||||||
RemainQuota int `json:"remain_quota" gorm:"default:0"`
|
RemainQuota int `json:"remain_quota" gorm:"default:0"`
|
||||||
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
||||||
|
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
|
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
|
||||||
@@ -130,7 +131,12 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
|
|||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota + ?", quota)).Error
|
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
||||||
|
map[string]interface{}{
|
||||||
|
"remain_quota": gorm.Expr("remain_quota + ?", quota),
|
||||||
|
"used_quota": gorm.Expr("used_quota - ?", quota),
|
||||||
|
},
|
||||||
|
).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,7 +144,12 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
|
|||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error
|
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
||||||
|
map[string]interface{}{
|
||||||
|
"remain_quota": gorm.Expr("remain_quota - ?", quota),
|
||||||
|
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||||
|
},
|
||||||
|
).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -263,7 +263,7 @@ const ChannelsTable = () => {
|
|||||||
icon='search'
|
icon='search'
|
||||||
fluid
|
fluid
|
||||||
iconPosition='left'
|
iconPosition='left'
|
||||||
placeholder='搜索渠道的 ID 和名称 ...'
|
placeholder='搜索渠道的 ID,名称和密钥 ...'
|
||||||
value={searchKeyword}
|
value={searchKeyword}
|
||||||
loading={searching}
|
loading={searching}
|
||||||
onChange={handleKeywordChange}
|
onChange={handleKeywordChange}
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ const OperationSetting = () => {
|
|||||||
AutomaticDisableChannelEnabled: '',
|
AutomaticDisableChannelEnabled: '',
|
||||||
ChannelDisableThreshold: 0,
|
ChannelDisableThreshold: 0,
|
||||||
LogConsumeEnabled: '',
|
LogConsumeEnabled: '',
|
||||||
DisplayInCurrencyEnabled: ''
|
DisplayInCurrencyEnabled: '',
|
||||||
|
DisplayTokenStatEnabled: ''
|
||||||
});
|
});
|
||||||
const [originInputs, setOriginInputs] = useState({});
|
const [originInputs, setOriginInputs] = useState({});
|
||||||
let [loading, setLoading] = useState(false);
|
let [loading, setLoading] = useState(false);
|
||||||
@@ -177,6 +178,12 @@ const OperationSetting = () => {
|
|||||||
name='DisplayInCurrencyEnabled'
|
name='DisplayInCurrencyEnabled'
|
||||||
onChange={handleInputChange}
|
onChange={handleInputChange}
|
||||||
/>
|
/>
|
||||||
|
<Form.Checkbox
|
||||||
|
checked={inputs.DisplayTokenStatEnabled === 'true'}
|
||||||
|
label='Billing 相关 API 显示令牌额度而非用户额度'
|
||||||
|
name='DisplayTokenStatEnabled'
|
||||||
|
onChange={handleInputChange}
|
||||||
|
/>
|
||||||
</Form.Group>
|
</Form.Group>
|
||||||
<Form.Button onClick={() => {
|
<Form.Button onClick={() => {
|
||||||
submitConfig('general').then();
|
submitConfig('general').then();
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import React, { useEffect, useState } from 'react';
|
import React, { useEffect, useState } from 'react';
|
||||||
import { Button, Form, Grid, Header, Image, Segment } from 'semantic-ui-react';
|
import { Button, Form, Grid, Header, Image, Segment } from 'semantic-ui-react';
|
||||||
import { API, copy, showError, showSuccess } from '../helpers';
|
import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
|
||||||
import { useSearchParams } from 'react-router-dom';
|
import { useSearchParams } from 'react-router-dom';
|
||||||
|
|
||||||
const PasswordResetConfirm = () => {
|
const PasswordResetConfirm = () => {
|
||||||
@@ -33,7 +33,7 @@ const PasswordResetConfirm = () => {
|
|||||||
if (success) {
|
if (success) {
|
||||||
let password = res.data.data;
|
let password = res.data.data;
|
||||||
await copy(password);
|
await copy(password);
|
||||||
showSuccess(`密码已重置并已复制到剪贴板:${password}`);
|
showNotice(`密码已重置并已复制到剪贴板:${password}`);
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -181,13 +181,21 @@ const TokensTable = () => {
|
|||||||
>
|
>
|
||||||
状态
|
状态
|
||||||
</Table.HeaderCell>
|
</Table.HeaderCell>
|
||||||
|
<Table.HeaderCell
|
||||||
|
style={{ cursor: 'pointer' }}
|
||||||
|
onClick={() => {
|
||||||
|
sortToken('used_quota');
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
已用额度
|
||||||
|
</Table.HeaderCell>
|
||||||
<Table.HeaderCell
|
<Table.HeaderCell
|
||||||
style={{ cursor: 'pointer' }}
|
style={{ cursor: 'pointer' }}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
sortToken('remain_quota');
|
sortToken('remain_quota');
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
额度
|
剩余额度
|
||||||
</Table.HeaderCell>
|
</Table.HeaderCell>
|
||||||
<Table.HeaderCell
|
<Table.HeaderCell
|
||||||
style={{ cursor: 'pointer' }}
|
style={{ cursor: 'pointer' }}
|
||||||
@@ -221,6 +229,7 @@ const TokensTable = () => {
|
|||||||
<Table.Row key={token.id}>
|
<Table.Row key={token.id}>
|
||||||
<Table.Cell>{token.name ? token.name : '无'}</Table.Cell>
|
<Table.Cell>{token.name ? token.name : '无'}</Table.Cell>
|
||||||
<Table.Cell>{renderStatus(token.status)}</Table.Cell>
|
<Table.Cell>{renderStatus(token.status)}</Table.Cell>
|
||||||
|
<Table.Cell>{renderQuota(token.used_quota)}</Table.Cell>
|
||||||
<Table.Cell>{token.unlimited_quota ? '无限制' : renderQuota(token.remain_quota, 2)}</Table.Cell>
|
<Table.Cell>{token.unlimited_quota ? '无限制' : renderQuota(token.remain_quota, 2)}</Table.Cell>
|
||||||
<Table.Cell>{renderTimestamp(token.created_time)}</Table.Cell>
|
<Table.Cell>{renderTimestamp(token.created_time)}</Table.Cell>
|
||||||
<Table.Cell>{token.expired_time === -1 ? '永不过期' : renderTimestamp(token.expired_time)}</Table.Cell>
|
<Table.Cell>{token.expired_time === -1 ? '永不过期' : renderTimestamp(token.expired_time)}</Table.Cell>
|
||||||
|
|||||||
@@ -32,15 +32,15 @@ const EditChannel = () => {
|
|||||||
let res = await API.get(`/api/channel/${channelId}`);
|
let res = await API.get(`/api/channel/${channelId}`);
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
if (data.models === "") {
|
if (data.models === '') {
|
||||||
data.models = []
|
data.models = [];
|
||||||
} else {
|
} else {
|
||||||
data.models = data.models.split(",")
|
data.models = data.models.split(',');
|
||||||
}
|
}
|
||||||
if (data.group === "") {
|
if (data.group === '') {
|
||||||
data.groups = []
|
data.groups = [];
|
||||||
} else {
|
} else {
|
||||||
data.groups = data.group.split(",")
|
data.groups = data.group.split(',');
|
||||||
}
|
}
|
||||||
setInputs(data);
|
setInputs(data);
|
||||||
} else {
|
} else {
|
||||||
@@ -55,10 +55,10 @@ const EditChannel = () => {
|
|||||||
setModelOptions(res.data.data.map((model) => ({
|
setModelOptions(res.data.data.map((model) => ({
|
||||||
key: model.id,
|
key: model.id,
|
||||||
text: model.id,
|
text: model.id,
|
||||||
value: model.id,
|
value: model.id
|
||||||
})));
|
})));
|
||||||
setFullModels(res.data.data.map((model) => model.id));
|
setFullModels(res.data.data.map((model) => model.id));
|
||||||
setBasicModels(res.data.data.filter((model) => !model.id.startsWith("gpt-4")).map((model) => model.id));
|
setBasicModels(res.data.data.filter((model) => !model.id.startsWith('gpt-4')).map((model) => model.id));
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
showError(error.message);
|
showError(error.message);
|
||||||
}
|
}
|
||||||
@@ -70,7 +70,7 @@ const EditChannel = () => {
|
|||||||
setGroupOptions(res.data.data.map((group) => ({
|
setGroupOptions(res.data.data.map((group) => ({
|
||||||
key: group,
|
key: group,
|
||||||
text: group,
|
text: group,
|
||||||
value: group,
|
value: group
|
||||||
})));
|
})));
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
showError(error.message);
|
showError(error.message);
|
||||||
@@ -90,6 +90,10 @@ const EditChannel = () => {
|
|||||||
showInfo('请填写渠道名称和渠道密钥!');
|
showInfo('请填写渠道名称和渠道密钥!');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (inputs.models.length === 0) {
|
||||||
|
showInfo('请至少选择一个模型!');
|
||||||
|
return;
|
||||||
|
}
|
||||||
let localInputs = inputs;
|
let localInputs = inputs;
|
||||||
if (localInputs.base_url.endsWith('/')) {
|
if (localInputs.base_url.endsWith('/')) {
|
||||||
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
|
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
|
||||||
@@ -98,8 +102,8 @@ const EditChannel = () => {
|
|||||||
localInputs.other = '2023-03-15-preview';
|
localInputs.other = '2023-03-15-preview';
|
||||||
}
|
}
|
||||||
let res;
|
let res;
|
||||||
localInputs.models = localInputs.models.join(",")
|
localInputs.models = localInputs.models.join(',');
|
||||||
localInputs.group = localInputs.groups.join(",")
|
localInputs.group = localInputs.groups.join(',');
|
||||||
if (isEdit) {
|
if (isEdit) {
|
||||||
res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) });
|
res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) });
|
||||||
} else {
|
} else {
|
||||||
@@ -181,9 +185,9 @@ const EditChannel = () => {
|
|||||||
inputs.type !== 3 && inputs.type !== 8 && (
|
inputs.type !== 3 && inputs.type !== 8 && (
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.Input
|
<Form.Input
|
||||||
label='Base URL'
|
label='镜像'
|
||||||
name='base_url'
|
name='base_url'
|
||||||
placeholder={'请输入自定义 Base URL,格式为:https://domain.com,可不填,不填使用渠道默认值'}
|
placeholder={'请输入镜像站地址,格式为:https://domain.com,可不填,不填则使用渠道默认值'}
|
||||||
onChange={handleInputChange}
|
onChange={handleInputChange}
|
||||||
value={inputs.base_url}
|
value={inputs.base_url}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
@@ -231,28 +235,17 @@ const EditChannel = () => {
|
|||||||
options={modelOptions}
|
options={modelOptions}
|
||||||
/>
|
/>
|
||||||
</Form.Field>
|
</Form.Field>
|
||||||
<div style={{ lineHeight: '40px', marginBottom: '12px'}}>
|
<div style={{ lineHeight: '40px', marginBottom: '12px' }}>
|
||||||
<Button type={'button'} onClick={() => {
|
<Button type={'button'} onClick={() => {
|
||||||
handleInputChange(null, { name: 'models', value: basicModels });
|
handleInputChange(null, { name: 'models', value: basicModels });
|
||||||
}}>填入基础模型</Button>
|
}}>填入基础模型</Button>
|
||||||
<Button type={'button'} onClick={() => {
|
<Button type={'button'} onClick={() => {
|
||||||
handleInputChange(null, { name: 'models', value: fullModels });
|
handleInputChange(null, { name: 'models', value: fullModels });
|
||||||
}}>填入所有模型</Button>
|
}}>填入所有模型</Button>
|
||||||
|
<Button type={'button'} onClick={() => {
|
||||||
|
handleInputChange(null, { name: 'models', value: [] });
|
||||||
|
}}>清除所有模型</Button>
|
||||||
</div>
|
</div>
|
||||||
{
|
|
||||||
inputs.type === 1 && (
|
|
||||||
<Form.Field>
|
|
||||||
<Form.Input
|
|
||||||
label='代理'
|
|
||||||
name='base_url'
|
|
||||||
placeholder={'请输入 OpenAI API 代理地址,如果不需要请留空,格式为:https://api.openai.com'}
|
|
||||||
onChange={handleInputChange}
|
|
||||||
value={inputs.base_url}
|
|
||||||
autoComplete='new-password'
|
|
||||||
/>
|
|
||||||
</Form.Field>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
{
|
{
|
||||||
batch ? <Form.Field>
|
batch ? <Form.Field>
|
||||||
<Form.TextArea
|
<Form.TextArea
|
||||||
|
|||||||
Reference in New Issue
Block a user