remove other platform supports, ONLY use chatGPT API

This commit is contained in:
RockYang
2024-07-22 18:36:58 +08:00
parent 4910be3403
commit a0aee80c63
14 changed files with 74 additions and 1147 deletions

View File

@@ -208,21 +208,12 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
Model: session.Model.Value,
Stream: true,
}
switch session.Model.Platform {
case types.Azure.Value, types.ChatGLM.Value, types.Baidu.Value, types.XunFei.Value:
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
break
case types.OpenAI.Value:
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
// OpenAI 支持函数功能
var items []model.Function
res := h.DB.Where("enabled", true).Find(&items)
if res.Error != nil {
break
}
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
// OpenAI 支持函数功能
var items []model.Function
res = h.DB.Where("enabled", true).Find(&items)
if res.Error == nil {
var tools = make([]types.Tool, 0)
for _, v := range items {
var parameters map[string]interface{}
@@ -248,15 +239,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
req.Tools = tools
req.ToolChoice = "auto"
}
case types.QWen.Value:
req.Parameters = map[string]interface{}{
"max_tokens": session.Model.MaxTokens,
"temperature": session.Model.Temperature,
}
break
default:
return fmt.Errorf("不支持的平台:%s", session.Model.Platform)
}
// 加载聊天上下文
@@ -344,65 +326,37 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
logger.Debug("最终Prompt", fullPrompt)
if session.Model.Platform == types.QWen.Value {
req.Input = make(map[string]interface{})
reqMgs = append(reqMgs, types.Message{
Role: "user",
Content: fullPrompt,
})
req.Input["messages"] = reqMgs
} else if session.Model.Platform == types.OpenAI.Value || session.Model.Platform == types.Azure.Value { // extract image for gpt-vision model
imgURLs := utils.ExtractImgURLs(prompt)
logger.Debugf("detected IMG: %+v", imgURLs)
var content interface{}
if len(imgURLs) > 0 {
data := make([]interface{}, 0)
for _, v := range imgURLs {
text = strings.Replace(text, v, "", 1)
data = append(data, gin.H{
"type": "image_url",
"image_url": gin.H{
"url": v,
},
})
}
// extract images from prompt
imgURLs := utils.ExtractImgURLs(prompt)
logger.Debugf("detected IMG: %+v", imgURLs)
var content interface{}
if len(imgURLs) > 0 {
data := make([]interface{}, 0)
for _, v := range imgURLs {
text = strings.Replace(text, v, "", 1)
data = append(data, gin.H{
"type": "text",
"text": strings.TrimSpace(text),
"type": "image_url",
"image_url": gin.H{
"url": v,
},
})
content = data
} else {
content = fullPrompt
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": content,
data = append(data, gin.H{
"type": "text",
"text": strings.TrimSpace(text),
})
content = data
} else {
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": fullPrompt,
})
content = fullPrompt
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": content,
})
logger.Debugf("%+v", req.Messages)
switch session.Model.Platform {
case types.Azure.Value:
return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.OpenAI.Value:
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.ChatGLM.Value:
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.Baidu.Value:
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.XunFei.Value:
return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.QWen.Value:
return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
return nil
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
// Tokens 统计 token 数量
@@ -485,48 +439,13 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
}
// ONLY allow apiURL in blank list
if session.Model.Platform == types.OpenAI.Value {
err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
if err != nil {
return nil, err
}
err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
if err != nil {
return nil, err
}
var apiURL string
switch session.Model.Platform {
case types.Azure.Value:
md := strings.Replace(req.Model, ".", "", 1)
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
break
case types.ChatGLM.Value:
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
req.Messages = nil
break
case types.Baidu.Value:
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
break
case types.QWen.Value:
apiURL = apiKey.ApiURL
req.Messages = nil
break
default:
apiURL = apiKey.ApiURL
}
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 百度文心,需要串接 access_token
if session.Model.Platform == types.Baidu.Value {
token, err := h.getBaiduToken(apiKey.Value)
if err != nil {
return nil, err
}
logger.Info("百度文心 Access_Token", token)
apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token)
}
logger.Debugf(utils.JsonEncode(req))
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
// 创建 HttpClient 请求对象
var client *http.Client
requestBody, err := json.Marshal(req)
@@ -550,28 +469,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
} else {
client = http.DefaultClient
}
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
switch session.Model.Platform {
case types.Azure.Value:
request.Header.Set("api-key", apiKey.Value)
break
case types.ChatGLM.Value:
token, err := h.getChatGLMToken(apiKey.Value)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
break
case types.Baidu.Value:
request.RequestURI = ""
case types.OpenAI.Value:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
break
case types.QWen.Value:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
request.Header.Set("X-DashScope-SSE", "enable")
break
}
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
return client.Do(request)
}