mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-26 11:23:43 +08:00 
			
		
		
		
	Compare commits
	
		
			10 Commits
		
	
	
		
			v0.6.5-alp
			...
			v0.6.5-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | ed70881a58 | ||
|  | 8b9fa3d6e4 | ||
|  | 8b9813d63b | ||
|  | dc7aaf2de5 | ||
|  | 065da8ef8c | ||
|  | e3cfb1fa52 | ||
|  | f89ae5ad58 | ||
|  | 06a3fc5421 | ||
|  | a9c464ec5a | ||
|  | 3f3c13c98c | 
| @@ -75,7 +75,7 @@ var ModelRatio = map[string]float64{ | ||||
| 	"ERNIE-Bot":       0.8572,     // ¥0.012 / 1k tokens | ||||
| 	"ERNIE-Bot-turbo": 0.5715,     // ¥0.008 / 1k tokens | ||||
| 	"ERNIE-Bot-4":     0.12 * RMB, // ¥0.12 / 1k tokens | ||||
| 	"ERNIE-Bot-8k":    0.024 * RMB, | ||||
| 	"ERNIE-Bot-8K":    0.024 * RMB, | ||||
| 	"Embedding-V1":    0.1429, // ¥0.002 / 1k tokens | ||||
| 	"bge-large-zh":    0.002 * RMB, | ||||
| 	"bge-large-en":    0.002 * RMB, | ||||
|   | ||||
| @@ -4,12 +4,14 @@ import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/helper" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://platform.openai.com/docs/api-reference/models/list | ||||
| @@ -120,9 +122,41 @@ func DashboardListModels(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func ListModels(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	var availableModels []string | ||||
| 	if c.GetString("available_models") != "" { | ||||
| 		availableModels = strings.Split(c.GetString("available_models"), ",") | ||||
| 	} else { | ||||
| 		userId := c.GetInt("id") | ||||
| 		userGroup, _ := model.CacheGetUserGroup(userId) | ||||
| 		availableModels, _ = model.CacheGetGroupModels(ctx, userGroup) | ||||
| 	} | ||||
| 	modelSet := make(map[string]bool) | ||||
| 	for _, availableModel := range availableModels { | ||||
| 		modelSet[availableModel] = true | ||||
| 	} | ||||
| 	var availableOpenAIModels []OpenAIModels | ||||
| 	for _, model := range openAIModels { | ||||
| 		if _, ok := modelSet[model.Id]; ok { | ||||
| 			modelSet[model.Id] = false | ||||
| 			availableOpenAIModels = append(availableOpenAIModels, model) | ||||
| 		} | ||||
| 	} | ||||
| 	for modelName, ok := range modelSet { | ||||
| 		if ok { | ||||
| 			availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{ | ||||
| 				Id:      modelName, | ||||
| 				Object:  "model", | ||||
| 				Created: 1626777600, | ||||
| 				OwnedBy: "custom", | ||||
| 				Root:    modelName, | ||||
| 				Parent:  nil, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	c.JSON(200, gin.H{ | ||||
| 		"object": "list", | ||||
| 		"data":   openAIModels, | ||||
| 		"data":   availableOpenAIModels, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| @@ -142,3 +176,30 @@ func RetrieveModel(c *gin.Context) { | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetUserAvailableModels(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	id := c.GetInt("id") | ||||
| 	userGroup, err := model.CacheGetUserGroup(id) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	models, err := model.CacheGetGroupModels(ctx, userGroup) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data":    models, | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
|   | ||||
| @@ -130,6 +130,7 @@ func AddToken(c *gin.Context) { | ||||
| 		ExpiredTime:    token.ExpiredTime, | ||||
| 		RemainQuota:    token.RemainQuota, | ||||
| 		UnlimitedQuota: token.UnlimitedQuota, | ||||
| 		Models:         token.Models, | ||||
| 	} | ||||
| 	err = cleanToken.Insert() | ||||
| 	if err != nil { | ||||
| @@ -216,6 +217,7 @@ func UpdateToken(c *gin.Context) { | ||||
| 		cleanToken.ExpiredTime = token.ExpiredTime | ||||
| 		cleanToken.RemainQuota = token.RemainQuota | ||||
| 		cleanToken.UnlimitedQuota = token.UnlimitedQuota | ||||
| 		cleanToken.Models = token.Models | ||||
| 	} | ||||
| 	err = cleanToken.Update() | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| @@ -107,6 +108,19 @@ func TokenAuth() func(c *gin.Context) { | ||||
| 			abortWithMessage(c, http.StatusForbidden, "用户已被封禁") | ||||
| 			return | ||||
| 		} | ||||
| 		requestModel, err := getRequestModel(c) | ||||
| 		if err != nil { | ||||
| 			abortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||
| 			return | ||||
| 		} | ||||
| 		c.Set("request_model", requestModel) | ||||
| 		if token.Models != nil && *token.Models != "" { | ||||
| 			c.Set("available_models", *token.Models) | ||||
| 			if requestModel != "" && !isModelInList(requestModel, *token.Models) { | ||||
| 				abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel)) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 		c.Set("id", token.UserId) | ||||
| 		c.Set("token_id", token.Id) | ||||
| 		c.Set("token_name", token.Name) | ||||
|   | ||||
| @@ -2,14 +2,12 @@ package middleware | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type ModelRequest struct { | ||||
| @@ -40,37 +38,11 @@ func Distribute() func(c *gin.Context) { | ||||
| 				return | ||||
| 			} | ||||
| 		} else { | ||||
| 			// Select a channel for the user | ||||
| 			var modelRequest ModelRequest | ||||
| 			err := common.UnmarshalBodyReusable(c, &modelRequest) | ||||
| 			requestModel := c.GetString("request_model") | ||||
| 			var err error | ||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false) | ||||
| 			if err != nil { | ||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的请求") | ||||
| 				return | ||||
| 			} | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = "text-moderation-stable" | ||||
| 				} | ||||
| 			} | ||||
| 			if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = c.Param("model") | ||||
| 				} | ||||
| 			} | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = "dall-e-2" | ||||
| 				} | ||||
| 			} | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = "whisper-1" | ||||
| 				} | ||||
| 			} | ||||
| 			requestModel = modelRequest.Model | ||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false) | ||||
| 			if err != nil { | ||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) | ||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel) | ||||
| 				if channel != nil { | ||||
| 					logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | ||||
| 					message = "数据库一致性已被破坏,请联系管理员" | ||||
|   | ||||
| @@ -1,9 +1,12 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func abortWithMessage(c *gin.Context, statusCode int, message string) { | ||||
| @@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) { | ||||
| 	c.Abort() | ||||
| 	logger.Error(c.Request.Context(), message) | ||||
| } | ||||
|  | ||||
| func getRequestModel(c *gin.Context) (string, error) { | ||||
| 	var modelRequest ModelRequest | ||||
| 	err := common.UnmarshalBodyReusable(c, &modelRequest) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err) | ||||
| 	} | ||||
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | ||||
| 		if modelRequest.Model == "" { | ||||
| 			modelRequest.Model = "text-moderation-stable" | ||||
| 		} | ||||
| 	} | ||||
| 	if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | ||||
| 		if modelRequest.Model == "" { | ||||
| 			modelRequest.Model = c.Param("model") | ||||
| 		} | ||||
| 	} | ||||
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | ||||
| 		if modelRequest.Model == "" { | ||||
| 			modelRequest.Model = "dall-e-2" | ||||
| 		} | ||||
| 	} | ||||
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { | ||||
| 		if modelRequest.Model == "" { | ||||
| 			modelRequest.Model = "whisper-1" | ||||
| 		} | ||||
| 	} | ||||
| 	return modelRequest.Model, nil | ||||
| } | ||||
|  | ||||
| func isModelInList(modelName string, models string) bool { | ||||
| 	modelList := strings.Split(models, ",") | ||||
| 	for _, model := range modelList { | ||||
| 		if modelName == model { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|   | ||||
| @@ -1,8 +1,10 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"gorm.io/gorm" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| @@ -88,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error { | ||||
| func UpdateAbilityStatus(channelId int, status bool) error { | ||||
| 	return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error | ||||
| } | ||||
|  | ||||
| func GetGroupModels(ctx context.Context, group string) ([]string, error) { | ||||
| 	groupCol := "`group`" | ||||
| 	trueVal := "1" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		groupCol = `"group"` | ||||
| 		trueVal = "true" | ||||
| 	} | ||||
| 	var models []string | ||||
| 	err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	sort.Strings(models) | ||||
| 	return models, err | ||||
| } | ||||
|   | ||||
| @@ -21,6 +21,7 @@ var ( | ||||
| 	UserId2GroupCacheSeconds  = config.SyncFrequency | ||||
| 	UserId2QuotaCacheSeconds  = config.SyncFrequency | ||||
| 	UserId2StatusCacheSeconds = config.SyncFrequency | ||||
| 	GroupModelsCacheSeconds   = config.SyncFrequency | ||||
| ) | ||||
|  | ||||
| func CacheGetTokenByKey(key string) (*Token, error) { | ||||
| @@ -146,6 +147,25 @@ func CacheIsUserEnabled(userId int) (bool, error) { | ||||
| 	return userEnabled, err | ||||
| } | ||||
|  | ||||
| func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { | ||||
| 	if !common.RedisEnabled { | ||||
| 		return GetGroupModels(ctx, group) | ||||
| 	} | ||||
| 	modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group)) | ||||
| 	if err == nil { | ||||
| 		return strings.Split(modelsStr, ","), nil | ||||
| 	} | ||||
| 	models, err := GetGroupModels(ctx, group) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second) | ||||
| 	if err != nil { | ||||
| 		logger.SysError("Redis set group models error: " + err.Error()) | ||||
| 	} | ||||
| 	return models, nil | ||||
| } | ||||
|  | ||||
| var group2model2channels map[string]map[string][]*Channel | ||||
| var channelSyncLock sync.RWMutex | ||||
|  | ||||
|   | ||||
| @@ -12,24 +12,25 @@ import ( | ||||
| ) | ||||
|  | ||||
| type Token struct { | ||||
| 	Id             int    `json:"id"` | ||||
| 	UserId         int    `json:"user_id"` | ||||
| 	Key            string `json:"key" gorm:"type:char(48);uniqueIndex"` | ||||
| 	Status         int    `json:"status" gorm:"default:1"` | ||||
| 	Name           string `json:"name" gorm:"index" ` | ||||
| 	CreatedTime    int64  `json:"created_time" gorm:"bigint"` | ||||
| 	AccessedTime   int64  `json:"accessed_time" gorm:"bigint"` | ||||
| 	ExpiredTime    int64  `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired | ||||
| 	RemainQuota    int64  `json:"remain_quota" gorm:"bigint;default:0"` | ||||
| 	UnlimitedQuota bool   `json:"unlimited_quota" gorm:"default:false"` | ||||
| 	UsedQuota      int64  `json:"used_quota" gorm:"bigint;default:0"` // used quota | ||||
| 	Id             int     `json:"id"` | ||||
| 	UserId         int     `json:"user_id"` | ||||
| 	Key            string  `json:"key" gorm:"type:char(48);uniqueIndex"` | ||||
| 	Status         int     `json:"status" gorm:"default:1"` | ||||
| 	Name           string  `json:"name" gorm:"index" ` | ||||
| 	CreatedTime    int64   `json:"created_time" gorm:"bigint"` | ||||
| 	AccessedTime   int64   `json:"accessed_time" gorm:"bigint"` | ||||
| 	ExpiredTime    int64   `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired | ||||
| 	RemainQuota    int64   `json:"remain_quota" gorm:"bigint;default:0"` | ||||
| 	UnlimitedQuota bool    `json:"unlimited_quota" gorm:"default:false"` | ||||
| 	UsedQuota      int64   `json:"used_quota" gorm:"bigint;default:0"` // used quota | ||||
| 	Models         *string `json:"models" gorm:"default:''"` | ||||
| } | ||||
|  | ||||
| func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) { | ||||
| 	var tokens []*Token | ||||
| 	var err error | ||||
| 	query := DB.Where("user_id = ?", userId) | ||||
| 	 | ||||
|  | ||||
| 	switch order { | ||||
| 	case "remain_quota": | ||||
| 		query = query.Order("unlimited_quota desc, remain_quota desc") | ||||
| @@ -38,7 +39,7 @@ func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token | ||||
| 	default: | ||||
| 		query = query.Order("id desc") | ||||
| 	} | ||||
| 	 | ||||
|  | ||||
| 	err = query.Limit(num).Offset(startIdx).Find(&tokens).Error | ||||
| 	return tokens, err | ||||
| } | ||||
| @@ -121,7 +122,7 @@ func (token *Token) Insert() error { | ||||
| // Update Make sure your token's fields is completed, because this will update non-zero values | ||||
| func (token *Token) Update() error { | ||||
| 	var err error | ||||
| 	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error | ||||
| 	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models").Updates(token).Error | ||||
| 	return err | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -50,8 +50,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 			TopP:              request.TopP, | ||||
| 			TopK:              request.TopK, | ||||
| 			ResultFormat:      "message", | ||||
| 			Tools:             request.Tools, | ||||
| 		}, | ||||
| 		Tools: request.Tools, | ||||
| 	} | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -16,21 +16,21 @@ type Input struct { | ||||
| } | ||||
|  | ||||
| type Parameters struct { | ||||
| 	TopP              float64 `json:"top_p,omitempty"` | ||||
| 	TopK              int     `json:"top_k,omitempty"` | ||||
| 	Seed              uint64  `json:"seed,omitempty"` | ||||
| 	EnableSearch      bool    `json:"enable_search,omitempty"` | ||||
| 	IncrementalOutput bool    `json:"incremental_output,omitempty"` | ||||
| 	MaxTokens         int     `json:"max_tokens,omitempty"` | ||||
| 	Temperature       float64 `json:"temperature,omitempty"` | ||||
| 	ResultFormat      string  `json:"result_format,omitempty"` | ||||
| 	TopP              float64      `json:"top_p,omitempty"` | ||||
| 	TopK              int          `json:"top_k,omitempty"` | ||||
| 	Seed              uint64       `json:"seed,omitempty"` | ||||
| 	EnableSearch      bool         `json:"enable_search,omitempty"` | ||||
| 	IncrementalOutput bool         `json:"incremental_output,omitempty"` | ||||
| 	MaxTokens         int          `json:"max_tokens,omitempty"` | ||||
| 	Temperature       float64      `json:"temperature,omitempty"` | ||||
| 	ResultFormat      string       `json:"result_format,omitempty"` | ||||
| 	Tools             []model.Tool `json:"tools,omitempty"` | ||||
| } | ||||
|  | ||||
| type ChatRequest struct { | ||||
| 	Model      string       `json:"model"` | ||||
| 	Input      Input        `json:"input"` | ||||
| 	Parameters Parameters   `json:"parameters,omitempty"` | ||||
| 	Tools      []model.Tool `json:"tools,omitempty"` | ||||
| 	Model      string     `json:"model"` | ||||
| 	Input      Input      `json:"input"` | ||||
| 	Parameters Parameters `json:"parameters,omitempty"` | ||||
| } | ||||
|  | ||||
| type EmbeddingRequest struct { | ||||
|   | ||||
| @@ -38,6 +38,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 		MaxTokens:   textRequest.MaxTokens, | ||||
| 		Temperature: textRequest.Temperature, | ||||
| 		TopP:        textRequest.TopP, | ||||
| 		TopK:        textRequest.TopK, | ||||
| 		Stream:      textRequest.Stream, | ||||
| 	} | ||||
| 	if claudeRequest.MaxTokens == 0 { | ||||
|   | ||||
| @@ -70,8 +70,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		var responseText string | ||||
| 		err, responseText, _ = StreamHandler(c, resp, meta.Mode) | ||||
| 		usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) | ||||
| 		err, responseText, usage = StreamHandler(c, resp, meta.Mode) | ||||
| 		if usage == nil { | ||||
| 			usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) | ||||
| 		} | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | ||||
| 	} | ||||
|   | ||||
| @@ -26,7 +26,11 @@ import ( | ||||
|  | ||||
| func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { | ||||
| 	messages := make([]Message, 0, len(request.Messages)) | ||||
| 	var lastToolCalls []model.Tool | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.ToolCalls != nil { | ||||
| 			lastToolCalls = message.ToolCalls | ||||
| 		} | ||||
| 		messages = append(messages, Message{ | ||||
| 			Role:    message.Role, | ||||
| 			Content: message.StringContent(), | ||||
| @@ -39,9 +43,33 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string | ||||
| 	xunfeiRequest.Parameter.Chat.TopK = request.N | ||||
| 	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens | ||||
| 	xunfeiRequest.Payload.Message.Text = messages | ||||
| 	if len(lastToolCalls) != 0 { | ||||
| 		for _, toolCall := range lastToolCalls { | ||||
| 			xunfeiRequest.Payload.Functions.Text = append(xunfeiRequest.Payload.Functions.Text, toolCall.Function) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return &xunfeiRequest | ||||
| } | ||||
|  | ||||
| func getToolCalls(response *ChatResponse) []model.Tool { | ||||
| 	var toolCalls []model.Tool | ||||
| 	if len(response.Payload.Choices.Text) == 0 { | ||||
| 		return toolCalls | ||||
| 	} | ||||
| 	item := response.Payload.Choices.Text[0] | ||||
| 	if item.FunctionCall == nil { | ||||
| 		return toolCalls | ||||
| 	} | ||||
| 	toolCall := model.Tool{ | ||||
| 		Id:       fmt.Sprintf("call_%s", helper.GetUUID()), | ||||
| 		Type:     "function", | ||||
| 		Function: *item.FunctionCall, | ||||
| 	} | ||||
| 	toolCalls = append(toolCalls, toolCall) | ||||
| 	return toolCalls | ||||
| } | ||||
|  | ||||
| func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 	if len(response.Payload.Choices.Text) == 0 { | ||||
| 		response.Payload.Choices.Text = []ChatResponseTextItem{ | ||||
| @@ -53,8 +81,9 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: model.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Payload.Choices.Text[0].Content, | ||||
| 			Role:      "assistant", | ||||
| 			Content:   response.Payload.Choices.Text[0].Content, | ||||
| 			ToolCalls: getToolCalls(response), | ||||
| 		}, | ||||
| 		FinishReason: constant.StopFinishReason, | ||||
| 	} | ||||
| @@ -78,6 +107,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl | ||||
| 	} | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content | ||||
| 	choice.Delta.ToolCalls = getToolCalls(xunfeiResponse) | ||||
| 	if xunfeiResponse.Payload.Choices.Status == 2 { | ||||
| 		choice.FinishReason = &constant.StopFinishReason | ||||
| 	} | ||||
|   | ||||
| @@ -26,13 +26,18 @@ type ChatRequest struct { | ||||
| 		Message struct { | ||||
| 			Text []Message `json:"text"` | ||||
| 		} `json:"message"` | ||||
| 		Functions struct { | ||||
| 			Text []model.Function `json:"text,omitempty"` | ||||
| 		} `json:"functions"` | ||||
| 	} `json:"payload"` | ||||
| } | ||||
|  | ||||
| type ChatResponseTextItem struct { | ||||
| 	Content string `json:"content"` | ||||
| 	Role    string `json:"role"` | ||||
| 	Index   int    `json:"index"` | ||||
| 	Content      string          `json:"content"` | ||||
| 	Role         string          `json:"role"` | ||||
| 	Index        int             `json:"index"` | ||||
| 	ContentType  string          `json:"content_type"` | ||||
| 	FunctionCall *model.Function `json:"function_call"` | ||||
| } | ||||
|  | ||||
| type ChatResponse struct { | ||||
|   | ||||
| @@ -24,6 +24,8 @@ type GeneralOpenAIRequest struct { | ||||
| 	User             string          `json:"user,omitempty"` | ||||
| 	Prompt           any             `json:"prompt,omitempty"` | ||||
| 	Input            any             `json:"input,omitempty"` | ||||
| 	EncodingFormat   string          `json:"encoding_format,omitempty"` | ||||
| 	Dimensions       int             `json:"dimensions,omitempty"` | ||||
| 	Instruction      string          `json:"instruction,omitempty"` | ||||
| 	Size             string          `json:"size,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -43,6 +43,7 @@ func SetApiRouter(router *gin.Engine) { | ||||
| 				selfRoute.GET("/token", controller.GenerateAccessToken) | ||||
| 				selfRoute.GET("/aff", controller.GetAffCode) | ||||
| 				selfRoute.POST("/topup", controller.TopUp) | ||||
| 				selfRoute.GET("/available_models", controller.GetUserAvailableModels) | ||||
| 			} | ||||
|  | ||||
| 			adminRoute := userRoute.Group("/") | ||||
|   | ||||
| @@ -1,19 +1,21 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { Button, Form, Header, Message, Segment } from 'semantic-ui-react'; | ||||
| import { useParams, useNavigate } from 'react-router-dom'; | ||||
| import { API, showError, showSuccess, timestamp2string } from '../../helpers'; | ||||
| import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render'; | ||||
| import { useNavigate, useParams } from 'react-router-dom'; | ||||
| import { API, copy, showError, showSuccess, timestamp2string } from '../../helpers'; | ||||
| import { renderQuotaWithPrompt } from '../../helpers/render'; | ||||
|  | ||||
| const EditToken = () => { | ||||
|   const params = useParams(); | ||||
|   const tokenId = params.id; | ||||
|   const isEdit = tokenId !== undefined; | ||||
|   const [loading, setLoading] = useState(isEdit); | ||||
|   const [modelOptions, setModelOptions] = useState([]); | ||||
|   const originInputs = { | ||||
|     name: '', | ||||
|     remain_quota: isEdit ? 0 : 500000, | ||||
|     expired_time: -1, | ||||
|     unlimited_quota: false | ||||
|     unlimited_quota: false, | ||||
|     models: [] | ||||
|   }; | ||||
|   const [inputs, setInputs] = useState(originInputs); | ||||
|   const { name, remain_quota, expired_time, unlimited_quota } = inputs; | ||||
| @@ -22,8 +24,8 @@ const EditToken = () => { | ||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||
|   }; | ||||
|   const handleCancel = () => { | ||||
|     navigate("/token"); | ||||
|   } | ||||
|     navigate('/token'); | ||||
|   }; | ||||
|   const setExpiredTime = (month, day, hour, minute) => { | ||||
|     let now = new Date(); | ||||
|     let timestamp = now.getTime() / 1000; | ||||
| @@ -50,6 +52,11 @@ const EditToken = () => { | ||||
|       if (data.expired_time !== -1) { | ||||
|         data.expired_time = timestamp2string(data.expired_time); | ||||
|       } | ||||
|       if (data.models === '') { | ||||
|         data.models = []; | ||||
|       } else { | ||||
|         data.models = data.models.split(','); | ||||
|       } | ||||
|       setInputs(data); | ||||
|     } else { | ||||
|       showError(message); | ||||
| @@ -60,8 +67,26 @@ const EditToken = () => { | ||||
|     if (isEdit) { | ||||
|       loadToken().then(); | ||||
|     } | ||||
|     loadAvailableModels().then(); | ||||
|   }, []); | ||||
|  | ||||
|   const loadAvailableModels = async () => { | ||||
|     let res = await API.get(`/api/user/available_models`); | ||||
|     const { success, message, data } = res.data; | ||||
|     if (success) { | ||||
|       let options = data.map((model) => { | ||||
|         return { | ||||
|           key: model, | ||||
|           text: model, | ||||
|           value: model | ||||
|         }; | ||||
|       }); | ||||
|       setModelOptions(options); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const submit = async () => { | ||||
|     if (!isEdit && inputs.name === '') return; | ||||
|     let localInputs = inputs; | ||||
| @@ -74,6 +99,7 @@ const EditToken = () => { | ||||
|       } | ||||
|       localInputs.expired_time = Math.ceil(time / 1000); | ||||
|     } | ||||
|     localInputs.models = localInputs.models.join(','); | ||||
|     let res; | ||||
|     if (isEdit) { | ||||
|       res = await API.put(`/api/token/`, { ...localInputs, id: parseInt(tokenId) }); | ||||
| @@ -109,6 +135,24 @@ const EditToken = () => { | ||||
|               required={!isEdit} | ||||
|             /> | ||||
|           </Form.Field> | ||||
|           <Form.Field> | ||||
|             <Form.Dropdown | ||||
|               label='模型范围' | ||||
|               placeholder={'请选择允许使用的模型,留空则不进行限制'} | ||||
|               name='models' | ||||
|               fluid | ||||
|               multiple | ||||
|               search | ||||
|               onLabelClick={(e, { value }) => { | ||||
|                 copy(value).then(); | ||||
|               }} | ||||
|               selection | ||||
|               onChange={handleInputChange} | ||||
|               value={inputs.models} | ||||
|               autoComplete='new-password' | ||||
|               options={modelOptions} | ||||
|             /> | ||||
|           </Form.Field> | ||||
|           <Form.Field> | ||||
|             <Form.Input | ||||
|               label='过期时间' | ||||
|   | ||||
		Reference in New Issue
	
	Block a user