mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-10 10:33:41 +08:00
Compare commits
2 Commits
v0.5.5-alp
...
v0.5.5-alp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3dc315e75 | ||
|
|
04acdb1ccb |
@@ -306,6 +306,10 @@ graph LR
|
|||||||
+ 例子:`CHANNEL_TEST_FREQUENCY=1440`
|
+ 例子:`CHANNEL_TEST_FREQUENCY=1440`
|
||||||
9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
|
9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
|
||||||
+ 例子:`POLLING_INTERVAL=5`
|
+ 例子:`POLLING_INTERVAL=5`
|
||||||
|
10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
|
||||||
|
+ 例子:`BATCH_UPDATE_ENABLED=true`
|
||||||
|
11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
|
||||||
|
+ 例子:`BATCH_UPDATE_INTERVAL=5`
|
||||||
|
|
||||||
### 命令行参数
|
### 命令行参数
|
||||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
||||||
|
|||||||
@@ -94,6 +94,9 @@ var RequestInterval = time.Duration(requestInterval) * time.Second
|
|||||||
|
|
||||||
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
|
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
|
||||||
|
|
||||||
|
var BatchUpdateEnabled = false
|
||||||
|
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RoleGuestUser = 0
|
RoleGuestUser = 0
|
||||||
RoleCommonUser = 1
|
RoleCommonUser = 1
|
||||||
@@ -175,6 +178,7 @@ const (
|
|||||||
ChannelTypeXunfei = 18
|
ChannelTypeXunfei = 18
|
||||||
ChannelType360 = 19
|
ChannelType360 = 19
|
||||||
ChannelTypeOpenRouter = 20
|
ChannelTypeOpenRouter = 20
|
||||||
|
ChannelTypeAIProxyLibrary = 21
|
||||||
)
|
)
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
var ChannelBaseURLs = []string{
|
||||||
@@ -199,4 +203,5 @@ var ChannelBaseURLs = []string{
|
|||||||
"", // 18
|
"", // 18
|
||||||
"https://ai.360.cn", // 19
|
"https://ai.360.cn", // 19
|
||||||
"https://openrouter.ai/api", // 20
|
"https://openrouter.ai/api", // 20
|
||||||
|
"https://api.aiproxy.io", // 21
|
||||||
}
|
}
|
||||||
|
|||||||
220
controller/relay-aiproxy.go
Normal file
220
controller/relay-aiproxy.go
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
|
||||||
|
|
||||||
|
type AIProxyLibraryRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Query string `json:"query"`
|
||||||
|
LibraryId string `json:"libraryId"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AIProxyLibraryError struct {
|
||||||
|
ErrCode int `json:"errCode"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AIProxyLibraryDocument struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AIProxyLibraryResponse struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Answer string `json:"answer"`
|
||||||
|
Documents []AIProxyLibraryDocument `json:"documents"`
|
||||||
|
AIProxyLibraryError
|
||||||
|
}
|
||||||
|
|
||||||
|
type AIProxyLibraryStreamResponse struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Finish bool `json:"finish"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Documents []AIProxyLibraryDocument `json:"documents"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
|
||||||
|
query := ""
|
||||||
|
if len(request.Messages) != 0 {
|
||||||
|
query = request.Messages[len(request.Messages)-1].Content
|
||||||
|
}
|
||||||
|
return &AIProxyLibraryRequest{
|
||||||
|
Model: request.Model,
|
||||||
|
Stream: request.Stream,
|
||||||
|
Query: query,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
|
||||||
|
if len(documents) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
content := "\n\n参考文档:\n"
|
||||||
|
for i, document := range documents {
|
||||||
|
content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
|
||||||
|
}
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
|
||||||
|
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
|
||||||
|
choice := OpenAITextResponseChoice{
|
||||||
|
Index: 0,
|
||||||
|
Message: Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: content,
|
||||||
|
},
|
||||||
|
FinishReason: "stop",
|
||||||
|
}
|
||||||
|
fullTextResponse := OpenAITextResponse{
|
||||||
|
Id: common.GetUUID(),
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Choices: []OpenAITextResponseChoice{choice},
|
||||||
|
}
|
||||||
|
return &fullTextResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
|
||||||
|
var choice ChatCompletionsStreamResponseChoice
|
||||||
|
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
|
||||||
|
choice.FinishReason = &stopFinishReason
|
||||||
|
return &ChatCompletionsStreamResponse{
|
||||||
|
Id: common.GetUUID(),
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: "",
|
||||||
|
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
|
||||||
|
var choice ChatCompletionsStreamResponseChoice
|
||||||
|
choice.Delta.Content = response.Content
|
||||||
|
return &ChatCompletionsStreamResponse{
|
||||||
|
Id: common.GetUUID(),
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: response.Model,
|
||||||
|
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
var usage Usage
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||||
|
return i + 1, data[0:i], nil
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
dataChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if len(data) < 5 { // ignore blank line or wrong format
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if data[:5] != "data:" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = data[5:]
|
||||||
|
dataChan <- data
|
||||||
|
}
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
setEventStreamHeaders(c)
|
||||||
|
var documents []AIProxyLibraryDocument
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
var AIProxyLibraryResponse AIProxyLibraryStreamResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if len(AIProxyLibraryResponse.Documents) != 0 {
|
||||||
|
documents = AIProxyLibraryResponse.Documents
|
||||||
|
}
|
||||||
|
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
response := documentsAIProxyLibrary(documents)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
var AIProxyLibraryResponse AIProxyLibraryResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
if AIProxyLibraryResponse.ErrCode != 0 {
|
||||||
|
return &OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: OpenAIError{
|
||||||
|
Message: AIProxyLibraryResponse.Message,
|
||||||
|
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
|
||||||
|
Code: AIProxyLibraryResponse.ErrCode,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &fullTextResponse.Usage
|
||||||
|
}
|
||||||
@@ -22,6 +22,7 @@ const (
|
|||||||
APITypeZhipu
|
APITypeZhipu
|
||||||
APITypeAli
|
APITypeAli
|
||||||
APITypeXunfei
|
APITypeXunfei
|
||||||
|
APITypeAIProxyLibrary
|
||||||
)
|
)
|
||||||
|
|
||||||
var httpClient *http.Client
|
var httpClient *http.Client
|
||||||
@@ -104,6 +105,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
apiType = APITypeAli
|
apiType = APITypeAli
|
||||||
case common.ChannelTypeXunfei:
|
case common.ChannelTypeXunfei:
|
||||||
apiType = APITypeXunfei
|
apiType = APITypeXunfei
|
||||||
|
case common.ChannelTypeAIProxyLibrary:
|
||||||
|
apiType = APITypeAIProxyLibrary
|
||||||
}
|
}
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
baseURL := common.ChannelBaseURLs[channelType]
|
||||||
requestURL := c.Request.URL.String()
|
requestURL := c.Request.URL.String()
|
||||||
@@ -171,6 +174,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
||||||
case APITypeAli:
|
case APITypeAli:
|
||||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||||||
|
case APITypeAIProxyLibrary:
|
||||||
|
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
|
||||||
}
|
}
|
||||||
var promptTokens int
|
var promptTokens int
|
||||||
var completionTokens int
|
var completionTokens int
|
||||||
@@ -263,6 +268,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
|
case APITypeAIProxyLibrary:
|
||||||
|
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
|
||||||
|
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
|
||||||
|
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
var req *http.Request
|
var req *http.Request
|
||||||
@@ -302,6 +315,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
if textRequest.Stream {
|
if textRequest.Stream {
|
||||||
req.Header.Set("X-DashScope-SSE", "enable")
|
req.Header.Set("X-DashScope-SSE", "enable")
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
@@ -516,6 +531,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
} else {
|
} else {
|
||||||
return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
|
return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
case APITypeAIProxyLibrary:
|
||||||
|
if isStream {
|
||||||
|
err, usage := aiProxyLibraryStreamHandler(c, resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if usage != nil {
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
err, usage := aiProxyLibraryHandler(c, resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if usage != nil {
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|||||||
5
main.go
5
main.go
@@ -77,6 +77,11 @@ func main() {
|
|||||||
}
|
}
|
||||||
go controller.AutomaticallyTestChannels(frequency)
|
go controller.AutomaticallyTestChannels(frequency)
|
||||||
}
|
}
|
||||||
|
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
|
||||||
|
common.BatchUpdateEnabled = true
|
||||||
|
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
||||||
|
model.InitBatchUpdater()
|
||||||
|
}
|
||||||
controller.InitTokenEncoders()
|
controller.InitTokenEncoders()
|
||||||
|
|
||||||
// Initialize HTTP server
|
// Initialize HTTP server
|
||||||
|
|||||||
@@ -115,8 +115,13 @@ func Distribute() func(c *gin.Context) {
|
|||||||
c.Set("model_mapping", channel.ModelMapping)
|
c.Set("model_mapping", channel.ModelMapping)
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
c.Set("base_url", channel.BaseURL)
|
c.Set("base_url", channel.BaseURL)
|
||||||
if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei {
|
switch channel.Type {
|
||||||
|
case common.ChannelTypeAzure:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
|
case common.ChannelTypeXunfei:
|
||||||
|
c.Set("api_version", channel.Other)
|
||||||
|
case common.ChannelTypeAIProxyLibrary:
|
||||||
|
c.Set("library_id", channel.Other)
|
||||||
}
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -141,6 +141,14 @@ func UpdateChannelStatusById(id int, status int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateChannelUsedQuota(id int, quota int) {
|
func UpdateChannelUsedQuota(id int, quota int) {
|
||||||
|
if common.BatchUpdateEnabled {
|
||||||
|
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
updateChannelUsedQuota(id, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateChannelUsedQuota(id int, quota int) {
|
||||||
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update channel used quota: " + err.Error())
|
common.SysError("failed to update channel used quota: " + err.Error())
|
||||||
|
|||||||
@@ -131,6 +131,14 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
|
|||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
|
if common.BatchUpdateEnabled {
|
||||||
|
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return increaseTokenQuota(id, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func increaseTokenQuota(id int, quota int) (err error) {
|
||||||
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"remain_quota": gorm.Expr("remain_quota + ?", quota),
|
"remain_quota": gorm.Expr("remain_quota + ?", quota),
|
||||||
@@ -144,6 +152,14 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
|
|||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
|
if common.BatchUpdateEnabled {
|
||||||
|
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return decreaseTokenQuota(id, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decreaseTokenQuota(id int, quota int) (err error) {
|
||||||
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"remain_quota": gorm.Expr("remain_quota - ?", quota),
|
"remain_quota": gorm.Expr("remain_quota - ?", quota),
|
||||||
|
|||||||
@@ -275,6 +275,14 @@ func IncreaseUserQuota(id int, quota int) (err error) {
|
|||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
|
if common.BatchUpdateEnabled {
|
||||||
|
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return increaseUserQuota(id, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func increaseUserQuota(id int, quota int) (err error) {
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -283,6 +291,14 @@ func DecreaseUserQuota(id int, quota int) (err error) {
|
|||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
|
if common.BatchUpdateEnabled {
|
||||||
|
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return decreaseUserQuota(id, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decreaseUserQuota(id int, quota int) (err error) {
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -293,10 +309,18 @@ func GetRootUserEmail() (email string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
||||||
|
if common.BatchUpdateEnabled {
|
||||||
|
addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
updateUserUsedQuotaAndRequestCount(id, quota, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
|
||||||
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||||
"request_count": gorm.Expr("request_count + ?", 1),
|
"request_count": gorm.Expr("request_count + ?", count),
|
||||||
},
|
},
|
||||||
).Error
|
).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
75
model/utils.go
Normal file
75
model/utils.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/common"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock
|
||||||
|
|
||||||
|
const (
|
||||||
|
BatchUpdateTypeUserQuota = iota
|
||||||
|
BatchUpdateTypeTokenQuota
|
||||||
|
BatchUpdateTypeUsedQuotaAndRequestCount
|
||||||
|
BatchUpdateTypeChannelUsedQuota
|
||||||
|
)
|
||||||
|
|
||||||
|
var batchUpdateStores []map[int]int
|
||||||
|
var batchUpdateLocks []sync.Mutex
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||||
|
batchUpdateStores = append(batchUpdateStores, make(map[int]int))
|
||||||
|
batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitBatchUpdater() {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
|
||||||
|
batchUpdate()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func addNewRecord(type_ int, id int, value int) {
|
||||||
|
batchUpdateLocks[type_].Lock()
|
||||||
|
defer batchUpdateLocks[type_].Unlock()
|
||||||
|
if _, ok := batchUpdateStores[type_][id]; !ok {
|
||||||
|
batchUpdateStores[type_][id] = value
|
||||||
|
} else {
|
||||||
|
batchUpdateStores[type_][id] += value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func batchUpdate() {
|
||||||
|
common.SysLog("batch update started")
|
||||||
|
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||||
|
batchUpdateLocks[i].Lock()
|
||||||
|
store := batchUpdateStores[i]
|
||||||
|
batchUpdateStores[i] = make(map[int]int)
|
||||||
|
batchUpdateLocks[i].Unlock()
|
||||||
|
|
||||||
|
for key, value := range store {
|
||||||
|
switch i {
|
||||||
|
case BatchUpdateTypeUserQuota:
|
||||||
|
err := increaseUserQuota(key, value)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to batch update user quota: " + err.Error())
|
||||||
|
}
|
||||||
|
case BatchUpdateTypeTokenQuota:
|
||||||
|
err := increaseTokenQuota(key, value)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to batch update token quota: " + err.Error())
|
||||||
|
}
|
||||||
|
case BatchUpdateTypeUsedQuotaAndRequestCount:
|
||||||
|
updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect
|
||||||
|
case BatchUpdateTypeChannelUsedQuota:
|
||||||
|
updateChannelUsedQuota(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
common.SysLog("batch update finished")
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ export const CHANNEL_OPTIONS = [
|
|||||||
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
|
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
|
||||||
{ key: 19, text: '360 智脑', value: 19, color: 'blue' },
|
{ key: 19, text: '360 智脑', value: 19, color: 'blue' },
|
||||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||||
|
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||||
{ key: 20, text: '代理:OpenRouter', value: 20, color: 'black' },
|
{ key: 20, text: '代理:OpenRouter', value: 20, color: 'black' },
|
||||||
{ key: 2, text: '代理:API2D', value: 2, color: 'blue' },
|
{ key: 2, text: '代理:API2D', value: 2, color: 'blue' },
|
||||||
{ key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' },
|
{ key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' },
|
||||||
|
|||||||
@@ -295,6 +295,20 @@ const EditChannel = () => {
|
|||||||
</Form.Field>
|
</Form.Field>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
inputs.type === 21 && (
|
||||||
|
<Form.Field>
|
||||||
|
<Form.Input
|
||||||
|
label='知识库 ID'
|
||||||
|
name='other'
|
||||||
|
placeholder={'请输入知识库 ID,例如:123456'}
|
||||||
|
onChange={handleInputChange}
|
||||||
|
value={inputs.other}
|
||||||
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
</Form.Field>
|
||||||
|
)
|
||||||
|
}
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.Dropdown
|
<Form.Dropdown
|
||||||
label='模型'
|
label='模型'
|
||||||
|
|||||||
Reference in New Issue
Block a user