mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-18 11:33:42 +08:00
Compare commits
27 Commits
v0.2.5-alp
...
v0.2.5-alp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
80af3718d0 | ||
|
|
77ea6bec46 | ||
|
|
87919b032d | ||
|
|
f7a4f18aff | ||
|
|
706449dede | ||
|
|
36d164be0e | ||
|
|
d80a7d3c97 | ||
|
|
44a8ade4ba | ||
|
|
2cca2a989e | ||
|
|
3065bf92ae | ||
|
|
2e595bdafb | ||
|
|
49df4b6eed | ||
|
|
5c39f54040 | ||
|
|
c0ab8ae953 | ||
|
|
923c2dee32 | ||
|
|
786ccc7da0 | ||
|
|
8eedad9470 | ||
|
|
319e97d677 | ||
|
|
6114c9bb96 | ||
|
|
3cf2f0d5cb | ||
|
|
2a345ae070 | ||
|
|
d8c91fa448 | ||
|
|
cc8cc8b386 | ||
|
|
1587ea565b | ||
|
|
a7a1fc615d | ||
|
|
b2a280c1ec | ||
|
|
f1fb7b32a3 |
4
.github/workflows/docker-image-amd64.yml
vendored
4
.github/workflows/docker-image-amd64.yml
vendored
@@ -1,10 +1,6 @@
|
||||
name: Publish Docker image (amd64)
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- '!*-alpha*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
|
||||
1
.github/workflows/docker-image-arm64.yml
vendored
1
.github/workflows/docker-image-arm64.yml
vendored
@@ -4,7 +4,6 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- '!*-alpha*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
|
||||
@@ -55,6 +55,7 @@ var WeChatAuthEnabled = false
|
||||
var TelegramOAuthEnabled = false
|
||||
var TurnstileCheckEnabled = false
|
||||
var RegisterEnabled = true
|
||||
var UserSelfDeletionEnabled = false
|
||||
|
||||
var EmailDomainRestrictionEnabled = false
|
||||
var EmailDomainWhitelist = []string{
|
||||
@@ -76,6 +77,7 @@ var LogConsumeEnabled = true
|
||||
|
||||
var SMTPServer = ""
|
||||
var SMTPPort = 587
|
||||
var SMTPSSLEnabled = false
|
||||
var SMTPAccount = ""
|
||||
var SMTPFrom = ""
|
||||
var SMTPToken = ""
|
||||
|
||||
@@ -24,7 +24,7 @@ func SendEmail(subject string, receiver string, content string) error {
|
||||
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
|
||||
to := strings.Split(receiver, ";")
|
||||
var err error
|
||||
if SMTPPort == 465 {
|
||||
if SMTPPort == 465 || SMTPSSLEnabled {
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: SMTPServer,
|
||||
|
||||
@@ -71,8 +71,11 @@ var DefaultModelRatio = map[string]float64{
|
||||
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
|
||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||
"PaLM-2": 1,
|
||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-1.0-pro-vision-001": 1,
|
||||
"gemini-1.0-pro-001": 1,
|
||||
"gemini-1.5-pro": 1,
|
||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||
@@ -212,5 +215,15 @@ func GetCompletionRatio(name string) float64 {
|
||||
} else if strings.HasPrefix(name, "claude-3") {
|
||||
return 5
|
||||
}
|
||||
if strings.HasPrefix(name, "mistral-") {
|
||||
return 3
|
||||
}
|
||||
if strings.HasPrefix(name, "gemini-") {
|
||||
return 3
|
||||
}
|
||||
switch name {
|
||||
case "llama2-70b-4096":
|
||||
return 0.8 / 0.7
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
@@ -18,9 +18,8 @@ func InitRedisClient() (err error) {
|
||||
return nil
|
||||
}
|
||||
if os.Getenv("SYNC_FREQUENCY") == "" {
|
||||
RedisEnabled = false
|
||||
SysLog("SYNC_FREQUENCY not set, Redis is disabled")
|
||||
return nil
|
||||
SysLog("SYNC_FREQUENCY not set, use default value 60")
|
||||
SyncFrequency = 60
|
||||
}
|
||||
SysLog("Redis is enabled")
|
||||
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
|
||||
|
||||
@@ -4,7 +4,8 @@ import "strings"
|
||||
|
||||
var CheckSensitiveEnabled = true
|
||||
var CheckSensitiveOnPromptEnabled = true
|
||||
var CheckSensitiveOnCompletionEnabled = true
|
||||
|
||||
//var CheckSensitiveOnCompletionEnabled = true
|
||||
|
||||
// StopOnSensitiveEnabled 如果检测到敏感词,是否立刻停止生成,否则替换敏感词
|
||||
var StopOnSensitiveEnabled = true
|
||||
@@ -37,6 +38,6 @@ func ShouldCheckPromptSensitive() bool {
|
||||
return CheckSensitiveEnabled && CheckSensitiveOnPromptEnabled
|
||||
}
|
||||
|
||||
func ShouldCheckCompletionSensitive() bool {
|
||||
return CheckSensitiveEnabled && CheckSensitiveOnCompletionEnabled
|
||||
}
|
||||
//func ShouldCheckCompletionSensitive() bool {
|
||||
// return CheckSensitiveEnabled && CheckSensitiveOnCompletionEnabled
|
||||
//}
|
||||
|
||||
@@ -87,7 +87,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
err := relaycommon.RelayErrorHandler(resp)
|
||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
|
||||
}
|
||||
usage, respErr, _ := adaptor.DoResponse(c, resp, meta)
|
||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||
if respErr != nil {
|
||||
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
|
||||
}
|
||||
@@ -108,6 +108,7 @@ func buildTestRequest() *dto.GeneralOpenAIRequest {
|
||||
testRequest := &dto.GeneralOpenAIRequest{
|
||||
Model: "", // this will be set later
|
||||
MaxTokens: 1,
|
||||
Stream: false,
|
||||
}
|
||||
content, _ := json.Marshal("hi")
|
||||
testMessage := dto.Message{
|
||||
|
||||
@@ -147,7 +147,7 @@ func UpdateMidjourneyTaskBulk() {
|
||||
task.Buttons = string(buttonStr)
|
||||
}
|
||||
|
||||
if task.Progress != "100%" && responseItem.FailReason != "" {
|
||||
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
|
||||
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
||||
task.Progress = "100%"
|
||||
err = model.CacheUpdateUserQuota(task.UserId)
|
||||
|
||||
@@ -558,6 +558,14 @@ func HardDeleteUser(c *gin.Context) {
|
||||
}
|
||||
|
||||
func DeleteSelf(c *gin.Context) {
|
||||
if !common.UserSelfDeletionEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "当前设置不允许用户自我删除账号",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
user, _ := model.GetUserById(id, false)
|
||||
|
||||
|
||||
@@ -11,6 +11,12 @@ type TextResponseWithError struct {
|
||||
Error OpenAIError `json:"error"`
|
||||
}
|
||||
|
||||
type SimpleResponse struct {
|
||||
Usage `json:"usage"`
|
||||
Error OpenAIError `json:"error"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type TextResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
|
||||
@@ -44,7 +44,7 @@ func Distribute() func(c *gin.Context) {
|
||||
// Select a channel for the user
|
||||
var modelRequest ModelRequest
|
||||
var err error
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/mj") {
|
||||
if strings.Contains(c.Request.URL.Path, "/mj/") {
|
||||
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
|
||||
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
|
||||
|
||||
@@ -36,6 +36,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
|
||||
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
|
||||
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
|
||||
common.OptionMap["UserSelfDeletionEnabled"] = strconv.FormatBool(common.UserSelfDeletionEnabled)
|
||||
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
|
||||
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
|
||||
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
|
||||
@@ -51,6 +52,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
|
||||
common.OptionMap["SMTPAccount"] = ""
|
||||
common.OptionMap["SMTPToken"] = ""
|
||||
common.OptionMap["SMTPSSLEnabled"] = strconv.FormatBool(common.SMTPSSLEnabled)
|
||||
common.OptionMap["Notice"] = ""
|
||||
common.OptionMap["About"] = ""
|
||||
common.OptionMap["HomePageContent"] = ""
|
||||
@@ -96,7 +98,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
|
||||
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
|
||||
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
|
||||
common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
||||
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
||||
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled)
|
||||
common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString()
|
||||
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(constant.StreamCacheQueueLength)
|
||||
@@ -177,6 +179,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.TurnstileCheckEnabled = boolValue
|
||||
case "RegisterEnabled":
|
||||
common.RegisterEnabled = boolValue
|
||||
case "UserSelfDeletionEnabled":
|
||||
common.UserSelfDeletionEnabled = boolValue
|
||||
case "EmailDomainRestrictionEnabled":
|
||||
common.EmailDomainRestrictionEnabled = boolValue
|
||||
case "AutomaticDisableChannelEnabled":
|
||||
@@ -201,10 +205,12 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
constant.CheckSensitiveEnabled = boolValue
|
||||
case "CheckSensitiveOnPromptEnabled":
|
||||
constant.CheckSensitiveOnPromptEnabled = boolValue
|
||||
case "CheckSensitiveOnCompletionEnabled":
|
||||
constant.CheckSensitiveOnCompletionEnabled = boolValue
|
||||
//case "CheckSensitiveOnCompletionEnabled":
|
||||
// constant.CheckSensitiveOnCompletionEnabled = boolValue
|
||||
case "StopOnSensitiveEnabled":
|
||||
constant.StopOnSensitiveEnabled = boolValue
|
||||
case "SMTPSSLEnabled":
|
||||
common.SMTPSSLEnabled = boolValue
|
||||
}
|
||||
}
|
||||
switch key {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -75,8 +76,26 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) {
|
||||
return users, err
|
||||
}
|
||||
|
||||
func SearchUsers(keyword string) (users []*User, err error) {
|
||||
err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
|
||||
func SearchUsers(keyword string) ([]*User, error) {
|
||||
var users []*User
|
||||
var err error
|
||||
|
||||
// 尝试将关键字转换为整数ID
|
||||
keywordInt, err := strconv.Atoi(keyword)
|
||||
if err == nil {
|
||||
// 如果转换成功,按照ID搜索用户
|
||||
err = DB.Unscoped().Omit("password").Where("id = ?", keywordInt).Find(&users).Error
|
||||
if err != nil || len(users) > 0 {
|
||||
// 如果依据ID找到用户或者发生错误,返回结果或错误
|
||||
return users, err
|
||||
}
|
||||
}
|
||||
|
||||
// 如果ID转换失败或者没有找到用户,依据其他字段进行模糊搜索
|
||||
err = DB.Unscoped().Omit("password").
|
||||
Where("username LIKE ? OR email LIKE ? OR display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").
|
||||
Find(&users).Error
|
||||
|
||||
return users, err
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ type Adaptor interface {
|
||||
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
|
||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = aliStreamHandler(c, resp)
|
||||
} else {
|
||||
|
||||
@@ -69,7 +69,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = baiduStreamHandler(c, resp)
|
||||
} else {
|
||||
|
||||
@@ -63,7 +63,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
|
||||
} else {
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
@@ -317,7 +316,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive())
|
||||
completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = geminiChatStreamHandler(c, resp)
|
||||
|
||||
@@ -5,8 +5,8 @@ const (
|
||||
)
|
||||
|
||||
var ModelList = []string{
|
||||
"gemini-pro",
|
||||
"gemini-pro-vision",
|
||||
"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
|
||||
"gemini-pro-vision", "gemini-1.0-pro-vision-001",
|
||||
}
|
||||
|
||||
var ChannelName = "google gemini"
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
@@ -257,7 +256,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||
completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive())
|
||||
completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, false)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
|
||||
@@ -2,7 +2,6 @@ package ollama
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -10,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
@@ -20,7 +20,12 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIReq
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/api/chat", info.BaseUrl), nil
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return info.BaseUrl + "/api/embeddings", nil
|
||||
default:
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
@@ -32,20 +37,29 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return requestOpenAI2Ollama(*request), nil
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return requestOpenAI2Embeddings(*request), nil
|
||||
default:
|
||||
return requestOpenAI2Ollama(*request), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,16 +3,24 @@ package ollama
|
||||
import "one-api/dto"
|
||||
|
||||
type OllamaRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []dto.Message `json:"messages,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Options *OllamaOptions `json:"options,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []dto.Message `json:"messages,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Topp float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaOptions struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Topp float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
type OllamaEmbeddingRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaEmbeddingResponse struct {
|
||||
Embedding []float64 `json:"embedding,omitempty"`
|
||||
}
|
||||
|
||||
//type OllamaOptions struct {
|
||||
//}
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
package ollama
|
||||
|
||||
import "one-api/dto"
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
||||
messages := make([]dto.Message, 0, len(request.Messages))
|
||||
@@ -18,15 +27,82 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
||||
Stop, _ = request.Stop.([]string)
|
||||
}
|
||||
return &OllamaRequest{
|
||||
Model: request.Model,
|
||||
Messages: messages,
|
||||
Stream: request.Stream,
|
||||
Options: &OllamaOptions{
|
||||
Temperature: request.Temperature,
|
||||
Seed: request.Seed,
|
||||
Topp: request.TopP,
|
||||
TopK: request.TopK,
|
||||
Stop: Stop,
|
||||
},
|
||||
Model: request.Model,
|
||||
Messages: messages,
|
||||
Stream: request.Stream,
|
||||
Temperature: request.Temperature,
|
||||
Seed: request.Seed,
|
||||
Topp: request.TopP,
|
||||
TopK: request.TopK,
|
||||
Stop: Stop,
|
||||
}
|
||||
}
|
||||
|
||||
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
|
||||
return &OllamaEmbeddingRequest{
|
||||
Model: request.Model,
|
||||
Prompt: request.Input,
|
||||
}
|
||||
}
|
||||
|
||||
func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var ollamaEmbeddingResponse OllamaEmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
|
||||
data = append(data, dto.OpenAIEmbeddingResponseItem{
|
||||
Embedding: ollamaEmbeddingResponse.Embedding,
|
||||
Object: "embedding",
|
||||
})
|
||||
usage := &dto.Usage{
|
||||
TotalTokens: promptTokens,
|
||||
CompletionTokens: 0,
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
embeddingResponse := &dto.OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: data,
|
||||
Model: model,
|
||||
Usage: *usage,
|
||||
}
|
||||
doResponseBody, err := json.Marshal(embeddingResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
// Copy headers
|
||||
for k, v := range resp.Header {
|
||||
// 删除任何现有的相同头部,以防止重复添加头部
|
||||
c.Writer.Header().Del(k)
|
||||
for _, vv := range v {
|
||||
c.Writer.Header().Add(k, vv)
|
||||
}
|
||||
}
|
||||
// reset content length
|
||||
c.Writer.Header().Del("Content-Length")
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/ai360"
|
||||
"one-api/relay/channel/moonshot"
|
||||
"one-api/relay/channel/lingyiwanwu"
|
||||
"one-api/relay/channel/moonshot"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
@@ -34,9 +34,6 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
model_ := info.UpstreamModelName
|
||||
model_ = strings.Replace(model_, ".", "", -1)
|
||||
// https://github.com/songquanpeng/one-api/issues/67
|
||||
model_ = strings.TrimSuffix(model_, "-0301")
|
||||
model_ = strings.TrimSuffix(model_, "-0314")
|
||||
model_ = strings.TrimSuffix(model_, "-0613")
|
||||
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||
@@ -72,13 +69,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
||||
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,14 +4,10 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
@@ -21,7 +17,7 @@ import (
|
||||
)
|
||||
|
||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
checkSensitive := constant.ShouldCheckCompletionSensitive()
|
||||
//checkSensitive := constant.ShouldCheckCompletionSensitive()
|
||||
var responseTextBuilder strings.Builder
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
@@ -53,20 +49,11 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
||||
continue
|
||||
}
|
||||
sensitive := false
|
||||
if checkSensitive {
|
||||
// check sensitive
|
||||
sensitive, _, data = service.SensitiveWordReplace(data, false)
|
||||
}
|
||||
dataChan <- data
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
streamItems = append(streamItems, data)
|
||||
}
|
||||
if sensitive && constant.StopOnSensitiveEnabled {
|
||||
dataChan <- "data: [DONE]"
|
||||
break
|
||||
}
|
||||
}
|
||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||
switch relayMode {
|
||||
@@ -142,118 +129,56 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
||||
return nil, responseTextBuilder.String()
|
||||
}
|
||||
|
||||
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
|
||||
var responseWithError dto.TextResponseWithError
|
||||
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var simpleResponse dto.SimpleResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &responseWithError)
|
||||
err = json.Unmarshal(responseBody, &simpleResponse)
|
||||
if err != nil {
|
||||
log.Printf("unmarshal_response_body_failed: body: %s, err: %v", string(responseBody), err)
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if responseWithError.Error.Type != "" {
|
||||
if simpleResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: responseWithError.Error,
|
||||
Error: simpleResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil, nil
|
||||
}, nil
|
||||
}
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
checkSensitive := constant.ShouldCheckCompletionSensitive()
|
||||
sensitiveWords := make([]string, 0)
|
||||
triggerSensitive := false
|
||||
|
||||
usage := &responseWithError.Usage
|
||||
|
||||
//textResponse := &dto.TextResponse{
|
||||
// Choices: responseWithError.Choices,
|
||||
// Usage: responseWithError.Usage,
|
||||
//}
|
||||
var doResponseBody []byte
|
||||
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
embeddingResponse := &dto.OpenAIEmbeddingResponse{
|
||||
Object: responseWithError.Object,
|
||||
Data: responseWithError.Data,
|
||||
Model: responseWithError.Model,
|
||||
Usage: *usage,
|
||||
if simpleResponse.Usage.TotalTokens == 0 {
|
||||
completionTokens := 0
|
||||
for _, choice := range simpleResponse.Choices {
|
||||
ctkm, _, _ := service.CountTokenText(string(choice.Message.Content), model, false)
|
||||
completionTokens += ctkm
|
||||
}
|
||||
doResponseBody, err = json.Marshal(embeddingResponse)
|
||||
default:
|
||||
if responseWithError.Usage.TotalTokens == 0 || checkSensitive {
|
||||
completionTokens := 0
|
||||
for i, choice := range responseWithError.Choices {
|
||||
stringContent := string(choice.Message.Content)
|
||||
ctkm, _, _ := service.CountTokenText(stringContent, model, false)
|
||||
completionTokens += ctkm
|
||||
if checkSensitive {
|
||||
sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
|
||||
if sensitive {
|
||||
triggerSensitive = true
|
||||
msg := choice.Message
|
||||
msg.Content = common.StringToByteSlice(stringContent)
|
||||
responseWithError.Choices[i].Message = msg
|
||||
sensitiveWords = append(sensitiveWords, words...)
|
||||
}
|
||||
}
|
||||
}
|
||||
responseWithError.Usage = dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
textResponse := &dto.TextResponse{
|
||||
Id: responseWithError.Id,
|
||||
Created: responseWithError.Created,
|
||||
Object: responseWithError.Object,
|
||||
Choices: responseWithError.Choices,
|
||||
Model: responseWithError.Model,
|
||||
Usage: *usage,
|
||||
}
|
||||
doResponseBody, err = json.Marshal(textResponse)
|
||||
}
|
||||
|
||||
if checkSensitive && triggerSensitive && constant.StopOnSensitiveEnabled {
|
||||
sensitiveWords = common.RemoveDuplicate(sensitiveWords)
|
||||
return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s",
|
||||
strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest),
|
||||
usage, &dto.SensitiveResponse{
|
||||
SensitiveWords: sensitiveWords,
|
||||
}
|
||||
} else {
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
// Copy headers
|
||||
for k, v := range resp.Header {
|
||||
// 删除任何现有的相同头部,以防止重复添加头部
|
||||
c.Writer.Header().Del(k)
|
||||
for _, vv := range v {
|
||||
c.Writer.Header().Add(k, vv)
|
||||
}
|
||||
}
|
||||
// reset content length
|
||||
c.Writer.Header().Del("Content-Length")
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
|
||||
simpleResponse.Usage = dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
return nil, usage, nil
|
||||
return nil, &simpleResponse.Usage
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = palmStreamHandler(c, resp)
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
@@ -157,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||
completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive())
|
||||
completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, false)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
|
||||
@@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = tencentStreamHandler(c, resp)
|
||||
|
||||
@@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return dummyResp, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
splits := strings.Split(info.ApiKey, "|")
|
||||
if len(splits) != 3 {
|
||||
return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest), nil
|
||||
return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||
}
|
||||
if a.request == nil {
|
||||
return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest), nil
|
||||
return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
|
||||
}
|
||||
if info.IsStream {
|
||||
err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
|
||||
|
||||
@@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = zhipuStreamHandler(c, resp)
|
||||
} else {
|
||||
|
||||
@@ -44,13 +44,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -56,29 +56,29 @@ func Path2RelayMode(path string) int {
|
||||
|
||||
func Path2RelayModeMidjourney(path string) int {
|
||||
relayMode := RelayModeUnknown
|
||||
if strings.HasPrefix(path, "/mj/submit/action") {
|
||||
if strings.HasSuffix(path, "/mj/submit/action") {
|
||||
// midjourney plus
|
||||
relayMode = RelayModeMidjourneyAction
|
||||
} else if strings.HasPrefix(path, "/mj/submit/modal") {
|
||||
} else if strings.HasSuffix(path, "/mj/submit/modal") {
|
||||
// midjourney plus
|
||||
relayMode = RelayModeMidjourneyModal
|
||||
} else if strings.HasPrefix(path, "/mj/submit/shorten") {
|
||||
} else if strings.HasSuffix(path, "/mj/submit/shorten") {
|
||||
// midjourney plus
|
||||
relayMode = RelayModeMidjourneyShorten
|
||||
} else if strings.HasPrefix(path, "/mj/insight-face/swap") {
|
||||
} else if strings.HasSuffix(path, "/mj/insight-face/swap") {
|
||||
// midjourney plus
|
||||
relayMode = RelayModeSwapFace
|
||||
} else if strings.HasPrefix(path, "/mj/submit/imagine") {
|
||||
} else if strings.HasSuffix(path, "/mj/submit/imagine") {
|
||||
relayMode = RelayModeMidjourneyImagine
|
||||
} else if strings.HasPrefix(path, "/mj/submit/blend") {
|
||||
} else if strings.HasSuffix(path, "/mj/submit/blend") {
|
||||
relayMode = RelayModeMidjourneyBlend
|
||||
} else if strings.HasPrefix(path, "/mj/submit/describe") {
|
||||
} else if strings.HasSuffix(path, "/mj/submit/describe") {
|
||||
relayMode = RelayModeMidjourneyDescribe
|
||||
} else if strings.HasPrefix(path, "/mj/notify") {
|
||||
} else if strings.HasSuffix(path, "/mj/notify") {
|
||||
relayMode = RelayModeMidjourneyNotify
|
||||
} else if strings.HasPrefix(path, "/mj/submit/change") {
|
||||
} else if strings.HasSuffix(path, "/mj/submit/change") {
|
||||
relayMode = RelayModeMidjourneyChange
|
||||
} else if strings.HasPrefix(path, "/mj/submit/simple-change") {
|
||||
} else if strings.HasSuffix(path, "/mj/submit/simple-change") {
|
||||
relayMode = RelayModeMidjourneyChange
|
||||
} else if strings.HasSuffix(path, "/fetch") {
|
||||
relayMode = RelayModeMidjourneyTaskFetch
|
||||
|
||||
@@ -173,7 +173,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
||||
quota = promptTokens
|
||||
} else {
|
||||
quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive())
|
||||
quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, false)
|
||||
}
|
||||
quota = int(float64(quota) * ratio)
|
||||
if ratio != 0 && quota <= 0 {
|
||||
|
||||
@@ -180,7 +180,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
Description: "quota_not_enough",
|
||||
}
|
||||
}
|
||||
requestURL := c.Request.URL.String()
|
||||
requestURL := getMjRequestPath(c.Request.URL.String())
|
||||
baseURL := c.GetString("base_url")
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
|
||||
@@ -260,7 +260,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
||||
c.Set("channel_id", originTask.ChannelId)
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
|
||||
requestURL := c.Request.URL.String()
|
||||
requestURL := getMjRequestPath(c.Request.URL.String())
|
||||
fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
|
||||
midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
|
||||
if err != nil {
|
||||
@@ -440,7 +440,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
}
|
||||
|
||||
//baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
requestURL := getMjRequestPath(c.Request.URL.String())
|
||||
|
||||
baseURL := c.GetString("base_url")
|
||||
|
||||
@@ -605,3 +605,15 @@ type taskChangeParams struct {
|
||||
Action string
|
||||
Index int
|
||||
}
|
||||
|
||||
func getMjRequestPath(path string) string {
|
||||
requestURL := path
|
||||
if strings.Contains(requestURL, "/mj-") {
|
||||
urls := strings.Split(requestURL, "/mj/")
|
||||
if len(urls) < 2 {
|
||||
return requestURL
|
||||
}
|
||||
requestURL = "/mj/" + urls[1]
|
||||
}
|
||||
return requestURL
|
||||
}
|
||||
|
||||
@@ -162,24 +162,15 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
return service.RelayErrorHandler(resp)
|
||||
return service.OpenAIErrorWrapper(fmt.Errorf("bad response status code: %d", resp.StatusCode), "bad_response_status_code", resp.StatusCode)
|
||||
}
|
||||
|
||||
usage, openaiErr, sensitiveResp := adaptor.DoResponse(c, resp, relayInfo)
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
if sensitiveResp == nil { // 如果没有敏感词检查结果
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
return openaiErr
|
||||
} else {
|
||||
// 如果有敏感词检查结果,不返回预消耗配额,继续消耗配额
|
||||
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, sensitiveResp)
|
||||
if constant.StopOnSensitiveEnabled { // 是否直接返回错误
|
||||
return openaiErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
return openaiErr
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, nil)
|
||||
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -258,7 +249,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
|
||||
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
|
||||
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
||||
modelPrice float64, sensitiveResp *dto.SensitiveResponse) {
|
||||
modelPrice float64) {
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
promptTokens := usage.PromptTokens
|
||||
@@ -293,9 +284,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
||||
logContent += fmt.Sprintf("(可能是上游超时)")
|
||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota))
|
||||
} else {
|
||||
if sensitiveResp != nil {
|
||||
logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
|
||||
}
|
||||
//if sensitiveResp != nil {
|
||||
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
|
||||
//}
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
if err != nil {
|
||||
|
||||
@@ -43,7 +43,16 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/moderations", controller.Relay)
|
||||
}
|
||||
|
||||
relayMjRouter := router.Group("/mj")
|
||||
registerMjRouterGroup(relayMjRouter)
|
||||
|
||||
relayMjModeRouter := router.Group("/:mode/mj")
|
||||
registerMjRouterGroup(relayMjModeRouter)
|
||||
//relayMjRouter.Use()
|
||||
}
|
||||
|
||||
func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
|
||||
relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage)
|
||||
relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
|
||||
{
|
||||
@@ -61,5 +70,4 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
|
||||
}
|
||||
//relayMjRouter.Use()
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int)
|
||||
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
|
||||
text := err.Error()
|
||||
// 定义一个正则表达式匹配URL
|
||||
if strings.Contains(text, "Post") {
|
||||
if strings.Contains(text, "Post") || strings.Contains(text, "dial") {
|
||||
common.SysLog(fmt.Sprintf("error: %s", text))
|
||||
text = "请求上游地址失败"
|
||||
}
|
||||
|
||||
@@ -208,7 +208,7 @@ func CountTokenInput(input any, model string, check bool) (int, error, bool) {
|
||||
}
|
||||
return CountTokenText(text, model, check)
|
||||
}
|
||||
return 0, errors.New("unsupported input type"), false
|
||||
return CountTokenInput(fmt.Sprintf("%v", input), model, check)
|
||||
}
|
||||
|
||||
func CountAudioToken(text string, model string, check bool) (int, error, bool) {
|
||||
|
||||
@@ -330,21 +330,21 @@ const OperationSetting = () => {
|
||||
name='CheckSensitiveOnPromptEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.CheckSensitiveOnCompletionEnabled === 'true'}
|
||||
label='启用生成内容检查'
|
||||
name='CheckSensitiveOnCompletionEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.StopOnSensitiveEnabled === 'true'}
|
||||
label='在检测到屏蔽词时,立刻停止生成,否则替换屏蔽词'
|
||||
name='StopOnSensitiveEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
{/*<Form.Checkbox*/}
|
||||
{/* checked={inputs.CheckSensitiveOnCompletionEnabled === 'true'}*/}
|
||||
{/* label='启用生成内容检查'*/}
|
||||
{/* name='CheckSensitiveOnCompletionEnabled'*/}
|
||||
{/* onChange={handleInputChange}*/}
|
||||
{/*/>*/}
|
||||
</Form.Group>
|
||||
{/*<Form.Group inline>*/}
|
||||
{/* <Form.Checkbox*/}
|
||||
{/* checked={inputs.StopOnSensitiveEnabled === 'true'}*/}
|
||||
{/* label='在检测到屏蔽词时,立刻停止生成,否则替换屏蔽词'*/}
|
||||
{/* name='StopOnSensitiveEnabled'*/}
|
||||
{/* onChange={handleInputChange}*/}
|
||||
{/* />*/}
|
||||
{/*</Form.Group>*/}
|
||||
{/*<Form.Group>*/}
|
||||
{/* <Form.Input*/}
|
||||
{/* label="流模式下缓存队列,默认不缓存,设置越大检测越准确,但是回复会有卡顿感"*/}
|
||||
|
||||
@@ -45,7 +45,9 @@ const SystemSetting = () => {
|
||||
TurnstileSiteKey: '',
|
||||
TurnstileSecretKey: '',
|
||||
RegisterEnabled: '',
|
||||
UserSelfDeletionEnabled: false,
|
||||
EmailDomainRestrictionEnabled: '',
|
||||
SMTPSSLEnabled: '',
|
||||
EmailDomainWhitelist: [],
|
||||
// telegram login
|
||||
TelegramOAuthEnabled: '',
|
||||
@@ -103,7 +105,9 @@ const SystemSetting = () => {
|
||||
case 'TelegramOAuthEnabled':
|
||||
case 'TurnstileCheckEnabled':
|
||||
case 'EmailDomainRestrictionEnabled':
|
||||
case 'SMTPSSLEnabled':
|
||||
case 'RegisterEnabled':
|
||||
case 'UserSelfDeletionEnabled':
|
||||
case 'PaymentEnabled':
|
||||
value = inputs[key] === 'true' ? 'false' : 'true';
|
||||
break;
|
||||
@@ -137,7 +141,7 @@ const SystemSetting = () => {
|
||||
}
|
||||
if (
|
||||
name === 'Notice' ||
|
||||
name.startsWith('SMTP') ||
|
||||
(name.startsWith('SMTP') && name !== 'SMTPSSLEnabled') ||
|
||||
name === 'ServerAddress' ||
|
||||
name === 'StripeApiSecret' ||
|
||||
name === 'StripeWebhookSecret' ||
|
||||
@@ -536,6 +540,12 @@ const SystemSetting = () => {
|
||||
name='TurnstileCheckEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.UserSelfDeletionEnabled === 'true'}
|
||||
label='允许用户自行删除账户'
|
||||
name='UserSelfDeletionEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Divider />
|
||||
<Header as='h3'>
|
||||
@@ -644,6 +654,14 @@ const SystemSetting = () => {
|
||||
placeholder='敏感信息不会发送到前端显示'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths={3}>
|
||||
<Form.Checkbox
|
||||
label='启用SMTP SSL(465端口强制开启)'
|
||||
name='SMTPSSLEnabled'
|
||||
onChange={handleInputChange}
|
||||
checked={inputs.SMTPSSLEnabled === 'true'}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitSMTP}>保存 SMTP 设置</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3'>
|
||||
|
||||
@@ -426,8 +426,11 @@ const TokensTable = () => {
|
||||
if (await copy(text)) {
|
||||
showSuccess('已复制到剪贴板!');
|
||||
} else {
|
||||
// setSearchKeyword(text);
|
||||
Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text });
|
||||
Modal.error({
|
||||
title: '无法复制到剪贴板,请手动复制',
|
||||
content: text,
|
||||
size: 'large',
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user