mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-18 01:06:39 +08:00
finish baidu ai model api implementation
This commit is contained in:
parent
4fc01f3f7b
commit
ba206bb387
@ -67,4 +67,8 @@ var ModelToTokens = map[string]int{
|
|||||||
"gpt-3.5-turbo-16k": 16384,
|
"gpt-3.5-turbo-16k": 16384,
|
||||||
"gpt-4": 8192,
|
"gpt-4": 8192,
|
||||||
"gpt-4-32k": 32768,
|
"gpt-4-32k": 32768,
|
||||||
|
"chatglm_pro": 32768,
|
||||||
|
"chatglm_std": 16384,
|
||||||
|
"chatglm_lite": 4096,
|
||||||
|
"ernie_bot_turbo": 8192, // 文心一言
|
||||||
}
|
}
|
||||||
|
@ -79,6 +79,7 @@ type ChatConfig struct {
|
|||||||
OpenAI ModelAPIConfig `json:"open_ai"`
|
OpenAI ModelAPIConfig `json:"open_ai"`
|
||||||
Azure ModelAPIConfig `json:"azure"`
|
Azure ModelAPIConfig `json:"azure"`
|
||||||
ChatGML ModelAPIConfig `json:"chat_gml"`
|
ChatGML ModelAPIConfig `json:"chat_gml"`
|
||||||
|
Baidu ModelAPIConfig `json:"baidu"`
|
||||||
|
|
||||||
EnableContext bool `json:"enable_context"` // 是否开启聊天上下文
|
EnableContext bool `json:"enable_context"` // 是否开启聊天上下文
|
||||||
EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录
|
EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录
|
||||||
|
@ -36,11 +36,11 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
|
|||||||
|
|
||||||
apiKey := model.ApiKey{}
|
apiKey := model.ApiKey{}
|
||||||
if data.Id > 0 {
|
if data.Id > 0 {
|
||||||
h.db.Find(&apiKey)
|
h.db.Find(&apiKey, data.Id)
|
||||||
}
|
}
|
||||||
apiKey.Platform = data.Platform
|
apiKey.Platform = data.Platform
|
||||||
apiKey.Value = data.Value
|
apiKey.Value = data.Value
|
||||||
res := h.db.Save(&apiKey)
|
res := h.db.Debug().Save(&apiKey)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "更新数据库失败!")
|
resp.ERROR(c, "更新数据库失败!")
|
||||||
return
|
return
|
||||||
|
@ -9,14 +9,30 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"io"
|
"io"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type baiduResp struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int `json:"created"`
|
||||||
|
SentenceId int `json:"sentence_id"`
|
||||||
|
IsEnd bool `json:"is_end"`
|
||||||
|
IsTruncated bool `json:"is_truncated"`
|
||||||
|
Result string `json:"result"`
|
||||||
|
NeedClearHistory bool `json:"need_clear_history"`
|
||||||
|
Usage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
} `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
// 将消息发送给百度文心一言大模型 API 并获取结果,通过 WebSocket 推送到客户端
|
// 将消息发送给百度文心一言大模型 API 并获取结果,通过 WebSocket 推送到客户端
|
||||||
func (h *ChatHandler) sendBaiduMessage(
|
func (h *ChatHandler) sendBaiduMessage(
|
||||||
chatCtx []interface{},
|
chatCtx []interface{},
|
||||||
@ -56,38 +72,42 @@ func (h *ChatHandler) sendBaiduMessage(
|
|||||||
// 循环读取 Chunk 消息
|
// 循环读取 Chunk 消息
|
||||||
var message = types.Message{}
|
var message = types.Message{}
|
||||||
var contents = make([]string, 0)
|
var contents = make([]string, 0)
|
||||||
var event, content string
|
var content string
|
||||||
scanner := bufio.NewScanner(response.Body)
|
scanner := bufio.NewScanner(response.Body)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
if len(line) < 5 || strings.HasPrefix(line, "id:") {
|
if len(line) < 5 || strings.HasPrefix(line, "id:") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(line, "event:") {
|
|
||||||
event = line[6:]
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(line, "data:") {
|
if strings.HasPrefix(line, "data:") {
|
||||||
content = line[5:]
|
content = line[5:]
|
||||||
}
|
}
|
||||||
switch event {
|
|
||||||
case "add":
|
var resp baiduResp
|
||||||
|
err := utils.JsonDecode(content, &resp)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with parse data line: ", err)
|
||||||
|
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
if len(contents) == 0 {
|
if len(contents) == 0 {
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||||
}
|
}
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||||
Type: types.WsMiddle,
|
Type: types.WsMiddle,
|
||||||
Content: utils.InterfaceToString(content),
|
Content: utils.InterfaceToString(resp.Result),
|
||||||
})
|
})
|
||||||
contents = append(contents, content)
|
contents = append(contents, resp.Result)
|
||||||
case "finish":
|
|
||||||
|
if resp.IsTruncated {
|
||||||
|
utils.ReplyMessage(ws, "AI 输出异常中断")
|
||||||
break
|
break
|
||||||
case "error":
|
}
|
||||||
utils.ReplyMessage(ws, fmt.Sprintf("**调用 ChatGLM API 出错:%s**", content))
|
|
||||||
|
if resp.IsEnd {
|
||||||
break
|
break
|
||||||
case "interrupted":
|
|
||||||
utils.ReplyMessage(ws, "**调用 ChatGLM API 出错,当前输出被中断!**")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end for
|
} // end for
|
||||||
@ -192,17 +212,14 @@ func (h *ChatHandler) sendBaiduMessage(
|
|||||||
}
|
}
|
||||||
|
|
||||||
var res struct {
|
var res struct {
|
||||||
Code int `json:"code"`
|
Code int `json:"error_code"`
|
||||||
Success bool `json:"success"`
|
Msg string `json:"error_msg"`
|
||||||
Msg string `json:"msg"`
|
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(body, &res)
|
err = json.Unmarshal(body, &res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error with decode response: %v", err)
|
return fmt.Errorf("error with decode response: %v", err)
|
||||||
}
|
}
|
||||||
if !res.Success {
|
utils.ReplyMessage(ws, "请求百度文心大模型 API 失败:"+res.Msg)
|
||||||
utils.ReplyMessage(ws, "请求 ChatGLM 失败:"+res.Msg)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -215,21 +232,41 @@ func (h *ChatHandler) getBaiduToken(apiKey string) (string, error) {
|
|||||||
return tokenString, nil
|
return tokenString, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
expr := time.Hour * 2
|
expr := time.Hour * 24 * 20 // access_token 有效期
|
||||||
key := strings.Split(apiKey, ".")
|
key := strings.Split(apiKey, "|")
|
||||||
if len(key) != 2 {
|
if len(key) != 2 {
|
||||||
return "", fmt.Errorf("invalid api key: %s", apiKey)
|
return "", fmt.Errorf("invalid api key: %s", apiKey)
|
||||||
}
|
}
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", key[0], key[1])
|
||||||
"api_key": key[0],
|
client := &http.Client{}
|
||||||
"timestamp": time.Now().Unix(),
|
req, err := http.NewRequest("POST", url, nil)
|
||||||
"exp": time.Now().Add(expr).Add(time.Second * 10).Unix(),
|
if err != nil {
|
||||||
})
|
return "", err
|
||||||
token.Header["alg"] = "HS256"
|
}
|
||||||
token.Header["sign_type"] = "SIGN"
|
req.Header.Add("Content-Type", "application/json")
|
||||||
delete(token.Header, "typ")
|
req.Header.Add("Accept", "application/json")
|
||||||
// Sign and get the complete encoded token as a string using the secret
|
|
||||||
tokenString, err = token.SignedString([]byte(key[1]))
|
res, err := client.Do(req)
|
||||||
h.redis.Set(ctx, apiKey, tokenString, expr)
|
if err != nil {
|
||||||
return tokenString, err
|
return "", fmt.Errorf("error with send request: %w", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error with read response: %w", err)
|
||||||
|
}
|
||||||
|
var r map[string]interface{}
|
||||||
|
err = json.Unmarshal(body, &r)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error with parse response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r["error"] != nil {
|
||||||
|
return "", fmt.Errorf("error with api response: %s", r["error_description"])
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenString = fmt.Sprintf("%s", r["access_token"])
|
||||||
|
h.redis.Set(ctx, apiKey, tokenString, expr)
|
||||||
|
return tokenString, nil
|
||||||
}
|
}
|
||||||
|
@ -202,6 +202,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
case types.OpenAI:
|
case types.OpenAI:
|
||||||
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
|
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
|
||||||
req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens
|
req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens
|
||||||
|
// OpenAI 支持函数功能
|
||||||
var functions = make([]types.Function, 0)
|
var functions = make([]types.Function, 0)
|
||||||
for _, f := range types.InnerFunctions {
|
for _, f := range types.InnerFunctions {
|
||||||
if !h.App.SysConfig.EnabledDraw && f.Name == types.FuncMidJourney {
|
if !h.App.SysConfig.EnabledDraw && f.Name == types.FuncMidJourney {
|
||||||
@ -281,6 +282,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||||
case types.ChatGLM:
|
case types.ChatGLM:
|
||||||
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||||
|
case types.Baidu:
|
||||||
|
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||||
|
|
||||||
}
|
}
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||||
Type: types.WsMiddle,
|
Type: types.WsMiddle,
|
||||||
@ -364,12 +368,36 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
|||||||
break
|
break
|
||||||
case types.ChatGLM:
|
case types.ChatGLM:
|
||||||
apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1)
|
apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1)
|
||||||
req.Prompt = req.Messages
|
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
|
||||||
req.Messages = nil
|
req.Messages = nil
|
||||||
break
|
break
|
||||||
|
case types.Baidu:
|
||||||
|
apiURL = h.App.ChatConfig.Baidu.ApiURL
|
||||||
|
break
|
||||||
default:
|
default:
|
||||||
apiURL = h.App.ChatConfig.OpenAI.ApiURL
|
apiURL = h.App.ChatConfig.OpenAI.ApiURL
|
||||||
}
|
}
|
||||||
|
if *apiKey == "" {
|
||||||
|
var key model.ApiKey
|
||||||
|
res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key)
|
||||||
|
if res.Error != nil {
|
||||||
|
return nil, errors.New("no available key, please import key")
|
||||||
|
}
|
||||||
|
// 更新 API KEY 的最后使用时间
|
||||||
|
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
|
||||||
|
*apiKey = key.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
// 百度文心,需要串接 access_token
|
||||||
|
if platform == types.Baidu {
|
||||||
|
token, err := h.getBaiduToken(*apiKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
logger.Info("百度文心 Access_Token:", token)
|
||||||
|
apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token)
|
||||||
|
}
|
||||||
|
|
||||||
// 创建 HttpClient 请求对象
|
// 创建 HttpClient 请求对象
|
||||||
var client *http.Client
|
var client *http.Client
|
||||||
requestBody, err := json.Marshal(req)
|
requestBody, err := json.Marshal(req)
|
||||||
@ -394,17 +422,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
|||||||
} else {
|
} else {
|
||||||
client = http.DefaultClient
|
client = http.DefaultClient
|
||||||
}
|
}
|
||||||
if *apiKey == "" {
|
|
||||||
var key model.ApiKey
|
|
||||||
res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key)
|
|
||||||
if res.Error != nil {
|
|
||||||
return nil, errors.New("no available key, please import key")
|
|
||||||
}
|
|
||||||
// 更新 API KEY 的最后使用时间
|
|
||||||
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
|
|
||||||
*apiKey = key.Value
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, *apiKey, proxyURL, req.Model)
|
logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, *apiKey, proxyURL, req.Model)
|
||||||
switch platform {
|
switch platform {
|
||||||
case types.Azure:
|
case types.Azure:
|
||||||
@ -418,7 +435,9 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
|||||||
logger.Info(token)
|
logger.Info(token)
|
||||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||||
break
|
break
|
||||||
default:
|
case types.Baidu:
|
||||||
|
request.RequestURI = ""
|
||||||
|
case types.OpenAI:
|
||||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
|
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
|
||||||
}
|
}
|
||||||
return client.Do(request)
|
return client.Do(request)
|
||||||
|
@ -1,14 +1,55 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
bytes, err := os.ReadFile("res/text2img.json")
|
apiKey := "qjvqGdqpTY7qQaGBMenM7XgQ"
|
||||||
|
apiSecret := "3G1RzBGXywZv4VbYRTyAfNns1vIOAG8t"
|
||||||
|
token, err := getBaiduToken(apiKey, apiSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
fmt.Println(string(bytes))
|
|
||||||
|
fmt.Println(token)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBaiduToken(apiKey string, apiSecret string) (string, error) {
|
||||||
|
|
||||||
|
url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", apiKey, apiSecret)
|
||||||
|
client := &http.Client{}
|
||||||
|
req, err := http.NewRequest("POST", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
req.Header.Add("Accept", "application/json")
|
||||||
|
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error with send request: %w", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error with read response: %w", err)
|
||||||
|
}
|
||||||
|
var r map[string]interface{}
|
||||||
|
err = json.Unmarshal(body, &r)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error with parse response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r["error"] != nil {
|
||||||
|
return "", fmt.Errorf("error with api response: %s", r["error_description"])
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s", r["access_token"]), nil
|
||||||
}
|
}
|
||||||
|
@ -91,7 +91,6 @@ const platforms = ref([
|
|||||||
{name: "【百度】文心一言", value: "Baidu"},
|
{name: "【百度】文心一言", value: "Baidu"},
|
||||||
{name: "【微软】Azure", value: "Azure"},
|
{name: "【微软】Azure", value: "Azure"},
|
||||||
{name: "【OpenAI】ChatGPT", value: "OpenAI"},
|
{name: "【OpenAI】ChatGPT", value: "OpenAI"},
|
||||||
|
|
||||||
])
|
])
|
||||||
|
|
||||||
// 获取数据
|
// 获取数据
|
||||||
|
@ -47,7 +47,7 @@
|
|||||||
<el-form :model="item" label-width="120px" ref="formRef" :rules="rules">
|
<el-form :model="item" label-width="120px" ref="formRef" :rules="rules">
|
||||||
<el-form-item label="所属平台:" prop="platform">
|
<el-form-item label="所属平台:" prop="platform">
|
||||||
<el-select v-model="item.platform" placeholder="请选择平台">
|
<el-select v-model="item.platform" placeholder="请选择平台">
|
||||||
<el-option v-for="item in platforms" :value="item" :key="item">{{ item }}</el-option>
|
<el-option v-for="item in platforms" :value="item.value" :key="item.value">{{ item.name }}</el-option>
|
||||||
</el-select>
|
</el-select>
|
||||||
</el-form-item>
|
</el-form-item>
|
||||||
|
|
||||||
@ -94,7 +94,12 @@ const rules = reactive({
|
|||||||
})
|
})
|
||||||
const loading = ref(true)
|
const loading = ref(true)
|
||||||
const formRef = ref(null)
|
const formRef = ref(null)
|
||||||
const platforms = ref(["Azure", "OpenAI", "ChatGLM"])
|
const platforms = ref([
|
||||||
|
{name: "【清华智普】ChatGLM", value: "ChatGLM"},
|
||||||
|
{name: "【百度】文心一言", value: "Baidu"},
|
||||||
|
{name: "【微软】Azure", value: "Azure"},
|
||||||
|
{name: "【OpenAI】ChatGPT", value: "OpenAI"},
|
||||||
|
])
|
||||||
|
|
||||||
// 获取数据
|
// 获取数据
|
||||||
httpGet('/api/admin/model/list').then((res) => {
|
httpGet('/api/admin/model/list').then((res) => {
|
||||||
|
@ -158,6 +158,9 @@ onMounted(() => {
|
|||||||
if (res.data.chat_gml) {
|
if (res.data.chat_gml) {
|
||||||
chat.value.chat_gml = res.data.chat_gml
|
chat.value.chat_gml = res.data.chat_gml
|
||||||
}
|
}
|
||||||
|
if (res.data.baidu) {
|
||||||
|
chat.value.baidu = res.data.baidu
|
||||||
|
}
|
||||||
chat.value.context_deep = res.data.context_deep
|
chat.value.context_deep = res.data.context_deep
|
||||||
chat.value.enable_context = res.data.enable_context
|
chat.value.enable_context = res.data.enable_context
|
||||||
chat.value.enable_history = res.data.enable_history
|
chat.value.enable_history = res.data.enable_history
|
||||||
|
Loading…
Reference in New Issue
Block a user