websocket api refactor is ready

This commit is contained in:
RockYang
2024-09-29 19:28:47 +08:00
parent 00a8bc6784
commit e28a12a1ee
19 changed files with 210 additions and 464 deletions

View File

@@ -8,23 +8,15 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// MarkMapHandler 生成思维导图
@@ -44,23 +36,33 @@ func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.Us
// Generate 生成思维导图
func (h *MarkMapHandler) Generate(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
ModelId int `json:"model_id"`
}
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error {
userId := h.GetLoginUserId(c)
var user model.User
res := h.DB.Model(&model.User{}).First(&user, userId)
if res.Error != nil {
return fmt.Errorf("error with query user info: %v", res.Error)
err := h.DB.Where("id", userId).First(&user, userId).Error
if err != nil {
resp.ERROR(c, "error with query user info")
return
}
var chatModel model.ChatModel
res = h.DB.Where("id", modelId).First(&chatModel)
if res.Error != nil {
return fmt.Errorf("error with query chat model: %v", res.Error)
err = h.DB.Where("id", data.ModelId).First(&chatModel).Error
if err != nil {
resp.ERROR(c, "error with query chat model")
return
}
if user.Power < chatModel.Power {
return fmt.Errorf("您当前剩余算力(%d已不足以支付当前模型算力%d", user.Power, chatModel.Power)
resp.ERROR(c, fmt.Sprintf("您当前剩余算力(%d已不足以支付当前模型算力%d", user.Power, chatModel.Power))
return
}
messages := make([]interface{}, 0)
@@ -82,117 +84,27 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
### 支付宝
### 微信
另外,除此之外不要任何解释性语句。
请直接生成结果,不要任何解释性语句。
`})
messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图要求结构清晰有条理", prompt)})
var req = types.ApiRequest{
Model: chatModel.Value,
Stream: true,
Messages: messages,
}
var apiKey model.ApiKey
response, err := h.doRequest(req, chatModel, &apiKey)
messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图要求结构清晰有条理", data.Prompt)})
content, err := utils.SendOpenAIMessage(h.DB, messages, chatModel.Value, chatModel.KeyId)
if err != nil {
return fmt.Errorf("请求 OpenAI API 失败: %s", err)
}
defer response.Body.Close()
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
// 循环读取 Chunk 消息
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
continue
}
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错
return fmt.Errorf("error with decode data: %v", line)
}
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
continue
}
if responseBody.Choices[0].FinishReason == "stop" {
break
}
utils.SendMsg(client, types.ReplyMessage{
Type: types.MsgTypeText,
Body: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
} // end for
utils.SendMsg(client, types.ReplyMessage{Type: types.MsgTypeEnd})
} else {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求 OpenAI API 失败:%s", string(body))
resp.ERROR(c, fmt.Sprintf("请求 OpenAI API 失败: %s", err))
return
}
// 扣减算力
if chatModel.Power > 0 {
err = h.userService.DecreasePower(userId, chatModel.Power, model.PowerLog{
err = h.userService.DecreasePower(int(userId), chatModel.Power, model.PowerLog{
Type: types.PowerConsume,
Model: chatModel.Value,
Remark: fmt.Sprintf("AI绘制思维导图模型名称%s, ", chatModel.Value),
})
if err != nil {
return err
resp.ERROR(c, "error with save power log, "+err.Error())
return
}
}
return nil
}
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
session := h.DB.Session(&gorm.Session{})
// if the chat model bind a KEY, use it directly
if chatModel.KeyId > 0 {
session = session.Where("id", chatModel.KeyId)
} else { // use the last unused key
session = session.Where("type", "chat").
Where("enabled", true).Order("last_used_at ASC")
}
res := session.First(apiKey)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 创建 HttpClient 请求对象
var client *http.Client
requestBody, err := json.Marshal(req)
if err != nil {
return nil, err
}
request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
if len(apiKey.ProxyURL) > 5 { // 使用代理
proxy, _ := url.Parse(apiKey.ProxyURL)
client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
} else {
client = http.DefaultClient
}
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
logger.Debugf("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
return client.Do(request)
resp.SUCCESS(c, content)
}