remove other platform supports, ONLY use chatGPT API

This commit is contained in:
RockYang 2024-07-22 18:36:58 +08:00
parent 09f44e6d9b
commit e17dcf4d5f
14 changed files with 74 additions and 1147 deletions

View File

@ -155,45 +155,6 @@ func (c RedisConfig) Url() string {
return fmt.Sprintf("%s:%d", c.Host, c.Port)
}
type Platform struct {
Name string `json:"name"`
Value string `json:"value"`
ChatURL string `json:"chat_url"`
ImgURL string `json:"img_url"`
}
var OpenAI = Platform{
Name: "OpenAI - GPT",
Value: "OpenAI",
ChatURL: "https://api.chat-plus.net/v1/chat/completions",
ImgURL: "https://api.chat-plus.net/v1/images/generations",
}
var Azure = Platform{
Name: "微软 - Azure",
Value: "Azure",
ChatURL: "https://chat-bot-api.openai.azure.com/openai/deployments/{model}/chat/completions?api-version=2023-05-15",
}
var ChatGLM = Platform{
Name: "智谱 - ChatGLM",
Value: "ChatGLM",
ChatURL: "https://open.bigmodel.cn/api/paas/v3/model-api/{model}/sse-invoke",
}
var Baidu = Platform{
Name: "百度 - 文心大模型",
Value: "Baidu",
ChatURL: "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}",
}
var XunFei = Platform{
Name: "讯飞 - 星火大模型",
Value: "XunFei",
ChatURL: "wss://spark-api.xf-yun.com/{version}/chat",
}
var QWen = Platform{
Name: "阿里 - 通义千问",
Value: "QWen",
ChatURL: "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation",
}
type SystemConfig struct {
Title string `json:"title,omitempty"` // 网站标题
Slogan string `json:"slogan,omitempty"` // 网站 slogan

View File

@ -150,10 +150,9 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
// GetAppConfig 获取内置配置
func (h *ConfigHandler) GetAppConfig(c *gin.Context) {
resp.SUCCESS(c, gin.H{
"mj_plus": h.App.Config.MjPlusConfigs,
"mj_proxy": h.App.Config.MjProxyConfigs,
"sd": h.App.Config.SdConfigs,
"platforms": Platforms,
"mj_plus": h.App.Config.MjPlusConfigs,
"mj_proxy": h.App.Config.MjProxyConfigs,
"sd": h.App.Config.SdConfigs,
})
}

View File

@ -1,12 +0,0 @@
package admin
import "geekai/core/types"
var Platforms = []types.Platform{
types.OpenAI,
types.QWen,
types.XunFei,
types.ChatGLM,
types.Baidu,
types.Azure,
}

View File

@ -1,111 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"io"
"strings"
"time"
)
// 微软 Azure 模型消息发送实现
func (h *ChatHandler) sendAzureMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
continue
}
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错
return errors.New(line)
}
if len(responseBody.Choices) == 0 {
continue
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
continue
} else if responseBody.Choices[0].FinishReason != "" {
break // 输出完成或者输出中断了
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, utils.InterfaceToString(content))
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
}
} // end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
} else {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求大模型 API 失败:%s", body)
}
return nil
}

View File

@ -1,185 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"io"
"net/http"
"strings"
"time"
)
type baiduResp struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
IsTruncated bool `json:"is_truncated"`
Result string `json:"result"`
NeedClearHistory bool `json:"need_clear_history"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
// 百度文心一言消息发送实现
func (h *ChatHandler) sendBaiduMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
logger.Error(err)
if strings.Contains(err.Error(), "context canceled") {
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var content string
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if len(line) < 5 || strings.HasPrefix(line, "id:") {
continue
}
if strings.HasPrefix(line, "data:") {
content = line[5:]
}
// 处理代码换行
if len(content) == 0 {
content = "\n"
}
var resp baiduResp
err := utils.JsonDecode(content, &resp)
if err != nil {
logger.Error("error with parse data line: ", err)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
break
}
if len(contents) == 0 {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(resp.Result),
})
contents = append(contents, resp.Result)
if resp.IsTruncated {
utils.ReplyMessage(ws, "AI 输出异常中断")
break
}
if resp.IsEnd {
break
}
} // end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
} else {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求大模型 API 失败:%s", body)
}
return nil
}
func (h *ChatHandler) getBaiduToken(apiKey string) (string, error) {
ctx := context.Background()
tokenString, err := h.redis.Get(ctx, apiKey).Result()
if err == nil {
return tokenString, nil
}
expr := time.Hour * 24 * 20 // access_token 有效期
key := strings.Split(apiKey, "|")
if len(key) != 2 {
return "", fmt.Errorf("invalid api key: %s", apiKey)
}
url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", key[0], key[1])
client := &http.Client{}
req, err := http.NewRequest("POST", url, nil)
if err != nil {
return "", err
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
res, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("error with send request: %w", err)
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return "", fmt.Errorf("error with read response: %w", err)
}
var r map[string]interface{}
err = json.Unmarshal(body, &r)
if err != nil {
return "", fmt.Errorf("error with parse response: %w", err)
}
if r["error"] != nil {
return "", fmt.Errorf("error with api response: %s", r["error_description"])
}
tokenString = fmt.Sprintf("%s", r["access_token"])
h.redis.Set(ctx, apiKey, tokenString, expr)
return tokenString, nil
}

View File

@ -208,21 +208,12 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
Model: session.Model.Value,
Stream: true,
}
switch session.Model.Platform {
case types.Azure.Value, types.ChatGLM.Value, types.Baidu.Value, types.XunFei.Value:
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
break
case types.OpenAI.Value:
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
// OpenAI 支持函数功能
var items []model.Function
res := h.DB.Where("enabled", true).Find(&items)
if res.Error != nil {
break
}
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
// OpenAI 支持函数功能
var items []model.Function
res = h.DB.Where("enabled", true).Find(&items)
if res.Error == nil {
var tools = make([]types.Tool, 0)
for _, v := range items {
var parameters map[string]interface{}
@ -248,15 +239,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
req.Tools = tools
req.ToolChoice = "auto"
}
case types.QWen.Value:
req.Parameters = map[string]interface{}{
"max_tokens": session.Model.MaxTokens,
"temperature": session.Model.Temperature,
}
break
default:
return fmt.Errorf("不支持的平台:%s", session.Model.Platform)
}
// 加载聊天上下文
@ -344,65 +326,37 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
logger.Debug("最终Prompt", fullPrompt)
if session.Model.Platform == types.QWen.Value {
req.Input = make(map[string]interface{})
reqMgs = append(reqMgs, types.Message{
Role: "user",
Content: fullPrompt,
})
req.Input["messages"] = reqMgs
} else if session.Model.Platform == types.OpenAI.Value || session.Model.Platform == types.Azure.Value { // extract image for gpt-vision model
imgURLs := utils.ExtractImgURLs(prompt)
logger.Debugf("detected IMG: %+v", imgURLs)
var content interface{}
if len(imgURLs) > 0 {
data := make([]interface{}, 0)
for _, v := range imgURLs {
text = strings.Replace(text, v, "", 1)
data = append(data, gin.H{
"type": "image_url",
"image_url": gin.H{
"url": v,
},
})
}
// extract images from prompt
imgURLs := utils.ExtractImgURLs(prompt)
logger.Debugf("detected IMG: %+v", imgURLs)
var content interface{}
if len(imgURLs) > 0 {
data := make([]interface{}, 0)
for _, v := range imgURLs {
text = strings.Replace(text, v, "", 1)
data = append(data, gin.H{
"type": "text",
"text": strings.TrimSpace(text),
"type": "image_url",
"image_url": gin.H{
"url": v,
},
})
content = data
} else {
content = fullPrompt
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": content,
data = append(data, gin.H{
"type": "text",
"text": strings.TrimSpace(text),
})
content = data
} else {
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": fullPrompt,
})
content = fullPrompt
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": content,
})
logger.Debugf("%+v", req.Messages)
switch session.Model.Platform {
case types.Azure.Value:
return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.OpenAI.Value:
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.ChatGLM.Value:
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.Baidu.Value:
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.XunFei.Value:
return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.QWen.Value:
return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
return nil
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
// Tokens 统计 token 数量
@ -485,48 +439,13 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
}
// ONLY allow apiURL in blank list
if session.Model.Platform == types.OpenAI.Value {
err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
if err != nil {
return nil, err
}
err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
if err != nil {
return nil, err
}
var apiURL string
switch session.Model.Platform {
case types.Azure.Value:
md := strings.Replace(req.Model, ".", "", 1)
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
break
case types.ChatGLM.Value:
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
req.Messages = nil
break
case types.Baidu.Value:
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
break
case types.QWen.Value:
apiURL = apiKey.ApiURL
req.Messages = nil
break
default:
apiURL = apiKey.ApiURL
}
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 百度文心,需要串接 access_token
if session.Model.Platform == types.Baidu.Value {
token, err := h.getBaiduToken(apiKey.Value)
if err != nil {
return nil, err
}
logger.Info("百度文心 Access_Token", token)
apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token)
}
logger.Debugf(utils.JsonEncode(req))
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
// 创建 HttpClient 请求对象
var client *http.Client
requestBody, err := json.Marshal(req)
@ -550,28 +469,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
} else {
client = http.DefaultClient
}
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
switch session.Model.Platform {
case types.Azure.Value:
request.Header.Set("api-key", apiKey.Value)
break
case types.ChatGLM.Value:
token, err := h.getChatGLMToken(apiKey.Value)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
break
case types.Baidu.Value:
request.RequestURI = ""
case types.OpenAI.Value:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
break
case types.QWen.Value:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
request.Header.Set("X-DashScope-SSE", "enable")
break
}
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
return client.Do(request)
}

View File

@ -1,142 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"context"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"github.com/golang-jwt/jwt/v5"
"io"
"strings"
"time"
)
// 清华大学 ChatGML 消息发送实现
func (h *ChatHandler) sendChatGLMMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var event, content string
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if len(line) < 5 || strings.HasPrefix(line, "id:") {
continue
}
if strings.HasPrefix(line, "event:") {
event = line[6:]
continue
}
if strings.HasPrefix(line, "data:") {
content = line[5:]
}
// 处理代码换行
if len(content) == 0 {
content = "\n"
}
switch event {
case "add":
if len(contents) == 0 {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(content),
})
contents = append(contents, content)
case "finish":
break
case "error":
utils.ReplyMessage(ws, fmt.Sprintf("**调用 ChatGLM API 出错:%s**", content))
break
case "interrupted":
utils.ReplyMessage(ws, "**调用 ChatGLM API 出错,当前输出被中断!**")
}
} // end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
} else {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求大模型 API 失败:%s", body)
}
return nil
}
func (h *ChatHandler) getChatGLMToken(apiKey string) (string, error) {
ctx := context.Background()
tokenString, err := h.redis.Get(ctx, apiKey).Result()
if err == nil {
return tokenString, nil
}
expr := time.Hour * 2
key := strings.Split(apiKey, ".")
if len(key) != 2 {
return "", fmt.Errorf("invalid api key: %s", apiKey)
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"api_key": key[0],
"timestamp": time.Now().Unix(),
"exp": time.Now().Add(expr).Add(time.Second * 10).Unix(),
})
token.Header["alg"] = "HS256"
token.Header["sign_type"] = "SIGN"
delete(token.Header, "typ")
// Sign and get the complete encoded token as a string using the secret
tokenString, err = token.SignedString([]byte(key[1]))
h.redis.Set(ctx, apiKey, tokenString, expr)
return tokenString, err
}

View File

@ -1,150 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"context"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"github.com/syndtr/goleveldb/leveldb/errors"
"io"
"strings"
"time"
)
type qWenResp struct {
Output struct {
FinishReason string `json:"finish_reason"`
Text string `json:"text"`
} `json:"output,omitempty"`
Usage struct {
TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage,omitempty"`
RequestID string `json:"request_id"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
// 通义千问消息发送实现
func (h *ChatHandler) sendQWenMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
var content, lastText, newText string
var outPutStart = false
for scanner.Scan() {
line := scanner.Text()
if len(line) < 5 || strings.HasPrefix(line, "id:") ||
strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") {
continue
}
if !strings.HasPrefix(line, "data:") {
continue
}
content = line[5:]
var resp qWenResp
if len(contents) == 0 { // 发送消息头
if !outPutStart {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
outPutStart = true
continue
} else {
// 处理代码换行
content = "\n"
}
} else {
err := utils.JsonDecode(content, &resp)
if err != nil {
logger.Error("error with parse data line: ", content)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
break
}
if resp.Message != "" {
utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message))
break
}
}
//通过比较 lastText上一次的文本和 currentText当前的文本
//提取出新添加的文本部分。然后只将这部分新文本发送到客户端。
//每次循环结束后lastText 会更新为当前的完整文本,以便于下一次循环进行比较。
currentText := resp.Output.Text
if currentText != lastText {
// 提取新增文本
newText = strings.Replace(currentText, lastText, "", 1)
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(newText),
})
lastText = currentText // 更新 lastText
}
contents = append(contents, newText)
if resp.Output.FinishReason == "stop" {
break
}
} //end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
} else {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求大模型 API 失败:%s", body)
}
return nil
}

View File

@ -1,255 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"io"
"net/http"
"net/url"
"strings"
"time"
)
type xunFeiResp struct {
Header struct {
Code int `json:"code"`
Message string `json:"message"`
Sid string `json:"sid"`
Status int `json:"status"`
} `json:"header"`
Payload struct {
Choices struct {
Status int `json:"status"`
Seq int `json:"seq"`
Text []struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
} `json:"text"`
} `json:"choices"`
Usage struct {
Text struct {
QuestionTokens int `json:"question_tokens"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"text"`
} `json:"usage"`
} `json:"payload"`
}
var Model2URL = map[string]string{
"general": "v1.1",
"generalv2": "v2.1",
"generalv3": "v3.1",
"generalv3.5": "v3.5",
}
// 科大讯飞消息发送实现
func (h *ChatHandler) sendXunFeiMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
var apiKey model.ApiKey
var res *gorm.DB
// use the bind key
if session.Model.KeyId > 0 {
res = h.DB.Where("id", session.Model.KeyId).Find(&apiKey)
}
// use the last unused key
if apiKey.Id == 0 {
res = h.DB.Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(&apiKey)
}
if res.Error != nil {
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
// 更新 API KEY 的最后使用时间
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
key := strings.Split(apiKey.Value, "|")
if len(key) != 3 {
utils.ReplyMessage(ws, "非法的 API KEY")
return nil
}
apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
//握手并建立websocket 连接
conn, resp, err := d.Dial(wsURL, nil)
if err != nil {
logger.Error(readResp(resp) + err.Error())
utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error())
return nil
} else if resp.StatusCode != 101 {
utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error())
return nil
}
data := buildRequest(key[0], req)
fmt.Printf("%+v", data)
fmt.Println(apiURL)
err = conn.WriteJSON(data)
if err != nil {
utils.ReplyMessage(ws, "发送消息失败:"+err.Error())
return nil
}
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var content string
for {
_, msg, err := conn.ReadMessage()
if err != nil {
logger.Error("error with read message:", err)
utils.ReplyMessage(ws, fmt.Sprintf("**数据读取失败:%s**", err))
break
}
// 解析数据
var result xunFeiResp
err = json.Unmarshal(msg, &result)
if err != nil {
logger.Error("error with parsing JSON:", err)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
return nil
}
if result.Header.Code != 0 {
utils.ReplyMessage(ws, fmt.Sprintf("**请求 API 返回错误:%s**", result.Header.Message))
return nil
}
content = result.Payload.Choices.Text[0].Content
// 处理代码换行
if len(content) == 0 {
content = "\n"
}
contents = append(contents, content)
// 第一个结果
if result.Payload.Choices.Status == 0 {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(content),
})
if result.Payload.Choices.Status == 2 { // 最终结果
_ = conn.Close() // 关闭连接
break
}
select {
case <-ctx.Done():
utils.ReplyMessage(ws, "**用户取消了生成指令!**")
return nil
default:
continue
}
}
// 消息发送成功
if len(contents) > 0 {
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
return nil
}
// 构建 websocket 请求实体
func buildRequest(appid string, req types.ApiRequest) map[string]interface{} {
return map[string]interface{}{
"header": map[string]interface{}{
"app_id": appid,
},
"parameter": map[string]interface{}{
"chat": map[string]interface{}{
"domain": req.Model,
"temperature": req.Temperature,
"top_k": int64(6),
"max_tokens": int64(req.MaxTokens),
"auditing": "default",
},
},
"payload": map[string]interface{}{
"message": map[string]interface{}{
"text": req.Messages,
},
},
}
}
// 创建鉴权 URL
func assembleAuthUrl(hostURL string, apiKey, apiSecret string) (string, error) {
ul, err := url.Parse(hostURL)
if err != nil {
return "", err
}
date := time.Now().UTC().Format(time.RFC1123)
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
//拼接签名字符串
signStr := strings.Join(signString, "\n")
sha := hmacWithSha256(signStr, apiSecret)
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
"hmac-sha256", "host date request-line", sha)
//将请求参数使用base64编码
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
v := url.Values{}
v.Add("host", ul.Host)
v.Add("date", date)
v.Add("authorization", authorization)
//将编码后的字符串url encode后添加到url后面
return hostURL + "?" + v.Encode(), nil
}
// 使用 sha256 签名
func hmacWithSha256(data, key string) string {
mac := hmac.New(sha256.New, []byte(key))
mac.Write([]byte(data))
encodeData := mac.Sum(nil)
return base64.StdEncoding.EncodeToString(encodeData)
}
// 读取响应
func readResp(resp *http.Response) string {
if resp == nil {
return ""
}
b, err := io.ReadAll(resp.Body)
if err != nil {
panic(err)
}
return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b))
}

View File

@ -212,21 +212,21 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
}
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
session := h.DB.Session(&gorm.Session{})
// if the chat model bind a KEY, use it directly
var res *gorm.DB
if chatModel.KeyId > 0 {
res = h.DB.Where("id", chatModel.KeyId).Find(apiKey)
}
// use the last unused key
if apiKey.Id == 0 {
res = h.DB.Where("platform", types.OpenAI.Value).
Where("type", "chat").
Where("enabled", true).Order("last_used_at ASC").First(apiKey)
session = session.Where("id", chatModel.KeyId)
} else { // use the last unused key
session = session.Where("type", "chat").
Where("enabled", true).Order("last_used_at ASC")
}
res := session.First(apiKey)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
apiURL := apiKey.ApiURL
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())

View File

@ -145,7 +145,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
// get image generation API KEY
var apiKey model.ApiKey
tx = s.db.Where("type", "img").
tx = s.db.Where("type", "dalle").
Where("enabled", true).
Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil {
@ -157,6 +157,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
if len(apiKey.ProxyURL) > 5 {
s.httpClient.SetProxyURL(apiKey.ProxyURL).R()
}
apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL)
reqBody := imgReq{
Model: "dall-e-3",
Prompt: prompt,
@ -165,14 +166,13 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
Style: task.Style,
Quality: task.Quality,
}
logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, BODY: %+v", apiKey.Platform, apiKey.ApiURL, apiKey.Value, reqBody)
request := s.httpClient.R().SetHeader("Content-Type", "application/json")
if apiKey.Platform == types.Azure.Value {
request = request.SetHeader("api-key", apiKey.Value)
} else {
request = request.SetHeader("Authorization", "Bearer "+apiKey.Value)
}
r, err := request.SetBody(reqBody).SetErrorResult(&errRes).SetSuccessResult(&res).Post(apiKey.ApiURL)
logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, BODY: %+v", apiKey.Platform, apiURL, apiKey.Value, reqBody)
r, err := s.httpClient.R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
SetErrorResult(&errRes).
SetSuccessResult(&res).
Post(apiURL)
if err != nil {
return "", fmt.Errorf("error with send request: %v", err)
}

View File

@ -54,7 +54,7 @@ type apiErrRes struct {
func OpenAIRequest(db *gorm.DB, prompt string) (string, error) {
var apiKey model.ApiKey
res := db.Where("platform", types.OpenAI.Value).Where("type", "chat").Where("enabled", true).First(&apiKey)
res := db.Where("type", "chat").Where("enabled", true).First(&apiKey)
if res.Error != nil {
return "", fmt.Errorf("error with fetch OpenAI API KEY%v", res.Error)
}

View File

@ -2,19 +2,11 @@
<div class="container list" v-loading="loading">
<div class="handle-box">
<el-select v-model="query.platform" placeholder="平台" class="handle-input">
<el-option
v-for="item in platforms"
:key="item.value"
:label="item.name"
:value="item.value"
/>
</el-select>
<el-select v-model="query.type" placeholder="类型" class="handle-input">
<el-option
v-for="item in types"
:key="item.value"
:label="item.name"
:label="item.label"
:value="item.value"
/>
</el-select>
@ -28,7 +20,6 @@
<el-row>
<el-table :data="items" :row-key="row => row.id" table-layout="auto">
<el-table-column prop="platform" label="所属平台"/>
<el-table-column prop="name" label="名称"/>
<el-table-column prop="value" label="API KEY">
<template #default="scope">
@ -46,10 +37,9 @@
</el-icon>
</template>
</el-table-column>
<el-table-column prop="type" label="用途">
<el-table-column prop="type" label="类型">
<template #default="scope">
<el-tag v-if="scope.row.type === 'chat'">聊天</el-tag>
<el-tag v-else-if="scope.row.type === 'img'" type="success">绘图</el-tag>
{{getTypeName(scope.row.type)}}
</template>
</el-table-column>
<el-table-column prop="proxy_url" label="代理地址"/>
@ -84,23 +74,7 @@
:close-on-click-modal="false"
:title="title"
>
<el-alert
type="warning"
:closable="false"
show-icon
style="margin-bottom: 10px; font-size:14px;">
<p><b>注意</b>如果是百度文心一言平台API-KEY APIKey|SecretKey中间用竖线|连接</p>
<p><b>注意</b>如果是讯飞星火大模型API-KEY AppId|APIKey|APISecret中间用竖线|连接</p>
</el-alert>
<el-form :model="item" label-width="120px" ref="formRef" :rules="rules">
<el-form-item label="所属平台:" prop="platform">
<el-select v-model="item.platform" placeholder="请选择平台" @change="changePlatform">
<el-option v-for="item in platforms" :value="item.value" :label="item.name" :key="item.value">{{
item.name
}}
</el-option>
</el-select>
</el-form-item>
<el-form-item label="名称:" prop="name">
<el-input v-model="item.name" autocomplete="off"/>
</el-form-item>
@ -117,12 +91,12 @@
</el-form-item>
<el-form-item label="API URL" prop="api_url">
<el-input v-model="item.api_url" autocomplete="off"
placeholder="必须填土完整的 Chat API URLhttps://api.openai.com/v1/chat/completions"/>
<div class="info">如果你使用了第三方中转这里就填写中转地址</div>
placeholder="只填 BASE URL 即可https://api.openai.com"/>
</el-form-item>
<el-form-item label="代理地址:" prop="proxy_url">
<el-input v-model="item.proxy_url" autocomplete="off"/>
<div class="info">如果想要通过代理来访问 API请填写代理地址http://127.0.0.1:7890</div>
</el-form-item>
<el-form-item label="启用状态:" prop="enable">
@ -150,11 +124,10 @@ import ClipboardJS from "clipboard";
//
const items = ref([])
const query = ref({type: '',platform:''})
const query = ref({type: ''})
const item = ref({})
const showDialog = ref(false)
const rules = reactive({
platform: [{required: true, message: '请选择平台', trigger: 'change',}],
name: [{required: true, message: '请输入名称', trigger: 'change',}],
type: [{required: true, message: '请选择用途', trigger: 'change',}],
value: [{required: true, message: '请输入 API KEY 值', trigger: 'change',}]
@ -163,11 +136,10 @@ const rules = reactive({
const loading = ref(true)
const formRef = ref(null)
const title = ref("")
const platforms = ref([])
const types = ref([
{label: "对话", value:"chat"},
{label: "Midjourney", value:"mj"},
{label: "DALL-E", value:"dall"},
{label: "DALL-E", value:"dalle"},
{label: "Suno文生歌", value:"suno"},
{label: "Luma视频", value:"luma"},
])
@ -184,12 +156,6 @@ onMounted(() => {
ElMessage.error('复制失败!');
})
httpGet("/api/admin/config/get/app").then(res => {
platforms.value = res.data.platforms
}).catch(e =>{
ElMessage.error("获取配置失败:"+e.message)
})
fetchData()
})
@ -197,6 +163,15 @@ onUnmounted(() => {
clipboard.value.destroy()
})
const getTypeName = (type) => {
for (let v of types.value) {
if (v.value === type) {
return v.label
}
}
return ""
}
//
const fetchData = () => {
@ -266,26 +241,6 @@ const set = (filed, row) => {
})
}
const selectedPlatform = ref(null)
const changePlatform = (value) => {
console.log(value)
for (let v of platforms.value) {
if (v.value === value) {
selectedPlatform.value = v
item.value.api_url = v.chat_url
}
}
}
const changeType = (value) => {
if (selectedPlatform.value) {
if(value === 'img') {
item.value.api_url = selectedPlatform.value.img_url
} else {
item.value.api_url = selectedPlatform.value.chat_url
}
}
}
</script>
<style lang="stylus" scoped>

View File

@ -17,11 +17,6 @@
<el-row>
<el-table :data="items" :row-key="row => row.id" table-layout="auto">
<el-table-column prop="platform" label="所属平台">
<template #default="scope">
<span class="sort" :data-id="scope.row.id">{{ scope.row.platform }}</span>
</template>
</el-table-column>
<el-table-column prop="name" label="模型名称"/>
<el-table-column prop="value" label="模型值">
<template #default="scope">
@ -67,15 +62,6 @@
style="width: 90%; max-width: 600px;"
>
<el-form :model="item" label-width="120px" ref="formRef" :rules="rules">
<el-form-item label="所属平台:" prop="platform">
<el-select v-model="item.platform" placeholder="请选择平台">
<el-option v-for="item in platforms" :value="item.value" :label="item.name" :key="item.value">{{
item.name
}}
</el-option>
</el-select>
</el-form-item>
<el-form-item label="模型名称:" prop="name">
<el-input v-model="item.name" autocomplete="off"/>
</el-form-item>
@ -116,18 +102,7 @@
class="box-item"
effect="dark"
raw-content
content="gpt-3.5-turbo:4096 <br/>
gpt-3.5-turbo-16k: 16384 <br/>
gpt-4: 8192 <br/>
gpt-4-32k: 32768 <br/>
chatglm_pro: 32768 <br/>
chatglm_std: 16384 <br/>
chatglm_lite: 4096 <br/>
qwen-turbo: 8192 <br/>
qwen-plus: 32768 <br/>
文心一言: 8192 <br/>
星火1.0: 4096 <br/>
星火2.0-星火3.5: 8192"
content="去各大模型的官方 API 文档查询模型支持的最大上下文长度"
placement="right"
>
<el-icon>
@ -220,7 +195,6 @@ const rules = reactive({
})
const loading = ref(true)
const formRef = ref(null)
const platforms = ref([])
// API KEY
const apiKeys = ref([])
@ -285,12 +259,6 @@ onMounted(() => {
clipboard.value.on('error', () => {
ElMessage.error('复制失败!');
})
httpGet("/api/admin/config/get/app").then(res => {
platforms.value = res.data.platforms
}).catch(e =>{
ElMessage.error("获取配置失败:"+e.message)
})
})
onUnmounted(() => {
@ -300,7 +268,7 @@ onUnmounted(() => {
const add = function () {
title.value = "新增模型"
showDialog.value = true
item.value = {enabled: true, weight: 1, open: true}
item.value = {enabled: true, power: 1, open: true,max_tokens: 1024,max_context: 8192, temperature: 0.9,}
}
const edit = function (row) {