mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 08:13:43 +08:00 
			
		
		
		
	add baidu ai model api configrations
This commit is contained in:
		@@ -2,9 +2,9 @@ package types
 | 
			
		||||
 | 
			
		||||
// ApiRequest API 请求实体
 | 
			
		||||
type ApiRequest struct {
 | 
			
		||||
	Model       string        `json:"model"`
 | 
			
		||||
	Model       string        `json:"model,omitempty"` // 兼容百度文心一言
 | 
			
		||||
	Temperature float32       `json:"temperature"`
 | 
			
		||||
	MaxTokens   int           `json:"max_tokens"`
 | 
			
		||||
	MaxTokens   int           `json:"max_tokens,omitempty"` // 兼容百度文心一言
 | 
			
		||||
	Stream      bool          `json:"stream"`
 | 
			
		||||
	Messages    []interface{} `json:"messages,omitempty"`
 | 
			
		||||
	Prompt      []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
 | 
			
		||||
 
 | 
			
		||||
@@ -90,6 +90,7 @@ type Platform string
 | 
			
		||||
const OpenAI = Platform("OpenAI")
 | 
			
		||||
const Azure = Platform("Azure")
 | 
			
		||||
const ChatGLM = Platform("ChatGLM")
 | 
			
		||||
const Baidu = Platform("Baidu")
 | 
			
		||||
 | 
			
		||||
// UserChatConfig 用户的聊天配置
 | 
			
		||||
type UserChatConfig struct {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										235
									
								
								api/handler/baidu_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										235
									
								
								api/handler/baidu_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,235 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/golang-jwt/jwt/v5"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
	"unicode/utf8"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 将消息发送给百度文心一言大模型 API 并获取结果,通过 WebSocket 推送到客户端
 | 
			
		||||
func (h *ChatHandler) sendBaiduMessage(
 | 
			
		||||
	chatCtx []interface{},
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	session *types.ChatSession,
 | 
			
		||||
	role model.ChatRole,
 | 
			
		||||
	prompt string,
 | 
			
		||||
	ws *types.WsClient) error {
 | 
			
		||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
			
		||||
	start := time.Now()
 | 
			
		||||
	var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
 | 
			
		||||
	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
			logger.Info("用户取消了请求:", prompt)
 | 
			
		||||
			return nil
 | 
			
		||||
		} else if strings.Contains(err.Error(), "no available key") {
 | 
			
		||||
			utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			logger.Error(err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		utils.ReplyMessage(ws, ErrorMsg)
 | 
			
		||||
		utils.ReplyMessage(ws, "")
 | 
			
		||||
		return err
 | 
			
		||||
	} else {
 | 
			
		||||
		defer response.Body.Close()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	contentType := response.Header.Get("Content-Type")
 | 
			
		||||
	if strings.Contains(contentType, "text/event-stream") {
 | 
			
		||||
		replyCreatedAt := time.Now() // 记录回复时间
 | 
			
		||||
		// 循环读取 Chunk 消息
 | 
			
		||||
		var message = types.Message{}
 | 
			
		||||
		var contents = make([]string, 0)
 | 
			
		||||
		var event, content string
 | 
			
		||||
		scanner := bufio.NewScanner(response.Body)
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			line := scanner.Text()
 | 
			
		||||
			if len(line) < 5 || strings.HasPrefix(line, "id:") {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(line, "event:") {
 | 
			
		||||
				event = line[6:]
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if strings.HasPrefix(line, "data:") {
 | 
			
		||||
				content = line[5:]
 | 
			
		||||
			}
 | 
			
		||||
			switch event {
 | 
			
		||||
			case "add":
 | 
			
		||||
				if len(contents) == 0 {
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
				}
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
					Type:    types.WsMiddle,
 | 
			
		||||
					Content: utils.InterfaceToString(content),
 | 
			
		||||
				})
 | 
			
		||||
				contents = append(contents, content)
 | 
			
		||||
			case "finish":
 | 
			
		||||
				break
 | 
			
		||||
			case "error":
 | 
			
		||||
				utils.ReplyMessage(ws, fmt.Sprintf("**调用 ChatGLM API 出错:%s**", content))
 | 
			
		||||
				break
 | 
			
		||||
			case "interrupted":
 | 
			
		||||
				utils.ReplyMessage(ws, "**调用 ChatGLM API 出错,当前输出被中断!**")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		} // end for
 | 
			
		||||
 | 
			
		||||
		if err := scanner.Err(); err != nil {
 | 
			
		||||
			if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
				logger.Info("用户取消了请求:", prompt)
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Error("信息读取出错:", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			// 更新用户的对话次数
 | 
			
		||||
			if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
 | 
			
		||||
				h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if message.Role == "" {
 | 
			
		||||
				message.Role = "assistant"
 | 
			
		||||
			}
 | 
			
		||||
			message.Content = strings.Join(contents, "")
 | 
			
		||||
			useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
			// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
			if h.App.ChatConfig.EnableContext {
 | 
			
		||||
				chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
				chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
				h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 追加聊天记录
 | 
			
		||||
			if h.App.ChatConfig.EnableHistory {
 | 
			
		||||
				// for prompt
 | 
			
		||||
				promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Error(err)
 | 
			
		||||
				}
 | 
			
		||||
				historyUserMsg := model.HistoryMessage{
 | 
			
		||||
					UserId:     userVo.Id,
 | 
			
		||||
					ChatId:     session.ChatId,
 | 
			
		||||
					RoleId:     role.Id,
 | 
			
		||||
					Type:       types.PromptMsg,
 | 
			
		||||
					Icon:       userVo.Avatar,
 | 
			
		||||
					Content:    prompt,
 | 
			
		||||
					Tokens:     promptToken,
 | 
			
		||||
					UseContext: true,
 | 
			
		||||
				}
 | 
			
		||||
				historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
				historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
				res := h.db.Save(&historyUserMsg)
 | 
			
		||||
				if res.Error != nil {
 | 
			
		||||
					logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// for reply
 | 
			
		||||
				// 计算本次对话消耗的总 token 数量
 | 
			
		||||
				replyToken, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
				totalTokens := replyToken + getTotalTokens(req)
 | 
			
		||||
				historyReplyMsg := model.HistoryMessage{
 | 
			
		||||
					UserId:     userVo.Id,
 | 
			
		||||
					ChatId:     session.ChatId,
 | 
			
		||||
					RoleId:     role.Id,
 | 
			
		||||
					Type:       types.ReplyMsg,
 | 
			
		||||
					Icon:       role.Icon,
 | 
			
		||||
					Content:    message.Content,
 | 
			
		||||
					Tokens:     totalTokens,
 | 
			
		||||
					UseContext: true,
 | 
			
		||||
				}
 | 
			
		||||
				historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
				historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
				res = h.db.Create(&historyReplyMsg)
 | 
			
		||||
				if res.Error != nil {
 | 
			
		||||
					logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
				}
 | 
			
		||||
				// 更新用户信息
 | 
			
		||||
				h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
 | 
			
		||||
					UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 保存当前会话
 | 
			
		||||
			var chatItem model.ChatItem
 | 
			
		||||
			res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				chatItem.ChatId = session.ChatId
 | 
			
		||||
				chatItem.UserId = session.UserId
 | 
			
		||||
				chatItem.RoleId = role.Id
 | 
			
		||||
				chatItem.ModelId = session.Model.Id
 | 
			
		||||
				if utf8.RuneCountInString(prompt) > 30 {
 | 
			
		||||
					chatItem.Title = string([]rune(prompt)[:30]) + "..."
 | 
			
		||||
				} else {
 | 
			
		||||
					chatItem.Title = prompt
 | 
			
		||||
				}
 | 
			
		||||
				h.db.Create(&chatItem)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		body, err := io.ReadAll(response.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("error with reading response: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var res struct {
 | 
			
		||||
			Code    int    `json:"code"`
 | 
			
		||||
			Success bool   `json:"success"`
 | 
			
		||||
			Msg     string `json:"msg"`
 | 
			
		||||
		}
 | 
			
		||||
		err = json.Unmarshal(body, &res)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("error with decode response: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		if !res.Success {
 | 
			
		||||
			utils.ReplyMessage(ws, "请求 ChatGLM 失败:"+res.Msg)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) getBaiduToken(apiKey string) (string, error) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	tokenString, err := h.redis.Get(ctx, apiKey).Result()
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return tokenString, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	expr := time.Hour * 2
 | 
			
		||||
	key := strings.Split(apiKey, ".")
 | 
			
		||||
	if len(key) != 2 {
 | 
			
		||||
		return "", fmt.Errorf("invalid api key: %s", apiKey)
 | 
			
		||||
	}
 | 
			
		||||
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
 | 
			
		||||
		"api_key":   key[0],
 | 
			
		||||
		"timestamp": time.Now().Unix(),
 | 
			
		||||
		"exp":       time.Now().Add(expr).Add(time.Second * 10).Unix(),
 | 
			
		||||
	})
 | 
			
		||||
	token.Header["alg"] = "HS256"
 | 
			
		||||
	token.Header["sign_type"] = "SIGN"
 | 
			
		||||
	delete(token.Header, "typ")
 | 
			
		||||
	// Sign and get the complete encoded token as a string using the secret
 | 
			
		||||
	tokenString, err = token.SignedString([]byte(key[1]))
 | 
			
		||||
	h.redis.Set(ctx, apiKey, tokenString, expr)
 | 
			
		||||
	return tokenString, err
 | 
			
		||||
}
 | 
			
		||||
@@ -196,7 +196,10 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
		req.Temperature = h.App.ChatConfig.ChatGML.Temperature
 | 
			
		||||
		req.MaxTokens = h.App.ChatConfig.ChatGML.MaxTokens
 | 
			
		||||
		break
 | 
			
		||||
	default:
 | 
			
		||||
	case types.Baidu:
 | 
			
		||||
		req.Temperature = h.App.ChatConfig.OpenAI.Temperature
 | 
			
		||||
		// TODO: 目前只支持 ERNIE-Bot-turbo 模型,如果是 ERNIE-Bot 模型则需要增加函数支持
 | 
			
		||||
	case types.OpenAI:
 | 
			
		||||
		req.Temperature = h.App.ChatConfig.OpenAI.Temperature
 | 
			
		||||
		req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens
 | 
			
		||||
		var functions = make([]types.Function, 0)
 | 
			
		||||
@@ -207,6 +210,10 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
			functions = append(functions, f)
 | 
			
		||||
		}
 | 
			
		||||
		req.Functions = functions
 | 
			
		||||
	default:
 | 
			
		||||
		utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
 | 
			
		||||
		utils.ReplyMessage(ws, "")
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 加载聊天上下文
 | 
			
		||||
 
 | 
			
		||||
@@ -39,12 +39,16 @@
 | 
			
		||||
    <el-dialog
 | 
			
		||||
        v-model="showDialog"
 | 
			
		||||
        :title="title"
 | 
			
		||||
        style="width: 90%; max-width: 600px;"
 | 
			
		||||
    >
 | 
			
		||||
      <el-alert title="注意:如果是百度文心一言平台,需要用竖线(|)将 API Key 和 Secret Key 串接起来填入!"
 | 
			
		||||
                type="warning"
 | 
			
		||||
                :closable="false"
 | 
			
		||||
                show-icon
 | 
			
		||||
                style="margin-bottom: 10px; font-size:14px;"/>
 | 
			
		||||
      <el-form :model="item" label-width="120px" ref="formRef" :rules="rules">
 | 
			
		||||
        <el-form-item label="所属平台:" prop="platform">
 | 
			
		||||
          <el-select v-model="item.platform" placeholder="请选择平台">
 | 
			
		||||
            <el-option v-for="item in platforms" :value="item" :key="item">{{ item }}</el-option>
 | 
			
		||||
            <el-option v-for="item in platforms" :value="item.value" :key="item.value">{{ item.name }}</el-option>
 | 
			
		||||
          </el-select>
 | 
			
		||||
        </el-form-item>
 | 
			
		||||
 | 
			
		||||
@@ -82,7 +86,13 @@ const rules = reactive({
 | 
			
		||||
const loading = ref(true)
 | 
			
		||||
const formRef = ref(null)
 | 
			
		||||
const title = ref("")
 | 
			
		||||
const platforms = ref(["Azure", "OpenAI", "ChatGLM"])
 | 
			
		||||
const platforms = ref([
 | 
			
		||||
  {name: "【清华智普】ChatGLM", value: "ChatGLM"},
 | 
			
		||||
  {name: "【百度】文心一言", value: "Baidu"},
 | 
			
		||||
  {name: "【微软】Azure", value: "Azure"},
 | 
			
		||||
  {name: "【OpenAI】ChatGPT", value: "OpenAI"},
 | 
			
		||||
 | 
			
		||||
])
 | 
			
		||||
 | 
			
		||||
// 获取数据
 | 
			
		||||
httpGet('/api/admin/apikey/list').then((res) => {
 | 
			
		||||
 
 | 
			
		||||
@@ -90,13 +90,25 @@
 | 
			
		||||
          <el-input v-model="chat['chat_gml']['api_url']" placeholder="支持变量,{model} => 模型名称"/>
 | 
			
		||||
        </el-form-item>
 | 
			
		||||
        <el-form-item label="模型创意度">
 | 
			
		||||
          <el-slider v-model="chat['chat_gml']['temperature']" :max="2" :step="0.1"/>
 | 
			
		||||
          <el-slider v-model="chat['chat_gml']['temperature']" :max="1" :step="0.01"/>
 | 
			
		||||
          <div class="tip">值越大 AI 回答越发散,值越小回答越保守,建议保持默认值</div>
 | 
			
		||||
        </el-form-item>
 | 
			
		||||
        <el-form-item label="最大响应长度">
 | 
			
		||||
          <el-input v-model.number="chat['chat_gml']['max_tokens']" placeholder="回复的最大字数,最大4096"/>
 | 
			
		||||
        </el-form-item>
 | 
			
		||||
 | 
			
		||||
        <el-divider content-position="center">文心一言</el-divider>
 | 
			
		||||
        <el-form-item label="API 地址" prop="baidu.api_url">
 | 
			
		||||
          <el-input v-model="chat['baidu']['api_url']" placeholder="支持变量,{model} => 模型名称"/>
 | 
			
		||||
        </el-form-item>
 | 
			
		||||
        <el-form-item label="模型创意度">
 | 
			
		||||
          <el-slider v-model="chat['baidu']['temperature']" :max="1" :step="0.01"/>
 | 
			
		||||
          <div class="tip">值越大 AI 回答越发散,值越小回答越保守,建议保持默认值</div>
 | 
			
		||||
        </el-form-item>
 | 
			
		||||
        <el-form-item label="最大响应长度">
 | 
			
		||||
          <el-input v-model.number="chat['baidu']['max_tokens']" placeholder="回复的最大字数,最大4096"/>
 | 
			
		||||
        </el-form-item>
 | 
			
		||||
 | 
			
		||||
        <el-form-item style="text-align: right">
 | 
			
		||||
          <el-button type="primary" @click="save('chat')">保存</el-button>
 | 
			
		||||
        </el-form-item>
 | 
			
		||||
@@ -116,7 +128,8 @@ const system = ref({models: []})
 | 
			
		||||
const chat = ref({
 | 
			
		||||
  open_ai: {api_url: "", temperature: 1, max_tokens: 1024},
 | 
			
		||||
  azure: {api_url: "", temperature: 1, max_tokens: 1024},
 | 
			
		||||
  chat_gml: {api_url: "", temperature: 1, max_tokens: 1024},
 | 
			
		||||
  chat_gml: {api_url: "", temperature: 0.95, max_tokens: 1024},
 | 
			
		||||
  baidu: {api_url: "", temperature: 0.95, max_tokens: 1024},
 | 
			
		||||
  context_deep: 0,
 | 
			
		||||
  enable_context: true,
 | 
			
		||||
  enable_history: true,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user