mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-31 13:53:41 +08:00 
			
		
		
		
	Compare commits
	
		
			21 Commits
		
	
	
		
			v0.6.4-alp
			...
			v0.6.5-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | f8cc63f00b | ||
|  | 0a37aa4cbd | ||
|  | 054b00b725 | ||
|  | 76569bb0b6 | ||
|  | 1994256bac | ||
|  | 1f80b0a39f | ||
|  | f73f2e51df | ||
|  | 6f036bd0c9 | ||
|  | fb90747c23 | ||
|  | ed70881a58 | ||
|  | 8b9fa3d6e4 | ||
|  | 8b9813d63b | ||
|  | dc7aaf2de5 | ||
|  | 065da8ef8c | ||
|  | e3cfb1fa52 | ||
|  | f89ae5ad58 | ||
|  | 06a3fc5421 | ||
|  | a9c464ec5a | ||
|  | 3f3c13c98c | ||
|  | 2ba28c72cb | ||
|  | 5e81e19bc8 | 
| @@ -109,6 +109,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 |     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||||
| 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 | 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 | ||||||
| 24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 | 24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 | ||||||
|  | 25. 支持**扩展**,详情请参考此处 [API 文档](./docs/API.md)。 | ||||||
|  |  | ||||||
| ## 部署 | ## 部署 | ||||||
| ### 基于 Docker 进行部署 | ### 基于 Docker 进行部署 | ||||||
|   | |||||||
							
								
								
									
										6
									
								
								common/conv/any.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								common/conv/any.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | |||||||
|  | package conv | ||||||
|  |  | ||||||
|  | func AsString(v any) string { | ||||||
|  | 	str, _ := v.(string) | ||||||
|  | 	return str | ||||||
|  | } | ||||||
| @@ -72,14 +72,22 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"claude-3-sonnet-20240229": 3.0 / 1000 * USD, | 	"claude-3-sonnet-20240229": 3.0 / 1000 * USD, | ||||||
| 	"claude-3-opus-20240229":   15.0 / 1000 * USD, | 	"claude-3-opus-20240229":   15.0 / 1000 * USD, | ||||||
| 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 | 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 | ||||||
| 	"ERNIE-Bot":       0.8572,     // ¥0.012 / 1k tokens | 	"ERNIE-4.0-8K":       0.120 * RMB, | ||||||
| 	"ERNIE-Bot-turbo": 0.5715,     // ¥0.008 / 1k tokens | 	"ERNIE-Bot-8K-0922":  0.024 * RMB, | ||||||
| 	"ERNIE-Bot-4":     0.12 * RMB, // ¥0.12 / 1k tokens | 	"ERNIE-3.5-8K":       0.012 * RMB, | ||||||
| 	"ERNIE-Bot-8k":    0.024 * RMB, | 	"ERNIE-Lite-8K-0922": 0.008 * RMB, | ||||||
| 	"Embedding-V1":    0.1429, // ¥0.002 / 1k tokens | 	"ERNIE-Speed-8K":     0.004 * RMB, | ||||||
| 	"bge-large-zh":    0.002 * RMB, | 	"ERNIE-3.5-4K-0205":  0.012 * RMB, | ||||||
| 	"bge-large-en":    0.002 * RMB, | 	"ERNIE-3.5-8K-0205":  0.024 * RMB, | ||||||
| 	"bge-large-8k":    0.002 * RMB, | 	"ERNIE-3.5-8K-1222":  0.012 * RMB, | ||||||
|  | 	"ERNIE-Lite-8K":      0.003 * RMB, | ||||||
|  | 	"ERNIE-Speed-128K":   0.004 * RMB, | ||||||
|  | 	"ERNIE-Tiny-8K":      0.001 * RMB, | ||||||
|  | 	"BLOOMZ-7B":          0.004 * RMB, | ||||||
|  | 	"Embedding-V1":       0.002 * RMB, | ||||||
|  | 	"bge-large-zh":       0.002 * RMB, | ||||||
|  | 	"bge-large-en":       0.002 * RMB, | ||||||
|  | 	"tao-8k":             0.002 * RMB, | ||||||
| 	// https://ai.google.dev/pricing | 	// https://ai.google.dev/pricing | ||||||
| 	"PaLM-2":                    1, | 	"PaLM-2":                    1, | ||||||
| 	"gemini-pro":                1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | 	"gemini-pro":                1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||||
| @@ -91,6 +99,7 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"glm-4":                     0.1 * RMB, | 	"glm-4":                     0.1 * RMB, | ||||||
| 	"glm-4v":                    0.1 * RMB, | 	"glm-4v":                    0.1 * RMB, | ||||||
| 	"glm-3-turbo":               0.005 * RMB, | 	"glm-3-turbo":               0.005 * RMB, | ||||||
|  | 	"embedding-2":               0.0005 * RMB, | ||||||
| 	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens | 	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens | ||||||
| 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | ||||||
| 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens | 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens | ||||||
|   | |||||||
| @@ -4,12 +4,14 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||||
| 	"github.com/songquanpeng/one-api/relay/constant" | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
| 	"github.com/songquanpeng/one-api/relay/helper" | 	"github.com/songquanpeng/one-api/relay/helper" | ||||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // https://platform.openai.com/docs/api-reference/models/list | // https://platform.openai.com/docs/api-reference/models/list | ||||||
| @@ -120,9 +122,41 @@ func DashboardListModels(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func ListModels(c *gin.Context) { | func ListModels(c *gin.Context) { | ||||||
|  | 	ctx := c.Request.Context() | ||||||
|  | 	var availableModels []string | ||||||
|  | 	if c.GetString("available_models") != "" { | ||||||
|  | 		availableModels = strings.Split(c.GetString("available_models"), ",") | ||||||
|  | 	} else { | ||||||
|  | 		userId := c.GetInt("id") | ||||||
|  | 		userGroup, _ := model.CacheGetUserGroup(userId) | ||||||
|  | 		availableModels, _ = model.CacheGetGroupModels(ctx, userGroup) | ||||||
|  | 	} | ||||||
|  | 	modelSet := make(map[string]bool) | ||||||
|  | 	for _, availableModel := range availableModels { | ||||||
|  | 		modelSet[availableModel] = true | ||||||
|  | 	} | ||||||
|  | 	availableOpenAIModels := make([]OpenAIModels, 0) | ||||||
|  | 	for _, model := range openAIModels { | ||||||
|  | 		if _, ok := modelSet[model.Id]; ok { | ||||||
|  | 			modelSet[model.Id] = false | ||||||
|  | 			availableOpenAIModels = append(availableOpenAIModels, model) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	for modelName, ok := range modelSet { | ||||||
|  | 		if ok { | ||||||
|  | 			availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{ | ||||||
|  | 				Id:      modelName, | ||||||
|  | 				Object:  "model", | ||||||
|  | 				Created: 1626777600, | ||||||
|  | 				OwnedBy: "custom", | ||||||
|  | 				Root:    modelName, | ||||||
|  | 				Parent:  nil, | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| 	c.JSON(200, gin.H{ | 	c.JSON(200, gin.H{ | ||||||
| 		"object": "list", | 		"object": "list", | ||||||
| 		"data":   openAIModels, | 		"data":   availableOpenAIModels, | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -142,3 +176,30 @@ func RetrieveModel(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GetUserAvailableModels(c *gin.Context) { | ||||||
|  | 	ctx := c.Request.Context() | ||||||
|  | 	id := c.GetInt("id") | ||||||
|  | 	userGroup, err := model.CacheGetUserGroup(id) | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	models, err := model.CacheGetGroupModels(ctx, userGroup) | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 		"success": true, | ||||||
|  | 		"message": "", | ||||||
|  | 		"data":    models, | ||||||
|  | 	}) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|   | |||||||
| @@ -130,6 +130,7 @@ func AddToken(c *gin.Context) { | |||||||
| 		ExpiredTime:    token.ExpiredTime, | 		ExpiredTime:    token.ExpiredTime, | ||||||
| 		RemainQuota:    token.RemainQuota, | 		RemainQuota:    token.RemainQuota, | ||||||
| 		UnlimitedQuota: token.UnlimitedQuota, | 		UnlimitedQuota: token.UnlimitedQuota, | ||||||
|  | 		Models:         token.Models, | ||||||
| 	} | 	} | ||||||
| 	err = cleanToken.Insert() | 	err = cleanToken.Insert() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -216,6 +217,7 @@ func UpdateToken(c *gin.Context) { | |||||||
| 		cleanToken.ExpiredTime = token.ExpiredTime | 		cleanToken.ExpiredTime = token.ExpiredTime | ||||||
| 		cleanToken.RemainQuota = token.RemainQuota | 		cleanToken.RemainQuota = token.RemainQuota | ||||||
| 		cleanToken.UnlimitedQuota = token.UnlimitedQuota | 		cleanToken.UnlimitedQuota = token.UnlimitedQuota | ||||||
|  | 		cleanToken.Models = token.Models | ||||||
| 	} | 	} | ||||||
| 	err = cleanToken.Update() | 	err = cleanToken.Update() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|   | |||||||
| @@ -180,27 +180,27 @@ func Register(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func GetAllUsers(c *gin.Context) { | func GetAllUsers(c *gin.Context) { | ||||||
|     p, _ := strconv.Atoi(c.Query("p")) | 	p, _ := strconv.Atoi(c.Query("p")) | ||||||
|     if p < 0 { | 	if p < 0 { | ||||||
|         p = 0 | 		p = 0 | ||||||
|     } | 	} | ||||||
|  |  | ||||||
|     order := c.DefaultQuery("order", "") | 	order := c.DefaultQuery("order", "") | ||||||
|     users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order) | 	users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order) | ||||||
|  |  | ||||||
|     if err != nil { | 	if err != nil { | ||||||
|         c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|             "success": false, | 			"success": false, | ||||||
|             "message": err.Error(), | 			"message": err.Error(), | ||||||
|         }) | 		}) | ||||||
|         return | 		return | ||||||
|     } | 	} | ||||||
|  |  | ||||||
|     c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
|         "success": true, | 		"success": true, | ||||||
|         "message": "", | 		"message": "", | ||||||
|         "data":    users, | 		"data":    users, | ||||||
|     }) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| func SearchUsers(c *gin.Context) { | func SearchUsers(c *gin.Context) { | ||||||
| @@ -770,3 +770,38 @@ func TopUp(c *gin.Context) { | |||||||
| 	}) | 	}) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type adminTopUpRequest struct { | ||||||
|  | 	UserId int    `json:"user_id"` | ||||||
|  | 	Quota  int    `json:"quota"` | ||||||
|  | 	Remark string `json:"remark"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func AdminTopUp(c *gin.Context) { | ||||||
|  | 	req := adminTopUpRequest{} | ||||||
|  | 	err := c.ShouldBindJSON(&req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	err = model.IncreaseUserQuota(req.UserId, int64(req.Quota)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if req.Remark == "" { | ||||||
|  | 		req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota))) | ||||||
|  | 	} | ||||||
|  | 	model.RecordTopupLog(req.UserId, req.Remark, req.Quota) | ||||||
|  | 	c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 		"success": true, | ||||||
|  | 		"message": "", | ||||||
|  | 	}) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|   | |||||||
							
								
								
									
										44
									
								
								docs/API.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								docs/API.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | |||||||
|  | # 使用 API 操控 & 扩展 One API | ||||||
|  | > 欢迎提交 PR 在此放上你的拓展项目。 | ||||||
|  |  | ||||||
|  | 例如,虽然 One API 本身没有直接支持支付,但是你可以通过系统扩展的 API 来实现支付功能。 | ||||||
|  |  | ||||||
|  | 又或者你想自定义渠道管理策略,也可以通过 API 来实现渠道的禁用与启用。 | ||||||
|  |  | ||||||
|  | ## 鉴权 | ||||||
|  | One API 支持两种鉴权方式:Cookie 和 Token,对于 Token,参照下图获取: | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | 之后,将 Token 作为请求头的 Authorization 字段的值即可,例如下面使用 Token 调用测试渠道的 API: | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ## 请求格式与响应格式 | ||||||
|  | One API 使用 JSON 格式进行请求和响应。 | ||||||
|  |  | ||||||
|  | 对于响应体,一般格式如下: | ||||||
|  | ```json | ||||||
|  | { | ||||||
|  |   "message": "请求信息", | ||||||
|  |   "success": true, | ||||||
|  |   "data": {} | ||||||
|  | } | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## API 列表 | ||||||
|  | > 当前 API 列表不全,请自行通过浏览器抓取前端请求 | ||||||
|  |  | ||||||
|  | 如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。 | ||||||
|  |  | ||||||
|  | ### 获取当前登录用户信息 | ||||||
|  | **GET** `/api/user/self` | ||||||
|  |  | ||||||
|  | ### 为给定用户充值额度 | ||||||
|  | **POST** `/api/topup` | ||||||
|  | ```json | ||||||
|  | { | ||||||
|  |   "user_id": 1, | ||||||
|  |   "quota": 100000, | ||||||
|  |   "remark": "充值 100000 额度" | ||||||
|  | } | ||||||
|  | ``` | ||||||
| @@ -1,6 +1,7 @@ | |||||||
| package middleware | package middleware | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"fmt" | ||||||
| 	"github.com/gin-contrib/sessions" | 	"github.com/gin-contrib/sessions" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
| @@ -107,6 +108,19 @@ func TokenAuth() func(c *gin.Context) { | |||||||
| 			abortWithMessage(c, http.StatusForbidden, "用户已被封禁") | 			abortWithMessage(c, http.StatusForbidden, "用户已被封禁") | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  | 		requestModel, err := getRequestModel(c) | ||||||
|  | 		if err != nil { | ||||||
|  | 			abortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		c.Set("request_model", requestModel) | ||||||
|  | 		if token.Models != nil && *token.Models != "" { | ||||||
|  | 			c.Set("available_models", *token.Models) | ||||||
|  | 			if requestModel != "" && !isModelInList(requestModel, *token.Models) { | ||||||
|  | 				abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel)) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
| 		c.Set("id", token.UserId) | 		c.Set("id", token.UserId) | ||||||
| 		c.Set("token_id", token.Id) | 		c.Set("token_id", token.Id) | ||||||
| 		c.Set("token_name", token.Name) | 		c.Set("token_name", token.Name) | ||||||
|   | |||||||
| @@ -2,14 +2,12 @@ package middleware | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"github.com/songquanpeng/one-api/model" | 	"github.com/songquanpeng/one-api/model" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" |  | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type ModelRequest struct { | type ModelRequest struct { | ||||||
| @@ -40,37 +38,11 @@ func Distribute() func(c *gin.Context) { | |||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} else { | 		} else { | ||||||
| 			// Select a channel for the user | 			requestModel := c.GetString("request_model") | ||||||
| 			var modelRequest ModelRequest | 			var err error | ||||||
| 			err := common.UnmarshalBodyReusable(c, &modelRequest) | 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的请求") | 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel) | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { |  | ||||||
| 				if modelRequest.Model == "" { |  | ||||||
| 					modelRequest.Model = "text-moderation-stable" |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 			if strings.HasSuffix(c.Request.URL.Path, "embeddings") { |  | ||||||
| 				if modelRequest.Model == "" { |  | ||||||
| 					modelRequest.Model = c.Param("model") |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { |  | ||||||
| 				if modelRequest.Model == "" { |  | ||||||
| 					modelRequest.Model = "dall-e-2" |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { |  | ||||||
| 				if modelRequest.Model == "" { |  | ||||||
| 					modelRequest.Model = "whisper-1" |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 			requestModel = modelRequest.Model |  | ||||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false) |  | ||||||
| 			if err != nil { |  | ||||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) |  | ||||||
| 				if channel != nil { | 				if channel != nil { | ||||||
| 					logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | 					logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | ||||||
| 					message = "数据库一致性已被破坏,请联系管理员" | 					message = "数据库一致性已被破坏,请联系管理员" | ||||||
|   | |||||||
| @@ -1,9 +1,12 @@ | |||||||
| package middleware | package middleware | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func abortWithMessage(c *gin.Context, statusCode int, message string) { | func abortWithMessage(c *gin.Context, statusCode int, message string) { | ||||||
| @@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) { | |||||||
| 	c.Abort() | 	c.Abort() | ||||||
| 	logger.Error(c.Request.Context(), message) | 	logger.Error(c.Request.Context(), message) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func getRequestModel(c *gin.Context) (string, error) { | ||||||
|  | 	var modelRequest ModelRequest | ||||||
|  | 	err := common.UnmarshalBodyReusable(c, &modelRequest) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err) | ||||||
|  | 	} | ||||||
|  | 	if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | ||||||
|  | 		if modelRequest.Model == "" { | ||||||
|  | 			modelRequest.Model = "text-moderation-stable" | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | ||||||
|  | 		if modelRequest.Model == "" { | ||||||
|  | 			modelRequest.Model = c.Param("model") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | ||||||
|  | 		if modelRequest.Model == "" { | ||||||
|  | 			modelRequest.Model = "dall-e-2" | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { | ||||||
|  | 		if modelRequest.Model == "" { | ||||||
|  | 			modelRequest.Model = "whisper-1" | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return modelRequest.Model, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func isModelInList(modelName string, models string) bool { | ||||||
|  | 	modelList := strings.Split(models, ",") | ||||||
|  | 	for _, model := range modelList { | ||||||
|  | 		if modelName == model { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|   | |||||||
| @@ -1,7 +1,10 @@ | |||||||
| package model | package model | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"gorm.io/gorm" | ||||||
|  | 	"sort" | ||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -13,7 +16,7 @@ type Ability struct { | |||||||
| 	Priority  *int64 `json:"priority" gorm:"bigint;default:0;index"` | 	Priority  *int64 `json:"priority" gorm:"bigint;default:0;index"` | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | func GetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { | ||||||
| 	ability := Ability{} | 	ability := Ability{} | ||||||
| 	groupCol := "`group`" | 	groupCol := "`group`" | ||||||
| 	trueVal := "1" | 	trueVal := "1" | ||||||
| @@ -23,8 +26,13 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var err error = nil | 	var err error = nil | ||||||
| 	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) | 	var channelQuery *gorm.DB | ||||||
| 	channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) | 	if ignoreFirstPriority { | ||||||
|  | 		channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) | ||||||
|  | 	} else { | ||||||
|  | 		maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) | ||||||
|  | 		channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) | ||||||
|  | 	} | ||||||
| 	if common.UsingSQLite || common.UsingPostgreSQL { | 	if common.UsingSQLite || common.UsingPostgreSQL { | ||||||
| 		err = channelQuery.Order("RANDOM()").First(&ability).Error | 		err = channelQuery.Order("RANDOM()").First(&ability).Error | ||||||
| 	} else { | 	} else { | ||||||
| @@ -82,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error { | |||||||
| func UpdateAbilityStatus(channelId int, status bool) error { | func UpdateAbilityStatus(channelId int, status bool) error { | ||||||
| 	return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error | 	return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GetGroupModels(ctx context.Context, group string) ([]string, error) { | ||||||
|  | 	groupCol := "`group`" | ||||||
|  | 	trueVal := "1" | ||||||
|  | 	if common.UsingPostgreSQL { | ||||||
|  | 		groupCol = `"group"` | ||||||
|  | 		trueVal = "true" | ||||||
|  | 	} | ||||||
|  | 	var models []string | ||||||
|  | 	err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	sort.Strings(models) | ||||||
|  | 	return models, err | ||||||
|  | } | ||||||
|   | |||||||
| @@ -21,6 +21,7 @@ var ( | |||||||
| 	UserId2GroupCacheSeconds  = config.SyncFrequency | 	UserId2GroupCacheSeconds  = config.SyncFrequency | ||||||
| 	UserId2QuotaCacheSeconds  = config.SyncFrequency | 	UserId2QuotaCacheSeconds  = config.SyncFrequency | ||||||
| 	UserId2StatusCacheSeconds = config.SyncFrequency | 	UserId2StatusCacheSeconds = config.SyncFrequency | ||||||
|  | 	GroupModelsCacheSeconds   = config.SyncFrequency | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func CacheGetTokenByKey(key string) (*Token, error) { | func CacheGetTokenByKey(key string) (*Token, error) { | ||||||
| @@ -146,6 +147,25 @@ func CacheIsUserEnabled(userId int) (bool, error) { | |||||||
| 	return userEnabled, err | 	return userEnabled, err | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { | ||||||
|  | 	if !common.RedisEnabled { | ||||||
|  | 		return GetGroupModels(ctx, group) | ||||||
|  | 	} | ||||||
|  | 	modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group)) | ||||||
|  | 	if err == nil { | ||||||
|  | 		return strings.Split(modelsStr, ","), nil | ||||||
|  | 	} | ||||||
|  | 	models, err := GetGroupModels(ctx, group) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.SysError("Redis set group models error: " + err.Error()) | ||||||
|  | 	} | ||||||
|  | 	return models, nil | ||||||
|  | } | ||||||
|  |  | ||||||
| var group2model2channels map[string]map[string][]*Channel | var group2model2channels map[string]map[string][]*Channel | ||||||
| var channelSyncLock sync.RWMutex | var channelSyncLock sync.RWMutex | ||||||
|  |  | ||||||
| @@ -205,7 +225,7 @@ func SyncChannelCache(frequency int) { | |||||||
|  |  | ||||||
| func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { | func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { | ||||||
| 	if !config.MemoryCacheEnabled { | 	if !config.MemoryCacheEnabled { | ||||||
| 		return GetRandomSatisfiedChannel(group, model) | 		return GetRandomSatisfiedChannel(group, model, ignoreFirstPriority) | ||||||
| 	} | 	} | ||||||
| 	channelSyncLock.RLock() | 	channelSyncLock.RLock() | ||||||
| 	defer channelSyncLock.RUnlock() | 	defer channelSyncLock.RUnlock() | ||||||
|   | |||||||
							
								
								
									
										15
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								model/log.go
									
									
									
									
									
								
							| @@ -51,6 +51,21 @@ func RecordLog(userId int, logType int, content string) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func RecordTopupLog(userId int, content string, quota int) { | ||||||
|  | 	log := &Log{ | ||||||
|  | 		UserId:    userId, | ||||||
|  | 		Username:  GetUsernameById(userId), | ||||||
|  | 		CreatedAt: helper.GetTimestamp(), | ||||||
|  | 		Type:      LogTypeTopup, | ||||||
|  | 		Content:   content, | ||||||
|  | 		Quota:     quota, | ||||||
|  | 	} | ||||||
|  | 	err := LOG_DB.Create(log).Error | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.SysError("failed to record log: " + err.Error()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { | func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { | ||||||
| 	logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | 	logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | ||||||
| 	if !config.LogConsumeEnabled { | 	if !config.LogConsumeEnabled { | ||||||
|   | |||||||
| @@ -12,17 +12,18 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| type Token struct { | type Token struct { | ||||||
| 	Id             int    `json:"id"` | 	Id             int     `json:"id"` | ||||||
| 	UserId         int    `json:"user_id"` | 	UserId         int     `json:"user_id"` | ||||||
| 	Key            string `json:"key" gorm:"type:char(48);uniqueIndex"` | 	Key            string  `json:"key" gorm:"type:char(48);uniqueIndex"` | ||||||
| 	Status         int    `json:"status" gorm:"default:1"` | 	Status         int     `json:"status" gorm:"default:1"` | ||||||
| 	Name           string `json:"name" gorm:"index" ` | 	Name           string  `json:"name" gorm:"index" ` | ||||||
| 	CreatedTime    int64  `json:"created_time" gorm:"bigint"` | 	CreatedTime    int64   `json:"created_time" gorm:"bigint"` | ||||||
| 	AccessedTime   int64  `json:"accessed_time" gorm:"bigint"` | 	AccessedTime   int64   `json:"accessed_time" gorm:"bigint"` | ||||||
| 	ExpiredTime    int64  `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired | 	ExpiredTime    int64   `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired | ||||||
| 	RemainQuota    int64  `json:"remain_quota" gorm:"bigint;default:0"` | 	RemainQuota    int64   `json:"remain_quota" gorm:"bigint;default:0"` | ||||||
| 	UnlimitedQuota bool   `json:"unlimited_quota" gorm:"default:false"` | 	UnlimitedQuota bool    `json:"unlimited_quota" gorm:"default:false"` | ||||||
| 	UsedQuota      int64  `json:"used_quota" gorm:"bigint;default:0"` // used quota | 	UsedQuota      int64   `json:"used_quota" gorm:"bigint;default:0"` // used quota | ||||||
|  | 	Models         *string `json:"models" gorm:"default:''"` | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) { | func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) { | ||||||
| @@ -121,7 +122,7 @@ func (token *Token) Insert() error { | |||||||
| // Update Make sure your token's fields is completed, because this will update non-zero values | // Update Make sure your token's fields is completed, because this will update non-zero values | ||||||
| func (token *Token) Update() error { | func (token *Token) Update() error { | ||||||
| 	var err error | 	var err error | ||||||
| 	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error | 	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models").Updates(token).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -48,6 +48,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | |||||||
| 			MaxTokens:         request.MaxTokens, | 			MaxTokens:         request.MaxTokens, | ||||||
| 			Temperature:       request.Temperature, | 			Temperature:       request.Temperature, | ||||||
| 			TopP:              request.TopP, | 			TopP:              request.TopP, | ||||||
|  | 			TopK:              request.TopK, | ||||||
|  | 			ResultFormat:      "message", | ||||||
|  | 			Tools:             request.Tools, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -117,19 +120,11 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR | |||||||
| } | } | ||||||
|  |  | ||||||
| func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { | func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||||
| 	choice := openai.TextResponseChoice{ |  | ||||||
| 		Index: 0, |  | ||||||
| 		Message: model.Message{ |  | ||||||
| 			Role:    "assistant", |  | ||||||
| 			Content: response.Output.Text, |  | ||||||
| 		}, |  | ||||||
| 		FinishReason: response.Output.FinishReason, |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse := openai.TextResponse{ | 	fullTextResponse := openai.TextResponse{ | ||||||
| 		Id:      response.RequestId, | 		Id:      response.RequestId, | ||||||
| 		Object:  "chat.completion", | 		Object:  "chat.completion", | ||||||
| 		Created: helper.GetTimestamp(), | 		Created: helper.GetTimestamp(), | ||||||
| 		Choices: []openai.TextResponseChoice{choice}, | 		Choices: response.Output.Choices, | ||||||
| 		Usage: model.Usage{ | 		Usage: model.Usage{ | ||||||
| 			PromptTokens:     response.Usage.InputTokens, | 			PromptTokens:     response.Usage.InputTokens, | ||||||
| 			CompletionTokens: response.Usage.OutputTokens, | 			CompletionTokens: response.Usage.OutputTokens, | ||||||
| @@ -140,10 +135,14 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { | |||||||
| } | } | ||||||
|  |  | ||||||
| func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { | func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { | ||||||
|  | 	if len(aliResponse.Output.Choices) == 0 { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	aliChoice := aliResponse.Output.Choices[0] | ||||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | 	var choice openai.ChatCompletionsStreamResponseChoice | ||||||
| 	choice.Delta.Content = aliResponse.Output.Text | 	choice.Delta = aliChoice.Message | ||||||
| 	if aliResponse.Output.FinishReason != "null" { | 	if aliChoice.FinishReason != "null" { | ||||||
| 		finishReason := aliResponse.Output.FinishReason | 		finishReason := aliChoice.FinishReason | ||||||
| 		choice.FinishReason = &finishReason | 		choice.FinishReason = &finishReason | ||||||
| 	} | 	} | ||||||
| 	response := openai.ChatCompletionsStreamResponse{ | 	response := openai.ChatCompletionsStreamResponse{ | ||||||
| @@ -204,6 +203,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | |||||||
| 				usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens | 				usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens | ||||||
| 			} | 			} | ||||||
| 			response := streamResponseAli2OpenAI(&aliResponse) | 			response := streamResponseAli2OpenAI(&aliResponse) | ||||||
|  | 			if response == nil { | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
| 			//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) | 			//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) | ||||||
| 			//lastResponseText = aliResponse.Output.Text | 			//lastResponseText = aliResponse.Output.Text | ||||||
| 			jsonResponse, err := json.Marshal(response) | 			jsonResponse, err := json.Marshal(response) | ||||||
| @@ -226,6 +228,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | |||||||
| } | } | ||||||
|  |  | ||||||
| func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||||
|  | 	ctx := c.Request.Context() | ||||||
| 	var aliResponse ChatResponse | 	var aliResponse ChatResponse | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -235,6 +238,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
|  | 	logger.Debugf(ctx, "response body: %s\n", responseBody) | ||||||
| 	err = json.Unmarshal(responseBody, &aliResponse) | 	err = json.Unmarshal(responseBody, &aliResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|   | |||||||
| @@ -1,5 +1,10 @@ | |||||||
| package ali | package ali | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
|  | ) | ||||||
|  |  | ||||||
| type Message struct { | type Message struct { | ||||||
| 	Content string `json:"content"` | 	Content string `json:"content"` | ||||||
| 	Role    string `json:"role"` | 	Role    string `json:"role"` | ||||||
| @@ -11,13 +16,15 @@ type Input struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| type Parameters struct { | type Parameters struct { | ||||||
| 	TopP              float64 `json:"top_p,omitempty"` | 	TopP              float64      `json:"top_p,omitempty"` | ||||||
| 	TopK              int     `json:"top_k,omitempty"` | 	TopK              int          `json:"top_k,omitempty"` | ||||||
| 	Seed              uint64  `json:"seed,omitempty"` | 	Seed              uint64       `json:"seed,omitempty"` | ||||||
| 	EnableSearch      bool    `json:"enable_search,omitempty"` | 	EnableSearch      bool         `json:"enable_search,omitempty"` | ||||||
| 	IncrementalOutput bool    `json:"incremental_output,omitempty"` | 	IncrementalOutput bool         `json:"incremental_output,omitempty"` | ||||||
| 	MaxTokens         int     `json:"max_tokens,omitempty"` | 	MaxTokens         int          `json:"max_tokens,omitempty"` | ||||||
| 	Temperature       float64 `json:"temperature,omitempty"` | 	Temperature       float64      `json:"temperature,omitempty"` | ||||||
|  | 	ResultFormat      string       `json:"result_format,omitempty"` | ||||||
|  | 	Tools             []model.Tool `json:"tools,omitempty"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type ChatRequest struct { | type ChatRequest struct { | ||||||
| @@ -62,8 +69,9 @@ type Usage struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| type Output struct { | type Output struct { | ||||||
| 	Text         string `json:"text"` | 	//Text         string                      `json:"text"` | ||||||
| 	FinishReason string `json:"finish_reason"` | 	//FinishReason string                      `json:"finish_reason"` | ||||||
|  | 	Choices []openai.TextResponseChoice `json:"choices"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type ChatResponse struct { | type ChatResponse struct { | ||||||
|   | |||||||
| @@ -38,6 +38,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | |||||||
| 		MaxTokens:   textRequest.MaxTokens, | 		MaxTokens:   textRequest.MaxTokens, | ||||||
| 		Temperature: textRequest.Temperature, | 		Temperature: textRequest.Temperature, | ||||||
| 		TopP:        textRequest.TopP, | 		TopP:        textRequest.TopP, | ||||||
|  | 		TopK:        textRequest.TopK, | ||||||
| 		Stream:      textRequest.Stream, | 		Stream:      textRequest.Stream, | ||||||
| 	} | 	} | ||||||
| 	if claudeRequest.MaxTokens == 0 { | 	if claudeRequest.MaxTokens == 0 { | ||||||
|   | |||||||
| @@ -38,16 +38,26 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | |||||||
| 		suffix += "completions_pro" | 		suffix += "completions_pro" | ||||||
| 	case "ERNIE-Bot-4": | 	case "ERNIE-Bot-4": | ||||||
| 		suffix += "completions_pro" | 		suffix += "completions_pro" | ||||||
| 	case "ERNIE-3.5-8K": |  | ||||||
| 		suffix += "completions" |  | ||||||
| 	case "ERNIE-Bot-8K": |  | ||||||
| 		suffix += "ernie_bot_8k" |  | ||||||
| 	case "ERNIE-Bot": | 	case "ERNIE-Bot": | ||||||
| 		suffix += "completions" | 		suffix += "completions" | ||||||
| 	case "ERNIE-Speed": |  | ||||||
| 		suffix += "ernie_speed" |  | ||||||
| 	case "ERNIE-Bot-turbo": | 	case "ERNIE-Bot-turbo": | ||||||
| 		suffix += "eb-instant" | 		suffix += "eb-instant" | ||||||
|  | 	case "ERNIE-Speed": | ||||||
|  | 		suffix += "ernie_speed" | ||||||
|  | 	case "ERNIE-Bot-8K": | ||||||
|  | 		suffix += "ernie_bot_8k" | ||||||
|  | 	case "ERNIE-4.0-8K": | ||||||
|  | 		suffix += "completions_pro" | ||||||
|  | 	case "ERNIE-3.5-8K": | ||||||
|  | 		suffix += "completions" | ||||||
|  | 	case "ERNIE-Speed-8K": | ||||||
|  | 		suffix += "ernie_speed" | ||||||
|  | 	case "ERNIE-Speed-128K": | ||||||
|  | 		suffix += "ernie-speed-128k" | ||||||
|  | 	case "ERNIE-Lite-8K": | ||||||
|  | 		suffix += "ernie-lite-8k" | ||||||
|  | 	case "ERNIE-Tiny-8K": | ||||||
|  | 		suffix += "ernie-tiny-8k" | ||||||
| 	case "BLOOMZ-7B": | 	case "BLOOMZ-7B": | ||||||
| 		suffix += "bloomz_7b1" | 		suffix += "bloomz_7b1" | ||||||
| 	case "Embedding-V1": | 	case "Embedding-V1": | ||||||
| @@ -59,7 +69,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | |||||||
| 	case "tao-8k": | 	case "tao-8k": | ||||||
| 		suffix += "tao_8k" | 		suffix += "tao_8k" | ||||||
| 	default: | 	default: | ||||||
| 		suffix += meta.ActualModelName | 		suffix += strings.ToLower(meta.ActualModelName) | ||||||
| 	} | 	} | ||||||
| 	fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix) | 	fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix) | ||||||
| 	var accessToken string | 	var accessToken string | ||||||
|   | |||||||
| @@ -1,11 +1,18 @@ | |||||||
| package baidu | package baidu | ||||||
|  |  | ||||||
| var ModelList = []string{ | var ModelList = []string{ | ||||||
| 	"ERNIE-Bot-4", | 	"ERNIE-4.0-8K", | ||||||
| 	"ERNIE-Bot-8K", | 	"ERNIE-Bot-8K-0922", | ||||||
| 	"ERNIE-Bot", | 	"ERNIE-3.5-8K", | ||||||
| 	"ERNIE-Speed", | 	"ERNIE-Lite-8K-0922", | ||||||
| 	"ERNIE-Bot-turbo", | 	"ERNIE-Speed-8K", | ||||||
|  | 	"ERNIE-3.5-4K-0205", | ||||||
|  | 	"ERNIE-3.5-8K-0205", | ||||||
|  | 	"ERNIE-3.5-8K-1222", | ||||||
|  | 	"ERNIE-Lite-8K", | ||||||
|  | 	"ERNIE-Speed-128K", | ||||||
|  | 	"ERNIE-Tiny-8K", | ||||||
|  | 	"BLOOMZ-7B", | ||||||
| 	"Embedding-V1", | 	"Embedding-V1", | ||||||
| 	"bge-large-zh", | 	"bge-large-zh", | ||||||
| 	"bge-large-en", | 	"bge-large-en", | ||||||
|   | |||||||
| @@ -70,8 +70,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io | |||||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||||
| 	if meta.IsStream { | 	if meta.IsStream { | ||||||
| 		var responseText string | 		var responseText string | ||||||
| 		err, responseText, _ = StreamHandler(c, resp, meta.Mode) | 		err, responseText, usage = StreamHandler(c, resp, meta.Mode) | ||||||
| 		usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) | 		if usage == nil { | ||||||
|  | 			usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) | ||||||
|  | 		} | ||||||
| 	} else { | 	} else { | ||||||
| 		err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | 		err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -6,6 +6,7 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/conv" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"github.com/songquanpeng/one-api/relay/constant" | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
| 	"github.com/songquanpeng/one-api/relay/model" | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
| @@ -53,7 +54,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | |||||||
| 						continue // just ignore the error | 						continue // just ignore the error | ||||||
| 					} | 					} | ||||||
| 					for _, choice := range streamResponse.Choices { | 					for _, choice := range streamResponse.Choices { | ||||||
| 						responseText += choice.Delta.Content | 						responseText += conv.AsString(choice.Delta.Content) | ||||||
| 					} | 					} | ||||||
| 					if streamResponse.Usage != nil { | 					if streamResponse.Usage != nil { | ||||||
| 						usage = streamResponse.Usage | 						usage = streamResponse.Usage | ||||||
|   | |||||||
| @@ -118,12 +118,9 @@ type ImageResponse struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| type ChatCompletionsStreamResponseChoice struct { | type ChatCompletionsStreamResponseChoice struct { | ||||||
| 	Index int `json:"index"` | 	Index        int           `json:"index"` | ||||||
| 	Delta struct { | 	Delta        model.Message `json:"delta"` | ||||||
| 		Content string `json:"content"` | 	FinishReason *string       `json:"finish_reason,omitempty"` | ||||||
| 		Role    string `json:"role,omitempty"` |  | ||||||
| 	} `json:"delta"` |  | ||||||
| 	FinishReason *string `json:"finish_reason,omitempty"` |  | ||||||
| } | } | ||||||
|  |  | ||||||
| type ChatCompletionsStreamResponse struct { | type ChatCompletionsStreamResponse struct { | ||||||
|   | |||||||
| @@ -10,6 +10,7 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/conv" | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||||
| @@ -129,7 +130,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | |||||||
| 			} | 			} | ||||||
| 			response := streamResponseTencent2OpenAI(&TencentResponse) | 			response := streamResponseTencent2OpenAI(&TencentResponse) | ||||||
| 			if len(response.Choices) != 0 { | 			if len(response.Choices) != 0 { | ||||||
| 				responseText += response.Choices[0].Delta.Content | 				responseText += conv.AsString(response.Choices[0].Delta.Content) | ||||||
| 			} | 			} | ||||||
| 			jsonResponse, err := json.Marshal(response) | 			jsonResponse, err := json.Marshal(response) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
|   | |||||||
| @@ -26,7 +26,11 @@ import ( | |||||||
|  |  | ||||||
| func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { | func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { | ||||||
| 	messages := make([]Message, 0, len(request.Messages)) | 	messages := make([]Message, 0, len(request.Messages)) | ||||||
|  | 	var lastToolCalls []model.Tool | ||||||
| 	for _, message := range request.Messages { | 	for _, message := range request.Messages { | ||||||
|  | 		if message.ToolCalls != nil { | ||||||
|  | 			lastToolCalls = message.ToolCalls | ||||||
|  | 		} | ||||||
| 		messages = append(messages, Message{ | 		messages = append(messages, Message{ | ||||||
| 			Role:    message.Role, | 			Role:    message.Role, | ||||||
| 			Content: message.StringContent(), | 			Content: message.StringContent(), | ||||||
| @@ -39,9 +43,33 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string | |||||||
| 	xunfeiRequest.Parameter.Chat.TopK = request.N | 	xunfeiRequest.Parameter.Chat.TopK = request.N | ||||||
| 	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens | 	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens | ||||||
| 	xunfeiRequest.Payload.Message.Text = messages | 	xunfeiRequest.Payload.Message.Text = messages | ||||||
|  | 	if len(lastToolCalls) != 0 { | ||||||
|  | 		for _, toolCall := range lastToolCalls { | ||||||
|  | 			xunfeiRequest.Payload.Functions.Text = append(xunfeiRequest.Payload.Functions.Text, toolCall.Function) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	return &xunfeiRequest | 	return &xunfeiRequest | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func getToolCalls(response *ChatResponse) []model.Tool { | ||||||
|  | 	var toolCalls []model.Tool | ||||||
|  | 	if len(response.Payload.Choices.Text) == 0 { | ||||||
|  | 		return toolCalls | ||||||
|  | 	} | ||||||
|  | 	item := response.Payload.Choices.Text[0] | ||||||
|  | 	if item.FunctionCall == nil { | ||||||
|  | 		return toolCalls | ||||||
|  | 	} | ||||||
|  | 	toolCall := model.Tool{ | ||||||
|  | 		Id:       fmt.Sprintf("call_%s", helper.GetUUID()), | ||||||
|  | 		Type:     "function", | ||||||
|  | 		Function: *item.FunctionCall, | ||||||
|  | 	} | ||||||
|  | 	toolCalls = append(toolCalls, toolCall) | ||||||
|  | 	return toolCalls | ||||||
|  | } | ||||||
|  |  | ||||||
| func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { | func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||||
| 	if len(response.Payload.Choices.Text) == 0 { | 	if len(response.Payload.Choices.Text) == 0 { | ||||||
| 		response.Payload.Choices.Text = []ChatResponseTextItem{ | 		response.Payload.Choices.Text = []ChatResponseTextItem{ | ||||||
| @@ -53,8 +81,9 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { | |||||||
| 	choice := openai.TextResponseChoice{ | 	choice := openai.TextResponseChoice{ | ||||||
| 		Index: 0, | 		Index: 0, | ||||||
| 		Message: model.Message{ | 		Message: model.Message{ | ||||||
| 			Role:    "assistant", | 			Role:      "assistant", | ||||||
| 			Content: response.Payload.Choices.Text[0].Content, | 			Content:   response.Payload.Choices.Text[0].Content, | ||||||
|  | 			ToolCalls: getToolCalls(response), | ||||||
| 		}, | 		}, | ||||||
| 		FinishReason: constant.StopFinishReason, | 		FinishReason: constant.StopFinishReason, | ||||||
| 	} | 	} | ||||||
| @@ -78,6 +107,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl | |||||||
| 	} | 	} | ||||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | 	var choice openai.ChatCompletionsStreamResponseChoice | ||||||
| 	choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content | 	choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content | ||||||
|  | 	choice.Delta.ToolCalls = getToolCalls(xunfeiResponse) | ||||||
| 	if xunfeiResponse.Payload.Choices.Status == 2 { | 	if xunfeiResponse.Payload.Choices.Status == 2 { | ||||||
| 		choice.FinishReason = &constant.StopFinishReason | 		choice.FinishReason = &constant.StopFinishReason | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -26,13 +26,18 @@ type ChatRequest struct { | |||||||
| 		Message struct { | 		Message struct { | ||||||
| 			Text []Message `json:"text"` | 			Text []Message `json:"text"` | ||||||
| 		} `json:"message"` | 		} `json:"message"` | ||||||
|  | 		Functions struct { | ||||||
|  | 			Text []model.Function `json:"text,omitempty"` | ||||||
|  | 		} `json:"functions,omitempty"` | ||||||
| 	} `json:"payload"` | 	} `json:"payload"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type ChatResponseTextItem struct { | type ChatResponseTextItem struct { | ||||||
| 	Content string `json:"content"` | 	Content      string          `json:"content"` | ||||||
| 	Role    string `json:"role"` | 	Role         string          `json:"role"` | ||||||
| 	Index   int    `json:"index"` | 	Index        int             `json:"index"` | ||||||
|  | 	ContentType  string          `json:"content_type"` | ||||||
|  | 	FunctionCall *model.Function `json:"function_call"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type ChatResponse struct { | type ChatResponse struct { | ||||||
|   | |||||||
| @@ -6,6 +6,7 @@ import ( | |||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel" | 	"github.com/songquanpeng/one-api/relay/channel" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
| 	"github.com/songquanpeng/one-api/relay/model" | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
| 	"io" | 	"io" | ||||||
| @@ -35,6 +36,9 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | |||||||
| 	if a.APIVersion == "v4" { | 	if a.APIVersion == "v4" { | ||||||
| 		return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil | 		return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil | ||||||
| 	} | 	} | ||||||
|  | 	if meta.Mode == constant.RelayModeEmbeddings { | ||||||
|  | 		return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil | ||||||
|  | 	} | ||||||
| 	method := "invoke" | 	method := "invoke" | ||||||
| 	if meta.IsStream { | 	if meta.IsStream { | ||||||
| 		method = "sse-invoke" | 		method = "sse-invoke" | ||||||
| @@ -53,18 +57,24 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | |||||||
| 	if request == nil { | 	if request == nil { | ||||||
| 		return nil, errors.New("request is nil") | 		return nil, errors.New("request is nil") | ||||||
| 	} | 	} | ||||||
| 	// TopP (0.0, 1.0) | 	switch relayMode { | ||||||
| 	request.TopP = math.Min(0.99, request.TopP) | 	case constant.RelayModeEmbeddings: | ||||||
| 	request.TopP = math.Max(0.01, request.TopP) | 		baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) | ||||||
|  | 		return baiduEmbeddingRequest, nil | ||||||
|  | 	default: | ||||||
|  | 		// TopP (0.0, 1.0) | ||||||
|  | 		request.TopP = math.Min(0.99, request.TopP) | ||||||
|  | 		request.TopP = math.Max(0.01, request.TopP) | ||||||
|  |  | ||||||
| 	// Temperature (0.0, 1.0) | 		// Temperature (0.0, 1.0) | ||||||
| 	request.Temperature = math.Min(0.99, request.Temperature) | 		request.Temperature = math.Min(0.99, request.Temperature) | ||||||
| 	request.Temperature = math.Max(0.01, request.Temperature) | 		request.Temperature = math.Max(0.01, request.Temperature) | ||||||
| 	a.SetVersionByModeName(request.Model) | 		a.SetVersionByModeName(request.Model) | ||||||
| 	if a.APIVersion == "v4" { | 		if a.APIVersion == "v4" { | ||||||
| 		return request, nil | 			return request, nil | ||||||
|  | 		} | ||||||
|  | 		return ConvertRequest(*request), nil | ||||||
| 	} | 	} | ||||||
| 	return ConvertRequest(*request), nil |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { | func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { | ||||||
| @@ -84,14 +94,26 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel | |||||||
| 	if a.APIVersion == "v4" { | 	if a.APIVersion == "v4" { | ||||||
| 		return a.DoResponseV4(c, resp, meta) | 		return a.DoResponseV4(c, resp, meta) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if meta.IsStream { | 	if meta.IsStream { | ||||||
| 		err, usage = StreamHandler(c, resp) | 		err, usage = StreamHandler(c, resp) | ||||||
| 	} else { | 	} else { | ||||||
| 		err, usage = Handler(c, resp) | 		if meta.Mode == constant.RelayModeEmbeddings { | ||||||
|  | 			err, usage = EmbeddingsHandler(c, resp) | ||||||
|  | 		} else { | ||||||
|  | 			err, usage = Handler(c, resp) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { | ||||||
|  | 	return &EmbeddingRequest{ | ||||||
|  | 		Model: "embedding-2", | ||||||
|  | 		Input: request.Input.(string), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func (a *Adaptor) GetModelList() []string { | func (a *Adaptor) GetModelList() []string { | ||||||
| 	return ModelList | 	return ModelList | ||||||
| } | } | ||||||
|   | |||||||
| @@ -2,5 +2,5 @@ package zhipu | |||||||
|  |  | ||||||
| var ModelList = []string{ | var ModelList = []string{ | ||||||
| 	"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", | 	"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", | ||||||
| 	"glm-4", "glm-4v", "glm-3-turbo", | 	"glm-4", "glm-4v", "glm-3-turbo", "embedding-2", | ||||||
| } | } | ||||||
|   | |||||||
| @@ -254,3 +254,50 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * | |||||||
| 	_, err = c.Writer.Write(jsonResponse) | 	_, err = c.Writer.Write(jsonResponse) | ||||||
| 	return nil, &fullTextResponse.Usage | 	return nil, &fullTextResponse.Usage | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||||
|  | 	var zhipuResponse EmbeddingRespone | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = json.Unmarshal(responseBody, &zhipuResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &fullTextResponse.Usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func embeddingResponseZhipu2OpenAI(response *EmbeddingRespone) *openai.EmbeddingResponse { | ||||||
|  | 	openAIEmbeddingResponse := openai.EmbeddingResponse{ | ||||||
|  | 		Object: "list", | ||||||
|  | 		Data:   make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)), | ||||||
|  | 		Model:  response.Model, | ||||||
|  | 		Usage: model.Usage{ | ||||||
|  | 			PromptTokens:     response.PromptTokens, | ||||||
|  | 			CompletionTokens: response.CompletionTokens, | ||||||
|  | 			TotalTokens:      response.Usage.TotalTokens, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, item := range response.Embeddings { | ||||||
|  | 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ | ||||||
|  | 			Object:    `embedding`, | ||||||
|  | 			Index:     item.Index, | ||||||
|  | 			Embedding: item.Embedding, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 	return &openAIEmbeddingResponse | ||||||
|  | } | ||||||
|   | |||||||
| @@ -44,3 +44,21 @@ type tokenData struct { | |||||||
| 	Token      string | 	Token      string | ||||||
| 	ExpiryTime time.Time | 	ExpiryTime time.Time | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type EmbeddingRequest struct { | ||||||
|  | 	Model string `json:"model"` | ||||||
|  | 	Input string `json:"input"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type EmbeddingRespone struct { | ||||||
|  | 	Model       string          `json:"model"` | ||||||
|  | 	Object      string          `json:"object"` | ||||||
|  | 	Embeddings  []EmbeddingData `json:"data"` | ||||||
|  | 	model.Usage `json:"usage"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type EmbeddingData struct { | ||||||
|  | 	Index     int       `json:"index"` | ||||||
|  | 	Object    string    `json:"object"` | ||||||
|  | 	Embedding []float64 `json:"embedding"` | ||||||
|  | } | ||||||
|   | |||||||
| @@ -5,25 +5,29 @@ type ResponseFormat struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| type GeneralOpenAIRequest struct { | type GeneralOpenAIRequest struct { | ||||||
| 	Model            string          `json:"model,omitempty"` |  | ||||||
| 	Messages         []Message       `json:"messages,omitempty"` | 	Messages         []Message       `json:"messages,omitempty"` | ||||||
| 	Prompt           any             `json:"prompt,omitempty"` | 	Model            string          `json:"model,omitempty"` | ||||||
| 	Stream           bool            `json:"stream,omitempty"` |  | ||||||
| 	MaxTokens        int             `json:"max_tokens,omitempty"` |  | ||||||
| 	Temperature      float64         `json:"temperature,omitempty"` |  | ||||||
| 	TopP             float64         `json:"top_p,omitempty"` |  | ||||||
| 	N                int             `json:"n,omitempty"` |  | ||||||
| 	Input            any             `json:"input,omitempty"` |  | ||||||
| 	Instruction      string          `json:"instruction,omitempty"` |  | ||||||
| 	Size             string          `json:"size,omitempty"` |  | ||||||
| 	Functions        any             `json:"functions,omitempty"` |  | ||||||
| 	FrequencyPenalty float64         `json:"frequency_penalty,omitempty"` | 	FrequencyPenalty float64         `json:"frequency_penalty,omitempty"` | ||||||
|  | 	MaxTokens        int             `json:"max_tokens,omitempty"` | ||||||
|  | 	N                int             `json:"n,omitempty"` | ||||||
| 	PresencePenalty  float64         `json:"presence_penalty,omitempty"` | 	PresencePenalty  float64         `json:"presence_penalty,omitempty"` | ||||||
| 	ResponseFormat   *ResponseFormat `json:"response_format,omitempty"` | 	ResponseFormat   *ResponseFormat `json:"response_format,omitempty"` | ||||||
| 	Seed             float64         `json:"seed,omitempty"` | 	Seed             float64         `json:"seed,omitempty"` | ||||||
| 	Tools            any             `json:"tools,omitempty"` | 	Stream           bool            `json:"stream,omitempty"` | ||||||
|  | 	Temperature      float64         `json:"temperature,omitempty"` | ||||||
|  | 	TopP             float64         `json:"top_p,omitempty"` | ||||||
|  | 	TopK             int             `json:"top_k,omitempty"` | ||||||
|  | 	Tools            []Tool          `json:"tools,omitempty"` | ||||||
| 	ToolChoice       any             `json:"tool_choice,omitempty"` | 	ToolChoice       any             `json:"tool_choice,omitempty"` | ||||||
|  | 	FunctionCall     any             `json:"function_call,omitempty"` | ||||||
|  | 	Functions        any             `json:"functions,omitempty"` | ||||||
| 	User             string          `json:"user,omitempty"` | 	User             string          `json:"user,omitempty"` | ||||||
|  | 	Prompt           any             `json:"prompt,omitempty"` | ||||||
|  | 	Input            any             `json:"input,omitempty"` | ||||||
|  | 	EncodingFormat   string          `json:"encoding_format,omitempty"` | ||||||
|  | 	Dimensions       int             `json:"dimensions,omitempty"` | ||||||
|  | 	Instruction      string          `json:"instruction,omitempty"` | ||||||
|  | 	Size             string          `json:"size,omitempty"` | ||||||
| } | } | ||||||
|  |  | ||||||
| func (r GeneralOpenAIRequest) ParseInput() []string { | func (r GeneralOpenAIRequest) ParseInput() []string { | ||||||
|   | |||||||
| @@ -1,9 +1,10 @@ | |||||||
| package model | package model | ||||||
|  |  | ||||||
| type Message struct { | type Message struct { | ||||||
| 	Role    string  `json:"role"` | 	Role      string  `json:"role,omitempty"` | ||||||
| 	Content any     `json:"content"` | 	Content   any     `json:"content,omitempty"` | ||||||
| 	Name    *string `json:"name,omitempty"` | 	Name      *string `json:"name,omitempty"` | ||||||
|  | 	ToolCalls []Tool  `json:"tool_calls,omitempty"` | ||||||
| } | } | ||||||
|  |  | ||||||
| func (m Message) IsStringContent() bool { | func (m Message) IsStringContent() bool { | ||||||
|   | |||||||
							
								
								
									
										14
									
								
								relay/model/tool.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								relay/model/tool.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | |||||||
|  | package model | ||||||
|  |  | ||||||
|  | type Tool struct { | ||||||
|  | 	Id       string   `json:"id,omitempty"` | ||||||
|  | 	Type     string   `json:"type"` | ||||||
|  | 	Function Function `json:"function"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Function struct { | ||||||
|  | 	Description string `json:"description,omitempty"` | ||||||
|  | 	Name        string `json:"name"` | ||||||
|  | 	Parameters  any    `json:"parameters,omitempty"` // request | ||||||
|  | 	Arguments   any    `json:"arguments,omitempty"`  // response | ||||||
|  | } | ||||||
| @@ -46,6 +46,15 @@ func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool { | |||||||
| 	} else if strings.HasPrefix(err.Message, "This organization has been disabled.") { | 	} else if strings.HasPrefix(err.Message, "This organization has been disabled.") { | ||||||
| 		return true | 		return true | ||||||
| 	} | 	} | ||||||
|  | 	if strings.Contains(err.Message, "quota") { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if strings.Contains(err.Message, "credit") { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if strings.Contains(err.Message, "balance") { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -26,6 +26,7 @@ func SetApiRouter(router *gin.Engine) { | |||||||
| 		apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) | 		apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) | ||||||
| 		apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) | 		apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) | ||||||
| 		apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) | 		apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) | ||||||
|  | 		apiRouter.POST("/topup", middleware.AdminAuth(), controller.AdminTopUp) | ||||||
|  |  | ||||||
| 		userRoute := apiRouter.Group("/user") | 		userRoute := apiRouter.Group("/user") | ||||||
| 		{ | 		{ | ||||||
| @@ -43,6 +44,7 @@ func SetApiRouter(router *gin.Engine) { | |||||||
| 				selfRoute.GET("/token", controller.GenerateAccessToken) | 				selfRoute.GET("/token", controller.GenerateAccessToken) | ||||||
| 				selfRoute.GET("/aff", controller.GetAffCode) | 				selfRoute.GET("/aff", controller.GetAffCode) | ||||||
| 				selfRoute.POST("/topup", controller.TopUp) | 				selfRoute.POST("/topup", controller.TopUp) | ||||||
|  | 				selfRoute.GET("/available_models", controller.GetUserAvailableModels) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			adminRoute := userRoute.Group("/") | 			adminRoute := userRoute.Group("/") | ||||||
|   | |||||||
| @@ -1,19 +1,21 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import { Button, Form, Header, Message, Segment } from 'semantic-ui-react'; | import { Button, Form, Header, Message, Segment } from 'semantic-ui-react'; | ||||||
| import { useParams, useNavigate } from 'react-router-dom'; | import { useNavigate, useParams } from 'react-router-dom'; | ||||||
| import { API, showError, showSuccess, timestamp2string } from '../../helpers'; | import { API, copy, showError, showSuccess, timestamp2string } from '../../helpers'; | ||||||
| import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render'; | import { renderQuotaWithPrompt } from '../../helpers/render'; | ||||||
|  |  | ||||||
| const EditToken = () => { | const EditToken = () => { | ||||||
|   const params = useParams(); |   const params = useParams(); | ||||||
|   const tokenId = params.id; |   const tokenId = params.id; | ||||||
|   const isEdit = tokenId !== undefined; |   const isEdit = tokenId !== undefined; | ||||||
|   const [loading, setLoading] = useState(isEdit); |   const [loading, setLoading] = useState(isEdit); | ||||||
|  |   const [modelOptions, setModelOptions] = useState([]); | ||||||
|   const originInputs = { |   const originInputs = { | ||||||
|     name: '', |     name: '', | ||||||
|     remain_quota: isEdit ? 0 : 500000, |     remain_quota: isEdit ? 0 : 500000, | ||||||
|     expired_time: -1, |     expired_time: -1, | ||||||
|     unlimited_quota: false |     unlimited_quota: false, | ||||||
|  |     models: [] | ||||||
|   }; |   }; | ||||||
|   const [inputs, setInputs] = useState(originInputs); |   const [inputs, setInputs] = useState(originInputs); | ||||||
|   const { name, remain_quota, expired_time, unlimited_quota } = inputs; |   const { name, remain_quota, expired_time, unlimited_quota } = inputs; | ||||||
| @@ -22,8 +24,8 @@ const EditToken = () => { | |||||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); |     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||||
|   }; |   }; | ||||||
|   const handleCancel = () => { |   const handleCancel = () => { | ||||||
|     navigate("/token"); |     navigate('/token'); | ||||||
|   } |   }; | ||||||
|   const setExpiredTime = (month, day, hour, minute) => { |   const setExpiredTime = (month, day, hour, minute) => { | ||||||
|     let now = new Date(); |     let now = new Date(); | ||||||
|     let timestamp = now.getTime() / 1000; |     let timestamp = now.getTime() / 1000; | ||||||
| @@ -50,6 +52,11 @@ const EditToken = () => { | |||||||
|       if (data.expired_time !== -1) { |       if (data.expired_time !== -1) { | ||||||
|         data.expired_time = timestamp2string(data.expired_time); |         data.expired_time = timestamp2string(data.expired_time); | ||||||
|       } |       } | ||||||
|  |       if (data.models === '') { | ||||||
|  |         data.models = []; | ||||||
|  |       } else { | ||||||
|  |         data.models = data.models.split(','); | ||||||
|  |       } | ||||||
|       setInputs(data); |       setInputs(data); | ||||||
|     } else { |     } else { | ||||||
|       showError(message); |       showError(message); | ||||||
| @@ -60,8 +67,26 @@ const EditToken = () => { | |||||||
|     if (isEdit) { |     if (isEdit) { | ||||||
|       loadToken().then(); |       loadToken().then(); | ||||||
|     } |     } | ||||||
|  |     loadAvailableModels().then(); | ||||||
|   }, []); |   }, []); | ||||||
|  |  | ||||||
|  |   const loadAvailableModels = async () => { | ||||||
|  |     let res = await API.get(`/api/user/available_models`); | ||||||
|  |     const { success, message, data } = res.data; | ||||||
|  |     if (success) { | ||||||
|  |       let options = data.map((model) => { | ||||||
|  |         return { | ||||||
|  |           key: model, | ||||||
|  |           text: model, | ||||||
|  |           value: model | ||||||
|  |         }; | ||||||
|  |       }); | ||||||
|  |       setModelOptions(options); | ||||||
|  |     } else { | ||||||
|  |       showError(message); | ||||||
|  |     } | ||||||
|  |   }; | ||||||
|  |  | ||||||
|   const submit = async () => { |   const submit = async () => { | ||||||
|     if (!isEdit && inputs.name === '') return; |     if (!isEdit && inputs.name === '') return; | ||||||
|     let localInputs = inputs; |     let localInputs = inputs; | ||||||
| @@ -74,6 +99,7 @@ const EditToken = () => { | |||||||
|       } |       } | ||||||
|       localInputs.expired_time = Math.ceil(time / 1000); |       localInputs.expired_time = Math.ceil(time / 1000); | ||||||
|     } |     } | ||||||
|  |     localInputs.models = localInputs.models.join(','); | ||||||
|     let res; |     let res; | ||||||
|     if (isEdit) { |     if (isEdit) { | ||||||
|       res = await API.put(`/api/token/`, { ...localInputs, id: parseInt(tokenId) }); |       res = await API.put(`/api/token/`, { ...localInputs, id: parseInt(tokenId) }); | ||||||
| @@ -109,6 +135,24 @@ const EditToken = () => { | |||||||
|               required={!isEdit} |               required={!isEdit} | ||||||
|             /> |             /> | ||||||
|           </Form.Field> |           </Form.Field> | ||||||
|  |           <Form.Field> | ||||||
|  |             <Form.Dropdown | ||||||
|  |               label='模型范围' | ||||||
|  |               placeholder={'请选择允许使用的模型,留空则不进行限制'} | ||||||
|  |               name='models' | ||||||
|  |               fluid | ||||||
|  |               multiple | ||||||
|  |               search | ||||||
|  |               onLabelClick={(e, { value }) => { | ||||||
|  |                 copy(value).then(); | ||||||
|  |               }} | ||||||
|  |               selection | ||||||
|  |               onChange={handleInputChange} | ||||||
|  |               value={inputs.models} | ||||||
|  |               autoComplete='new-password' | ||||||
|  |               options={modelOptions} | ||||||
|  |             /> | ||||||
|  |           </Form.Field> | ||||||
|           <Form.Field> |           <Form.Field> | ||||||
|             <Form.Input |             <Form.Input | ||||||
|               label='过期时间' |               label='过期时间' | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ const TopUp = () => { | |||||||
|   const [topUpLink, setTopUpLink] = useState(''); |   const [topUpLink, setTopUpLink] = useState(''); | ||||||
|   const [userQuota, setUserQuota] = useState(0); |   const [userQuota, setUserQuota] = useState(0); | ||||||
|   const [isSubmitting, setIsSubmitting] = useState(false); |   const [isSubmitting, setIsSubmitting] = useState(false); | ||||||
|  |   const [user, setUser] = useState({}); | ||||||
|  |  | ||||||
|   const topUp = async () => { |   const topUp = async () => { | ||||||
|     if (redemptionCode === '') { |     if (redemptionCode === '') { | ||||||
| @@ -41,7 +42,14 @@ const TopUp = () => { | |||||||
|       showError('超级管理员未设置充值链接!'); |       showError('超级管理员未设置充值链接!'); | ||||||
|       return; |       return; | ||||||
|     } |     } | ||||||
|     window.open(topUpLink, '_blank'); |     let url = new URL(topUpLink); | ||||||
|  |     let username = user.username; | ||||||
|  |     let user_id = user.id; | ||||||
|  |     // add  username and user_id to the topup link | ||||||
|  |     url.searchParams.append('username', username); | ||||||
|  |     url.searchParams.append('user_id', user_id); | ||||||
|  |     url.searchParams.append('transaction_id', crypto.randomUUID()); | ||||||
|  |     window.open(url.toString(), '_blank'); | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   const getUserQuota = async ()=>{ |   const getUserQuota = async ()=>{ | ||||||
| @@ -49,6 +57,7 @@ const TopUp = () => { | |||||||
|     const {success, message, data} = res.data; |     const {success, message, data} = res.data; | ||||||
|     if (success) { |     if (success) { | ||||||
|       setUserQuota(data.quota); |       setUserQuota(data.quota); | ||||||
|  |       setUser(data); | ||||||
|     } else { |     } else { | ||||||
|       showError(message); |       showError(message); | ||||||
|     } |     } | ||||||
| @@ -80,7 +89,7 @@ const TopUp = () => { | |||||||
|               }} |               }} | ||||||
|             /> |             /> | ||||||
|             <Button color='green' onClick={openTopUpLink}> |             <Button color='green' onClick={openTopUpLink}> | ||||||
|               获取兑换码 |               充值 | ||||||
|             </Button> |             </Button> | ||||||
|             <Button color='yellow' onClick={topUp} disabled={isSubmitting}> |             <Button color='yellow' onClick={topUp} disabled={isSubmitting}> | ||||||
|                 {isSubmitting ? '兑换中...' : '兑换'} |                 {isSubmitting ? '兑换中...' : '兑换'} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user