fix: only reduce remain times when request /v1/chat/completions (close #15)

BREAKING CHANGE: now remain_times is -1 doesn't mean unlimited times anymore!
This commit is contained in:
JustSong
2023-04-26 10:45:34 +08:00
parent eb8f43acb5
commit 109736cc05
6 changed files with 83 additions and 40 deletions

View File

@@ -7,16 +7,20 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"
)
func Relay(c *gin.Context) {
channelType := c.GetInt("channel")
tokenId := c.GetInt("token_id")
isUnlimitedTimes := c.GetBool("unlimited_times")
baseURL := common.ChannelBaseURLs[channelType]
if channelType == common.ChannelTypeCustom {
baseURL = c.GetString("base_url")
}
req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s%s", baseURL, c.Request.URL.String()), c.Request.Body)
requestURL := c.Request.URL.String()
req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s%s", baseURL, requestURL), c.Request.Body)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
@@ -46,7 +50,19 @@ func Relay(c *gin.Context) {
})
return
}
defer resp.Body.Close()
defer func() {
err := req.Body.Close()
if err != nil {
common.SysError("Error closing request body: " + err.Error())
}
if !isUnlimitedTimes && requestURL == "/v1/chat/completions" {
err := model.DecreaseTokenRemainTimesById(tokenId)
if err != nil {
common.SysError("Error decreasing token remain times: " + err.Error())
}
}
}()
isStream := resp.Header.Get("Content-Type") == "text/event-stream"
if isStream {
scanner := bufio.NewScanner(resp.Body)

View File

@@ -93,13 +93,14 @@ func AddToken(c *gin.Context) {
return
}
cleanToken := model.Token{
UserId: c.GetInt("id"),
Name: token.Name,
Key: common.GetUUID(),
CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(),
ExpiredTime: token.ExpiredTime,
RemainTimes: token.RemainTimes,
UserId: c.GetInt("id"),
Name: token.Name,
Key: common.GetUUID(),
CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(),
ExpiredTime: token.ExpiredTime,
RemainTimes: token.RemainTimes,
UnlimitedTimes: token.UnlimitedTimes,
}
err = cleanToken.Insert()
if err != nil {
@@ -136,6 +137,7 @@ func DeleteToken(c *gin.Context) {
func UpdateToken(c *gin.Context) {
userId := c.GetInt("id")
statusOnly := c.Query("status_only")
token := model.Token{}
err := c.ShouldBindJSON(&token)
if err != nil {
@@ -161,19 +163,23 @@ func UpdateToken(c *gin.Context) {
})
return
}
if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainTimes == 0 {
if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainTimes <= 0 && !cleanToken.UnlimitedTimes {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌可用次数已用尽,无法启用,请先修改令牌剩余次数",
"message": "令牌可用次数已用尽,无法启用,请先修改令牌剩余次数,或者设置为无限次数",
})
return
}
}
cleanToken.Name = token.Name
cleanToken.Status = token.Status
cleanToken.ExpiredTime = token.ExpiredTime
cleanToken.RemainTimes = token.RemainTimes
if statusOnly != "" {
cleanToken.Status = token.Status
} else {
// If you add more fields, please also update token.Update()
cleanToken.Name = token.Name
cleanToken.ExpiredTime = token.ExpiredTime
cleanToken.RemainTimes = token.RemainTimes
cleanToken.UnlimitedTimes = token.UnlimitedTimes
}
err = cleanToken.Update()
if err != nil {
c.JSON(http.StatusOK, gin.H{