From 36e99cf6ec35a34fc59a8df1189d7fb04f10d976 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Thu, 28 Nov 2024 09:15:40 +0000 Subject: [PATCH] fix: update image request handling to always return one image and improve error logging --- relay/adaptor/replicate/adaptor.go | 4 +++- relay/adaptor/replicate/image.go | 9 +++++++-- relay/controller/image.go | 9 ++++++++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/relay/adaptor/replicate/adaptor.go b/relay/adaptor/replicate/adaptor.go index 42f0ae02..7ab0c59d 100644 --- a/relay/adaptor/replicate/adaptor.go +++ b/relay/adaptor/replicate/adaptor.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/meta" @@ -29,7 +30,7 @@ func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { Guidance: 3, Seed: int(time.Now().UnixNano()), SafetyTolerance: 5, - NImages: request.N, + NImages: 1, // replicate will always return 1 image Width: 1440, Height: 1440, AspectRatio: "1:1", @@ -60,6 +61,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me } func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + logger.Info(c, "send image request to replicate") return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/replicate/image.go b/relay/adaptor/replicate/image.go index cd62936b..3687249a 100644 --- a/relay/adaptor/replicate/image.go +++ b/relay/adaptor/replicate/image.go @@ -39,6 +39,8 @@ import ( // return nil, nil // } +var errNextLoop = errors.New("next_loop") + func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { if resp.StatusCode != http.StatusCreated { payload, _ := io.ReadAll(resp.Body) @@ -67,7 +69,6 @@ func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCo return errors.Wrap(err, "new request") } - logger.Debug(c, "send image request to replicate") taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) taskResp, err := http.DefaultClient.Do(taskReq) if err != nil { @@ -97,7 +98,7 @@ func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCo return errors.Errorf("task failed: %s", taskData.Status) default: time.Sleep(time.Second * 3) - return nil + return errNextLoop } output, err := taskData.GetOutput() @@ -170,6 +171,10 @@ func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCo return nil }() if err != nil { + if errors.Is(err, errNextLoop) { + continue + } + return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil } diff --git a/relay/controller/image.go b/relay/controller/image.go index 18c87ce1..d02c9552 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -175,7 +175,14 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) - quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) + var quota int64 + switch meta.ChannelType { + case channeltype.Replicate: + // replicate always return 1 image + quota = int64(ratio * imageCostRatio * 1000) + default: + quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) + } if userQuota-quota < 0 { return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)