dalle3 and gptt-4o api compatible with azure

This commit is contained in:
RockYang 2024-06-03 18:34:37 +08:00
parent 3cc2263dc7
commit 66ccb387e8
2 changed files with 24 additions and 19 deletions

View File

@ -330,7 +330,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
Content: prompt,
})
req.Input["messages"] = reqMgs
} else if session.Model.Platform == types.OpenAI.Value { // extract image for gpt-vision model
} else if session.Model.Platform == types.OpenAI.Value || session.Model.Platform == types.Azure.Value { // extract image for gpt-vision model
imgURLs := utils.ExtractImgURL(prompt)
logger.Debugf("detected IMG: %+v", imgURLs)
var content interface{}

View File

@ -111,11 +111,11 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
if err != nil {
return "", fmt.Errorf("error with translate prompt: %v", err)
if err == nil {
prompt = content
logger.Debugf("重写后提示词:%s", prompt)
}
prompt = content
logger.Debugf("重写后提示词:%s", prompt)
}
var user model.User
@ -126,8 +126,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
// get image generation API KEY
var apiKey model.ApiKey
tx := s.db.Where("platform", types.OpenAI.Value).
Where("type", "img").
tx := s.db.Where("type", "img").
Where("enabled", true).
Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil {
@ -139,17 +138,23 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
if len(apiKey.ProxyURL) > 5 {
s.httpClient.SetProxyURL(apiKey.ProxyURL).R()
}
logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
r, err := s.httpClient.R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(imgReq{
Model: "dall-e-3",
Prompt: prompt,
N: 1,
Size: task.Size,
Style: task.Style,
Quality: task.Quality,
}).
reqBody := imgReq{
Model: "dall-e-3",
Prompt: prompt,
N: 1,
Size: task.Size,
Style: task.Style,
Quality: task.Quality,
}
logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, BODY: %+v", apiKey.Platform, apiKey.ApiURL, apiKey.Value, reqBody)
request := s.httpClient.R().SetHeader("Content-Type", "application/json")
if apiKey.Platform == types.Azure.Value {
request = request.SetHeader("api-key", apiKey.Value)
} else {
request = request.SetHeader("Authorization", "Bearer "+apiKey.Value)
}
r, err := request.SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
SetErrorResult(&errRes).
SetSuccessResult(&res).Post(apiKey.ApiURL)
if err != nil {
@ -157,7 +162,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
}
if r.IsErrorState() {
return "", fmt.Errorf("error with send request: %v", errRes.Error)
return "", fmt.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error)
}
// update the api key last use time
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())