mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
feat: refactor LLM api request code, get API URL from ApiKey object
This commit is contained in:
parent
21f2622a4b
commit
4b1c4f7ccc
@ -14,7 +14,7 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
|
|||||||
绘画函数插件。
|
绘画函数插件。
|
||||||
|
|
||||||
## 最新版本一键部署脚本
|
## 最新版本一键部署脚本
|
||||||
|
目前仅支持 Ubuntu 和 Centos 系统。
|
||||||
```shell
|
```shell
|
||||||
bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v3.2.3-8b588904ef.sh)"
|
bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v3.2.3-8b588904ef.sh)"
|
||||||
```
|
```
|
||||||
|
@ -62,7 +62,6 @@ type ApiError struct {
|
|||||||
|
|
||||||
const PromptMsg = "prompt" // prompt message
|
const PromptMsg = "prompt" // prompt message
|
||||||
const ReplyMsg = "reply" // reply message
|
const ReplyMsg = "reply" // reply message
|
||||||
const MjMsg = "mj"
|
|
||||||
|
|
||||||
var ModelToTokens = map[string]int{
|
var ModelToTokens = map[string]int{
|
||||||
"gpt-3.5-turbo": 4096,
|
"gpt-3.5-turbo": 4096,
|
||||||
@ -75,4 +74,12 @@ var ModelToTokens = map[string]int{
|
|||||||
"ernie_bot_turbo": 8192, // 文心一言
|
"ernie_bot_turbo": 8192, // 文心一言
|
||||||
"general": 8192, // 科大讯飞
|
"general": 8192, // 科大讯飞
|
||||||
"general2": 8192,
|
"general2": 8192,
|
||||||
|
"general3": 8192,
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetModelMaxToken(model string) int {
|
||||||
|
if token, ok := ModelToTokens[model]; ok {
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
return 4096
|
||||||
}
|
}
|
||||||
|
@ -141,7 +141,6 @@ type InviteReward struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ModelAPIConfig struct {
|
type ModelAPIConfig struct {
|
||||||
ApiURL string `json:"api_url,omitempty"`
|
|
||||||
Temperature float32 `json:"temperature"`
|
Temperature float32 `json:"temperature"`
|
||||||
MaxTokens int `json:"max_tokens"`
|
MaxTokens int `json:"max_tokens"`
|
||||||
ApiKey string `json:"api_key"`
|
ApiKey string `json:"api_key"`
|
||||||
|
@ -29,7 +29,7 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
ws *types.WsClient) error {
|
ws *types.WsClient) error {
|
||||||
promptCreatedAt := time.Now() // 记录提问时间
|
promptCreatedAt := time.Now() // 记录提问时间
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
|
var apiKey = model.ApiKey{}
|
||||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
||||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -46,7 +46,7 @@ func (h *ChatHandler) sendBaiduMessage(
|
|||||||
ws *types.WsClient) error {
|
ws *types.WsClient) error {
|
||||||
promptCreatedAt := time.Now() // 记录提问时间
|
promptCreatedAt := time.Now() // 记录提问时间
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
|
var apiKey = model.ApiKey{}
|
||||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
||||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -275,7 +275,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
for _, v := range messages {
|
for _, v := range messages {
|
||||||
tks, _ := utils.CalcTokens(v.Content, req.Model)
|
tks, _ := utils.CalcTokens(v.Content, req.Model)
|
||||||
if tokens+tks >= types.ModelToTokens[req.Model] {
|
if tokens+tks >= types.GetModelMaxToken(req.Model) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
tokens += tks
|
tokens += tks
|
||||||
@ -290,7 +290,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
for i := len(historyMessages) - 1; i >= 0; i-- {
|
for i := len(historyMessages) - 1; i >= 0; i-- {
|
||||||
msg := historyMessages[i]
|
msg := historyMessages[i]
|
||||||
if tokens+msg.Tokens >= types.ModelToTokens[session.Model.Value] {
|
if tokens+msg.Tokens >= types.GetModelMaxToken(session.Model.Value) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
tokens += msg.Tokens
|
tokens += msg.Tokens
|
||||||
@ -401,39 +401,33 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
|||||||
|
|
||||||
// 发送请求到 OpenAI 服务器
|
// 发送请求到 OpenAI 服务器
|
||||||
// useOwnApiKey: 是否使用了用户自己的 API KEY
|
// useOwnApiKey: 是否使用了用户自己的 API KEY
|
||||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *string) (*http.Response, error) {
|
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
|
||||||
|
res := h.db.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
|
||||||
|
if res.Error != nil {
|
||||||
|
return nil, errors.New("no available key, please import key")
|
||||||
|
}
|
||||||
var apiURL string
|
var apiURL string
|
||||||
switch platform {
|
switch platform {
|
||||||
case types.Azure:
|
case types.Azure:
|
||||||
md := strings.Replace(req.Model, ".", "", 1)
|
md := strings.Replace(req.Model, ".", "", 1)
|
||||||
apiURL = strings.Replace(h.App.ChatConfig.Azure.ApiURL, "{model}", md, 1)
|
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
|
||||||
break
|
break
|
||||||
case types.ChatGLM:
|
case types.ChatGLM:
|
||||||
apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1)
|
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
|
||||||
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
|
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
|
||||||
req.Messages = nil
|
req.Messages = nil
|
||||||
break
|
break
|
||||||
case types.Baidu:
|
case types.Baidu:
|
||||||
apiURL = strings.Replace(h.App.ChatConfig.Baidu.ApiURL, "{model}", req.Model, 1)
|
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
|
||||||
break
|
break
|
||||||
default:
|
default:
|
||||||
apiURL = h.App.ChatConfig.OpenAI.ApiURL
|
apiURL = apiKey.ApiURL
|
||||||
}
|
|
||||||
if *apiKey == "" {
|
|
||||||
var key model.ApiKey
|
|
||||||
res := h.db.Where("platform = ? AND type = ?", platform, "chat").Order("last_used_at ASC").First(&key)
|
|
||||||
if res.Error != nil {
|
|
||||||
return nil, errors.New("no available key, please import key")
|
|
||||||
}
|
}
|
||||||
// 更新 API KEY 的最后使用时间
|
// 更新 API KEY 的最后使用时间
|
||||||
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
|
h.db.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||||
*apiKey = key.Value
|
|
||||||
}
|
|
||||||
|
|
||||||
// 百度文心,需要串接 access_token
|
// 百度文心,需要串接 access_token
|
||||||
if platform == types.Baidu {
|
if platform == types.Baidu {
|
||||||
token, err := h.getBaiduToken(*apiKey)
|
token, err := h.getBaiduToken(apiKey.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -465,13 +459,13 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
|||||||
} else {
|
} else {
|
||||||
client = http.DefaultClient
|
client = http.DefaultClient
|
||||||
}
|
}
|
||||||
logger.Infof("Sending %s request, ApiURL:%s, PROXY: %s, Model: %s", platform, apiURL, proxyURL, req.Model)
|
logger.Debugf("Sending %s request, ApiURL:%s, ApiKey:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model)
|
||||||
switch platform {
|
switch platform {
|
||||||
case types.Azure:
|
case types.Azure:
|
||||||
request.Header.Set("api-key", *apiKey)
|
request.Header.Set("api-key", apiKey.Value)
|
||||||
break
|
break
|
||||||
case types.ChatGLM:
|
case types.ChatGLM:
|
||||||
token, err := h.getChatGLMToken(*apiKey)
|
token, err := h.getChatGLMToken(apiKey.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -480,7 +474,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
|||||||
case types.Baidu:
|
case types.Baidu:
|
||||||
request.RequestURI = ""
|
request.RequestURI = ""
|
||||||
case types.OpenAI:
|
case types.OpenAI:
|
||||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
|
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
||||||
}
|
}
|
||||||
return client.Do(request)
|
return client.Do(request)
|
||||||
}
|
}
|
||||||
|
@ -30,7 +30,7 @@ func (h *ChatHandler) sendChatGLMMessage(
|
|||||||
ws *types.WsClient) error {
|
ws *types.WsClient) error {
|
||||||
promptCreatedAt := time.Now() // 记录提问时间
|
promptCreatedAt := time.Now() // 记录提问时间
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
|
var apiKey = model.ApiKey{}
|
||||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
||||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -29,7 +29,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
ws *types.WsClient) error {
|
ws *types.WsClient) error {
|
||||||
promptCreatedAt := time.Now() // 记录提问时间
|
promptCreatedAt := time.Now() // 记录提问时间
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
|
var apiKey = model.ApiKey{}
|
||||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
||||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -67,29 +67,25 @@ func (h *ChatHandler) sendXunFeiMessage(
|
|||||||
prompt string,
|
prompt string,
|
||||||
ws *types.WsClient) error {
|
ws *types.WsClient) error {
|
||||||
promptCreatedAt := time.Now() // 记录提问时间
|
promptCreatedAt := time.Now() // 记录提问时间
|
||||||
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
|
var apiKey model.ApiKey
|
||||||
if apiKey == "" {
|
res := h.db.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
|
||||||
var key model.ApiKey
|
|
||||||
res := h.db.Where("platform = ? AND type = ?", session.Model.Platform, "chat").Order("last_used_at ASC").First(&key)
|
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// 更新 API KEY 的最后使用时间
|
// 更新 API KEY 的最后使用时间
|
||||||
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
|
h.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||||
apiKey = key.Value
|
|
||||||
}
|
|
||||||
|
|
||||||
d := websocket.Dialer{
|
d := websocket.Dialer{
|
||||||
HandshakeTimeout: 5 * time.Second,
|
HandshakeTimeout: 5 * time.Second,
|
||||||
}
|
}
|
||||||
key := strings.Split(apiKey, "|")
|
key := strings.Split(apiKey.Value, "|")
|
||||||
if len(key) != 3 {
|
if len(key) != 3 {
|
||||||
utils.ReplyMessage(ws, "非法的 API KEY!")
|
utils.ReplyMessage(ws, "非法的 API KEY!")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
apiURL := strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", Model2URL[req.Model], 1)
|
apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
|
||||||
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
|
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
|
||||||
//握手并建立websocket 连接
|
//握手并建立websocket 连接
|
||||||
conn, resp, err := d.Dial(wsURL, nil)
|
conn, resp, err := d.Dial(wsURL, nil)
|
||||||
|
@ -208,7 +208,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
|||||||
prompt := utils.InterfaceToString(params["prompt"])
|
prompt := utils.InterfaceToString(params["prompt"])
|
||||||
// get image generation API KEY
|
// get image generation API KEY
|
||||||
var apiKey model.ApiKey
|
var apiKey model.ApiKey
|
||||||
tx = h.db.Where("platform = ? AND type = ?", types.OpenAI, "img").Order("last_used_at ASC").First(&apiKey)
|
tx = h.db.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
|
||||||
if tx.Error != nil {
|
if tx.Error != nil {
|
||||||
resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error())
|
resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error())
|
||||||
return
|
return
|
||||||
@ -231,7 +231,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
|||||||
|
|
||||||
// translate prompt
|
// translate prompt
|
||||||
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
||||||
pt, err := utils.OpenAIRequest(fmt.Sprintf(translatePromptTemplate, params["prompt"]), apiKey.Value, h.App.Config.ProxyURL, chatConfig.OpenAI.ApiURL)
|
pt, err := utils.OpenAIRequest(fmt.Sprintf(translatePromptTemplate, params["prompt"]), apiKey, h.App.Config.ProxyURL)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
prompt = pt
|
prompt = pt
|
||||||
}
|
}
|
||||||
|
@ -66,10 +66,10 @@ func (h *PromptHandler) Translate(c *gin.Context) {
|
|||||||
func (h *PromptHandler) request(prompt string, promptTemplate string) (string, error) {
|
func (h *PromptHandler) request(prompt string, promptTemplate string) (string, error) {
|
||||||
// 获取 OpenAI 的 API KEY
|
// 获取 OpenAI 的 API KEY
|
||||||
var apiKey model.ApiKey
|
var apiKey model.ApiKey
|
||||||
res := h.db.Where("platform = ?", types.OpenAI).First(&apiKey)
|
res := h.db.Where("platform = ?", types.OpenAI).Where("type = ?", "chat").Where("enabled = ?", true).First(&apiKey)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error)
|
return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
return utils.OpenAIRequest(fmt.Sprintf(promptTemplate, prompt), apiKey.Value, h.App.Config.ProxyURL, h.App.ChatConfig.OpenAI.ApiURL)
|
return utils.OpenAIRequest(fmt.Sprintf(promptTemplate, prompt), apiKey, h.App.Config.ProxyURL)
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,10 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
func main() {
|
import (
|
||||||
|
"chatplus/utils"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
fmt.Println(utils.RandString(64))
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package utils
|
|||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
logger2 "chatplus/logger"
|
logger2 "chatplus/logger"
|
||||||
|
"chatplus/store/model"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
@ -85,7 +86,7 @@ type apiErrRes struct {
|
|||||||
} `json:"error"`
|
} `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenAIRequest(prompt string, apiKey string, proxy string, apiURL string) (string, error) {
|
func OpenAIRequest(prompt string, apiKey model.ApiKey, proxy string) (string, error) {
|
||||||
messages := make([]interface{}, 1)
|
messages := make([]interface{}, 1)
|
||||||
messages[0] = types.Message{
|
messages[0] = types.Message{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
@ -94,8 +95,12 @@ func OpenAIRequest(prompt string, apiKey string, proxy string, apiURL string) (s
|
|||||||
|
|
||||||
var response apiRes
|
var response apiRes
|
||||||
var errRes apiErrRes
|
var errRes apiErrRes
|
||||||
r, err := req.C().SetProxyURL(proxy).R().SetHeader("Content-Type", "application/json").
|
client := req.C()
|
||||||
SetHeader("Authorization", "Bearer "+apiKey).
|
if apiKey.UseProxy && proxy != "" {
|
||||||
|
client.SetProxyURL(proxy)
|
||||||
|
}
|
||||||
|
r, err := client.R().SetHeader("Content-Type", "application/json").
|
||||||
|
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||||
SetBody(types.ApiRequest{
|
SetBody(types.ApiRequest{
|
||||||
Model: "gpt-3.5-turbo",
|
Model: "gpt-3.5-turbo",
|
||||||
Temperature: 0.9,
|
Temperature: 0.9,
|
||||||
@ -104,7 +109,7 @@ func OpenAIRequest(prompt string, apiKey string, proxy string, apiURL string) (s
|
|||||||
Messages: messages,
|
Messages: messages,
|
||||||
}).
|
}).
|
||||||
SetErrorResult(&errRes).
|
SetErrorResult(&errRes).
|
||||||
SetSuccessResult(&response).Post(apiURL)
|
SetSuccessResult(&response).Post(apiKey.ApiURL)
|
||||||
if err != nil || r.IsErrorState() {
|
if err != nil || r.IsErrorState() {
|
||||||
return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)
|
return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user