mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	feat: refactor LLM api request code, get API URL from ApiKey object
This commit is contained in:
		@@ -14,7 +14,7 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
 | 
			
		||||
  绘画函数插件。
 | 
			
		||||
 | 
			
		||||
## 最新版本一键部署脚本
 | 
			
		||||
 | 
			
		||||
目前仅支持 Ubuntu 和 Centos 系统。
 | 
			
		||||
```shell
 | 
			
		||||
bash -c "$(curl -fsSL https://img.r9it.com/tmp/install-v3.2.3-8b588904ef.sh)"
 | 
			
		||||
```
 | 
			
		||||
 
 | 
			
		||||
@@ -62,7 +62,6 @@ type ApiError struct {
 | 
			
		||||
 | 
			
		||||
const PromptMsg = "prompt" // prompt message
 | 
			
		||||
const ReplyMsg = "reply"   // reply message
 | 
			
		||||
const MjMsg = "mj"
 | 
			
		||||
 | 
			
		||||
var ModelToTokens = map[string]int{
 | 
			
		||||
	"gpt-3.5-turbo":     4096,
 | 
			
		||||
@@ -75,4 +74,12 @@ var ModelToTokens = map[string]int{
 | 
			
		||||
	"ernie_bot_turbo":   8192, // 文心一言
 | 
			
		||||
	"general":           8192, // 科大讯飞
 | 
			
		||||
	"general2":          8192,
 | 
			
		||||
	"general3":          8192,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetModelMaxToken(model string) int {
 | 
			
		||||
	if token, ok := ModelToTokens[model]; ok {
 | 
			
		||||
		return token
 | 
			
		||||
	}
 | 
			
		||||
	return 4096
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -141,7 +141,6 @@ type InviteReward struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ModelAPIConfig struct {
 | 
			
		||||
	ApiURL      string  `json:"api_url,omitempty"`
 | 
			
		||||
	Temperature float32 `json:"temperature"`
 | 
			
		||||
	MaxTokens   int     `json:"max_tokens"`
 | 
			
		||||
	ApiKey      string  `json:"api_key"`
 | 
			
		||||
 
 | 
			
		||||
@@ -29,7 +29,7 @@ func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
	ws *types.WsClient) error {
 | 
			
		||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
			
		||||
	start := time.Now()
 | 
			
		||||
	var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
 | 
			
		||||
	var apiKey = model.ApiKey{}
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
 | 
			
		||||
	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -46,7 +46,7 @@ func (h *ChatHandler) sendBaiduMessage(
 | 
			
		||||
	ws *types.WsClient) error {
 | 
			
		||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
			
		||||
	start := time.Now()
 | 
			
		||||
	var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
 | 
			
		||||
	var apiKey = model.ApiKey{}
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
 | 
			
		||||
	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -275,7 +275,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				for _, v := range messages {
 | 
			
		||||
					tks, _ := utils.CalcTokens(v.Content, req.Model)
 | 
			
		||||
					if tokens+tks >= types.ModelToTokens[req.Model] {
 | 
			
		||||
					if tokens+tks >= types.GetModelMaxToken(req.Model) {
 | 
			
		||||
						break
 | 
			
		||||
					}
 | 
			
		||||
					tokens += tks
 | 
			
		||||
@@ -290,7 +290,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
				if res.Error == nil {
 | 
			
		||||
					for i := len(historyMessages) - 1; i >= 0; i-- {
 | 
			
		||||
						msg := historyMessages[i]
 | 
			
		||||
						if tokens+msg.Tokens >= types.ModelToTokens[session.Model.Value] {
 | 
			
		||||
						if tokens+msg.Tokens >= types.GetModelMaxToken(session.Model.Value) {
 | 
			
		||||
							break
 | 
			
		||||
						}
 | 
			
		||||
						tokens += msg.Tokens
 | 
			
		||||
@@ -401,39 +401,33 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
// 发送请求到 OpenAI 服务器
 | 
			
		||||
// useOwnApiKey: 是否使用了用户自己的 API KEY
 | 
			
		||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *string) (*http.Response, error) {
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
 | 
			
		||||
	res := h.db.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		return nil, errors.New("no available key, please import key")
 | 
			
		||||
	}
 | 
			
		||||
	var apiURL string
 | 
			
		||||
	switch platform {
 | 
			
		||||
	case types.Azure:
 | 
			
		||||
		md := strings.Replace(req.Model, ".", "", 1)
 | 
			
		||||
		apiURL = strings.Replace(h.App.ChatConfig.Azure.ApiURL, "{model}", md, 1)
 | 
			
		||||
		apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
 | 
			
		||||
		break
 | 
			
		||||
	case types.ChatGLM:
 | 
			
		||||
		apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1)
 | 
			
		||||
		apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
 | 
			
		||||
		req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
 | 
			
		||||
		req.Messages = nil
 | 
			
		||||
		break
 | 
			
		||||
	case types.Baidu:
 | 
			
		||||
		apiURL = strings.Replace(h.App.ChatConfig.Baidu.ApiURL, "{model}", req.Model, 1)
 | 
			
		||||
		apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
 | 
			
		||||
		break
 | 
			
		||||
	default:
 | 
			
		||||
		apiURL = h.App.ChatConfig.OpenAI.ApiURL
 | 
			
		||||
		apiURL = apiKey.ApiURL
 | 
			
		||||
	}
 | 
			
		||||
	if *apiKey == "" {
 | 
			
		||||
		var key model.ApiKey
 | 
			
		||||
		res := h.db.Where("platform = ? AND type = ?", platform, "chat").Order("last_used_at ASC").First(&key)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			return nil, errors.New("no available key, please import key")
 | 
			
		||||
		}
 | 
			
		||||
		// 更新 API KEY 的最后使用时间
 | 
			
		||||
		h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
 | 
			
		||||
		*apiKey = key.Value
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新 API KEY 的最后使用时间
 | 
			
		||||
	h.db.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
 | 
			
		||||
	// 百度文心,需要串接 access_token
 | 
			
		||||
	if platform == types.Baidu {
 | 
			
		||||
		token, err := h.getBaiduToken(*apiKey)
 | 
			
		||||
		token, err := h.getBaiduToken(apiKey.Value)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
@@ -465,13 +459,13 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
 | 
			
		||||
	} else {
 | 
			
		||||
		client = http.DefaultClient
 | 
			
		||||
	}
 | 
			
		||||
	logger.Infof("Sending %s request, ApiURL:%s, PROXY: %s, Model: %s", platform, apiURL, proxyURL, req.Model)
 | 
			
		||||
	logger.Debugf("Sending %s request, ApiURL:%s, ApiKey:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model)
 | 
			
		||||
	switch platform {
 | 
			
		||||
	case types.Azure:
 | 
			
		||||
		request.Header.Set("api-key", *apiKey)
 | 
			
		||||
		request.Header.Set("api-key", apiKey.Value)
 | 
			
		||||
		break
 | 
			
		||||
	case types.ChatGLM:
 | 
			
		||||
		token, err := h.getChatGLMToken(*apiKey)
 | 
			
		||||
		token, err := h.getChatGLMToken(apiKey.Value)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
@@ -480,7 +474,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
 | 
			
		||||
	case types.Baidu:
 | 
			
		||||
		request.RequestURI = ""
 | 
			
		||||
	case types.OpenAI:
 | 
			
		||||
		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
 | 
			
		||||
		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
 | 
			
		||||
	}
 | 
			
		||||
	return client.Do(request)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -30,7 +30,7 @@ func (h *ChatHandler) sendChatGLMMessage(
 | 
			
		||||
	ws *types.WsClient) error {
 | 
			
		||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
			
		||||
	start := time.Now()
 | 
			
		||||
	var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
 | 
			
		||||
	var apiKey = model.ApiKey{}
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
 | 
			
		||||
	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -29,7 +29,7 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
	ws *types.WsClient) error {
 | 
			
		||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
			
		||||
	start := time.Now()
 | 
			
		||||
	var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
 | 
			
		||||
	var apiKey = model.ApiKey{}
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
 | 
			
		||||
	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -67,29 +67,25 @@ func (h *ChatHandler) sendXunFeiMessage(
 | 
			
		||||
	prompt string,
 | 
			
		||||
	ws *types.WsClient) error {
 | 
			
		||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
			
		||||
	var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
 | 
			
		||||
	if apiKey == "" {
 | 
			
		||||
		var key model.ApiKey
 | 
			
		||||
		res := h.db.Where("platform = ? AND type = ?", session.Model.Platform, "chat").Order("last_used_at ASC").First(&key)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		// 更新 API KEY 的最后使用时间
 | 
			
		||||
		h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
 | 
			
		||||
		apiKey = key.Value
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	res := h.db.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	// 更新 API KEY 的最后使用时间
 | 
			
		||||
	h.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
 | 
			
		||||
 | 
			
		||||
	d := websocket.Dialer{
 | 
			
		||||
		HandshakeTimeout: 5 * time.Second,
 | 
			
		||||
	}
 | 
			
		||||
	key := strings.Split(apiKey, "|")
 | 
			
		||||
	key := strings.Split(apiKey.Value, "|")
 | 
			
		||||
	if len(key) != 3 {
 | 
			
		||||
		utils.ReplyMessage(ws, "非法的 API KEY!")
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	apiURL := strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", Model2URL[req.Model], 1)
 | 
			
		||||
	apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
 | 
			
		||||
	wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
 | 
			
		||||
	//握手并建立websocket 连接
 | 
			
		||||
	conn, resp, err := d.Dial(wsURL, nil)
 | 
			
		||||
 
 | 
			
		||||
@@ -208,7 +208,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
 | 
			
		||||
	prompt := utils.InterfaceToString(params["prompt"])
 | 
			
		||||
	// get image generation API KEY
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	tx = h.db.Where("platform = ? AND type = ?", types.OpenAI, "img").Order("last_used_at ASC").First(&apiKey)
 | 
			
		||||
	tx = h.db.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
@@ -231,7 +231,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	// translate prompt
 | 
			
		||||
	const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
 | 
			
		||||
	pt, err := utils.OpenAIRequest(fmt.Sprintf(translatePromptTemplate, params["prompt"]), apiKey.Value, h.App.Config.ProxyURL, chatConfig.OpenAI.ApiURL)
 | 
			
		||||
	pt, err := utils.OpenAIRequest(fmt.Sprintf(translatePromptTemplate, params["prompt"]), apiKey, h.App.Config.ProxyURL)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		prompt = pt
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -66,10 +66,10 @@ func (h *PromptHandler) Translate(c *gin.Context) {
 | 
			
		||||
func (h *PromptHandler) request(prompt string, promptTemplate string) (string, error) {
 | 
			
		||||
	// 获取 OpenAI 的 API KEY
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	res := h.db.Where("platform = ?", types.OpenAI).First(&apiKey)
 | 
			
		||||
	res := h.db.Where("platform = ?", types.OpenAI).Where("type = ?", "chat").Where("enabled = ?", true).First(&apiKey)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return utils.OpenAIRequest(fmt.Sprintf(promptTemplate, prompt), apiKey.Value, h.App.Config.ProxyURL, h.App.ChatConfig.OpenAI.ApiURL)
 | 
			
		||||
	return utils.OpenAIRequest(fmt.Sprintf(promptTemplate, prompt), apiKey, h.App.Config.ProxyURL)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,10 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"fmt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	
 | 
			
		||||
	fmt.Println(utils.RandString(64))
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ package utils
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
@@ -85,7 +86,7 @@ type apiErrRes struct {
 | 
			
		||||
	} `json:"error"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func OpenAIRequest(prompt string, apiKey string, proxy string, apiURL string) (string, error) {
 | 
			
		||||
func OpenAIRequest(prompt string, apiKey model.ApiKey, proxy string) (string, error) {
 | 
			
		||||
	messages := make([]interface{}, 1)
 | 
			
		||||
	messages[0] = types.Message{
 | 
			
		||||
		Role:    "user",
 | 
			
		||||
@@ -94,8 +95,12 @@ func OpenAIRequest(prompt string, apiKey string, proxy string, apiURL string) (s
 | 
			
		||||
 | 
			
		||||
	var response apiRes
 | 
			
		||||
	var errRes apiErrRes
 | 
			
		||||
	r, err := req.C().SetProxyURL(proxy).R().SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+apiKey).
 | 
			
		||||
	client := req.C()
 | 
			
		||||
	if apiKey.UseProxy && proxy != "" {
 | 
			
		||||
		client.SetProxyURL(proxy)
 | 
			
		||||
	}
 | 
			
		||||
	r, err := client.R().SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+apiKey.Value).
 | 
			
		||||
		SetBody(types.ApiRequest{
 | 
			
		||||
			Model:       "gpt-3.5-turbo",
 | 
			
		||||
			Temperature: 0.9,
 | 
			
		||||
@@ -104,7 +109,7 @@ func OpenAIRequest(prompt string, apiKey string, proxy string, apiURL string) (s
 | 
			
		||||
			Messages:    messages,
 | 
			
		||||
		}).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		SetSuccessResult(&response).Post(apiURL)
 | 
			
		||||
		SetSuccessResult(&response).Post(apiKey.ApiURL)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user