mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-17 19:13:42 +08:00
Compare commits
30 Commits
v0.2.7.0-a
...
v0.2.7.2-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
220ab412e2 | ||
|
|
7029065892 | ||
|
|
0f687aab9a | ||
|
|
5e936b3923 | ||
|
|
d55cb35c1c | ||
|
|
5be4cbcaaf | ||
|
|
e67aa370bc | ||
|
|
7b36a2b885 | ||
|
|
c88f3741e6 | ||
|
|
4e7e206290 | ||
|
|
579fc8129e | ||
|
|
f55f63f412 | ||
|
|
0526c85732 | ||
|
|
b75134ece4 | ||
|
|
a075598757 | ||
|
|
a984daa503 | ||
|
|
90abe7f27d | ||
|
|
bb313eb26f | ||
|
|
02545e4856 | ||
|
|
49cec50908 | ||
|
|
4f6710e50c | ||
|
|
03b130f2b5 | ||
|
|
45b9de9df9 | ||
|
|
e062cf32e3 | ||
|
|
52debe7572 | ||
|
|
df6502733c | ||
|
|
9896ba0a64 | ||
|
|
e8b93ed6ec | ||
|
|
b0e234e8f5 | ||
|
|
20d71711d3 |
1
.github/workflows/docker-image-arm64.yml
vendored
1
.github/workflows/docker-image-arm64.yml
vendored
@@ -4,6 +4,7 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- '!*-alpha*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
|
||||
@@ -72,7 +72,8 @@
|
||||
|
||||
## 比原版One API多出的配置
|
||||
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
|
||||
|
||||
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`, 可选值为 `true` 和 `false`
|
||||
- `FORCE_STREAM_OPTION`:覆盖客户端stream_options参数,请求上游返回流模式usage,目前仅支持 `OpenAI` 渠道类型
|
||||
## 部署
|
||||
### 部署要求
|
||||
- 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机)
|
||||
|
||||
@@ -212,36 +212,37 @@ const (
|
||||
ChannelTypeSunoAPI = 36
|
||||
ChannelTypeDify = 37
|
||||
ChannelTypeJina = 38
|
||||
ChannelCloudflare = 39
|
||||
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"http://localhost:11434", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://ai.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.cloud.tencent.com", //23
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"http://localhost:11434", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://ai.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.tencentcloudapi.com", //23
|
||||
"https://generativelanguage.googleapis.com", //24
|
||||
"https://api.moonshot.cn", //25
|
||||
"https://open.bigmodel.cn", //26
|
||||
@@ -257,4 +258,5 @@ var ChannelBaseURLs = []string{
|
||||
"", //36
|
||||
"", //37
|
||||
"https://api.jina.ai", //38
|
||||
"https://api.cloudflare.com", //39
|
||||
}
|
||||
|
||||
@@ -24,3 +24,15 @@ func GetEnvOrDefaultString(env string, defaultValue string) string {
|
||||
}
|
||||
return os.Getenv(env)
|
||||
}
|
||||
|
||||
func GetEnvOrDefaultBool(env string, defaultValue bool) bool {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
b, err := strconv.ParseBool(os.Getenv(env))
|
||||
if err != nil {
|
||||
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %t", env, err.Error(), defaultValue))
|
||||
return defaultValue
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -5,3 +5,7 @@ import (
|
||||
)
|
||||
|
||||
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
|
||||
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
|
||||
@@ -4,6 +4,7 @@ var MjNotifyEnabled = false
|
||||
var MjAccountFilterEnabled = false
|
||||
var MjModeClearEnabled = false
|
||||
var MjForwardUrlEnabled = true
|
||||
var MjActionCheckSuccessEnabled = true
|
||||
|
||||
const (
|
||||
MjErrorUnknown = 5
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -24,7 +25,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
|
||||
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||
tik := time.Now()
|
||||
if channel.Type == common.ChannelTypeMidjourney {
|
||||
return errors.New("midjourney channel test is not supported"), nil
|
||||
@@ -40,35 +41,13 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
Body: nil,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeAzure:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
//case common.ChannelTypeAIProxyLibrary:
|
||||
// c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
}
|
||||
|
||||
meta := relaycommon.GenRelayInfo(c)
|
||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
}
|
||||
if testModel == "" {
|
||||
if channel.TestModel != nil && *channel.TestModel != "" {
|
||||
testModel = *channel.TestModel
|
||||
} else {
|
||||
if len(adaptor.GetModelList()) > 0 {
|
||||
testModel = adaptor.GetModelList()[0]
|
||||
if len(channel.GetModels()) > 0 {
|
||||
testModel = channel.GetModels()[0]
|
||||
} else {
|
||||
testModel = "gpt-3.5-turbo"
|
||||
}
|
||||
@@ -79,8 +58,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
openaiErr := service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError).Error
|
||||
return err, &openaiErr
|
||||
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[testModel] != "" {
|
||||
testModel = modelMap[testModel]
|
||||
@@ -88,6 +66,20 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
}
|
||||
}
|
||||
|
||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
|
||||
middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||
|
||||
meta := relaycommon.GenRelayInfo(c)
|
||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
}
|
||||
|
||||
request := buildTestRequest()
|
||||
request.Model = testModel
|
||||
meta.UpstreamModelName = testModel
|
||||
@@ -95,7 +87,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
|
||||
adaptor.Init(meta, *request)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, meta, request)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
@@ -111,17 +103,16 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
}
|
||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||
err := relaycommon.RelayErrorHandler(resp)
|
||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
|
||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err
|
||||
}
|
||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||
if respErr != nil {
|
||||
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
|
||||
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
||||
}
|
||||
if usage == nil {
|
||||
return errors.New("usage is nil"), nil
|
||||
}
|
||||
result := w.Result()
|
||||
// print result.Body
|
||||
respBody, err := io.ReadAll(result.Body)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
@@ -230,7 +221,7 @@ func testAllChannels(notify bool) error {
|
||||
for _, channel := range channels {
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
err, openaiErr := testChannel(channel, "")
|
||||
err, openaiWithStatusErr := testChannel(channel, "")
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
|
||||
@@ -239,27 +230,29 @@ func testAllChannels(notify bool) error {
|
||||
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||
ban = true
|
||||
}
|
||||
if openaiErr != nil {
|
||||
err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message))
|
||||
ban = true
|
||||
|
||||
// request error disables the channel
|
||||
if openaiWithStatusErr != nil {
|
||||
oaiErr := openaiWithStatusErr.Error
|
||||
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
|
||||
ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
|
||||
}
|
||||
|
||||
// parse *int to bool
|
||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||
ban = false
|
||||
}
|
||||
if openaiErr != nil {
|
||||
openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{
|
||||
StatusCode: -1,
|
||||
Error: *openaiErr,
|
||||
LocalError: false,
|
||||
}
|
||||
if isChannelEnabled && service.ShouldDisableChannel(channel.Type, &openAiErrWithStatus) && ban {
|
||||
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||
}
|
||||
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
|
||||
service.EnableChannel(channel.Id, channel.Name)
|
||||
}
|
||||
|
||||
// disable channel
|
||||
if ban && isChannelEnabled {
|
||||
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||
}
|
||||
|
||||
// enable channel
|
||||
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
|
||||
service.EnableChannel(channel.Id, channel.Name)
|
||||
}
|
||||
|
||||
channel.UpdateResponseTime(milliseconds)
|
||||
time.Sleep(common.RequestInterval)
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ type GeneralOpenAIRequest struct {
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
@@ -43,8 +44,12 @@ type OpenAIFunction struct {
|
||||
Parameters any `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
|
||||
return int64(r.MaxTokens)
|
||||
type StreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) GetMaxTokens() int {
|
||||
return int(r.MaxTokens)
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
|
||||
@@ -66,10 +66,6 @@ type ChatCompletionsStreamResponseChoiceDelta struct {
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool {
|
||||
return c.Content == nil && len(c.ToolCalls) == 0
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
|
||||
c.Content = &s
|
||||
}
|
||||
@@ -102,10 +98,23 @@ type ChatCompletionsStreamResponse struct {
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint *string `json:"system_fingerprint"`
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage"`
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
|
||||
if c.SystemFingerprint == nil {
|
||||
return ""
|
||||
}
|
||||
return *c.SystemFingerprint
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponse) SetSystemFingerprint(s string) {
|
||||
c.SystemFingerprint = &s
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseSimple struct {
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type CompletionsStreamResponse struct {
|
||||
|
||||
@@ -198,11 +198,11 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
//case common.ChannelTypeAIProxyLibrary:
|
||||
// c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
case common.ChannelCloudflare:
|
||||
c.Set("api_version", channel.Other)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Channel struct {
|
||||
@@ -33,6 +34,13 @@ type Channel struct {
|
||||
OtherInfo string `json:"other_info"`
|
||||
}
|
||||
|
||||
func (channel *Channel) GetModels() []string {
|
||||
if channel.Models == "" {
|
||||
return []string{}
|
||||
}
|
||||
return strings.Split(strings.Trim(channel.Models, ","), ",")
|
||||
}
|
||||
|
||||
func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
||||
otherInfo := make(map[string]interface{})
|
||||
if channel.OtherInfo != "" {
|
||||
|
||||
@@ -99,6 +99,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
|
||||
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
|
||||
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
|
||||
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(constant.MjActionCheckSuccessEnabled)
|
||||
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
|
||||
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
|
||||
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
||||
@@ -210,6 +211,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
constant.MjModeClearEnabled = boolValue
|
||||
case "MjForwardUrlEnabled":
|
||||
constant.MjForwardUrlEnabled = boolValue
|
||||
case "MjActionCheckSuccessEnabled":
|
||||
constant.MjActionCheckSuccessEnabled = boolValue
|
||||
case "CheckSensitiveEnabled":
|
||||
constant.CheckSensitiveEnabled = boolValue
|
||||
case "CheckSensitiveOnPromptEnabled":
|
||||
|
||||
@@ -298,7 +298,8 @@ func (user *User) ValidateAndFill() (err error) {
|
||||
if user.Username == "" || password == "" {
|
||||
return errors.New("用户名或密码为空")
|
||||
}
|
||||
DB.Where(User{Username: user.Username}).First(user)
|
||||
// find buy username or email
|
||||
DB.Where("username = ? OR email = ?", user.Username, user.Username).First(user)
|
||||
okay := common.ValidatePasswordAndHash(password, user.Password)
|
||||
if !okay || user.Status != common.UserStatusEnabled {
|
||||
return errors.New("用户名或密码错误,或用户已被封禁")
|
||||
|
||||
@@ -14,7 +14,7 @@ type Adaptor interface {
|
||||
InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest)
|
||||
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
|
||||
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
||||
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
||||
|
||||
@@ -42,11 +42,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
|
||||
return baiduEmbeddingRequest, nil
|
||||
|
||||
@@ -41,7 +41,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -68,7 +68,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = awsStreamHandler(c, info, a.RequestMode)
|
||||
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
err, usage = awsHandler(c, info, a.RequestMode)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
relaymodel "one-api/dto"
|
||||
"one-api/relay/channel/claude"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -112,7 +113,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||
@@ -162,7 +163,6 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
event, ok := <-stream.Events()
|
||||
if !ok {
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -214,6 +214,17 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
if info.ShouldIncludeUsage {
|
||||
response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
|
||||
err := service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
service.Done(c)
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
@@ -99,11 +99,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
|
||||
return baiduEmbeddingRequest, nil
|
||||
|
||||
@@ -19,7 +19,7 @@ type BaiduChatRequest struct {
|
||||
System string `json:"system,omitempty"`
|
||||
DisableSearch bool `json:"disable_search,omitempty"`
|
||||
EnableCitation bool `json:"enable_citation,omitempty"`
|
||||
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@@ -23,14 +23,20 @@ var baiduTokenStore sync.Map
|
||||
|
||||
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
baiduRequest := BaiduChatRequest{
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
PenaltyScore: request.FrequencyPenalty,
|
||||
Stream: request.Stream,
|
||||
DisableSearch: false,
|
||||
EnableCitation: false,
|
||||
MaxOutputTokens: int(request.MaxTokens),
|
||||
UserId: request.User,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
PenaltyScore: request.FrequencyPenalty,
|
||||
Stream: request.Stream,
|
||||
DisableSearch: false,
|
||||
EnableCitation: false,
|
||||
UserId: request.User,
|
||||
}
|
||||
if request.MaxTokens != 0 {
|
||||
maxTokens := int(request.MaxTokens)
|
||||
if request.MaxTokens == 1 {
|
||||
maxTokens = 2
|
||||
}
|
||||
baiduRequest.MaxOutputTokens = &maxTokens
|
||||
}
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
|
||||
@@ -53,7 +53,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
@@ -72,6 +72,19 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
if claudeRequest.MaxTokens == 0 {
|
||||
claudeRequest.MaxTokens = 4096
|
||||
}
|
||||
if textRequest.Stop != nil {
|
||||
// stop maybe string/array string, convert to array string
|
||||
switch textRequest.Stop.(type) {
|
||||
case string:
|
||||
claudeRequest.StopSequences = []string{textRequest.Stop.(string)}
|
||||
case []interface{}:
|
||||
stopSequences := make([]string, 0)
|
||||
for _, stop := range textRequest.Stop.([]interface{}) {
|
||||
stopSequences = append(stopSequences, stop.(string))
|
||||
}
|
||||
claudeRequest.StopSequences = stopSequences
|
||||
}
|
||||
}
|
||||
formatMessages := make([]dto.Message, 0)
|
||||
var lastMessage *dto.Message
|
||||
for i, message := range textRequest.Messages {
|
||||
@@ -330,22 +343,15 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
response.Created = createdTime
|
||||
response.Model = info.UpstreamModelName
|
||||
|
||||
jsonStr, err := json.Marshal(response)
|
||||
err = service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
common.SysError(err.Error())
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
@@ -356,6 +362,18 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
if info.ShouldIncludeUsage {
|
||||
response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
|
||||
err := service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
service.Done(c)
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
|
||||
76
relay/channel/cloudflare/adaptor.go
Normal file
76
relay/channel/cloudflare/adaptor.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil
|
||||
case constant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeCompletions:
|
||||
return convertCf2CompletionsRequest(*request), nil
|
||||
default:
|
||||
return request, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = cfStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = cfHandler(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
38
relay/channel/cloudflare/constant.go
Normal file
38
relay/channel/cloudflare/constant.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package cloudflare
|
||||
|
||||
var ModelList = []string{
|
||||
"@cf/meta/llama-2-7b-chat-fp16",
|
||||
"@cf/meta/llama-2-7b-chat-int8",
|
||||
"@cf/mistral/mistral-7b-instruct-v0.1",
|
||||
"@hf/thebloke/deepseek-coder-6.7b-base-awq",
|
||||
"@hf/thebloke/deepseek-coder-6.7b-instruct-awq",
|
||||
"@cf/deepseek-ai/deepseek-math-7b-base",
|
||||
"@cf/deepseek-ai/deepseek-math-7b-instruct",
|
||||
"@cf/thebloke/discolm-german-7b-v1-awq",
|
||||
"@cf/tiiuae/falcon-7b-instruct",
|
||||
"@cf/google/gemma-2b-it-lora",
|
||||
"@hf/google/gemma-7b-it",
|
||||
"@cf/google/gemma-7b-it-lora",
|
||||
"@hf/nousresearch/hermes-2-pro-mistral-7b",
|
||||
"@hf/thebloke/llama-2-13b-chat-awq",
|
||||
"@cf/meta-llama/llama-2-7b-chat-hf-lora",
|
||||
"@cf/meta/llama-3-8b-instruct",
|
||||
"@hf/thebloke/llamaguard-7b-awq",
|
||||
"@hf/thebloke/mistral-7b-instruct-v0.1-awq",
|
||||
"@hf/mistralai/mistral-7b-instruct-v0.2",
|
||||
"@cf/mistral/mistral-7b-instruct-v0.2-lora",
|
||||
"@hf/thebloke/neural-chat-7b-v3-1-awq",
|
||||
"@cf/openchat/openchat-3.5-0106",
|
||||
"@hf/thebloke/openhermes-2.5-mistral-7b-awq",
|
||||
"@cf/microsoft/phi-2",
|
||||
"@cf/qwen/qwen1.5-0.5b-chat",
|
||||
"@cf/qwen/qwen1.5-1.8b-chat",
|
||||
"@cf/qwen/qwen1.5-14b-chat-awq",
|
||||
"@cf/qwen/qwen1.5-7b-chat-awq",
|
||||
"@cf/defog/sqlcoder-7b-2",
|
||||
"@hf/nexusflow/starling-lm-7b-beta",
|
||||
"@cf/tinyllama/tinyllama-1.1b-chat-v1.0",
|
||||
"@hf/thebloke/zephyr-7b-beta-awq",
|
||||
}
|
||||
|
||||
var ChannelName = "cloudflare"
|
||||
13
relay/channel/cloudflare/model.go
Normal file
13
relay/channel/cloudflare/model.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package cloudflare
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type CfRequest struct {
|
||||
Messages []dto.Message `json:"messages,omitempty"`
|
||||
Lora string `json:"lora,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
}
|
||||
121
relay/channel/cloudflare/relay_cloudflare.go
Normal file
121
relay/channel/cloudflare/relay_cloudflare.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
|
||||
p, _ := textRequest.Prompt.(string)
|
||||
return &CfRequest{
|
||||
Prompt: p,
|
||||
MaxTokens: textRequest.GetMaxTokens(),
|
||||
Stream: textRequest.Stream,
|
||||
Temperature: textRequest.Temperature,
|
||||
}
|
||||
}
|
||||
|
||||
func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
service.SetEventStreamHeaders(c)
|
||||
id := service.GetResponseID(c)
|
||||
var responseText string
|
||||
isFirst := true
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < len("data: ") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "data: ")
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &response)
|
||||
if err != nil {
|
||||
common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
|
||||
continue
|
||||
}
|
||||
for _, choice := range response.Choices {
|
||||
choice.Delta.Role = "assistant"
|
||||
responseText += choice.Delta.GetContentString()
|
||||
}
|
||||
response.Id = id
|
||||
response.Model = info.UpstreamModelName
|
||||
err = service.ObjectData(c, response)
|
||||
if isFirst {
|
||||
isFirst = false
|
||||
info.FirstResponseTime = time.Now()
|
||||
}
|
||||
if err != nil {
|
||||
common.LogError(c, "error_rendering_stream_response: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||
}
|
||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
if info.ShouldIncludeUsage {
|
||||
response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||
err := service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
|
||||
}
|
||||
}
|
||||
service.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
common.LogError(c, "close_response_body_failed: "+err.Error())
|
||||
}
|
||||
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var response dto.TextResponse
|
||||
err = json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
response.Model = info.UpstreamModelName
|
||||
var responseText string
|
||||
for _, choice := range response.Choices {
|
||||
responseText += choice.Message.StringContent()
|
||||
}
|
||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
response.Usage = *usage
|
||||
response.Id = service.GetResponseID(c)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
return nil, usage
|
||||
}
|
||||
@@ -34,7 +34,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return requestOpenAI2Cohere(*request), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ type CohereRequest struct {
|
||||
ChatHistory []ChatHistory `json:"chat_history"`
|
||||
Message string `json:"message"`
|
||||
Stream bool `json:"stream"`
|
||||
MaxTokens int64 `json:"max_tokens"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
}
|
||||
|
||||
type ChatHistory struct {
|
||||
|
||||
@@ -32,7 +32,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ type DifyData struct {
|
||||
|
||||
type DifyChatCompletionResponse struct {
|
||||
ConversationId string `json:"conversation_id"`
|
||||
Answers string `json:"answers"`
|
||||
Answer string `json:"answer"`
|
||||
CreateAt int64 `json:"create_at"`
|
||||
MetaData DifyMetaData `json:"metadata"`
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
@@ -48,9 +49,9 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
|
||||
Model: "dify",
|
||||
}
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
if difyResponse.Event == "workflow_started" {
|
||||
if constant.DifyDebug && difyResponse.Event == "workflow_started" {
|
||||
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
|
||||
} else if difyResponse.Event == "node_started" {
|
||||
} else if constant.DifyDebug && difyResponse.Event == "node_started" {
|
||||
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
|
||||
} else if difyResponse.Event == "message" {
|
||||
choice.Delta.SetContentString(difyResponse.Answer)
|
||||
@@ -116,6 +117,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var difyResponse DifyChatCompletionResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
@@ -133,7 +135,7 @@ func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInf
|
||||
Created: common.GetTimestamp(),
|
||||
Usage: difyResponse.MetaData.Usage,
|
||||
}
|
||||
content, _ := json.Marshal(difyResponse.Answers)
|
||||
content, _ := json.Marshal(difyResponse.Answer)
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
@@ -52,7 +51,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -69,9 +68,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = geminiChatStreamHandler(c, resp, info)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
err, usage = geminiChatStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
||||
@@ -59,4 +59,11 @@ type GeminiChatPromptFeedback struct {
|
||||
type GeminiChatResponse struct {
|
||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||
UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
type GeminiUsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@@ -162,8 +163,12 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
|
||||
return &response
|
||||
}
|
||||
|
||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseText := ""
|
||||
responseJson := ""
|
||||
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createAt := common.GetTimestamp()
|
||||
var usage = &dto.Usage{}
|
||||
dataChan := make(chan string, 5)
|
||||
stopChan := make(chan bool, 2)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
@@ -182,6 +187,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
responseJson += data
|
||||
data = strings.TrimSpace(data)
|
||||
if !strings.HasPrefix(data, "\"text\": \"") {
|
||||
continue
|
||||
@@ -216,10 +222,10 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.SetContentString(dummy.Content)
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Id: id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "gemini-pro",
|
||||
Created: createAt,
|
||||
Model: info.UpstreamModelName,
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
@@ -230,15 +236,34 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
var geminiChatResponses []GeminiChatResponse
|
||||
err := json.Unmarshal([]byte(responseJson), &geminiChatResponses)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
log.Printf("cannot get gemini usage: %s", err.Error())
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
for _, response := range geminiChatResponses {
|
||||
usage.PromptTokens = response.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount
|
||||
}
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
return nil, responseText
|
||||
if info.ShouldIncludeUsage {
|
||||
response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
err := service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
service.Done(c)
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
@@ -267,11 +292,10 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
|
||||
@@ -36,7 +36,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -36,11 +36,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return requestOpenAI2Embeddings(*request), nil
|
||||
default:
|
||||
@@ -59,8 +59,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
||||
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
} else {
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"one-api/relay/channel/minimax"
|
||||
"one-api/relay/channel/moonshot"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -74,10 +73,13 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if info.ChannelType != common.ChannelTypeOpenAI {
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@@ -87,11 +89,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
var toolCount int
|
||||
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
err, usage, _, _ = OpenaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
||||
@@ -14,37 +14,34 @@ import (
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string, int) {
|
||||
//checkSensitive := constant.ShouldCheckCompletionSensitive()
|
||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) {
|
||||
hasStreamUsage := false
|
||||
responseId := ""
|
||||
var createAt int64 = 0
|
||||
var systemFingerprint string
|
||||
|
||||
var responseTextBuilder strings.Builder
|
||||
var usage = &dto.Usage{}
|
||||
toolCount := 0
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string, 5)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
var streamItems []string // store stream items
|
||||
|
||||
service.SetEventStreamHeaders(c)
|
||||
|
||||
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
stopChan := make(chan bool, 2)
|
||||
defer close(stopChan)
|
||||
defer close(dataChan)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
go func() {
|
||||
wg.Add(1)
|
||||
defer wg.Done()
|
||||
var streamItems []string // store stream items
|
||||
for scanner.Scan() {
|
||||
info.SetFirstResponseTime()
|
||||
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
@@ -52,43 +49,43 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
||||
continue
|
||||
}
|
||||
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
|
||||
// send data timeout, stop the stream
|
||||
common.LogError(c, "send data timeout, stop the stream")
|
||||
break
|
||||
}
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
service.StringData(c, data)
|
||||
streamItems = append(streamItems, data)
|
||||
}
|
||||
}
|
||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
var streamResponses []dto.ChatCompletionsStreamResponseSimple
|
||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.ChatCompletionsStreamResponseSimple
|
||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||
if err == nil {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > toolCount {
|
||||
toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// 超时处理逻辑
|
||||
common.LogError(c, "streaming timeout")
|
||||
case <-stopChan:
|
||||
// 正常结束
|
||||
}
|
||||
|
||||
// 计算token
|
||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
var streamResponses []dto.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||
if err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||
if err == nil {
|
||||
responseId = streamResponse.Id
|
||||
createAt = streamResponse.Created
|
||||
systemFingerprint = streamResponse.GetSystemFingerprint()
|
||||
if service.ValidUsage(streamResponse.Usage) {
|
||||
usage = streamResponse.Usage
|
||||
hasStreamUsage = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
@@ -103,66 +100,71 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
}
|
||||
}
|
||||
}
|
||||
case relayconstant.RelayModeCompletions:
|
||||
var streamResponses []dto.CompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.CompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||
if err == nil {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
} else {
|
||||
for _, streamResponse := range streamResponses {
|
||||
responseId = streamResponse.Id
|
||||
createAt = streamResponse.Created
|
||||
systemFingerprint = streamResponse.GetSystemFingerprint()
|
||||
if service.ValidUsage(streamResponse.Usage) {
|
||||
usage = streamResponse.Usage
|
||||
hasStreamUsage = true
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > toolCount {
|
||||
toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, streamResponse := range streamResponses {
|
||||
}
|
||||
}
|
||||
case relayconstant.RelayModeCompletions:
|
||||
var streamResponses []dto.CompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||
if err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.CompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||
if err == nil {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(dataChan) > 0 {
|
||||
// wait data out
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
common.SafeSendBool(stopChan, true)
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
isFirst := true
|
||||
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
|
||||
defer ticker.Stop()
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
common.LogError(c, "reading data from upstream timeout")
|
||||
return false
|
||||
case data := <-dataChan:
|
||||
if isFirst {
|
||||
isFirst = false
|
||||
info.FirstResponseTime = time.Now()
|
||||
} else {
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
|
||||
if strings.HasPrefix(data, "data: [DONE]") {
|
||||
data = data[:12]
|
||||
}
|
||||
// some implementations may add \r at the end of data
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
c.Render(-1, common.CustomEvent{Data: data})
|
||||
return true
|
||||
case <-stopChan:
|
||||
return false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if !hasStreamUsage {
|
||||
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
|
||||
if info.ShouldIncludeUsage && !hasStreamUsage {
|
||||
response := service.GenerateFinalUsageResponse(responseId, createAt, info.UpstreamModelName, *usage)
|
||||
response.SetSystemFingerprint(systemFingerprint)
|
||||
service.ObjectData(c, response)
|
||||
}
|
||||
|
||||
service.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount
|
||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
|
||||
}
|
||||
wg.Wait()
|
||||
return nil, responseTextBuilder.String(), toolCount
|
||||
return nil, usage, responseTextBuilder.String(), toolCount
|
||||
}
|
||||
|
||||
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
|
||||
@@ -33,7 +33,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -55,8 +55,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
||||
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
type Adaptor struct {
|
||||
Sign string
|
||||
AppID int64
|
||||
Action string
|
||||
Version string
|
||||
Timestamp int64
|
||||
@@ -34,7 +35,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIReq
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/hyllm/v1/chat/completions", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
@@ -46,17 +47,18 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
_, secretId, secretKey, err := parseTencentConfig(apiKey)
|
||||
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
||||
a.AppID = appId
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tencentRequest := requestOpenAI2Tencent(*request)
|
||||
tencentRequest := requestOpenAI2Tencent(a, *request)
|
||||
// we have to calculate the sign here
|
||||
a.Sign = getTencentSign(*tencentRequest, a, secretId, secretKey)
|
||||
return tencentRequest, nil
|
||||
|
||||
@@ -30,17 +30,17 @@ type TencentChatRequest struct {
|
||||
//
|
||||
// 注意:
|
||||
// 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。
|
||||
Stream *bool `json:"Stream"`
|
||||
Stream *bool `json:"Stream,omitempty"`
|
||||
// 说明:
|
||||
// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
|
||||
// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
|
||||
// 3. 非必要不建议使用,不合理的取值会影响效果。
|
||||
TopP *float64 `json:"TopP"`
|
||||
TopP *float64 `json:"TopP,omitempty"`
|
||||
// 说明:
|
||||
// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
|
||||
// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
|
||||
// 3. 非必要不建议使用,不合理的取值会影响效果。
|
||||
Temperature *float64 `json:"Temperature"`
|
||||
Temperature *float64 `json:"Temperature,omitempty"`
|
||||
}
|
||||
|
||||
type TencentError struct {
|
||||
@@ -69,3 +69,7 @@ type TencentChatResponse struct {
|
||||
Note string `json:"Note,omitempty"` // 注释
|
||||
ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
}
|
||||
|
||||
type TencentChatResponseSB struct {
|
||||
Response TencentChatResponse `json:"Response,omitempty"`
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
|
||||
// https://cloud.tencent.com/document/product/1729/97732
|
||||
|
||||
func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest {
|
||||
func requestOpenAI2Tencent(a *Adaptor, request dto.GeneralOpenAIRequest) *TencentChatRequest {
|
||||
messages := make([]*TencentMessage, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
@@ -31,17 +31,23 @@ func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest
|
||||
Role: message.Role,
|
||||
})
|
||||
}
|
||||
return &TencentChatRequest{
|
||||
Temperature: &request.Temperature,
|
||||
TopP: &request.TopP,
|
||||
Stream: &request.Stream,
|
||||
Messages: messages,
|
||||
Model: &request.Model,
|
||||
var req = TencentChatRequest{
|
||||
Stream: &request.Stream,
|
||||
Messages: messages,
|
||||
Model: &request.Model,
|
||||
}
|
||||
if request.TopP != 0 {
|
||||
req.TopP = &request.TopP
|
||||
}
|
||||
if request.Temperature != 0 {
|
||||
req.Temperature = &request.Temperature
|
||||
}
|
||||
return &req
|
||||
}
|
||||
|
||||
func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: response.Id,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Usage: dto.Usage{
|
||||
@@ -129,7 +135,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
||||
}
|
||||
|
||||
func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var TencentResponse TencentChatResponse
|
||||
var tencentSb TencentChatResponseSB
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -138,20 +144,20 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &TencentResponse)
|
||||
err = json.Unmarshal(responseBody, &tencentSb)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if TencentResponse.Error.Code != 0 {
|
||||
if tencentSb.Response.Error.Code != 0 {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: TencentResponse.Error.Message,
|
||||
Code: TencentResponse.Error.Code,
|
||||
Message: tencentSb.Response.Error.Message,
|
||||
Code: tencentSb.Response.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
||||
fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
|
||||
@@ -33,7 +33,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -57,9 +57,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
var toolCount int
|
||||
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
err, usage, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
|
||||
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
||||
@@ -9,24 +9,27 @@ import (
|
||||
)
|
||||
|
||||
type RelayInfo struct {
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
UserId int
|
||||
Group string
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
ApiType int
|
||||
IsStream bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
RequestURLPath string
|
||||
ApiVersion string
|
||||
PromptTokens int
|
||||
ApiKey string
|
||||
Organization string
|
||||
BaseUrl string
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
UserId int
|
||||
Group string
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
setFirstResponse bool
|
||||
ApiType int
|
||||
IsStream bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
RequestURLPath string
|
||||
ApiVersion string
|
||||
PromptTokens int
|
||||
ApiKey string
|
||||
Organization string
|
||||
BaseUrl string
|
||||
SupportStreamOptions bool
|
||||
ShouldIncludeUsage bool
|
||||
}
|
||||
|
||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
@@ -65,6 +68,11 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
if info.ChannelType == common.ChannelTypeAzure {
|
||||
info.ApiVersion = GetAPIVersion(c)
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
|
||||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
|
||||
info.ChannelType == common.ChannelCloudflare {
|
||||
info.SupportStreamOptions = true
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
@@ -76,6 +84,13 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
|
||||
info.IsStream = isStream
|
||||
}
|
||||
|
||||
func (info *RelayInfo) SetFirstResponseTime() {
|
||||
if !info.setFirstResponse {
|
||||
info.FirstResponseTime = time.Now()
|
||||
info.setFirstResponse = true
|
||||
}
|
||||
}
|
||||
|
||||
type TaskRelayInfo struct {
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
|
||||
@@ -38,6 +38,7 @@ func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.Open
|
||||
var textResponse dto.TextResponseWithError
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
OpenAIErrorWithStatusCode.Error.Message = fmt.Sprintf("error unmarshalling response body: %s", responseBody)
|
||||
return
|
||||
}
|
||||
OpenAIErrorWithStatusCode.Error = textResponse.Error
|
||||
|
||||
@@ -22,6 +22,7 @@ const (
|
||||
APITypeCohere
|
||||
APITypeDify
|
||||
APITypeJina
|
||||
APITypeCloudflare
|
||||
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
@@ -63,6 +64,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType = APITypeDify
|
||||
case common.ChannelTypeJina:
|
||||
apiType = APITypeJina
|
||||
case common.ChannelCloudflare:
|
||||
apiType = APITypeCloudflare
|
||||
}
|
||||
if apiType == -1 {
|
||||
return APITypeOpenAI, false
|
||||
|
||||
@@ -415,9 +415,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
originTask := model.GetByMJId(userId, mjId)
|
||||
if originTask == nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
|
||||
} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
||||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||
if constant.MjActionCheckSuccessEnabled {
|
||||
if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
||||
}
|
||||
}
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
||||
|
||||
@@ -77,7 +77,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
isModelMapped := false
|
||||
//isModelMapped := false
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
@@ -87,7 +87,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
if modelMap[textRequest.Model] != "" {
|
||||
textRequest.Model = modelMap[textRequest.Model]
|
||||
// set upstream model name
|
||||
isModelMapped = true
|
||||
//isModelMapped = true
|
||||
}
|
||||
}
|
||||
relayInfo.UpstreamModelName = textRequest.Model
|
||||
@@ -130,33 +130,38 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
// 如果不支持StreamOptions,将StreamOptions设置为nil
|
||||
if !relayInfo.SupportStreamOptions || !textRequest.Stream {
|
||||
textRequest.StreamOptions = nil
|
||||
} else {
|
||||
// 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
|
||||
if constant.ForceStreamOption {
|
||||
textRequest.StreamOptions = &dto.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
|
||||
relayInfo.ShouldIncludeUsage = textRequest.StreamOptions.IncludeUsage
|
||||
}
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.Init(relayInfo, *textRequest)
|
||||
var requestBody io.Reader
|
||||
if relayInfo.ApiType == relayconstant.APITypeOpenAI {
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(textRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"one-api/relay/channel/aws"
|
||||
"one-api/relay/channel/baidu"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/cloudflare"
|
||||
"one-api/relay/channel/cohere"
|
||||
"one-api/relay/channel/dify"
|
||||
"one-api/relay/channel/gemini"
|
||||
@@ -59,6 +60,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
||||
return &dify.Adaptor{}
|
||||
case constant.APITypeJina:
|
||||
return &jina.Adaptor{}
|
||||
case constant.APITypeCloudflare:
|
||||
return &cloudflare.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -74,14 +74,14 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
|
||||
return false
|
||||
}
|
||||
|
||||
func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError, status int) bool {
|
||||
func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool {
|
||||
if !common.AutomaticEnableChannelEnabled {
|
||||
return false
|
||||
}
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if openAIErr != nil {
|
||||
if openaiWithStatusErr != nil {
|
||||
return false
|
||||
}
|
||||
if status != common.ChannelStatusAutoDisabled {
|
||||
|
||||
@@ -35,3 +35,8 @@ func ObjectData(c *gin.Context, object interface{}) error {
|
||||
func Done(c *gin.Context) {
|
||||
StringData(c, "[DONE]")
|
||||
}
|
||||
|
||||
func GetResponseID(c *gin.Context) string {
|
||||
logID := c.GetString("X-Oneapi-Request-Id")
|
||||
return fmt.Sprintf("chatcmpl-%s", logID)
|
||||
}
|
||||
@@ -24,3 +24,19 @@ func ResponseText2Usage(responseText string, modeName string, promptTokens int)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return usage, err
|
||||
}
|
||||
|
||||
func GenerateFinalUsageResponse(id string, createAt int64, model string, usage dto.Usage) *dto.ChatCompletionsStreamResponse {
|
||||
return &dto.ChatCompletionsStreamResponse{
|
||||
Id: id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: createAt,
|
||||
Model: model,
|
||||
SystemFingerprint: nil,
|
||||
Choices: make([]dto.ChatCompletionsStreamResponseChoice, 0),
|
||||
Usage: &usage,
|
||||
}
|
||||
}
|
||||
|
||||
func ValidUsage(usage *dto.Usage) bool {
|
||||
return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0)
|
||||
}
|
||||
|
||||
@@ -550,7 +550,7 @@ const ChannelsTable = () => {
|
||||
);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
setChannels(data);
|
||||
setChannelFormat(data);
|
||||
setActivePage(1);
|
||||
} else {
|
||||
showError(message);
|
||||
|
||||
@@ -42,6 +42,7 @@ const OperationSetting = () => {
|
||||
MjAccountFilterEnabled: false,
|
||||
MjModeClearEnabled: false,
|
||||
MjForwardUrlEnabled: false,
|
||||
MjActionCheckSuccessEnabled: false,
|
||||
DrawingEnabled: false,
|
||||
DataExportEnabled: false,
|
||||
DataExportDefaultTime: 'hour',
|
||||
|
||||
@@ -99,6 +99,7 @@ export const CHANNEL_OPTIONS = [
|
||||
color: 'orange',
|
||||
label: 'Google PaLM2',
|
||||
},
|
||||
{ key: 39, text: 'Cloudflare', value: 39, color: 'grey', label: 'Cloudflare' },
|
||||
{ key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
|
||||
{ key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
|
||||
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
|
||||
|
||||
@@ -153,8 +153,8 @@ export function renderModelPrice(
|
||||
let inputRatioPrice = modelRatio * 2.0;
|
||||
let completionRatioPrice = modelRatio * 2.0 * completionRatio;
|
||||
let price =
|
||||
(inputTokens / 1000000) * inputRatioPrice +
|
||||
(completionTokens / 1000000) * completionRatioPrice;
|
||||
(inputTokens / 1000000) * inputRatioPrice * groupRatio +
|
||||
(completionTokens / 1000000) * completionRatioPrice * groupRatio;
|
||||
return (
|
||||
<>
|
||||
<article>
|
||||
|
||||
@@ -605,6 +605,24 @@ const EditChannel = (props) => {
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
{inputs.type === 39 && (
|
||||
<>
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>Account ID:</Typography.Text>
|
||||
</div>
|
||||
<Input
|
||||
name='other'
|
||||
placeholder={
|
||||
'请输入Account ID,例如:d6b5da8hk1awo8nap34ube6gh'
|
||||
}
|
||||
onChange={(value) => {
|
||||
handleInputChange('other', value);
|
||||
}}
|
||||
value={inputs.other}
|
||||
autoComplete='new-password'
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>模型:</Typography.Text>
|
||||
</div>
|
||||
|
||||
@@ -16,6 +16,7 @@ export default function SettingsDrawing(props) {
|
||||
MjAccountFilterEnabled: false,
|
||||
MjForwardUrlEnabled: false,
|
||||
MjModeClearEnabled: false,
|
||||
MjActionCheckSuccessEnabled: false,
|
||||
});
|
||||
const refForm = useRef();
|
||||
const [inputsRow, setInputsRow] = useState(inputs);
|
||||
@@ -156,6 +157,25 @@ export default function SettingsDrawing(props) {
|
||||
}
|
||||
/>
|
||||
</Col>
|
||||
<Col span={8}>
|
||||
<Form.Switch
|
||||
field={'MjActionCheckSuccessEnabled'}
|
||||
label={
|
||||
<>
|
||||
检测必须等待绘图成功才能进行放大等操作
|
||||
</>
|
||||
}
|
||||
size='large'
|
||||
checkedText='|'
|
||||
uncheckedText='〇'
|
||||
onChange={(value) =>
|
||||
setInputs({
|
||||
...inputs,
|
||||
MjActionCheckSuccessEnabled: value,
|
||||
})
|
||||
}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
<Row>
|
||||
<Button size='large' onClick={onSubmit}>
|
||||
|
||||
Reference in New Issue
Block a user