mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-29 04:43:41 +08:00
Compare commits
7 Commits
v0.5.5-alp
...
v0.5.6-alp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92001986db | ||
|
|
a5647b1ea7 | ||
|
|
215e54fc96 | ||
|
|
ecf8a6d875 | ||
|
|
24df3e5f62 | ||
|
|
12ef9679a7 | ||
|
|
328aa68255 |
@@ -269,6 +269,12 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
|
||||
|
||||
注意,具体的 API Base 的格式取决于你所使用的客户端。
|
||||
|
||||
例如对于 OpenAI 的官方库:
|
||||
```bash
|
||||
OPENAI_API_KEY="sk-xxxxxx"
|
||||
OPENAI_API_BASE="https://<HOST>:<PORT>/v1"
|
||||
```
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
A(用户)
|
||||
|
||||
@@ -2,6 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
@@ -18,19 +19,21 @@ func GetAllLogs(c *gin.Context) {
|
||||
username := c.Query("username")
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetUserLogs(c *gin.Context) {
|
||||
@@ -46,34 +49,36 @@ func GetUserLogs(c *gin.Context) {
|
||||
modelName := c.Query("model_name")
|
||||
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SearchAllLogs(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
logs, err := model.SearchAllLogs(keyword)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SearchUserLogs(c *gin.Context) {
|
||||
@@ -81,17 +86,18 @@ func SearchUserLogs(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
logs, err := model.SearchUserLogs(userId, keyword)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetLogsStat(c *gin.Context) {
|
||||
@@ -101,9 +107,10 @@ func GetLogsStat(c *gin.Context) {
|
||||
tokenName := c.Query("token_name")
|
||||
username := c.Query("username")
|
||||
modelName := c.Query("model_name")
|
||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
|
||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
@@ -111,6 +118,7 @@ func GetLogsStat(c *gin.Context) {
|
||||
//"token": tokenNum,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetLogsSelfStat(c *gin.Context) {
|
||||
@@ -120,9 +128,10 @@ func GetLogsSelfStat(c *gin.Context) {
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
|
||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
@@ -130,4 +139,30 @@ func GetLogsSelfStat(c *gin.Context) {
|
||||
//"token": tokenNum,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteHistoryLogs(c *gin.Context) {
|
||||
targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64)
|
||||
if targetTimestamp == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "target timestamp is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
count, err := model.DeleteOldLog(targetTimestamp)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": count,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
|
||||
@@ -107,7 +108,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, 0, 0, audioModel, tokenName, quota, logContent)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
|
||||
@@ -19,6 +19,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
userId := c.GetInt("id")
|
||||
consumeQuota := c.GetBool("consume_quota")
|
||||
group := c.GetString("group")
|
||||
@@ -138,7 +139,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, 0, 0, imageModel, tokenName, quota, logContent)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
|
||||
@@ -38,6 +38,7 @@ func init() {
|
||||
|
||||
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
consumeQuota := c.GetBool("consume_quota")
|
||||
@@ -364,7 +365,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
|
||||
var textResponse TextResponse
|
||||
tokenName := c.GetString("token_name")
|
||||
channelId := c.GetInt("channel_id")
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
// c.Writer.Flush()
|
||||
@@ -397,7 +397,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
if quota != 0 {
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
@@ -541,24 +541,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
return nil
|
||||
}
|
||||
case APITypeXunfei:
|
||||
if isStream {
|
||||
auth := c.Request.Header.Get("Authorization")
|
||||
auth = strings.TrimPrefix(auth, "Bearer ")
|
||||
splits := strings.Split(auth, "|")
|
||||
if len(splits) != 3 {
|
||||
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||
}
|
||||
err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
|
||||
auth := c.Request.Header.Get("Authorization")
|
||||
auth = strings.TrimPrefix(auth, "Bearer ")
|
||||
splits := strings.Split(auth, "|")
|
||||
if len(splits) != 3 {
|
||||
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||
}
|
||||
var err *OpenAIErrorWithStatusCode
|
||||
var usage *Usage
|
||||
if isStream {
|
||||
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
|
||||
} else {
|
||||
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
case APITypeAIProxyLibrary:
|
||||
if isStream {
|
||||
err, usage := aiProxyLibraryStreamHandler(c, resp)
|
||||
|
||||
@@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
|
||||
Role: "assistant",
|
||||
Content: response.Payload.Choices.Text[0].Content,
|
||||
},
|
||||
FinishReason: stopFinishReason,
|
||||
}
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
Object: "chat.completion",
|
||||
@@ -177,33 +178,82 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
||||
}
|
||||
|
||||
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||
}
|
||||
setEventStreamHeaders(c)
|
||||
var usage Usage
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion == "" {
|
||||
apiVersion = c.GetString("api_version")
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case xunfeiResponse := <-dataChan:
|
||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||
}
|
||||
if apiVersion == "" {
|
||||
apiVersion = "v1.1"
|
||||
common.SysLog("api_version not found, use default: " + apiVersion)
|
||||
var usage Usage
|
||||
var content string
|
||||
var xunfeiResponse XunfeiChatResponse
|
||||
stop := false
|
||||
for !stop {
|
||||
select {
|
||||
case xunfeiResponse = <-dataChan:
|
||||
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||
case stop = <-stopChan:
|
||||
}
|
||||
}
|
||||
domain := "general"
|
||||
if apiVersion == "v2.1" {
|
||||
domain = "generalv2"
|
||||
|
||||
xunfeiResponse.Payload.Choices.Text[0].Content = content
|
||||
|
||||
response := responseXunfei2OpenAI(&xunfeiResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion)
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
|
||||
d := websocket.Dialer{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
}
|
||||
conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
|
||||
conn, resp, err := d.Dial(authUrl, nil)
|
||||
if err != nil || resp.StatusCode != 101 {
|
||||
return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
|
||||
return nil, nil, err
|
||||
}
|
||||
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
||||
err = conn.WriteJSON(data)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dataChan := make(chan XunfeiChatResponse)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
@@ -230,61 +280,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case xunfeiResponse := <-dataChan:
|
||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
return nil, &usage
|
||||
|
||||
return dataChan, stopChan, nil
|
||||
}
|
||||
|
||||
func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var xunfeiResponse XunfeiChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion == "" {
|
||||
apiVersion = c.GetString("api_version")
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
if apiVersion == "" {
|
||||
apiVersion = "v1.1"
|
||||
common.SysLog("api_version not found, use default: " + apiVersion)
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &xunfeiResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
domain := "general"
|
||||
if apiVersion == "v2.1" {
|
||||
domain = "generalv2"
|
||||
}
|
||||
if xunfeiResponse.Header.Code != 0 {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
Message: xunfeiResponse.Header.Message,
|
||||
Type: "xunfei_error",
|
||||
Param: "",
|
||||
Code: xunfeiResponse.Header.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
||||
return domain, authUrl
|
||||
}
|
||||
|
||||
@@ -10,15 +10,16 @@ type Ability struct {
|
||||
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
|
||||
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||
}
|
||||
|
||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
ability := Ability{}
|
||||
var err error = nil
|
||||
if common.UsingSQLite {
|
||||
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
|
||||
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error
|
||||
} else {
|
||||
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
|
||||
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -40,6 +41,7 @@ func (channel *Channel) AddAbilities() error {
|
||||
Model: model,
|
||||
ChannelId: channel.Id,
|
||||
Enabled: channel.Status == common.ChannelStatusEnabled,
|
||||
Priority: channel.Priority,
|
||||
}
|
||||
abilities = append(abilities, ability)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -159,6 +160,17 @@ func InitChannelCache() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sort by priority
|
||||
for group, model2channels := range newGroup2model2channels {
|
||||
for model, channels := range model2channels {
|
||||
sort.Slice(channels, func(i, j int) bool {
|
||||
return channels[i].GetPriority() > channels[j].GetPriority()
|
||||
})
|
||||
newGroup2model2channels[group][model] = channels
|
||||
}
|
||||
}
|
||||
|
||||
channelSyncLock.Lock()
|
||||
group2model2channels = newGroup2model2channels
|
||||
channelSyncLock.Unlock()
|
||||
@@ -183,6 +195,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
||||
if len(channels) == 0 {
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
idx := rand.Intn(len(channels))
|
||||
endIdx := len(channels)
|
||||
// choose by priority
|
||||
firstChannel := channels[0]
|
||||
if firstChannel.GetPriority() > 0 {
|
||||
for i := range channels {
|
||||
if channels[i].GetPriority() != firstChannel.GetPriority() {
|
||||
endIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
idx := rand.Intn(endIdx)
|
||||
return channels[idx], nil
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ type Channel struct {
|
||||
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
||||
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||
}
|
||||
|
||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
||||
@@ -78,6 +79,13 @@ func BatchInsertChannels(channels []Channel) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (channel *Channel) GetPriority() int64 {
|
||||
if channel == nil {
|
||||
return 0
|
||||
}
|
||||
return *channel.Priority
|
||||
}
|
||||
|
||||
func (channel *Channel) Insert() error {
|
||||
var err error
|
||||
err = DB.Create(channel).Error
|
||||
|
||||
22
model/log.go
22
model/log.go
@@ -19,6 +19,7 @@ type Log struct {
|
||||
Quota int `json:"quota" gorm:"default:0"`
|
||||
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
||||
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
||||
Channel int `json:"channel" gorm:"default:0"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -46,8 +47,9 @@ func RecordLog(userId int, logType int, content string) {
|
||||
}
|
||||
}
|
||||
|
||||
func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
|
||||
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||
|
||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
|
||||
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||
if !common.LogConsumeEnabled {
|
||||
return
|
||||
}
|
||||
@@ -62,6 +64,7 @@ func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, complet
|
||||
TokenName: tokenName,
|
||||
ModelName: modelName,
|
||||
Quota: quota,
|
||||
Channel: channelId,
|
||||
}
|
||||
err := DB.Create(log).Error
|
||||
if err != nil {
|
||||
@@ -69,7 +72,7 @@ func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, complet
|
||||
}
|
||||
}
|
||||
|
||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
|
||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
|
||||
var tx *gorm.DB
|
||||
if logType == LogTypeUnknown {
|
||||
tx = DB
|
||||
@@ -91,6 +94,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
if endTimestamp != 0 {
|
||||
tx = tx.Where("created_at <= ?", endTimestamp)
|
||||
}
|
||||
if channel != 0 {
|
||||
tx = tx.Where("channel = ?", channel)
|
||||
}
|
||||
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
||||
return logs, err
|
||||
}
|
||||
@@ -128,7 +134,7 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
|
||||
return logs, err
|
||||
}
|
||||
|
||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) {
|
||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) {
|
||||
tx := DB.Table("logs").Select("sum(quota)")
|
||||
if username != "" {
|
||||
tx = tx.Where("username = ?", username)
|
||||
@@ -145,6 +151,9 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
||||
if modelName != "" {
|
||||
tx = tx.Where("model_name = ?", modelName)
|
||||
}
|
||||
if channel != 0 {
|
||||
tx = tx.Where("channel = ?", channel)
|
||||
}
|
||||
tx.Where("type = ?", LogTypeConsume).Scan("a)
|
||||
return quota
|
||||
}
|
||||
@@ -169,3 +178,8 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
||||
tx.Where("type = ?", LogTypeConsume).Scan(&token)
|
||||
return token
|
||||
}
|
||||
|
||||
func DeleteOldLog(targetTimestamp int64) (int64, error) {
|
||||
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
@@ -98,6 +98,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
}
|
||||
logRoute := apiRouter.Group("/log")
|
||||
logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
|
||||
logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs)
|
||||
logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat)
|
||||
logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat)
|
||||
logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
|
||||
import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react';
|
||||
import { Link } from 'react-router-dom';
|
||||
import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers';
|
||||
|
||||
@@ -24,7 +24,7 @@ function renderType(type) {
|
||||
}
|
||||
type2label[0] = { value: 0, text: '未知类型', color: 'grey' };
|
||||
}
|
||||
return <Label basic color={type2label[type].color}>{type2label[type].text}</Label>;
|
||||
return <Label basic color={type2label[type]?.color}>{type2label[type]?.text}</Label>;
|
||||
}
|
||||
|
||||
function renderBalance(type, balance) {
|
||||
@@ -96,7 +96,7 @@ const ChannelsTable = () => {
|
||||
});
|
||||
}, []);
|
||||
|
||||
const manageChannel = async (id, action, idx) => {
|
||||
const manageChannel = async (id, action, idx, priority) => {
|
||||
let data = { id };
|
||||
let res;
|
||||
switch (action) {
|
||||
@@ -111,6 +111,13 @@ const ChannelsTable = () => {
|
||||
data.status = 2;
|
||||
res = await API.put('/api/channel/', data);
|
||||
break;
|
||||
case 'priority':
|
||||
if (priority === '') {
|
||||
return;
|
||||
}
|
||||
data.priority = parseInt(priority);
|
||||
res = await API.put('/api/channel/', data);
|
||||
break;
|
||||
}
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
@@ -335,6 +342,14 @@ const ChannelsTable = () => {
|
||||
>
|
||||
余额
|
||||
</Table.HeaderCell>
|
||||
<Table.HeaderCell
|
||||
style={{ cursor: 'pointer' }}
|
||||
onClick={() => {
|
||||
sortChannel('priority');
|
||||
}}
|
||||
>
|
||||
优先级
|
||||
</Table.HeaderCell>
|
||||
<Table.HeaderCell>操作</Table.HeaderCell>
|
||||
</Table.Row>
|
||||
</Table.Header>
|
||||
@@ -373,6 +388,22 @@ const ChannelsTable = () => {
|
||||
basic
|
||||
/>
|
||||
</Table.Cell>
|
||||
<Table.Cell>
|
||||
<Popup
|
||||
trigger={<Input type="number" defaultValue={channel.priority} onBlur={(event) => {
|
||||
manageChannel(
|
||||
channel.id,
|
||||
'priority',
|
||||
idx,
|
||||
event.target.value,
|
||||
);
|
||||
}}>
|
||||
<input style={{maxWidth:'60px'}} />
|
||||
</Input>}
|
||||
content='渠道选择优先级,越高越优先'
|
||||
basic
|
||||
/>
|
||||
</Table.Cell>
|
||||
<Table.Cell>
|
||||
<div>
|
||||
<Button
|
||||
@@ -441,7 +472,7 @@ const ChannelsTable = () => {
|
||||
|
||||
<Table.Footer>
|
||||
<Table.Row>
|
||||
<Table.HeaderCell colSpan='8'>
|
||||
<Table.HeaderCell colSpan='9'>
|
||||
<Button size='small' as={Link} to='/channel/add' loading={loading}>
|
||||
添加新的渠道
|
||||
</Button>
|
||||
|
||||
@@ -56,9 +56,10 @@ const LogsTable = () => {
|
||||
token_name: '',
|
||||
model_name: '',
|
||||
start_timestamp: timestamp2string(0),
|
||||
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600)
|
||||
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
|
||||
channel: ''
|
||||
});
|
||||
const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs;
|
||||
const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs;
|
||||
|
||||
const [stat, setStat] = useState({
|
||||
quota: 0,
|
||||
@@ -84,7 +85,7 @@ const LogsTable = () => {
|
||||
const getLogStat = async () => {
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
|
||||
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
setStat(data);
|
||||
@@ -109,7 +110,7 @@ const LogsTable = () => {
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
if (isAdminUser) {
|
||||
url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
|
||||
} else {
|
||||
url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
}
|
||||
@@ -205,16 +206,9 @@ const LogsTable = () => {
|
||||
</Header>
|
||||
<Form>
|
||||
<Form.Group>
|
||||
{
|
||||
isAdminUser && (
|
||||
<Form.Input fluid label={'用户名称'} width={2} value={username}
|
||||
placeholder={'可选值'} name='username'
|
||||
onChange={handleInputChange} />
|
||||
)
|
||||
}
|
||||
<Form.Input fluid label={'令牌名称'} width={isAdminUser ? 2 : 3} value={token_name}
|
||||
<Form.Input fluid label={'令牌名称'} width={3} value={token_name}
|
||||
placeholder={'可选值'} name='token_name' onChange={handleInputChange} />
|
||||
<Form.Input fluid label='模型名称' width={isAdminUser ? 2 : 3} value={model_name} placeholder='可选值'
|
||||
<Form.Input fluid label='模型名称' width={3} value={model_name} placeholder='可选值'
|
||||
name='model_name'
|
||||
onChange={handleInputChange} />
|
||||
<Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local'
|
||||
@@ -225,6 +219,19 @@ const LogsTable = () => {
|
||||
onChange={handleInputChange} />
|
||||
<Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button>
|
||||
</Form.Group>
|
||||
{
|
||||
isAdminUser && <>
|
||||
<Form.Group>
|
||||
<Form.Input fluid label={'渠道 ID'} width={3} value={channel}
|
||||
placeholder='可选值' name='channel'
|
||||
onChange={handleInputChange} />
|
||||
<Form.Input fluid label={'用户名称'} width={3} value={username}
|
||||
placeholder={'可选值'} name='username'
|
||||
onChange={handleInputChange} />
|
||||
|
||||
</Form.Group>
|
||||
</>
|
||||
}
|
||||
</Form>
|
||||
<Table basic compact size='small'>
|
||||
<Table.Header>
|
||||
@@ -238,6 +245,17 @@ const LogsTable = () => {
|
||||
>
|
||||
时间
|
||||
</Table.HeaderCell>
|
||||
{
|
||||
isAdminUser && <Table.HeaderCell
|
||||
style={{ cursor: 'pointer' }}
|
||||
onClick={() => {
|
||||
sortLog('channel');
|
||||
}}
|
||||
width={1}
|
||||
>
|
||||
渠道
|
||||
</Table.HeaderCell>
|
||||
}
|
||||
{
|
||||
isAdminUser && <Table.HeaderCell
|
||||
style={{ cursor: 'pointer' }}
|
||||
@@ -299,16 +317,16 @@ const LogsTable = () => {
|
||||
onClick={() => {
|
||||
sortLog('quota');
|
||||
}}
|
||||
width={2}
|
||||
width={1}
|
||||
>
|
||||
消耗额度
|
||||
额度
|
||||
</Table.HeaderCell>
|
||||
<Table.HeaderCell
|
||||
style={{ cursor: 'pointer' }}
|
||||
onClick={() => {
|
||||
sortLog('content');
|
||||
}}
|
||||
width={isAdminUser ? 4 : 5}
|
||||
width={isAdminUser ? 4 : 6}
|
||||
>
|
||||
详情
|
||||
</Table.HeaderCell>
|
||||
@@ -326,6 +344,11 @@ const LogsTable = () => {
|
||||
return (
|
||||
<Table.Row key={log.id}>
|
||||
<Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
|
||||
{
|
||||
isAdminUser && (
|
||||
<Table.Cell>{log.channel ? <Label basic>{log.channel}</Label> : ''}</Table.Cell>
|
||||
)
|
||||
}
|
||||
{
|
||||
isAdminUser && (
|
||||
<Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell>
|
||||
@@ -345,7 +368,7 @@ const LogsTable = () => {
|
||||
|
||||
<Table.Footer>
|
||||
<Table.Row>
|
||||
<Table.HeaderCell colSpan={'9'}>
|
||||
<Table.HeaderCell colSpan={'10'}>
|
||||
<Select
|
||||
placeholder='选择明细分类'
|
||||
options={LOG_OPTIONS}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Divider, Form, Grid, Header } from 'semantic-ui-react';
|
||||
import { API, showError, verifyJSON } from '../helpers';
|
||||
import { API, showError, showSuccess, timestamp2string, verifyJSON } from '../helpers';
|
||||
|
||||
const OperationSetting = () => {
|
||||
let now = new Date();
|
||||
let [inputs, setInputs] = useState({
|
||||
QuotaForNewUser: 0,
|
||||
QuotaForInviter: 0,
|
||||
@@ -20,10 +21,11 @@ const OperationSetting = () => {
|
||||
DisplayInCurrencyEnabled: '',
|
||||
DisplayTokenStatEnabled: '',
|
||||
ApproximateTokenEnabled: '',
|
||||
RetryTimes: 0,
|
||||
RetryTimes: 0
|
||||
});
|
||||
const [originInputs, setOriginInputs] = useState({});
|
||||
let [loading, setLoading] = useState(false);
|
||||
let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
|
||||
|
||||
const getOptions = async () => {
|
||||
const res = await API.get('/api/option/');
|
||||
@@ -130,6 +132,17 @@ const OperationSetting = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const deleteHistoryLogs = async () => {
|
||||
console.log(inputs);
|
||||
const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
showSuccess(`${data} 条日志已清理!`);
|
||||
return;
|
||||
}
|
||||
showError('日志清理失败:' + message);
|
||||
};
|
||||
|
||||
return (
|
||||
<Grid columns={1}>
|
||||
<Grid.Column>
|
||||
@@ -179,12 +192,6 @@ const OperationSetting = () => {
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.LogConsumeEnabled === 'true'}
|
||||
label='启用额度消费日志记录'
|
||||
name='LogConsumeEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.DisplayInCurrencyEnabled === 'true'}
|
||||
label='以货币形式显示额度'
|
||||
@@ -208,6 +215,28 @@ const OperationSetting = () => {
|
||||
submitConfig('general').then();
|
||||
}}>保存通用设置</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3'>
|
||||
日志设置
|
||||
</Header>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.LogConsumeEnabled === 'true'}
|
||||
label='启用额度消费日志记录'
|
||||
name='LogConsumeEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths={4}>
|
||||
<Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
|
||||
name='history_timestamp'
|
||||
onChange={(e, { name, value }) => {
|
||||
setHistoryTimestamp(value);
|
||||
}} />
|
||||
</Form.Group>
|
||||
<Form.Button onClick={() => {
|
||||
deleteHistoryLogs().then();
|
||||
}}>清理历史日志</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3'>
|
||||
监控设置
|
||||
</Header>
|
||||
|
||||
@@ -96,7 +96,7 @@ const TokensTable = () => {
|
||||
let nextUrl;
|
||||
|
||||
if (nextLink) {
|
||||
nextUrl = nextLink + `/#/?settings={"key":"sk-${key}"}`;
|
||||
nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
||||
} else {
|
||||
nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user