diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 4ad6965a..9c78f30d 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -330,7 +330,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio Content: prompt, }) req.Input["messages"] = reqMgs - } else if session.Model.Platform == types.OpenAI.Value { // extract image for gpt-vision model + } else if session.Model.Platform == types.OpenAI.Value || session.Model.Platform == types.Azure.Value { // extract image for gpt-vision model imgURLs := utils.ExtractImgURL(prompt) logger.Debugf("detected IMG: %+v", imgURLs) var content interface{} diff --git a/api/service/dalle/service.go b/api/service/dalle/service.go index f3e813b2..fa209f20 100644 --- a/api/service/dalle/service.go +++ b/api/service/dalle/service.go @@ -111,11 +111,11 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { // translate prompt if utils.HasChinese(task.Prompt) { content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt)) - if err != nil { - return "", fmt.Errorf("error with translate prompt: %v", err) + if err == nil { + prompt = content + logger.Debugf("重写后提示词:%s", prompt) } - prompt = content - logger.Debugf("重写后提示词:%s", prompt) + } var user model.User @@ -126,8 +126,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { // get image generation API KEY var apiKey model.ApiKey - tx := s.db.Where("platform", types.OpenAI.Value). - Where("type", "img"). + tx := s.db.Where("type", "img"). Where("enabled", true). Order("last_used_at ASC").First(&apiKey) if tx.Error != nil { @@ -139,17 +138,23 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { if len(apiKey.ProxyURL) > 5 { s.httpClient.SetProxyURL(apiKey.ProxyURL).R() } - logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL) - r, err := s.httpClient.R().SetHeader("Content-Type", "application/json"). - SetHeader("Authorization", "Bearer "+apiKey.Value). - SetBody(imgReq{ - Model: "dall-e-3", - Prompt: prompt, - N: 1, - Size: task.Size, - Style: task.Style, - Quality: task.Quality, - }). + reqBody := imgReq{ + Model: "dall-e-3", + Prompt: prompt, + N: 1, + Size: task.Size, + Style: task.Style, + Quality: task.Quality, + } + logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, BODY: %+v", apiKey.Platform, apiKey.ApiURL, apiKey.Value, reqBody) + request := s.httpClient.R().SetHeader("Content-Type", "application/json") + if apiKey.Platform == types.Azure.Value { + request = request.SetHeader("api-key", apiKey.Value) + } else { + request = request.SetHeader("Authorization", "Bearer "+apiKey.Value) + } + r, err := request.SetHeader("Authorization", "Bearer "+apiKey.Value). + SetBody(reqBody). SetErrorResult(&errRes). SetSuccessResult(&res).Post(apiKey.ApiURL) if err != nil { @@ -157,7 +162,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { } if r.IsErrorState() { - return "", fmt.Errorf("error with send request: %v", errRes.Error) + return "", fmt.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error) } // update the api key last use time s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())