finish baidu ai model api implementation

This commit is contained in:
RockYang 2023-10-11 14:21:16 +08:00
parent da38684cdd
commit c37902d3a4
9 changed files with 171 additions and 62 deletions

View File

@ -67,4 +67,8 @@ var ModelToTokens = map[string]int{
"gpt-3.5-turbo-16k": 16384,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"chatglm_pro": 32768,
"chatglm_std": 16384,
"chatglm_lite": 4096,
"ernie_bot_turbo": 8192, // 文心一言
}

View File

@ -79,6 +79,7 @@ type ChatConfig struct {
OpenAI ModelAPIConfig `json:"open_ai"`
Azure ModelAPIConfig `json:"azure"`
ChatGML ModelAPIConfig `json:"chat_gml"`
Baidu ModelAPIConfig `json:"baidu"`
EnableContext bool `json:"enable_context"` // 是否开启聊天上下文
EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录

View File

@ -36,11 +36,11 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
apiKey := model.ApiKey{}
if data.Id > 0 {
h.db.Find(&apiKey)
h.db.Find(&apiKey, data.Id)
}
apiKey.Platform = data.Platform
apiKey.Value = data.Value
res := h.db.Save(&apiKey)
res := h.db.Debug().Save(&apiKey)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败!")
return

View File

@ -9,14 +9,30 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/golang-jwt/jwt/v5"
"gorm.io/gorm"
"io"
"net/http"
"strings"
"time"
"unicode/utf8"
)
type baiduResp struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
IsTruncated bool `json:"is_truncated"`
Result string `json:"result"`
NeedClearHistory bool `json:"need_clear_history"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
// 将消息发送给百度文心一言大模型 API 并获取结果,通过 WebSocket 推送到客户端
func (h *ChatHandler) sendBaiduMessage(
chatCtx []interface{},
@ -56,38 +72,42 @@ func (h *ChatHandler) sendBaiduMessage(
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var event, content string
var content string
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if len(line) < 5 || strings.HasPrefix(line, "id:") {
continue
}
if strings.HasPrefix(line, "event:") {
event = line[6:]
continue
}
if strings.HasPrefix(line, "data:") {
content = line[5:]
}
switch event {
case "add":
var resp baiduResp
err := utils.JsonDecode(content, &resp)
if err != nil {
logger.Error("error with parse data line: ", err)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
break
}
if len(contents) == 0 {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(content),
Content: utils.InterfaceToString(resp.Result),
})
contents = append(contents, content)
case "finish":
contents = append(contents, resp.Result)
if resp.IsTruncated {
utils.ReplyMessage(ws, "AI 输出异常中断")
break
case "error":
utils.ReplyMessage(ws, fmt.Sprintf("**调用 ChatGLM API 出错:%s**", content))
}
if resp.IsEnd {
break
case "interrupted":
utils.ReplyMessage(ws, "**调用 ChatGLM API 出错,当前输出被中断!**")
}
} // end for
@ -192,17 +212,14 @@ func (h *ChatHandler) sendBaiduMessage(
}
var res struct {
Code int `json:"code"`
Success bool `json:"success"`
Msg string `json:"msg"`
Code int `json:"error_code"`
Msg string `json:"error_msg"`
}
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if !res.Success {
utils.ReplyMessage(ws, "请求 ChatGLM 失败:"+res.Msg)
}
utils.ReplyMessage(ws, "请求百度文心大模型 API 失败:"+res.Msg)
}
return nil
@ -215,21 +232,41 @@ func (h *ChatHandler) getBaiduToken(apiKey string) (string, error) {
return tokenString, nil
}
expr := time.Hour * 2
key := strings.Split(apiKey, ".")
expr := time.Hour * 24 * 20 // access_token 有效期
key := strings.Split(apiKey, "|")
if len(key) != 2 {
return "", fmt.Errorf("invalid api key: %s", apiKey)
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"api_key": key[0],
"timestamp": time.Now().Unix(),
"exp": time.Now().Add(expr).Add(time.Second * 10).Unix(),
})
token.Header["alg"] = "HS256"
token.Header["sign_type"] = "SIGN"
delete(token.Header, "typ")
// Sign and get the complete encoded token as a string using the secret
tokenString, err = token.SignedString([]byte(key[1]))
h.redis.Set(ctx, apiKey, tokenString, expr)
return tokenString, err
url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", key[0], key[1])
client := &http.Client{}
req, err := http.NewRequest("POST", url, nil)
if err != nil {
return "", err
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
res, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("error with send request: %w", err)
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return "", fmt.Errorf("error with read response: %w", err)
}
var r map[string]interface{}
err = json.Unmarshal(body, &r)
if err != nil {
return "", fmt.Errorf("error with parse response: %w", err)
}
if r["error"] != nil {
return "", fmt.Errorf("error with api response: %s", r["error_description"])
}
tokenString = fmt.Sprintf("%s", r["access_token"])
h.redis.Set(ctx, apiKey, tokenString, expr)
return tokenString, nil
}

View File

@ -202,6 +202,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
case types.OpenAI:
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens
// OpenAI 支持函数功能
var functions = make([]types.Function, 0)
for _, f := range types.InnerFunctions {
if !h.App.SysConfig.EnabledDraw && f.Name == types.FuncMidJourney {
@ -281,6 +282,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.ChatGLM:
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.Baidu:
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
@ -364,12 +368,36 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
break
case types.ChatGLM:
apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1)
req.Prompt = req.Messages
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
req.Messages = nil
break
case types.Baidu:
apiURL = h.App.ChatConfig.Baidu.ApiURL
break
default:
apiURL = h.App.ChatConfig.OpenAI.ApiURL
}
if *apiKey == "" {
var key model.ApiKey
res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
// 更新 API KEY 的最后使用时间
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
*apiKey = key.Value
}
// 百度文心,需要串接 access_token
if platform == types.Baidu {
token, err := h.getBaiduToken(*apiKey)
if err != nil {
return nil, err
}
logger.Info("百度文心 Access_Token", token)
apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token)
}
// 创建 HttpClient 请求对象
var client *http.Client
requestBody, err := json.Marshal(req)
@ -394,17 +422,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
} else {
client = http.DefaultClient
}
if *apiKey == "" {
var key model.ApiKey
res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
// 更新 API KEY 的最后使用时间
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
*apiKey = key.Value
}
logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, *apiKey, proxyURL, req.Model)
switch platform {
case types.Azure:
@ -418,7 +435,9 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
logger.Info(token)
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
break
default:
case types.Baidu:
request.RequestURI = ""
case types.OpenAI:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
}
return client.Do(request)

View File

@ -1,14 +1,55 @@
package main
import (
"encoding/json"
"fmt"
"os"
"io"
"log"
"net/http"
)
func main() {
bytes, err := os.ReadFile("res/text2img.json")
apiKey := "qjvqGdqpTY7qQaGBMenM7XgQ"
apiSecret := "3G1RzBGXywZv4VbYRTyAfNns1vIOAG8t"
token, err := getBaiduToken(apiKey, apiSecret)
if err != nil {
panic(err)
log.Fatal(err)
}
fmt.Println(string(bytes))
fmt.Println(token)
}
func getBaiduToken(apiKey string, apiSecret string) (string, error) {
url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", apiKey, apiSecret)
client := &http.Client{}
req, err := http.NewRequest("POST", url, nil)
if err != nil {
return "", err
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
res, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("error with send request: %w", err)
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return "", fmt.Errorf("error with read response: %w", err)
}
var r map[string]interface{}
err = json.Unmarshal(body, &r)
if err != nil {
return "", fmt.Errorf("error with parse response: %w", err)
}
if r["error"] != nil {
return "", fmt.Errorf("error with api response: %s", r["error_description"])
}
return fmt.Sprintf("%s", r["access_token"]), nil
}

View File

@ -91,7 +91,6 @@ const platforms = ref([
{name: "【百度】文心一言", value: "Baidu"},
{name: "【微软】Azure", value: "Azure"},
{name: "【OpenAI】ChatGPT", value: "OpenAI"},
])
//

View File

@ -47,7 +47,7 @@
<el-form :model="item" label-width="120px" ref="formRef" :rules="rules">
<el-form-item label="所属平台:" prop="platform">
<el-select v-model="item.platform" placeholder="请选择平台">
<el-option v-for="item in platforms" :value="item" :key="item">{{ item }}</el-option>
<el-option v-for="item in platforms" :value="item.value" :key="item.value">{{ item.name }}</el-option>
</el-select>
</el-form-item>
@ -94,7 +94,12 @@ const rules = reactive({
})
const loading = ref(true)
const formRef = ref(null)
const platforms = ref(["Azure", "OpenAI", "ChatGLM"])
const platforms = ref([
{name: "【清华智普】ChatGLM", value: "ChatGLM"},
{name: "【百度】文心一言", value: "Baidu"},
{name: "【微软】Azure", value: "Azure"},
{name: "【OpenAI】ChatGPT", value: "OpenAI"},
])
//
httpGet('/api/admin/model/list').then((res) => {

View File

@ -158,6 +158,9 @@ onMounted(() => {
if (res.data.chat_gml) {
chat.value.chat_gml = res.data.chat_gml
}
if (res.data.baidu) {
chat.value.baidu = res.data.baidu
}
chat.value.context_deep = res.data.context_deep
chat.value.enable_context = res.data.enable_context
chat.value.enable_history = res.data.enable_history