finish baidu ai model api implementation

This commit is contained in:
RockYang 2023-10-11 14:21:16 +08:00
parent da38684cdd
commit c37902d3a4
9 changed files with 171 additions and 62 deletions

View File

@ -67,4 +67,8 @@ var ModelToTokens = map[string]int{
"gpt-3.5-turbo-16k": 16384, "gpt-3.5-turbo-16k": 16384,
"gpt-4": 8192, "gpt-4": 8192,
"gpt-4-32k": 32768, "gpt-4-32k": 32768,
"chatglm_pro": 32768,
"chatglm_std": 16384,
"chatglm_lite": 4096,
"ernie_bot_turbo": 8192, // 文心一言
} }

View File

@ -79,6 +79,7 @@ type ChatConfig struct {
OpenAI ModelAPIConfig `json:"open_ai"` OpenAI ModelAPIConfig `json:"open_ai"`
Azure ModelAPIConfig `json:"azure"` Azure ModelAPIConfig `json:"azure"`
ChatGML ModelAPIConfig `json:"chat_gml"` ChatGML ModelAPIConfig `json:"chat_gml"`
Baidu ModelAPIConfig `json:"baidu"`
EnableContext bool `json:"enable_context"` // 是否开启聊天上下文 EnableContext bool `json:"enable_context"` // 是否开启聊天上下文
EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录 EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录

View File

@ -36,11 +36,11 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
apiKey := model.ApiKey{} apiKey := model.ApiKey{}
if data.Id > 0 { if data.Id > 0 {
h.db.Find(&apiKey) h.db.Find(&apiKey, data.Id)
} }
apiKey.Platform = data.Platform apiKey.Platform = data.Platform
apiKey.Value = data.Value apiKey.Value = data.Value
res := h.db.Save(&apiKey) res := h.db.Debug().Save(&apiKey)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "更新数据库失败!") resp.ERROR(c, "更新数据库失败!")
return return

View File

@ -9,14 +9,30 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/golang-jwt/jwt/v5"
"gorm.io/gorm" "gorm.io/gorm"
"io" "io"
"net/http"
"strings" "strings"
"time" "time"
"unicode/utf8" "unicode/utf8"
) )
type baiduResp struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
IsTruncated bool `json:"is_truncated"`
Result string `json:"result"`
NeedClearHistory bool `json:"need_clear_history"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
// 将消息发送给百度文心一言大模型 API 并获取结果,通过 WebSocket 推送到客户端 // 将消息发送给百度文心一言大模型 API 并获取结果,通过 WebSocket 推送到客户端
func (h *ChatHandler) sendBaiduMessage( func (h *ChatHandler) sendBaiduMessage(
chatCtx []interface{}, chatCtx []interface{},
@ -56,38 +72,42 @@ func (h *ChatHandler) sendBaiduMessage(
// 循环读取 Chunk 消息 // 循环读取 Chunk 消息
var message = types.Message{} var message = types.Message{}
var contents = make([]string, 0) var contents = make([]string, 0)
var event, content string var content string
scanner := bufio.NewScanner(response.Body) scanner := bufio.NewScanner(response.Body)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if len(line) < 5 || strings.HasPrefix(line, "id:") { if len(line) < 5 || strings.HasPrefix(line, "id:") {
continue continue
} }
if strings.HasPrefix(line, "event:") {
event = line[6:]
continue
}
if strings.HasPrefix(line, "data:") { if strings.HasPrefix(line, "data:") {
content = line[5:] content = line[5:]
} }
switch event {
case "add": var resp baiduResp
if len(contents) == 0 { err := utils.JsonDecode(content, &resp)
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) if err != nil {
} logger.Error("error with parse data line: ", err)
utils.ReplyChunkMessage(ws, types.WsMessage{ utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
Type: types.WsMiddle,
Content: utils.InterfaceToString(content),
})
contents = append(contents, content)
case "finish":
break break
case "error": }
utils.ReplyMessage(ws, fmt.Sprintf("**调用 ChatGLM API 出错:%s**", content))
if len(contents) == 0 {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(resp.Result),
})
contents = append(contents, resp.Result)
if resp.IsTruncated {
utils.ReplyMessage(ws, "AI 输出异常中断")
break
}
if resp.IsEnd {
break break
case "interrupted":
utils.ReplyMessage(ws, "**调用 ChatGLM API 出错,当前输出被中断!**")
} }
} // end for } // end for
@ -192,17 +212,14 @@ func (h *ChatHandler) sendBaiduMessage(
} }
var res struct { var res struct {
Code int `json:"code"` Code int `json:"error_code"`
Success bool `json:"success"` Msg string `json:"error_msg"`
Msg string `json:"msg"`
} }
err = json.Unmarshal(body, &res) err = json.Unmarshal(body, &res)
if err != nil { if err != nil {
return fmt.Errorf("error with decode response: %v", err) return fmt.Errorf("error with decode response: %v", err)
} }
if !res.Success { utils.ReplyMessage(ws, "请求百度文心大模型 API 失败:"+res.Msg)
utils.ReplyMessage(ws, "请求 ChatGLM 失败:"+res.Msg)
}
} }
return nil return nil
@ -215,21 +232,41 @@ func (h *ChatHandler) getBaiduToken(apiKey string) (string, error) {
return tokenString, nil return tokenString, nil
} }
expr := time.Hour * 2 expr := time.Hour * 24 * 20 // access_token 有效期
key := strings.Split(apiKey, ".") key := strings.Split(apiKey, "|")
if len(key) != 2 { if len(key) != 2 {
return "", fmt.Errorf("invalid api key: %s", apiKey) return "", fmt.Errorf("invalid api key: %s", apiKey)
} }
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", key[0], key[1])
"api_key": key[0], client := &http.Client{}
"timestamp": time.Now().Unix(), req, err := http.NewRequest("POST", url, nil)
"exp": time.Now().Add(expr).Add(time.Second * 10).Unix(), if err != nil {
}) return "", err
token.Header["alg"] = "HS256" }
token.Header["sign_type"] = "SIGN" req.Header.Add("Content-Type", "application/json")
delete(token.Header, "typ") req.Header.Add("Accept", "application/json")
// Sign and get the complete encoded token as a string using the secret
tokenString, err = token.SignedString([]byte(key[1])) res, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("error with send request: %w", err)
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return "", fmt.Errorf("error with read response: %w", err)
}
var r map[string]interface{}
err = json.Unmarshal(body, &r)
if err != nil {
return "", fmt.Errorf("error with parse response: %w", err)
}
if r["error"] != nil {
return "", fmt.Errorf("error with api response: %s", r["error_description"])
}
tokenString = fmt.Sprintf("%s", r["access_token"])
h.redis.Set(ctx, apiKey, tokenString, expr) h.redis.Set(ctx, apiKey, tokenString, expr)
return tokenString, err return tokenString, nil
} }

View File

@ -202,6 +202,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
case types.OpenAI: case types.OpenAI:
req.Temperature = h.App.ChatConfig.OpenAI.Temperature req.Temperature = h.App.ChatConfig.OpenAI.Temperature
req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens
// OpenAI 支持函数功能
var functions = make([]types.Function, 0) var functions = make([]types.Function, 0)
for _, f := range types.InnerFunctions { for _, f := range types.InnerFunctions {
if !h.App.SysConfig.EnabledDraw && f.Name == types.FuncMidJourney { if !h.App.SysConfig.EnabledDraw && f.Name == types.FuncMidJourney {
@ -281,6 +282,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.ChatGLM: case types.ChatGLM:
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.Baidu:
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
} }
utils.ReplyChunkMessage(ws, types.WsMessage{ utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle, Type: types.WsMiddle,
@ -364,12 +368,36 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
break break
case types.ChatGLM: case types.ChatGLM:
apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1) apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1)
req.Prompt = req.Messages req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
req.Messages = nil req.Messages = nil
break break
case types.Baidu:
apiURL = h.App.ChatConfig.Baidu.ApiURL
break
default: default:
apiURL = h.App.ChatConfig.OpenAI.ApiURL apiURL = h.App.ChatConfig.OpenAI.ApiURL
} }
if *apiKey == "" {
var key model.ApiKey
res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
// 更新 API KEY 的最后使用时间
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
*apiKey = key.Value
}
// 百度文心,需要串接 access_token
if platform == types.Baidu {
token, err := h.getBaiduToken(*apiKey)
if err != nil {
return nil, err
}
logger.Info("百度文心 Access_Token", token)
apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token)
}
// 创建 HttpClient 请求对象 // 创建 HttpClient 请求对象
var client *http.Client var client *http.Client
requestBody, err := json.Marshal(req) requestBody, err := json.Marshal(req)
@ -394,17 +422,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
} else { } else {
client = http.DefaultClient client = http.DefaultClient
} }
if *apiKey == "" {
var key model.ApiKey
res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
// 更新 API KEY 的最后使用时间
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
*apiKey = key.Value
}
logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, *apiKey, proxyURL, req.Model) logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, *apiKey, proxyURL, req.Model)
switch platform { switch platform {
case types.Azure: case types.Azure:
@ -418,7 +435,9 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
logger.Info(token) logger.Info(token)
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
break break
default: case types.Baidu:
request.RequestURI = ""
case types.OpenAI:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey)) request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
} }
return client.Do(request) return client.Do(request)

View File

@ -1,14 +1,55 @@
package main package main
import ( import (
"encoding/json"
"fmt" "fmt"
"os" "io"
"log"
"net/http"
) )
func main() { func main() {
bytes, err := os.ReadFile("res/text2img.json") apiKey := "qjvqGdqpTY7qQaGBMenM7XgQ"
apiSecret := "3G1RzBGXywZv4VbYRTyAfNns1vIOAG8t"
token, err := getBaiduToken(apiKey, apiSecret)
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
fmt.Println(string(bytes))
fmt.Println(token)
}
func getBaiduToken(apiKey string, apiSecret string) (string, error) {
url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", apiKey, apiSecret)
client := &http.Client{}
req, err := http.NewRequest("POST", url, nil)
if err != nil {
return "", err
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
res, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("error with send request: %w", err)
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return "", fmt.Errorf("error with read response: %w", err)
}
var r map[string]interface{}
err = json.Unmarshal(body, &r)
if err != nil {
return "", fmt.Errorf("error with parse response: %w", err)
}
if r["error"] != nil {
return "", fmt.Errorf("error with api response: %s", r["error_description"])
}
return fmt.Sprintf("%s", r["access_token"]), nil
} }

View File

@ -91,7 +91,6 @@ const platforms = ref([
{name: "【百度】文心一言", value: "Baidu"}, {name: "【百度】文心一言", value: "Baidu"},
{name: "【微软】Azure", value: "Azure"}, {name: "【微软】Azure", value: "Azure"},
{name: "【OpenAI】ChatGPT", value: "OpenAI"}, {name: "【OpenAI】ChatGPT", value: "OpenAI"},
]) ])
// //

View File

@ -47,7 +47,7 @@
<el-form :model="item" label-width="120px" ref="formRef" :rules="rules"> <el-form :model="item" label-width="120px" ref="formRef" :rules="rules">
<el-form-item label="所属平台:" prop="platform"> <el-form-item label="所属平台:" prop="platform">
<el-select v-model="item.platform" placeholder="请选择平台"> <el-select v-model="item.platform" placeholder="请选择平台">
<el-option v-for="item in platforms" :value="item" :key="item">{{ item }}</el-option> <el-option v-for="item in platforms" :value="item.value" :key="item.value">{{ item.name }}</el-option>
</el-select> </el-select>
</el-form-item> </el-form-item>
@ -94,7 +94,12 @@ const rules = reactive({
}) })
const loading = ref(true) const loading = ref(true)
const formRef = ref(null) const formRef = ref(null)
const platforms = ref(["Azure", "OpenAI", "ChatGLM"]) const platforms = ref([
{name: "【清华智普】ChatGLM", value: "ChatGLM"},
{name: "【百度】文心一言", value: "Baidu"},
{name: "【微软】Azure", value: "Azure"},
{name: "【OpenAI】ChatGPT", value: "OpenAI"},
])
// //
httpGet('/api/admin/model/list').then((res) => { httpGet('/api/admin/model/list').then((res) => {

View File

@ -158,6 +158,9 @@ onMounted(() => {
if (res.data.chat_gml) { if (res.data.chat_gml) {
chat.value.chat_gml = res.data.chat_gml chat.value.chat_gml = res.data.chat_gml
} }
if (res.data.baidu) {
chat.value.baidu = res.data.baidu
}
chat.value.context_deep = res.data.context_deep chat.value.context_deep = res.data.context_deep
chat.value.enable_context = res.data.enable_context chat.value.enable_context = res.data.enable_context
chat.value.enable_history = res.data.enable_history chat.value.enable_history = res.data.enable_history