mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-29 21:03:41 +08:00
Compare commits
11 Commits
v0.6.4-alp
...
v0.6.5-alp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed70881a58 | ||
|
|
8b9fa3d6e4 | ||
|
|
8b9813d63b | ||
|
|
dc7aaf2de5 | ||
|
|
065da8ef8c | ||
|
|
e3cfb1fa52 | ||
|
|
f89ae5ad58 | ||
|
|
06a3fc5421 | ||
|
|
a9c464ec5a | ||
|
|
3f3c13c98c | ||
|
|
2ba28c72cb |
6
common/conv/any.go
Normal file
6
common/conv/any.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package conv
|
||||||
|
|
||||||
|
func AsString(v any) string {
|
||||||
|
str, _ := v.(string)
|
||||||
|
return str
|
||||||
|
}
|
||||||
@@ -75,7 +75,7 @@ var ModelRatio = map[string]float64{
|
|||||||
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
|
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
|
||||||
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
|
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
|
||||||
"ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens
|
"ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens
|
||||||
"ERNIE-Bot-8k": 0.024 * RMB,
|
"ERNIE-Bot-8K": 0.024 * RMB,
|
||||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||||
"bge-large-zh": 0.002 * RMB,
|
"bge-large-zh": 0.002 * RMB,
|
||||||
"bge-large-en": 0.002 * RMB,
|
"bge-large-en": 0.002 * RMB,
|
||||||
|
|||||||
@@ -4,12 +4,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||||
"github.com/songquanpeng/one-api/relay/constant"
|
"github.com/songquanpeng/one-api/relay/constant"
|
||||||
"github.com/songquanpeng/one-api/relay/helper"
|
"github.com/songquanpeng/one-api/relay/helper"
|
||||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||||
"github.com/songquanpeng/one-api/relay/util"
|
"github.com/songquanpeng/one-api/relay/util"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/models/list
|
// https://platform.openai.com/docs/api-reference/models/list
|
||||||
@@ -120,9 +122,41 @@ func DashboardListModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ListModels(c *gin.Context) {
|
func ListModels(c *gin.Context) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
var availableModels []string
|
||||||
|
if c.GetString("available_models") != "" {
|
||||||
|
availableModels = strings.Split(c.GetString("available_models"), ",")
|
||||||
|
} else {
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||||
|
availableModels, _ = model.CacheGetGroupModels(ctx, userGroup)
|
||||||
|
}
|
||||||
|
modelSet := make(map[string]bool)
|
||||||
|
for _, availableModel := range availableModels {
|
||||||
|
modelSet[availableModel] = true
|
||||||
|
}
|
||||||
|
var availableOpenAIModels []OpenAIModels
|
||||||
|
for _, model := range openAIModels {
|
||||||
|
if _, ok := modelSet[model.Id]; ok {
|
||||||
|
modelSet[model.Id] = false
|
||||||
|
availableOpenAIModels = append(availableOpenAIModels, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for modelName, ok := range modelSet {
|
||||||
|
if ok {
|
||||||
|
availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{
|
||||||
|
Id: modelName,
|
||||||
|
Object: "model",
|
||||||
|
Created: 1626777600,
|
||||||
|
OwnedBy: "custom",
|
||||||
|
Root: modelName,
|
||||||
|
Parent: nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": openAIModels,
|
"data": availableOpenAIModels,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,3 +176,30 @@ func RetrieveModel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetUserAvailableModels(c *gin.Context) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
id := c.GetInt("id")
|
||||||
|
userGroup, err := model.CacheGetUserGroup(id)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
models, err := model.CacheGetGroupModels(ctx, userGroup)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": models,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|||||||
@@ -130,6 +130,7 @@ func AddToken(c *gin.Context) {
|
|||||||
ExpiredTime: token.ExpiredTime,
|
ExpiredTime: token.ExpiredTime,
|
||||||
RemainQuota: token.RemainQuota,
|
RemainQuota: token.RemainQuota,
|
||||||
UnlimitedQuota: token.UnlimitedQuota,
|
UnlimitedQuota: token.UnlimitedQuota,
|
||||||
|
Models: token.Models,
|
||||||
}
|
}
|
||||||
err = cleanToken.Insert()
|
err = cleanToken.Insert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -216,6 +217,7 @@ func UpdateToken(c *gin.Context) {
|
|||||||
cleanToken.ExpiredTime = token.ExpiredTime
|
cleanToken.ExpiredTime = token.ExpiredTime
|
||||||
cleanToken.RemainQuota = token.RemainQuota
|
cleanToken.RemainQuota = token.RemainQuota
|
||||||
cleanToken.UnlimitedQuota = token.UnlimitedQuota
|
cleanToken.UnlimitedQuota = token.UnlimitedQuota
|
||||||
|
cleanToken.Models = token.Models
|
||||||
}
|
}
|
||||||
err = cleanToken.Update()
|
err = cleanToken.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"github.com/gin-contrib/sessions"
|
"github.com/gin-contrib/sessions"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
@@ -107,6 +108,19 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
requestModel, err := getRequestModel(c)
|
||||||
|
if err != nil {
|
||||||
|
abortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set("request_model", requestModel)
|
||||||
|
if token.Models != nil && *token.Models != "" {
|
||||||
|
c.Set("available_models", *token.Models)
|
||||||
|
if requestModel != "" && !isModelInList(requestModel, *token.Models) {
|
||||||
|
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
c.Set("id", token.UserId)
|
c.Set("id", token.UserId)
|
||||||
c.Set("token_id", token.Id)
|
c.Set("token_id", token.Id)
|
||||||
c.Set("token_name", token.Name)
|
c.Set("token_name", token.Name)
|
||||||
|
|||||||
@@ -2,14 +2,12 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ModelRequest struct {
|
type ModelRequest struct {
|
||||||
@@ -40,37 +38,11 @@ func Distribute() func(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Select a channel for the user
|
requestModel := c.GetString("request_model")
|
||||||
var modelRequest ModelRequest
|
var err error
|
||||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
|
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel)
|
||||||
return
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
|
||||||
if modelRequest.Model == "" {
|
|
||||||
modelRequest.Model = "text-moderation-stable"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
|
||||||
if modelRequest.Model == "" {
|
|
||||||
modelRequest.Model = c.Param("model")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
|
||||||
if modelRequest.Model == "" {
|
|
||||||
modelRequest.Model = "dall-e-2"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
|
||||||
if modelRequest.Model == "" {
|
|
||||||
modelRequest.Model = "whisper-1"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
requestModel = modelRequest.Model
|
|
||||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false)
|
|
||||||
if err != nil {
|
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
|
||||||
if channel != nil {
|
if channel != nil {
|
||||||
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||||
message = "数据库一致性已被破坏,请联系管理员"
|
message = "数据库一致性已被破坏,请联系管理员"
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||||
@@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
|||||||
c.Abort()
|
c.Abort()
|
||||||
logger.Error(c.Request.Context(), message)
|
logger.Error(c.Request.Context(), message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getRequestModel(c *gin.Context) (string, error) {
|
||||||
|
var modelRequest ModelRequest
|
||||||
|
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err)
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||||
|
if modelRequest.Model == "" {
|
||||||
|
modelRequest.Model = "text-moderation-stable"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||||
|
if modelRequest.Model == "" {
|
||||||
|
modelRequest.Model = c.Param("model")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||||
|
if modelRequest.Model == "" {
|
||||||
|
modelRequest.Model = "dall-e-2"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||||
|
if modelRequest.Model == "" {
|
||||||
|
modelRequest.Model = "whisper-1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return modelRequest.Model, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isModelInList(modelName string, models string) bool {
|
||||||
|
modelList := strings.Split(models, ",")
|
||||||
|
for _, model := range modelList {
|
||||||
|
if modelName == model {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -88,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error {
|
|||||||
func UpdateAbilityStatus(channelId int, status bool) error {
|
func UpdateAbilityStatus(channelId int, status bool) error {
|
||||||
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
|
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetGroupModels(ctx context.Context, group string) ([]string, error) {
|
||||||
|
groupCol := "`group`"
|
||||||
|
trueVal := "1"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
groupCol = `"group"`
|
||||||
|
trueVal = "true"
|
||||||
|
}
|
||||||
|
var models []string
|
||||||
|
err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sort.Strings(models)
|
||||||
|
return models, err
|
||||||
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ var (
|
|||||||
UserId2GroupCacheSeconds = config.SyncFrequency
|
UserId2GroupCacheSeconds = config.SyncFrequency
|
||||||
UserId2QuotaCacheSeconds = config.SyncFrequency
|
UserId2QuotaCacheSeconds = config.SyncFrequency
|
||||||
UserId2StatusCacheSeconds = config.SyncFrequency
|
UserId2StatusCacheSeconds = config.SyncFrequency
|
||||||
|
GroupModelsCacheSeconds = config.SyncFrequency
|
||||||
)
|
)
|
||||||
|
|
||||||
func CacheGetTokenByKey(key string) (*Token, error) {
|
func CacheGetTokenByKey(key string) (*Token, error) {
|
||||||
@@ -146,6 +147,25 @@ func CacheIsUserEnabled(userId int) (bool, error) {
|
|||||||
return userEnabled, err
|
return userEnabled, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) {
|
||||||
|
if !common.RedisEnabled {
|
||||||
|
return GetGroupModels(ctx, group)
|
||||||
|
}
|
||||||
|
modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group))
|
||||||
|
if err == nil {
|
||||||
|
return strings.Split(modelsStr, ","), nil
|
||||||
|
}
|
||||||
|
models, err := GetGroupModels(ctx, group)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysError("Redis set group models error: " + err.Error())
|
||||||
|
}
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
|
|
||||||
var group2model2channels map[string]map[string][]*Channel
|
var group2model2channels map[string]map[string][]*Channel
|
||||||
var channelSyncLock sync.RWMutex
|
var channelSyncLock sync.RWMutex
|
||||||
|
|
||||||
|
|||||||
@@ -12,24 +12,25 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Token struct {
|
type Token struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
UserId int `json:"user_id"`
|
UserId int `json:"user_id"`
|
||||||
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
|
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
|
||||||
Status int `json:"status" gorm:"default:1"`
|
Status int `json:"status" gorm:"default:1"`
|
||||||
Name string `json:"name" gorm:"index" `
|
Name string `json:"name" gorm:"index" `
|
||||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||||
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
|
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
|
||||||
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
||||||
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
|
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
|
||||||
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
||||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
|
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
|
||||||
|
Models *string `json:"models" gorm:"default:''"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) {
|
func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) {
|
||||||
var tokens []*Token
|
var tokens []*Token
|
||||||
var err error
|
var err error
|
||||||
query := DB.Where("user_id = ?", userId)
|
query := DB.Where("user_id = ?", userId)
|
||||||
|
|
||||||
switch order {
|
switch order {
|
||||||
case "remain_quota":
|
case "remain_quota":
|
||||||
query = query.Order("unlimited_quota desc, remain_quota desc")
|
query = query.Order("unlimited_quota desc, remain_quota desc")
|
||||||
@@ -38,7 +39,7 @@ func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token
|
|||||||
default:
|
default:
|
||||||
query = query.Order("id desc")
|
query = query.Order("id desc")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = query.Limit(num).Offset(startIdx).Find(&tokens).Error
|
err = query.Limit(num).Offset(startIdx).Find(&tokens).Error
|
||||||
return tokens, err
|
return tokens, err
|
||||||
}
|
}
|
||||||
@@ -121,7 +122,7 @@ func (token *Token) Insert() error {
|
|||||||
// Update Make sure your token's fields is completed, because this will update non-zero values
|
// Update Make sure your token's fields is completed, because this will update non-zero values
|
||||||
func (token *Token) Update() error {
|
func (token *Token) Update() error {
|
||||||
var err error
|
var err error
|
||||||
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
|
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models").Updates(token).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -48,6 +48,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
MaxTokens: request.MaxTokens,
|
MaxTokens: request.MaxTokens,
|
||||||
Temperature: request.Temperature,
|
Temperature: request.Temperature,
|
||||||
TopP: request.TopP,
|
TopP: request.TopP,
|
||||||
|
TopK: request.TopK,
|
||||||
|
ResultFormat: "message",
|
||||||
|
Tools: request.Tools,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -117,19 +120,11 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR
|
|||||||
}
|
}
|
||||||
|
|
||||||
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
|
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||||
choice := openai.TextResponseChoice{
|
|
||||||
Index: 0,
|
|
||||||
Message: model.Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: response.Output.Text,
|
|
||||||
},
|
|
||||||
FinishReason: response.Output.FinishReason,
|
|
||||||
}
|
|
||||||
fullTextResponse := openai.TextResponse{
|
fullTextResponse := openai.TextResponse{
|
||||||
Id: response.RequestId,
|
Id: response.RequestId,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: helper.GetTimestamp(),
|
Created: helper.GetTimestamp(),
|
||||||
Choices: []openai.TextResponseChoice{choice},
|
Choices: response.Output.Choices,
|
||||||
Usage: model.Usage{
|
Usage: model.Usage{
|
||||||
PromptTokens: response.Usage.InputTokens,
|
PromptTokens: response.Usage.InputTokens,
|
||||||
CompletionTokens: response.Usage.OutputTokens,
|
CompletionTokens: response.Usage.OutputTokens,
|
||||||
@@ -140,10 +135,14 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||||
|
if len(aliResponse.Output.Choices) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
aliChoice := aliResponse.Output.Choices[0]
|
||||||
var choice openai.ChatCompletionsStreamResponseChoice
|
var choice openai.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.Content = aliResponse.Output.Text
|
choice.Delta = aliChoice.Message
|
||||||
if aliResponse.Output.FinishReason != "null" {
|
if aliChoice.FinishReason != "null" {
|
||||||
finishReason := aliResponse.Output.FinishReason
|
finishReason := aliChoice.FinishReason
|
||||||
choice.FinishReason = &finishReason
|
choice.FinishReason = &finishReason
|
||||||
}
|
}
|
||||||
response := openai.ChatCompletionsStreamResponse{
|
response := openai.ChatCompletionsStreamResponse{
|
||||||
@@ -204,6 +203,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||||
}
|
}
|
||||||
response := streamResponseAli2OpenAI(&aliResponse)
|
response := streamResponseAli2OpenAI(&aliResponse)
|
||||||
|
if response == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
||||||
//lastResponseText = aliResponse.Output.Text
|
//lastResponseText = aliResponse.Output.Text
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
@@ -226,6 +228,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
var aliResponse ChatResponse
|
var aliResponse ChatResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -235,6 +238,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
logger.Debugf(ctx, "response body: %s\n", responseBody)
|
||||||
err = json.Unmarshal(responseBody, &aliResponse)
|
err = json.Unmarshal(responseBody, &aliResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
|||||||
@@ -1,5 +1,10 @@
|
|||||||
package ali
|
package ali
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
)
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
@@ -11,13 +16,15 @@ type Input struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Parameters struct {
|
type Parameters struct {
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
Seed uint64 `json:"seed,omitempty"`
|
Seed uint64 `json:"seed,omitempty"`
|
||||||
EnableSearch bool `json:"enable_search,omitempty"`
|
EnableSearch bool `json:"enable_search,omitempty"`
|
||||||
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
ResultFormat string `json:"result_format,omitempty"`
|
||||||
|
Tools []model.Tool `json:"tools,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatRequest struct {
|
type ChatRequest struct {
|
||||||
@@ -62,8 +69,9 @@ type Usage struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Output struct {
|
type Output struct {
|
||||||
Text string `json:"text"`
|
//Text string `json:"text"`
|
||||||
FinishReason string `json:"finish_reason"`
|
//FinishReason string `json:"finish_reason"`
|
||||||
|
Choices []openai.TextResponseChoice `json:"choices"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatResponse struct {
|
type ChatResponse struct {
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
|||||||
MaxTokens: textRequest.MaxTokens,
|
MaxTokens: textRequest.MaxTokens,
|
||||||
Temperature: textRequest.Temperature,
|
Temperature: textRequest.Temperature,
|
||||||
TopP: textRequest.TopP,
|
TopP: textRequest.TopP,
|
||||||
|
TopK: textRequest.TopK,
|
||||||
Stream: textRequest.Stream,
|
Stream: textRequest.Stream,
|
||||||
}
|
}
|
||||||
if claudeRequest.MaxTokens == 0 {
|
if claudeRequest.MaxTokens == 0 {
|
||||||
|
|||||||
@@ -70,8 +70,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||||
if meta.IsStream {
|
if meta.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText, _ = StreamHandler(c, resp, meta.Mode)
|
err, responseText, usage = StreamHandler(c, resp, meta.Mode)
|
||||||
usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
if usage == nil {
|
||||||
|
usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/conv"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/relay/constant"
|
"github.com/songquanpeng/one-api/relay/constant"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
@@ -53,7 +54,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
|
|||||||
continue // just ignore the error
|
continue // just ignore the error
|
||||||
}
|
}
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseText += choice.Delta.Content
|
responseText += conv.AsString(choice.Delta.Content)
|
||||||
}
|
}
|
||||||
if streamResponse.Usage != nil {
|
if streamResponse.Usage != nil {
|
||||||
usage = streamResponse.Usage
|
usage = streamResponse.Usage
|
||||||
|
|||||||
@@ -118,12 +118,9 @@ type ImageResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletionsStreamResponseChoice struct {
|
type ChatCompletionsStreamResponseChoice struct {
|
||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
Delta struct {
|
Delta model.Message `json:"delta"`
|
||||||
Content string `json:"content"`
|
FinishReason *string `json:"finish_reason,omitempty"`
|
||||||
Role string `json:"role,omitempty"`
|
|
||||||
} `json:"delta"`
|
|
||||||
FinishReason *string `json:"finish_reason,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletionsStreamResponse struct {
|
type ChatCompletionsStreamResponse struct {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/conv"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||||
@@ -129,7 +130,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
}
|
}
|
||||||
response := streamResponseTencent2OpenAI(&TencentResponse)
|
response := streamResponseTencent2OpenAI(&TencentResponse)
|
||||||
if len(response.Choices) != 0 {
|
if len(response.Choices) != 0 {
|
||||||
responseText += response.Choices[0].Delta.Content
|
responseText += conv.AsString(response.Choices[0].Delta.Content)
|
||||||
}
|
}
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -26,7 +26,11 @@ import (
|
|||||||
|
|
||||||
func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
|
func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
|
||||||
messages := make([]Message, 0, len(request.Messages))
|
messages := make([]Message, 0, len(request.Messages))
|
||||||
|
var lastToolCalls []model.Tool
|
||||||
for _, message := range request.Messages {
|
for _, message := range request.Messages {
|
||||||
|
if message.ToolCalls != nil {
|
||||||
|
lastToolCalls = message.ToolCalls
|
||||||
|
}
|
||||||
messages = append(messages, Message{
|
messages = append(messages, Message{
|
||||||
Role: message.Role,
|
Role: message.Role,
|
||||||
Content: message.StringContent(),
|
Content: message.StringContent(),
|
||||||
@@ -39,9 +43,33 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string
|
|||||||
xunfeiRequest.Parameter.Chat.TopK = request.N
|
xunfeiRequest.Parameter.Chat.TopK = request.N
|
||||||
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
|
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
|
||||||
xunfeiRequest.Payload.Message.Text = messages
|
xunfeiRequest.Payload.Message.Text = messages
|
||||||
|
if len(lastToolCalls) != 0 {
|
||||||
|
for _, toolCall := range lastToolCalls {
|
||||||
|
xunfeiRequest.Payload.Functions.Text = append(xunfeiRequest.Payload.Functions.Text, toolCall.Function)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &xunfeiRequest
|
return &xunfeiRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getToolCalls(response *ChatResponse) []model.Tool {
|
||||||
|
var toolCalls []model.Tool
|
||||||
|
if len(response.Payload.Choices.Text) == 0 {
|
||||||
|
return toolCalls
|
||||||
|
}
|
||||||
|
item := response.Payload.Choices.Text[0]
|
||||||
|
if item.FunctionCall == nil {
|
||||||
|
return toolCalls
|
||||||
|
}
|
||||||
|
toolCall := model.Tool{
|
||||||
|
Id: fmt.Sprintf("call_%s", helper.GetUUID()),
|
||||||
|
Type: "function",
|
||||||
|
Function: *item.FunctionCall,
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, toolCall)
|
||||||
|
return toolCalls
|
||||||
|
}
|
||||||
|
|
||||||
func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
|
func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||||
if len(response.Payload.Choices.Text) == 0 {
|
if len(response.Payload.Choices.Text) == 0 {
|
||||||
response.Payload.Choices.Text = []ChatResponseTextItem{
|
response.Payload.Choices.Text = []ChatResponseTextItem{
|
||||||
@@ -53,8 +81,9 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
|
|||||||
choice := openai.TextResponseChoice{
|
choice := openai.TextResponseChoice{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Message: model.Message{
|
Message: model.Message{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: response.Payload.Choices.Text[0].Content,
|
Content: response.Payload.Choices.Text[0].Content,
|
||||||
|
ToolCalls: getToolCalls(response),
|
||||||
},
|
},
|
||||||
FinishReason: constant.StopFinishReason,
|
FinishReason: constant.StopFinishReason,
|
||||||
}
|
}
|
||||||
@@ -78,6 +107,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl
|
|||||||
}
|
}
|
||||||
var choice openai.ChatCompletionsStreamResponseChoice
|
var choice openai.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
||||||
|
choice.Delta.ToolCalls = getToolCalls(xunfeiResponse)
|
||||||
if xunfeiResponse.Payload.Choices.Status == 2 {
|
if xunfeiResponse.Payload.Choices.Status == 2 {
|
||||||
choice.FinishReason = &constant.StopFinishReason
|
choice.FinishReason = &constant.StopFinishReason
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,13 +26,18 @@ type ChatRequest struct {
|
|||||||
Message struct {
|
Message struct {
|
||||||
Text []Message `json:"text"`
|
Text []Message `json:"text"`
|
||||||
} `json:"message"`
|
} `json:"message"`
|
||||||
|
Functions struct {
|
||||||
|
Text []model.Function `json:"text,omitempty"`
|
||||||
|
} `json:"functions"`
|
||||||
} `json:"payload"`
|
} `json:"payload"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatResponseTextItem struct {
|
type ChatResponseTextItem struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
|
ContentType string `json:"content_type"`
|
||||||
|
FunctionCall *model.Function `json:"function_call"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatResponse struct {
|
type ChatResponse struct {
|
||||||
|
|||||||
@@ -5,25 +5,29 @@ type ResponseFormat struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type GeneralOpenAIRequest struct {
|
type GeneralOpenAIRequest struct {
|
||||||
Model string `json:"model,omitempty"`
|
|
||||||
Messages []Message `json:"messages,omitempty"`
|
Messages []Message `json:"messages,omitempty"`
|
||||||
Prompt any `json:"prompt,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
|
||||||
N int `json:"n,omitempty"`
|
|
||||||
Input any `json:"input,omitempty"`
|
|
||||||
Instruction string `json:"instruction,omitempty"`
|
|
||||||
Size string `json:"size,omitempty"`
|
|
||||||
Functions any `json:"functions,omitempty"`
|
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||||
Seed float64 `json:"seed,omitempty"`
|
Seed float64 `json:"seed,omitempty"`
|
||||||
Tools any `json:"tools,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
ToolChoice any `json:"tool_choice,omitempty"`
|
ToolChoice any `json:"tool_choice,omitempty"`
|
||||||
|
FunctionCall any `json:"function_call,omitempty"`
|
||||||
|
Functions any `json:"functions,omitempty"`
|
||||||
User string `json:"user,omitempty"`
|
User string `json:"user,omitempty"`
|
||||||
|
Prompt any `json:"prompt,omitempty"`
|
||||||
|
Input any `json:"input,omitempty"`
|
||||||
|
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||||
|
Dimensions int `json:"dimensions,omitempty"`
|
||||||
|
Instruction string `json:"instruction,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role,omitempty"`
|
||||||
Content any `json:"content"`
|
Content any `json:"content,omitempty"`
|
||||||
Name *string `json:"name,omitempty"`
|
Name *string `json:"name,omitempty"`
|
||||||
|
ToolCalls []Tool `json:"tool_calls,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Message) IsStringContent() bool {
|
func (m Message) IsStringContent() bool {
|
||||||
|
|||||||
14
relay/model/tool.go
Normal file
14
relay/model/tool.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Id string `json:"id,omitempty"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function Function `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Function struct {
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Parameters any `json:"parameters,omitempty"` // request
|
||||||
|
Arguments any `json:"arguments,omitempty"` // response
|
||||||
|
}
|
||||||
@@ -43,6 +43,7 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
selfRoute.GET("/token", controller.GenerateAccessToken)
|
selfRoute.GET("/token", controller.GenerateAccessToken)
|
||||||
selfRoute.GET("/aff", controller.GetAffCode)
|
selfRoute.GET("/aff", controller.GetAffCode)
|
||||||
selfRoute.POST("/topup", controller.TopUp)
|
selfRoute.POST("/topup", controller.TopUp)
|
||||||
|
selfRoute.GET("/available_models", controller.GetUserAvailableModels)
|
||||||
}
|
}
|
||||||
|
|
||||||
adminRoute := userRoute.Group("/")
|
adminRoute := userRoute.Group("/")
|
||||||
|
|||||||
@@ -1,19 +1,21 @@
|
|||||||
import React, { useEffect, useState } from 'react';
|
import React, { useEffect, useState } from 'react';
|
||||||
import { Button, Form, Header, Message, Segment } from 'semantic-ui-react';
|
import { Button, Form, Header, Message, Segment } from 'semantic-ui-react';
|
||||||
import { useParams, useNavigate } from 'react-router-dom';
|
import { useNavigate, useParams } from 'react-router-dom';
|
||||||
import { API, showError, showSuccess, timestamp2string } from '../../helpers';
|
import { API, copy, showError, showSuccess, timestamp2string } from '../../helpers';
|
||||||
import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render';
|
import { renderQuotaWithPrompt } from '../../helpers/render';
|
||||||
|
|
||||||
const EditToken = () => {
|
const EditToken = () => {
|
||||||
const params = useParams();
|
const params = useParams();
|
||||||
const tokenId = params.id;
|
const tokenId = params.id;
|
||||||
const isEdit = tokenId !== undefined;
|
const isEdit = tokenId !== undefined;
|
||||||
const [loading, setLoading] = useState(isEdit);
|
const [loading, setLoading] = useState(isEdit);
|
||||||
|
const [modelOptions, setModelOptions] = useState([]);
|
||||||
const originInputs = {
|
const originInputs = {
|
||||||
name: '',
|
name: '',
|
||||||
remain_quota: isEdit ? 0 : 500000,
|
remain_quota: isEdit ? 0 : 500000,
|
||||||
expired_time: -1,
|
expired_time: -1,
|
||||||
unlimited_quota: false
|
unlimited_quota: false,
|
||||||
|
models: []
|
||||||
};
|
};
|
||||||
const [inputs, setInputs] = useState(originInputs);
|
const [inputs, setInputs] = useState(originInputs);
|
||||||
const { name, remain_quota, expired_time, unlimited_quota } = inputs;
|
const { name, remain_quota, expired_time, unlimited_quota } = inputs;
|
||||||
@@ -22,8 +24,8 @@ const EditToken = () => {
|
|||||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||||
};
|
};
|
||||||
const handleCancel = () => {
|
const handleCancel = () => {
|
||||||
navigate("/token");
|
navigate('/token');
|
||||||
}
|
};
|
||||||
const setExpiredTime = (month, day, hour, minute) => {
|
const setExpiredTime = (month, day, hour, minute) => {
|
||||||
let now = new Date();
|
let now = new Date();
|
||||||
let timestamp = now.getTime() / 1000;
|
let timestamp = now.getTime() / 1000;
|
||||||
@@ -50,6 +52,11 @@ const EditToken = () => {
|
|||||||
if (data.expired_time !== -1) {
|
if (data.expired_time !== -1) {
|
||||||
data.expired_time = timestamp2string(data.expired_time);
|
data.expired_time = timestamp2string(data.expired_time);
|
||||||
}
|
}
|
||||||
|
if (data.models === '') {
|
||||||
|
data.models = [];
|
||||||
|
} else {
|
||||||
|
data.models = data.models.split(',');
|
||||||
|
}
|
||||||
setInputs(data);
|
setInputs(data);
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
@@ -60,8 +67,26 @@ const EditToken = () => {
|
|||||||
if (isEdit) {
|
if (isEdit) {
|
||||||
loadToken().then();
|
loadToken().then();
|
||||||
}
|
}
|
||||||
|
loadAvailableModels().then();
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
const loadAvailableModels = async () => {
|
||||||
|
let res = await API.get(`/api/user/available_models`);
|
||||||
|
const { success, message, data } = res.data;
|
||||||
|
if (success) {
|
||||||
|
let options = data.map((model) => {
|
||||||
|
return {
|
||||||
|
key: model,
|
||||||
|
text: model,
|
||||||
|
value: model
|
||||||
|
};
|
||||||
|
});
|
||||||
|
setModelOptions(options);
|
||||||
|
} else {
|
||||||
|
showError(message);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const submit = async () => {
|
const submit = async () => {
|
||||||
if (!isEdit && inputs.name === '') return;
|
if (!isEdit && inputs.name === '') return;
|
||||||
let localInputs = inputs;
|
let localInputs = inputs;
|
||||||
@@ -74,6 +99,7 @@ const EditToken = () => {
|
|||||||
}
|
}
|
||||||
localInputs.expired_time = Math.ceil(time / 1000);
|
localInputs.expired_time = Math.ceil(time / 1000);
|
||||||
}
|
}
|
||||||
|
localInputs.models = localInputs.models.join(',');
|
||||||
let res;
|
let res;
|
||||||
if (isEdit) {
|
if (isEdit) {
|
||||||
res = await API.put(`/api/token/`, { ...localInputs, id: parseInt(tokenId) });
|
res = await API.put(`/api/token/`, { ...localInputs, id: parseInt(tokenId) });
|
||||||
@@ -109,6 +135,24 @@ const EditToken = () => {
|
|||||||
required={!isEdit}
|
required={!isEdit}
|
||||||
/>
|
/>
|
||||||
</Form.Field>
|
</Form.Field>
|
||||||
|
<Form.Field>
|
||||||
|
<Form.Dropdown
|
||||||
|
label='模型范围'
|
||||||
|
placeholder={'请选择允许使用的模型,留空则不进行限制'}
|
||||||
|
name='models'
|
||||||
|
fluid
|
||||||
|
multiple
|
||||||
|
search
|
||||||
|
onLabelClick={(e, { value }) => {
|
||||||
|
copy(value).then();
|
||||||
|
}}
|
||||||
|
selection
|
||||||
|
onChange={handleInputChange}
|
||||||
|
value={inputs.models}
|
||||||
|
autoComplete='new-password'
|
||||||
|
options={modelOptions}
|
||||||
|
/>
|
||||||
|
</Form.Field>
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.Input
|
<Form.Input
|
||||||
label='过期时间'
|
label='过期时间'
|
||||||
|
|||||||
Reference in New Issue
Block a user