add: add images edits and variations API

This commit is contained in:
Martial BE
2023-12-01 18:25:05 +08:00
parent 9dd92bbddd
commit 0f038d715d
11 changed files with 302 additions and 24 deletions

View File

@@ -11,10 +11,35 @@ import (
"github.com/gin-gonic/gin"
)
type ModelRequestInterface interface {
GetModel() string
SetModel(string)
}
type ModelRequest struct {
Model string `json:"model"`
}
func (m *ModelRequest) GetModel() string {
return m.Model
}
func (m *ModelRequest) SetModel(model string) {
m.Model = model
}
type ModelFormRequest struct {
Model string `form:"model"`
}
func (m *ModelFormRequest) GetModel() string {
return m.Model
}
func (m *ModelFormRequest) SetModel(model string) {
m.Model = model
}
func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
userId := c.GetInt("id")
@@ -39,35 +64,36 @@ func Distribute() func(c *gin.Context) {
}
} else {
// Select a channel for the user
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
modelRequest := getModelRequest(c)
err := common.UnmarshalBodyReusable(c, modelRequest)
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
if modelRequest.GetModel() == "" {
modelRequest.SetModel("text-moderation-stable")
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
if modelRequest.GetModel() == "" {
modelRequest.SetModel(c.Param("model"))
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e-2"
if modelRequest.GetModel() == "" {
modelRequest.SetModel("dall-e-2")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
if modelRequest.Model == "" {
modelRequest.Model = "whisper-1"
if modelRequest.GetModel() == "" {
modelRequest.SetModel("whisper-1")
}
}
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.GetModel())
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.GetModel())
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
@@ -94,3 +120,14 @@ func Distribute() func(c *gin.Context) {
c.Next()
}
}
func getModelRequest(c *gin.Context) (modelRequest ModelRequestInterface) {
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
modelRequest = &ModelRequest{}
} else if strings.HasPrefix(contentType, "multipart/form-data") {
modelRequest = &ModelFormRequest{}
}
return
}