mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-06 17:23:42 +08:00
feat: allow chat model bind a fixed api key
This commit is contained in:
@@ -30,7 +30,7 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
start := time.Now()
|
||||
var apiKey = model.ApiKey{}
|
||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
||||
response, err := h.doRequest(ctx, req, session, &apiKey)
|
||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
|
||||
@@ -47,7 +47,7 @@ func (h *ChatHandler) sendBaiduMessage(
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
start := time.Now()
|
||||
var apiKey = model.ApiKey{}
|
||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
||||
response, err := h.doRequest(ctx, req, session, &apiKey)
|
||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
|
||||
@@ -122,6 +122,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
MaxTokens: chatModel.MaxTokens,
|
||||
MaxContext: chatModel.MaxContext,
|
||||
Temperature: chatModel.Temperature,
|
||||
KeyId: chatModel.KeyId,
|
||||
Platform: types.Platform(chatModel.Platform)}
|
||||
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
|
||||
|
||||
@@ -463,13 +464,21 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
||||
|
||||
// 发送请求到 OpenAI 服务器
|
||||
// useOwnApiKey: 是否使用了用户自己的 API KEY
|
||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
|
||||
res := h.DB.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
|
||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
|
||||
// if the chat model bind a KEY, use it directly
|
||||
var res *gorm.DB
|
||||
if session.Model.KeyId > 0 {
|
||||
res = h.DB.Where("id", session.Model.KeyId).Find(apiKey)
|
||||
}
|
||||
// use the last unused key
|
||||
if res.Error != nil {
|
||||
res = h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
|
||||
}
|
||||
if res.Error != nil {
|
||||
return nil, errors.New("no available key, please import key")
|
||||
}
|
||||
var apiURL string
|
||||
switch platform {
|
||||
switch session.Model.Platform {
|
||||
case types.Azure:
|
||||
md := strings.Replace(req.Model, ".", "", 1)
|
||||
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
|
||||
@@ -492,7 +501,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
||||
// 更新 API KEY 的最后使用时间
|
||||
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
// 百度文心,需要串接 access_token
|
||||
if platform == types.Baidu {
|
||||
if session.Model.Platform == types.Baidu {
|
||||
token, err := h.getBaiduToken(apiKey.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -527,8 +536,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
|
||||
} else {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model)
|
||||
switch platform {
|
||||
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, proxyURL, req.Model)
|
||||
switch session.Model.Platform {
|
||||
case types.Azure:
|
||||
request.Header.Set("api-key", apiKey.Value)
|
||||
break
|
||||
|
||||
@@ -31,7 +31,7 @@ func (h *ChatHandler) sendChatGLMMessage(
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
start := time.Now()
|
||||
var apiKey = model.ApiKey{}
|
||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
||||
response, err := h.doRequest(ctx, req, session, &apiKey)
|
||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
|
||||
@@ -31,7 +31,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
start := time.Now()
|
||||
var apiKey = model.ApiKey{}
|
||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
||||
response, err := h.doRequest(ctx, req, session, &apiKey)
|
||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
|
||||
@@ -45,7 +45,7 @@ func (h *ChatHandler) sendQWenMessage(
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
start := time.Now()
|
||||
var apiKey = model.ApiKey{}
|
||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
||||
response, err := h.doRequest(ctx, req, session, &apiKey)
|
||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"html/template"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -69,7 +70,15 @@ func (h *ChatHandler) sendXunFeiMessage(
|
||||
ws *types.WsClient) error {
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
var apiKey model.ApiKey
|
||||
res := h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
|
||||
var res *gorm.DB
|
||||
// use the bind key
|
||||
if session.Model.KeyId > 0 {
|
||||
res = h.DB.Where("id", session.Model.KeyId).Find(&apiKey)
|
||||
}
|
||||
// use the last unused key
|
||||
if res.Error != nil {
|
||||
res = h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
|
||||
}
|
||||
if res.Error != nil {
|
||||
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user