mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-10 11:13:42 +08:00
remove other platform supports, ONLY use chatGPT API
This commit is contained in:
@@ -208,21 +208,12 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
Model: session.Model.Value,
|
||||
Stream: true,
|
||||
}
|
||||
switch session.Model.Platform {
|
||||
case types.Azure.Value, types.ChatGLM.Value, types.Baidu.Value, types.XunFei.Value:
|
||||
req.Temperature = session.Model.Temperature
|
||||
req.MaxTokens = session.Model.MaxTokens
|
||||
break
|
||||
case types.OpenAI.Value:
|
||||
req.Temperature = session.Model.Temperature
|
||||
req.MaxTokens = session.Model.MaxTokens
|
||||
// OpenAI 支持函数功能
|
||||
var items []model.Function
|
||||
res := h.DB.Where("enabled", true).Find(&items)
|
||||
if res.Error != nil {
|
||||
break
|
||||
}
|
||||
|
||||
req.Temperature = session.Model.Temperature
|
||||
req.MaxTokens = session.Model.MaxTokens
|
||||
// OpenAI 支持函数功能
|
||||
var items []model.Function
|
||||
res = h.DB.Where("enabled", true).Find(&items)
|
||||
if res.Error == nil {
|
||||
var tools = make([]types.Tool, 0)
|
||||
for _, v := range items {
|
||||
var parameters map[string]interface{}
|
||||
@@ -248,15 +239,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
req.Tools = tools
|
||||
req.ToolChoice = "auto"
|
||||
}
|
||||
case types.QWen.Value:
|
||||
req.Parameters = map[string]interface{}{
|
||||
"max_tokens": session.Model.MaxTokens,
|
||||
"temperature": session.Model.Temperature,
|
||||
}
|
||||
break
|
||||
|
||||
default:
|
||||
return fmt.Errorf("不支持的平台:%s", session.Model.Platform)
|
||||
}
|
||||
|
||||
// 加载聊天上下文
|
||||
@@ -344,65 +326,37 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
}
|
||||
logger.Debug("最终Prompt:", fullPrompt)
|
||||
|
||||
if session.Model.Platform == types.QWen.Value {
|
||||
req.Input = make(map[string]interface{})
|
||||
reqMgs = append(reqMgs, types.Message{
|
||||
Role: "user",
|
||||
Content: fullPrompt,
|
||||
})
|
||||
req.Input["messages"] = reqMgs
|
||||
} else if session.Model.Platform == types.OpenAI.Value || session.Model.Platform == types.Azure.Value { // extract image for gpt-vision model
|
||||
imgURLs := utils.ExtractImgURLs(prompt)
|
||||
logger.Debugf("detected IMG: %+v", imgURLs)
|
||||
var content interface{}
|
||||
if len(imgURLs) > 0 {
|
||||
data := make([]interface{}, 0)
|
||||
for _, v := range imgURLs {
|
||||
text = strings.Replace(text, v, "", 1)
|
||||
data = append(data, gin.H{
|
||||
"type": "image_url",
|
||||
"image_url": gin.H{
|
||||
"url": v,
|
||||
},
|
||||
})
|
||||
}
|
||||
// extract images from prompt
|
||||
imgURLs := utils.ExtractImgURLs(prompt)
|
||||
logger.Debugf("detected IMG: %+v", imgURLs)
|
||||
var content interface{}
|
||||
if len(imgURLs) > 0 {
|
||||
data := make([]interface{}, 0)
|
||||
for _, v := range imgURLs {
|
||||
text = strings.Replace(text, v, "", 1)
|
||||
data = append(data, gin.H{
|
||||
"type": "text",
|
||||
"text": strings.TrimSpace(text),
|
||||
"type": "image_url",
|
||||
"image_url": gin.H{
|
||||
"url": v,
|
||||
},
|
||||
})
|
||||
content = data
|
||||
} else {
|
||||
content = fullPrompt
|
||||
}
|
||||
req.Messages = append(reqMgs, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
data = append(data, gin.H{
|
||||
"type": "text",
|
||||
"text": strings.TrimSpace(text),
|
||||
})
|
||||
content = data
|
||||
} else {
|
||||
req.Messages = append(reqMgs, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": fullPrompt,
|
||||
})
|
||||
content = fullPrompt
|
||||
}
|
||||
req.Messages = append(reqMgs, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
})
|
||||
|
||||
logger.Debugf("%+v", req.Messages)
|
||||
|
||||
switch session.Model.Platform {
|
||||
case types.Azure.Value:
|
||||
return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||
case types.OpenAI.Value:
|
||||
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||
case types.ChatGLM.Value:
|
||||
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||
case types.Baidu.Value:
|
||||
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||
case types.XunFei.Value:
|
||||
return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||
case types.QWen.Value:
|
||||
return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||
}
|
||||
|
||||
return nil
|
||||
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||
}
|
||||
|
||||
// Tokens 统计 token 数量
|
||||
@@ -485,48 +439,13 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
|
||||
}
|
||||
|
||||
// ONLY allow apiURL in blank list
|
||||
if session.Model.Platform == types.OpenAI.Value {
|
||||
err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var apiURL string
|
||||
switch session.Model.Platform {
|
||||
case types.Azure.Value:
|
||||
md := strings.Replace(req.Model, ".", "", 1)
|
||||
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
|
||||
break
|
||||
case types.ChatGLM.Value:
|
||||
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
|
||||
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
|
||||
req.Messages = nil
|
||||
break
|
||||
case types.Baidu.Value:
|
||||
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
|
||||
break
|
||||
case types.QWen.Value:
|
||||
apiURL = apiKey.ApiURL
|
||||
req.Messages = nil
|
||||
break
|
||||
default:
|
||||
apiURL = apiKey.ApiURL
|
||||
}
|
||||
// 更新 API KEY 的最后使用时间
|
||||
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
// 百度文心,需要串接 access_token
|
||||
if session.Model.Platform == types.Baidu.Value {
|
||||
token, err := h.getBaiduToken(apiKey.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.Info("百度文心 Access_Token:", token)
|
||||
apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token)
|
||||
}
|
||||
|
||||
logger.Debugf(utils.JsonEncode(req))
|
||||
|
||||
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
||||
// 创建 HttpClient 请求对象
|
||||
var client *http.Client
|
||||
requestBody, err := json.Marshal(req)
|
||||
@@ -550,28 +469,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
|
||||
} else {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
|
||||
switch session.Model.Platform {
|
||||
case types.Azure.Value:
|
||||
request.Header.Set("api-key", apiKey.Value)
|
||||
break
|
||||
case types.ChatGLM.Value:
|
||||
token, err := h.getChatGLMToken(apiKey.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
break
|
||||
case types.Baidu.Value:
|
||||
request.RequestURI = ""
|
||||
case types.OpenAI.Value:
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
||||
break
|
||||
case types.QWen.Value:
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
||||
request.Header.Set("X-DashScope-SSE", "enable")
|
||||
break
|
||||
}
|
||||
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
||||
return client.Do(request)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user