dalle3 and gptt-4o api compatible with azure

This commit is contained in:
RockYang 2024-06-03 18:34:37 +08:00
parent 3cc2263dc7
commit 66ccb387e8
2 changed files with 24 additions and 19 deletions

View File

@ -330,7 +330,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
Content: prompt, Content: prompt,
}) })
req.Input["messages"] = reqMgs req.Input["messages"] = reqMgs
} else if session.Model.Platform == types.OpenAI.Value { // extract image for gpt-vision model } else if session.Model.Platform == types.OpenAI.Value || session.Model.Platform == types.Azure.Value { // extract image for gpt-vision model
imgURLs := utils.ExtractImgURL(prompt) imgURLs := utils.ExtractImgURL(prompt)
logger.Debugf("detected IMG: %+v", imgURLs) logger.Debugf("detected IMG: %+v", imgURLs)
var content interface{} var content interface{}

View File

@ -111,13 +111,13 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
// translate prompt // translate prompt
if utils.HasChinese(task.Prompt) { if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt)) content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
if err != nil { if err == nil {
return "", fmt.Errorf("error with translate prompt: %v", err)
}
prompt = content prompt = content
logger.Debugf("重写后提示词:%s", prompt) logger.Debugf("重写后提示词:%s", prompt)
} }
}
var user model.User var user model.User
s.db.Where("id", task.UserId).First(&user) s.db.Where("id", task.UserId).First(&user)
if user.Power < task.Power { if user.Power < task.Power {
@ -126,8 +126,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
// get image generation API KEY // get image generation API KEY
var apiKey model.ApiKey var apiKey model.ApiKey
tx := s.db.Where("platform", types.OpenAI.Value). tx := s.db.Where("type", "img").
Where("type", "img").
Where("enabled", true). Where("enabled", true).
Order("last_used_at ASC").First(&apiKey) Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil { if tx.Error != nil {
@ -139,17 +138,23 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
if len(apiKey.ProxyURL) > 5 { if len(apiKey.ProxyURL) > 5 {
s.httpClient.SetProxyURL(apiKey.ProxyURL).R() s.httpClient.SetProxyURL(apiKey.ProxyURL).R()
} }
logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL) reqBody := imgReq{
r, err := s.httpClient.R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(imgReq{
Model: "dall-e-3", Model: "dall-e-3",
Prompt: prompt, Prompt: prompt,
N: 1, N: 1,
Size: task.Size, Size: task.Size,
Style: task.Style, Style: task.Style,
Quality: task.Quality, Quality: task.Quality,
}). }
logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, BODY: %+v", apiKey.Platform, apiKey.ApiURL, apiKey.Value, reqBody)
request := s.httpClient.R().SetHeader("Content-Type", "application/json")
if apiKey.Platform == types.Azure.Value {
request = request.SetHeader("api-key", apiKey.Value)
} else {
request = request.SetHeader("Authorization", "Bearer "+apiKey.Value)
}
r, err := request.SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
SetErrorResult(&errRes). SetErrorResult(&errRes).
SetSuccessResult(&res).Post(apiKey.ApiURL) SetSuccessResult(&res).Post(apiKey.ApiURL)
if err != nil { if err != nil {
@ -157,7 +162,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
} }
if r.IsErrorState() { if r.IsErrorState() {
return "", fmt.Errorf("error with send request: %v", errRes.Error) return "", fmt.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error)
} }
// update the api key last use time // update the api key last use time
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())