mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-09 15:43:41 +08:00
feat: support InsightFace (close #60)
This commit is contained in:
@@ -138,6 +138,111 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
||||
return
|
||||
}
|
||||
|
||||
func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
startTime := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
var swapFaceRequest dto.SwapFaceRequest
|
||||
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
|
||||
}
|
||||
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
|
||||
}
|
||||
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
||||
modelPrice := common.GetModelPrice(modelName, true)
|
||||
// 如果没有配置价格,则使用默认价格
|
||||
if modelPrice == -1 {
|
||||
defaultPrice, ok := common.DefaultModelPrice[modelName]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
modelPrice = defaultPrice
|
||||
}
|
||||
}
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
ratio := modelPrice * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: err.Error(),
|
||||
}
|
||||
}
|
||||
quota := int(ratio * common.QuotaPerUnit)
|
||||
|
||||
if userQuota-quota < 0 {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "quota_not_enough",
|
||||
}
|
||||
}
|
||||
requestURL := c.Request.URL.String()
|
||||
baseURL := c.GetString("base_url")
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*120, fullRequestURL)
|
||||
if err != nil {
|
||||
return &mjResp.Response
|
||||
}
|
||||
defer func(ctx context.Context) {
|
||||
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
}
|
||||
}(c.Request.Context())
|
||||
midjResponse := &mjResp.Response
|
||||
midjourneyTask := &model.Midjourney{
|
||||
UserId: userId,
|
||||
Code: midjResponse.Code,
|
||||
Action: constant.MjActionSwapFace,
|
||||
MjId: midjResponse.Result,
|
||||
Prompt: "swap_face",
|
||||
PromptEn: "",
|
||||
Description: midjResponse.Description,
|
||||
State: "",
|
||||
SubmitTime: startTime,
|
||||
StartTime: time.Now().UnixNano() / int64(time.Millisecond),
|
||||
FinishTime: 0,
|
||||
ImageUrl: "",
|
||||
Status: "",
|
||||
Progress: "0%",
|
||||
FailReason: "",
|
||||
ChannelId: c.GetInt("channel_id"),
|
||||
Quota: quota,
|
||||
}
|
||||
err = midjourneyTask.Insert()
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed")
|
||||
}
|
||||
c.Writer.WriteHeader(mjResp.StatusCode)
|
||||
respBody, err := json.Marshal(midjResponse)
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
|
||||
}
|
||||
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
||||
taskId := c.Param("id")
|
||||
userId := c.GetInt("id")
|
||||
@@ -157,10 +262,28 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
||||
|
||||
requestURL := c.Request.URL.String()
|
||||
fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
|
||||
midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, nil)
|
||||
midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
|
||||
if err != nil {
|
||||
return &midjResponseWithStatus.Response
|
||||
}
|
||||
//defer func(ctx context.Context) {
|
||||
// err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||
// if err != nil {
|
||||
// common.SysError("error consuming token remain quota: " + err.Error())
|
||||
// }
|
||||
// err = model.CacheUpdateUserQuota(userId)
|
||||
// if err != nil {
|
||||
// common.SysError("error update user quota cache: " + err.Error())
|
||||
// }
|
||||
// if quota != 0 {
|
||||
// tokenName := c.GetString("token_name")
|
||||
// logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
||||
// model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
||||
// model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
// channelId := c.GetInt("channel_id")
|
||||
// model.UpdateChannelUsedQuota(channelId, quota)
|
||||
// }
|
||||
//}(c.Request.Context())
|
||||
midjResponse := &midjResponseWithStatus.Response
|
||||
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
|
||||
respBody, err := json.Marshal(midjResponse)
|
||||
@@ -372,14 +495,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
}
|
||||
}
|
||||
|
||||
midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, &midjRequest)
|
||||
midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
|
||||
if err != nil {
|
||||
return &midjResponseWithStatus.Response
|
||||
}
|
||||
midjResponse := &midjResponseWithStatus.Response
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
if consumeQuota {
|
||||
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
|
||||
Reference in New Issue
Block a user