mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	feat: XunFei ai mode api implements is ready
This commit is contained in:
		@@ -67,8 +67,10 @@ var ModelToTokens = map[string]int{
 | 
			
		||||
	"gpt-3.5-turbo-16k": 16384,
 | 
			
		||||
	"gpt-4":             8192,
 | 
			
		||||
	"gpt-4-32k":         32768,
 | 
			
		||||
	"chatglm_pro":       32768,
 | 
			
		||||
	"chatglm_pro":       32768, // 清华智普
 | 
			
		||||
	"chatglm_std":       16384,
 | 
			
		||||
	"chatglm_lite":      4096,
 | 
			
		||||
	"ernie_bot_turbo":   8192, // 文心一言
 | 
			
		||||
	"general":           8192, // 科大讯飞
 | 
			
		||||
	"general2":          8192,
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -36,6 +36,16 @@ func (wc *WsClient) Send(message []byte) error {
 | 
			
		||||
	return wc.Conn.WriteMessage(wc.mt, message)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (wc *WsClient) SendJson(value interface{}) error {
 | 
			
		||||
	wc.lock.Lock()
 | 
			
		||||
	defer wc.lock.Unlock()
 | 
			
		||||
 | 
			
		||||
	if wc.Closed {
 | 
			
		||||
		return ErrConClosed
 | 
			
		||||
	}
 | 
			
		||||
	return wc.Conn.WriteJSON(value)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (wc *WsClient) Receive() (int, []byte, error) {
 | 
			
		||||
	if wc.Closed {
 | 
			
		||||
		return 0, nil, ErrConClosed
 | 
			
		||||
 
 | 
			
		||||
@@ -39,7 +39,12 @@ type ChatHandler struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, redis *redis.Client, service *mj.Service) *ChatHandler {
 | 
			
		||||
	h := ChatHandler{db: db, leveldb: levelDB, redis: redis, mjService: service}
 | 
			
		||||
	h := ChatHandler{
 | 
			
		||||
		db:        db,
 | 
			
		||||
		leveldb:   levelDB,
 | 
			
		||||
		redis:     redis,
 | 
			
		||||
		mjService: service,
 | 
			
		||||
	}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
}
 | 
			
		||||
@@ -127,7 +132,11 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
				client.Close()
 | 
			
		||||
				h.App.ChatClients.Delete(sessionId)
 | 
			
		||||
				h.App.ReqCancelFunc.Delete(sessionId)
 | 
			
		||||
				cancelFunc := h.App.ReqCancelFunc.Get(sessionId)
 | 
			
		||||
				if cancelFunc != nil {
 | 
			
		||||
					cancelFunc()
 | 
			
		||||
					h.App.ReqCancelFunc.Delete(sessionId)
 | 
			
		||||
				}
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
@@ -217,6 +226,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
			}
 | 
			
		||||
			req.Functions = functions
 | 
			
		||||
		}
 | 
			
		||||
	case types.XunFei:
 | 
			
		||||
		req.Temperature = h.App.ChatConfig.XunFei.Temperature
 | 
			
		||||
		req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens
 | 
			
		||||
	default:
 | 
			
		||||
		utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
 | 
			
		||||
		utils.ReplyMessage(ws, "")
 | 
			
		||||
@@ -291,6 +303,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
		return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
	case types.Baidu:
 | 
			
		||||
		return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
	case types.XunFei:
 | 
			
		||||
		return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
	utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
 
 | 
			
		||||
@@ -1,22 +1,54 @@
 | 
			
		||||
package chatimpl
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/hmac"
 | 
			
		||||
	"crypto/sha256"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
	"unicode/utf8"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type xunFeiResp struct {
 | 
			
		||||
	Header struct {
 | 
			
		||||
		Code    int    `json:"code"`
 | 
			
		||||
		Message string `json:"message"`
 | 
			
		||||
		Sid     string `json:"sid"`
 | 
			
		||||
		Status  int    `json:"status"`
 | 
			
		||||
	} `json:"header"`
 | 
			
		||||
	Payload struct {
 | 
			
		||||
		Choices struct {
 | 
			
		||||
			Status int `json:"status"`
 | 
			
		||||
			Seq    int `json:"seq"`
 | 
			
		||||
			Text   []struct {
 | 
			
		||||
				Content string `json:"content"`
 | 
			
		||||
				Role    string `json:"role"`
 | 
			
		||||
				Index   int    `json:"index"`
 | 
			
		||||
			} `json:"text"`
 | 
			
		||||
		} `json:"choices"`
 | 
			
		||||
		Usage struct {
 | 
			
		||||
			Text struct {
 | 
			
		||||
				QuestionTokens   int `json:"question_tokens"`
 | 
			
		||||
				PromptTokens     int `json:"prompt_tokens"`
 | 
			
		||||
				CompletionTokens int `json:"completion_tokens"`
 | 
			
		||||
				TotalTokens      int `json:"total_tokens"`
 | 
			
		||||
			} `json:"text"`
 | 
			
		||||
		} `json:"usage"`
 | 
			
		||||
	} `json:"payload"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 科大讯飞消息发送实现
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) sendXunFeiMessage(
 | 
			
		||||
@@ -29,229 +61,261 @@ func (h *ChatHandler) sendXunFeiMessage(
 | 
			
		||||
	prompt string,
 | 
			
		||||
	ws *types.WsClient) error {
 | 
			
		||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
			
		||||
	start := time.Now()
 | 
			
		||||
	var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
 | 
			
		||||
	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
			logger.Info("用户取消了请求:", prompt)
 | 
			
		||||
			return nil
 | 
			
		||||
		} else if strings.Contains(err.Error(), "no available key") {
 | 
			
		||||
	if apiKey == "" {
 | 
			
		||||
		var key model.ApiKey
 | 
			
		||||
		res := h.db.Where("platform = ?", session.Model.Platform).Order("last_used_at ASC").First(&key)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			logger.Error(err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		utils.ReplyMessage(ws, ErrorMsg)
 | 
			
		||||
		utils.ReplyMessage(ws, "")
 | 
			
		||||
		return err
 | 
			
		||||
	} else {
 | 
			
		||||
		defer response.Body.Close()
 | 
			
		||||
		// 更新 API KEY 的最后使用时间
 | 
			
		||||
		h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
 | 
			
		||||
		apiKey = key.Value
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	contentType := response.Header.Get("Content-Type")
 | 
			
		||||
	if strings.Contains(contentType, "text/event-stream") {
 | 
			
		||||
		replyCreatedAt := time.Now() // 记录回复时间
 | 
			
		||||
		// 循环读取 Chunk 消息
 | 
			
		||||
		var message = types.Message{}
 | 
			
		||||
		var contents = make([]string, 0)
 | 
			
		||||
		var content string
 | 
			
		||||
		scanner := bufio.NewScanner(response.Body)
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			line := scanner.Text()
 | 
			
		||||
			if len(line) < 5 || strings.HasPrefix(line, "id:") {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
	d := websocket.Dialer{
 | 
			
		||||
		HandshakeTimeout: 5 * time.Second,
 | 
			
		||||
	}
 | 
			
		||||
	key := strings.Split(apiKey, "|")
 | 
			
		||||
	if len(key) != 3 {
 | 
			
		||||
		utils.ReplyMessage(ws, "非法的 API KEY!")
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
			if strings.HasPrefix(line, "data:") {
 | 
			
		||||
				content = line[5:]
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var resp baiduResp
 | 
			
		||||
			err := utils.JsonDecode(content, &resp)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("error with parse data line: ", err)
 | 
			
		||||
				utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if len(contents) == 0 {
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
			}
 | 
			
		||||
			utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
				Type:    types.WsMiddle,
 | 
			
		||||
				Content: utils.InterfaceToString(resp.Result),
 | 
			
		||||
			})
 | 
			
		||||
			contents = append(contents, resp.Result)
 | 
			
		||||
 | 
			
		||||
			if resp.IsTruncated {
 | 
			
		||||
				utils.ReplyMessage(ws, "AI 输出异常中断")
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if resp.IsEnd {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		} // end for
 | 
			
		||||
 | 
			
		||||
		if err := scanner.Err(); err != nil {
 | 
			
		||||
			if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
				logger.Info("用户取消了请求:", prompt)
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Error("信息读取出错:", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			// 更新用户的对话次数
 | 
			
		||||
			if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
 | 
			
		||||
				h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if message.Role == "" {
 | 
			
		||||
				message.Role = "assistant"
 | 
			
		||||
			}
 | 
			
		||||
			message.Content = strings.Join(contents, "")
 | 
			
		||||
			useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
			// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
			if h.App.ChatConfig.EnableContext {
 | 
			
		||||
				chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
				chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
				h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 追加聊天记录
 | 
			
		||||
			if h.App.ChatConfig.EnableHistory {
 | 
			
		||||
				// for prompt
 | 
			
		||||
				promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Error(err)
 | 
			
		||||
				}
 | 
			
		||||
				historyUserMsg := model.HistoryMessage{
 | 
			
		||||
					UserId:     userVo.Id,
 | 
			
		||||
					ChatId:     session.ChatId,
 | 
			
		||||
					RoleId:     role.Id,
 | 
			
		||||
					Type:       types.PromptMsg,
 | 
			
		||||
					Icon:       userVo.Avatar,
 | 
			
		||||
					Content:    prompt,
 | 
			
		||||
					Tokens:     promptToken,
 | 
			
		||||
					UseContext: true,
 | 
			
		||||
				}
 | 
			
		||||
				historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
				historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
				res := h.db.Save(&historyUserMsg)
 | 
			
		||||
				if res.Error != nil {
 | 
			
		||||
					logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// for reply
 | 
			
		||||
				// 计算本次对话消耗的总 token 数量
 | 
			
		||||
				replyToken, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
				totalTokens := replyToken + getTotalTokens(req)
 | 
			
		||||
				historyReplyMsg := model.HistoryMessage{
 | 
			
		||||
					UserId:     userVo.Id,
 | 
			
		||||
					ChatId:     session.ChatId,
 | 
			
		||||
					RoleId:     role.Id,
 | 
			
		||||
					Type:       types.ReplyMsg,
 | 
			
		||||
					Icon:       role.Icon,
 | 
			
		||||
					Content:    message.Content,
 | 
			
		||||
					Tokens:     totalTokens,
 | 
			
		||||
					UseContext: true,
 | 
			
		||||
				}
 | 
			
		||||
				historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
				historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
				res = h.db.Create(&historyReplyMsg)
 | 
			
		||||
				if res.Error != nil {
 | 
			
		||||
					logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
				}
 | 
			
		||||
				// 更新用户信息
 | 
			
		||||
				h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
 | 
			
		||||
					UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 保存当前会话
 | 
			
		||||
			var chatItem model.ChatItem
 | 
			
		||||
			res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				chatItem.ChatId = session.ChatId
 | 
			
		||||
				chatItem.UserId = session.UserId
 | 
			
		||||
				chatItem.RoleId = role.Id
 | 
			
		||||
				chatItem.ModelId = session.Model.Id
 | 
			
		||||
				if utf8.RuneCountInString(prompt) > 30 {
 | 
			
		||||
					chatItem.Title = string([]rune(prompt)[:30]) + "..."
 | 
			
		||||
				} else {
 | 
			
		||||
					chatItem.Title = prompt
 | 
			
		||||
				}
 | 
			
		||||
				h.db.Create(&chatItem)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	var apiURL string
 | 
			
		||||
	if req.Model == "generalv2" {
 | 
			
		||||
		apiURL = strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", "v2.1", 1)
 | 
			
		||||
	} else {
 | 
			
		||||
		body, err := io.ReadAll(response.Body)
 | 
			
		||||
		apiURL = strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", "v1.1", 1)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
 | 
			
		||||
	//握手并建立websocket 连接
 | 
			
		||||
	conn, resp, err := d.Dial(wsURL, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(readResp(resp) + err.Error())
 | 
			
		||||
		utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error())
 | 
			
		||||
		return nil
 | 
			
		||||
	} else if resp.StatusCode != 101 {
 | 
			
		||||
		utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error())
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	data := buildRequest(key[0], req)
 | 
			
		||||
	fmt.Printf("%+v", data)
 | 
			
		||||
	fmt.Println(apiURL)
 | 
			
		||||
	err = conn.WriteJSON(data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		utils.ReplyMessage(ws, "发送消息失败:"+err.Error())
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	replyCreatedAt := time.Now() // 记录回复时间
 | 
			
		||||
	// 循环读取 Chunk 消息
 | 
			
		||||
	var message = types.Message{}
 | 
			
		||||
	var contents = make([]string, 0)
 | 
			
		||||
	var content string
 | 
			
		||||
	for {
 | 
			
		||||
		_, msg, err := conn.ReadMessage()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("error with reading response: %v", err)
 | 
			
		||||
			logger.Error("error with read message:", err)
 | 
			
		||||
			utils.ReplyMessage(ws, fmt.Sprintf("**数据读取失败:%s**", err))
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var res struct {
 | 
			
		||||
			Code int    `json:"error_code"`
 | 
			
		||||
			Msg  string `json:"error_msg"`
 | 
			
		||||
		}
 | 
			
		||||
		err = json.Unmarshal(body, &res)
 | 
			
		||||
		// 解析数据
 | 
			
		||||
		var result xunFeiResp
 | 
			
		||||
		err = json.Unmarshal(msg, &result)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("error with decode response: %v", err)
 | 
			
		||||
			logger.Error("error with parsing JSON:", err)
 | 
			
		||||
			utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if result.Header.Code != 0 {
 | 
			
		||||
			utils.ReplyMessage(ws, fmt.Sprintf("**请求 API 返回错误:%s**", result.Header.Message))
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		content = result.Payload.Choices.Text[0].Content
 | 
			
		||||
		contents = append(contents, content)
 | 
			
		||||
		// 第一个结果
 | 
			
		||||
		if result.Payload.Choices.Status == 0 {
 | 
			
		||||
			utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
		}
 | 
			
		||||
		utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
			Type:    types.WsMiddle,
 | 
			
		||||
			Content: utils.InterfaceToString(content),
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		if result.Payload.Choices.Status == 2 { // 最终结果
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		select {
 | 
			
		||||
		case <-ctx.Done():
 | 
			
		||||
			utils.ReplyMessage(ws, "**用户取消了生成指令!**")
 | 
			
		||||
			return nil
 | 
			
		||||
		default:
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 消息发送成功
 | 
			
		||||
	if len(contents) > 0 {
 | 
			
		||||
		// 更新用户的对话次数
 | 
			
		||||
		if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
 | 
			
		||||
			h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if message.Role == "" {
 | 
			
		||||
			message.Role = "assistant"
 | 
			
		||||
		}
 | 
			
		||||
		message.Content = strings.Join(contents, "")
 | 
			
		||||
		useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
		// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
		if h.App.ChatConfig.EnableContext {
 | 
			
		||||
			chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
			chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
			h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 追加聊天记录
 | 
			
		||||
		if h.App.ChatConfig.EnableHistory {
 | 
			
		||||
			// for prompt
 | 
			
		||||
			promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg := model.HistoryMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.PromptMsg,
 | 
			
		||||
				Icon:       userVo.Avatar,
 | 
			
		||||
				Content:    prompt,
 | 
			
		||||
				Tokens:     promptToken,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
			historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
			res := h.db.Save(&historyUserMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// for reply
 | 
			
		||||
			// 计算本次对话消耗的总 token 数量
 | 
			
		||||
			replyToken, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
			totalTokens := replyToken + getTotalTokens(req)
 | 
			
		||||
			historyReplyMsg := model.HistoryMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.ReplyMsg,
 | 
			
		||||
				Icon:       role.Icon,
 | 
			
		||||
				Content:    message.Content,
 | 
			
		||||
				Tokens:     totalTokens,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
			}
 | 
			
		||||
			historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
			historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
			res = h.db.Create(&historyReplyMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
			// 更新用户信息
 | 
			
		||||
			h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
 | 
			
		||||
				UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 保存当前会话
 | 
			
		||||
		var chatItem model.ChatItem
 | 
			
		||||
		res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			chatItem.ChatId = session.ChatId
 | 
			
		||||
			chatItem.UserId = session.UserId
 | 
			
		||||
			chatItem.RoleId = role.Id
 | 
			
		||||
			chatItem.ModelId = session.Model.Id
 | 
			
		||||
			if utf8.RuneCountInString(prompt) > 30 {
 | 
			
		||||
				chatItem.Title = string([]rune(prompt)[:30]) + "..."
 | 
			
		||||
			} else {
 | 
			
		||||
				chatItem.Title = prompt
 | 
			
		||||
			}
 | 
			
		||||
			h.db.Create(&chatItem)
 | 
			
		||||
		}
 | 
			
		||||
		utils.ReplyMessage(ws, "请求百度文心大模型 API 失败:"+res.Msg)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) getXunFeiToken(apiKey string) (string, error) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	tokenString, err := h.redis.Get(ctx, apiKey).Result()
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return tokenString, nil
 | 
			
		||||
// 构建 websocket 请求实体
 | 
			
		||||
func buildRequest(appid string, req types.ApiRequest) map[string]interface{} {
 | 
			
		||||
	return map[string]interface{}{
 | 
			
		||||
		"header": map[string]interface{}{
 | 
			
		||||
			"app_id": appid,
 | 
			
		||||
		},
 | 
			
		||||
		"parameter": map[string]interface{}{
 | 
			
		||||
			"chat": map[string]interface{}{
 | 
			
		||||
				"domain":      req.Model,
 | 
			
		||||
				"temperature": float64(req.Temperature),
 | 
			
		||||
				"top_k":       int64(6),
 | 
			
		||||
				"max_tokens":  int64(req.MaxTokens),
 | 
			
		||||
				"auditing":    "default",
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		"payload": map[string]interface{}{
 | 
			
		||||
			"message": map[string]interface{}{
 | 
			
		||||
				"text": req.Messages,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
	expr := time.Hour * 24 * 20 // access_token 有效期
 | 
			
		||||
	key := strings.Split(apiKey, "|")
 | 
			
		||||
	if len(key) != 2 {
 | 
			
		||||
		return "", fmt.Errorf("invalid api key: %s", apiKey)
 | 
			
		||||
	}
 | 
			
		||||
	url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", key[0], key[1])
 | 
			
		||||
	client := &http.Client{}
 | 
			
		||||
	req, err := http.NewRequest("POST", url, nil)
 | 
			
		||||
// 创建鉴权 URL
 | 
			
		||||
func assembleAuthUrl(hostURL string, apiKey, apiSecret string) (string, error) {
 | 
			
		||||
	ul, err := url.Parse(hostURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Add("Content-Type", "application/json")
 | 
			
		||||
	req.Header.Add("Accept", "application/json")
 | 
			
		||||
 | 
			
		||||
	res, err := client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with send request: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	defer res.Body.Close()
 | 
			
		||||
	date := time.Now().UTC().Format(time.RFC1123)
 | 
			
		||||
	signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
 | 
			
		||||
	//拼接签名字符串
 | 
			
		||||
	signStr := strings.Join(signString, "\n")
 | 
			
		||||
	sha := hmacWithSha256(signStr, apiSecret)
 | 
			
		||||
 | 
			
		||||
	body, err := io.ReadAll(res.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with read response: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	var r map[string]interface{}
 | 
			
		||||
	err = json.Unmarshal(body, &r)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with parse response: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r["error"] != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with api response: %s", r["error_description"])
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tokenString = fmt.Sprintf("%s", r["access_token"])
 | 
			
		||||
	h.redis.Set(ctx, apiKey, tokenString, expr)
 | 
			
		||||
	return tokenString, nil
 | 
			
		||||
	authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
 | 
			
		||||
		"hmac-sha256", "host date request-line", sha)
 | 
			
		||||
	//将请求参数使用base64编码
 | 
			
		||||
	authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
 | 
			
		||||
	v := url.Values{}
 | 
			
		||||
	v.Add("host", ul.Host)
 | 
			
		||||
	v.Add("date", date)
 | 
			
		||||
	v.Add("authorization", authorization)
 | 
			
		||||
	//将编码后的字符串url encode后添加到url后面
 | 
			
		||||
	return hostURL + "?" + v.Encode(), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 使用 sha256 签名
 | 
			
		||||
func hmacWithSha256(data, key string) string {
 | 
			
		||||
	mac := hmac.New(sha256.New, []byte(key))
 | 
			
		||||
	mac.Write([]byte(data))
 | 
			
		||||
	encodeData := mac.Sum(nil)
 | 
			
		||||
	return base64.StdEncoding.EncodeToString(encodeData)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 读取响应
 | 
			
		||||
func readResp(resp *http.Response) string {
 | 
			
		||||
	if resp == nil {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	b, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b))
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user