mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-09 02:33:42 +08:00
feat: auto translate and rewrite prompt for midjourney and stable-diffusion
This commit is contained in:
@@ -3,15 +3,10 @@ package utils
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/store/model"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
@@ -66,64 +61,3 @@ func DownloadImage(imageURL string, proxy string) ([]byte, error) {
|
||||
|
||||
return imageBytes, nil
|
||||
}
|
||||
|
||||
type apiRes struct {
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Index int `json:"index"`
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
type apiErrRes struct {
|
||||
Error struct {
|
||||
Code interface{} `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Param interface{} `json:"param"`
|
||||
Type string `json:"type"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func OpenAIRequest(db *gorm.DB, prompt string) (string, error) {
|
||||
var apiKey model.ApiKey
|
||||
res := db.Where("platform = ?", types.OpenAI).Where("type = ?", "chat").Where("enabled = ?", true).First(&apiKey)
|
||||
if res.Error != nil {
|
||||
return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error)
|
||||
}
|
||||
|
||||
messages := make([]interface{}, 1)
|
||||
messages[0] = types.Message{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
}
|
||||
|
||||
var response apiRes
|
||||
var errRes apiErrRes
|
||||
client := req.C()
|
||||
if apiKey.ProxyURL != "" {
|
||||
client.SetProxyURL(apiKey.ApiURL)
|
||||
}
|
||||
r, err := client.R().SetHeader("Content-Type", "application/json").
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(types.ApiRequest{
|
||||
Model: "gpt-3.5-turbo-0125",
|
||||
Temperature: 0.9,
|
||||
MaxTokens: 1024,
|
||||
Stream: false,
|
||||
Messages: messages,
|
||||
}).
|
||||
SetErrorResult(&errRes).
|
||||
SetSuccessResult(&response).Post(apiKey.ApiURL)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)
|
||||
}
|
||||
|
||||
// 更新 API KEY 的最后使用时间
|
||||
db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
|
||||
return response.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/store/model"
|
||||
"fmt"
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
func CalcTokens(text string, model string) (int, error) {
|
||||
@@ -18,3 +23,64 @@ func CalcTokens(text string, model string) (int, error) {
|
||||
token := tke.Encode(text, nil, nil)
|
||||
return len(token), nil
|
||||
}
|
||||
|
||||
type apiRes struct {
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Index int `json:"index"`
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
type apiErrRes struct {
|
||||
Error struct {
|
||||
Code interface{} `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Param interface{} `json:"param"`
|
||||
Type string `json:"type"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func OpenAIRequest(db *gorm.DB, prompt string) (string, error) {
|
||||
var apiKey model.ApiKey
|
||||
res := db.Where("platform = ?", types.OpenAI).Where("type = ?", "chat").Where("enabled = ?", true).First(&apiKey)
|
||||
if res.Error != nil {
|
||||
return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error)
|
||||
}
|
||||
|
||||
messages := make([]interface{}, 1)
|
||||
messages[0] = types.Message{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
}
|
||||
|
||||
var response apiRes
|
||||
var errRes apiErrRes
|
||||
client := req.C()
|
||||
if apiKey.ProxyURL != "" {
|
||||
client.SetProxyURL(apiKey.ApiURL)
|
||||
}
|
||||
r, err := client.R().SetHeader("Content-Type", "application/json").
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(types.ApiRequest{
|
||||
Model: "gpt-3.5-turbo-0125",
|
||||
Temperature: 0.9,
|
||||
MaxTokens: 1024,
|
||||
Stream: false,
|
||||
Messages: messages,
|
||||
}).
|
||||
SetErrorResult(&errRes).
|
||||
SetSuccessResult(&response).Post(apiKey.ApiURL)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)
|
||||
}
|
||||
|
||||
// 更新 API KEY 的最后使用时间
|
||||
db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
|
||||
return response.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
@@ -94,6 +95,7 @@ func InterfaceToString(value interface{}) string {
|
||||
return JsonEncode(value)
|
||||
}
|
||||
|
||||
// CutWords 截取前 N 个单词
|
||||
func CutWords(str string, num int) string {
|
||||
// 按空格分割字符串为单词切片
|
||||
words := strings.Fields(str)
|
||||
@@ -105,3 +107,13 @@ func CutWords(str string, num int) string {
|
||||
return str
|
||||
}
|
||||
}
|
||||
|
||||
// HasChinese 判断文本是否含有中文
|
||||
func HasChinese(text string) bool {
|
||||
for _, char := range text {
|
||||
if unicode.Is(unicode.Scripts["Han"], char) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user