mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-09 02:33:42 +08:00
add 'type' field for ChatModel, support Chat and Image model
This commit is contained in:
@@ -8,7 +8,6 @@ package dalle
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
@@ -17,9 +16,11 @@ import (
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -100,17 +101,18 @@ func (s *Service) Run() {
|
||||
type imgReq struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Size string `json:"size"`
|
||||
Quality string `json:"quality"`
|
||||
Style string `json:"style"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
}
|
||||
|
||||
type imgRes struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []struct {
|
||||
RevisedPrompt string `json:"revised_prompt"`
|
||||
Url string `json:"url"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
Url string `json:"url,omitempty"`
|
||||
B64Json string `json:"b64_json,omitempty"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
@@ -135,29 +137,20 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var user model.User
|
||||
s.db.Where("id", task.UserId).First(&user)
|
||||
if user.Power < task.Power {
|
||||
return "", errors.New("insufficient of power")
|
||||
}
|
||||
|
||||
// 扣减算力
|
||||
err := s.userService.DecreasePower(int(user.Id), task.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "dall-e-3",
|
||||
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with decrease power: %v", err)
|
||||
}
|
||||
var chatModel model.ChatModel
|
||||
s.db.Where("id = ?", task.ModelId).First(&chatModel)
|
||||
|
||||
// get image generation API KEY
|
||||
var apiKey model.ApiKey
|
||||
err = s.db.Where("type", "dalle").
|
||||
Where("enabled", true).
|
||||
Order("last_used_at ASC").First(&apiKey).Error
|
||||
session := s.db.Where("enabled", true)
|
||||
if chatModel.KeyId > 0 {
|
||||
session = session.Where("id = ?", chatModel.KeyId)
|
||||
} else {
|
||||
session = session.Where("type = ?", "dalle")
|
||||
}
|
||||
err := session.Order("last_used_at ASC").First(&apiKey).Error
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("no available DALL-E api key: %v", err)
|
||||
return "", fmt.Errorf("no available Image Generation api key: %v", err)
|
||||
}
|
||||
|
||||
var res imgRes
|
||||
@@ -167,7 +160,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL)
|
||||
reqBody := imgReq{
|
||||
Model: "dall-e-3",
|
||||
Model: chatModel.Value,
|
||||
Prompt: prompt,
|
||||
N: 1,
|
||||
Size: task.Size,
|
||||
@@ -182,20 +175,39 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
SetSuccessResult(&res).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
logger.Errorf("error with send request: %v", err)
|
||||
return "", fmt.Errorf("error with send request: %v", err)
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
logger.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error)
|
||||
return "", fmt.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error)
|
||||
}
|
||||
|
||||
all, _ := io.ReadAll(r.Body)
|
||||
logger.Debugf("response: %+v", string(all))
|
||||
|
||||
// update the api key last use time
|
||||
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
// update task progress
|
||||
err = s.db.Model(&model.DallJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
var imgURL string
|
||||
var data = map[string]interface{}{
|
||||
"progress": 100,
|
||||
"org_url": res.Data[0].Url,
|
||||
"prompt": prompt,
|
||||
}).Error
|
||||
}
|
||||
// 如果返回的是base64,则需要上传到oss
|
||||
if res.Data[0].B64Json != "" {
|
||||
imgURL, err = s.uploadManager.GetUploadHandler().PutBase64(res.Data[0].B64Json)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with upload image: %v", err)
|
||||
}
|
||||
logger.Infof("upload image to oss: %s", imgURL)
|
||||
data["img_url"] = imgURL
|
||||
} else {
|
||||
imgURL = res.Data[0].Url
|
||||
}
|
||||
data["org_url"] = imgURL
|
||||
// update task progress
|
||||
err = s.db.Model(&model.DallJob{Id: task.Id}).UpdateColumns(data).Error
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("err with update database: %v", err)
|
||||
}
|
||||
@@ -252,9 +264,14 @@ func (s *Service) CheckTaskStatus() {
|
||||
// 找出失败的任务,并恢复其扣减算力
|
||||
s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs)
|
||||
for _, job := range jobs {
|
||||
err := s.userService.IncreasePower(int(job.UserId), job.Power, model.PowerLog{
|
||||
var task types.DallTask
|
||||
err := utils.JsonDecode(job.TaskInfo, &task)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
err = s.userService.IncreasePower(int(job.UserId), job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: "dall-e-3",
|
||||
Model: task.ModelName,
|
||||
Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg),
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user