From c88f3741e64ae9c945c4ef54f775f264b47f951f Mon Sep 17 00:00:00 2001
From: CalciumIon <1808837298@qq.com>
Date: Thu, 11 Jul 2024 18:44:45 +0800
Subject: [PATCH 01/34] feat: support claude stop_sequences
---
relay/channel/claude/relay-claude.go | 13 +++++++++++++
1 file changed, 13 insertions(+)
diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go
index 9457f1e..945b20d 100644
--- a/relay/channel/claude/relay-claude.go
+++ b/relay/channel/claude/relay-claude.go
@@ -72,6 +72,19 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
}
+ if textRequest.Stop != nil {
+ // stop maybe string/array string, convert to array string
+ switch textRequest.Stop.(type) {
+ case string:
+ claudeRequest.StopSequences = []string{textRequest.Stop.(string)}
+ case []interface{}:
+ stopSequences := make([]string, 0)
+ for _, stop := range textRequest.Stop.([]interface{}) {
+ stopSequences = append(stopSequences, stop.(string))
+ }
+ claudeRequest.StopSequences = stopSequences
+ }
+ }
formatMessages := make([]dto.Message, 0)
var lastMessage *dto.Message
for i, message := range textRequest.Messages {
From 7b36a2b885f5264ae1d00c2a99d9fec52f0eb1c3 Mon Sep 17 00:00:00 2001
From: CalciumIon <1808837298@qq.com>
Date: Sat, 13 Jul 2024 19:55:22 +0800
Subject: [PATCH 02/34] feat: support cloudflare worker ai
---
common/constants.go | 2 +
controller/channel-test.go | 37 +++---
dto/text_request.go | 4 +-
middleware/distributor.go | 4 +-
relay/channel/cloudflare/adaptor.go | 76 ++++++++++++
relay/channel/cloudflare/constant.go | 38 ++++++
relay/channel/cloudflare/model.go | 13 +++
relay/channel/cloudflare/relay_cloudflare.go | 115 +++++++++++++++++++
relay/channel/cohere/dto.go | 2 +-
relay/common/relay_info.go | 3 +-
relay/constant/api_type.go | 3 +
relay/relay_adaptor.go | 3 +
service/{sse.go => relay.go} | 5 +
web/src/constants/channel.constants.js | 1 +
web/src/pages/Channel/EditChannel.js | 18 +++
15 files changed, 296 insertions(+), 28 deletions(-)
create mode 100644 relay/channel/cloudflare/adaptor.go
create mode 100644 relay/channel/cloudflare/constant.go
create mode 100644 relay/channel/cloudflare/model.go
create mode 100644 relay/channel/cloudflare/relay_cloudflare.go
rename service/{sse.go => relay.go} (87%)
diff --git a/common/constants.go b/common/constants.go
index 66cc10d..97e8583 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -212,6 +212,7 @@ const (
ChannelTypeSunoAPI = 36
ChannelTypeDify = 37
ChannelTypeJina = 38
+ ChannelCloudflare = 39
ChannelTypeDummy // this one is only for count, do not add any channel after this
@@ -257,4 +258,5 @@ var ChannelBaseURLs = []string{
"", //36
"", //37
"https://api.jina.ai", //38
+ "https://api.cloudflare.com", //39
}
diff --git a/controller/channel-test.go b/controller/channel-test.go
index 000d7f2..268dac2 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -12,6 +12,7 @@ import (
"net/url"
"one-api/common"
"one-api/dto"
+ "one-api/middleware"
"one-api/model"
"one-api/relay"
relaycommon "one-api/relay/common"
@@ -40,29 +41,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
Body: nil,
Header: make(http.Header),
}
- c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
- c.Request.Header.Set("Content-Type", "application/json")
- c.Set("channel", channel.Type)
- c.Set("base_url", channel.GetBaseURL())
- switch channel.Type {
- case common.ChannelTypeAzure:
- c.Set("api_version", channel.Other)
- case common.ChannelTypeXunfei:
- c.Set("api_version", channel.Other)
- //case common.ChannelTypeAIProxyLibrary:
- // c.Set("library_id", channel.Other)
- case common.ChannelTypeGemini:
- c.Set("api_version", channel.Other)
- case common.ChannelTypeAli:
- c.Set("plugin", channel.Other)
- }
- meta := relaycommon.GenRelayInfo(c)
- apiType, _ := constant.ChannelType2APIType(channel.Type)
- adaptor := relay.GetAdaptor(apiType)
- if adaptor == nil {
- return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
- }
if testModel == "" {
if channel.TestModel != nil && *channel.TestModel != "" {
testModel = *channel.TestModel
@@ -88,6 +67,20 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
}
}
+ c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Set("channel", channel.Type)
+ c.Set("base_url", channel.GetBaseURL())
+
+ middleware.SetupContextForSelectedChannel(c, channel, testModel)
+
+ meta := relaycommon.GenRelayInfo(c)
+ apiType, _ := constant.ChannelType2APIType(channel.Type)
+ adaptor := relay.GetAdaptor(apiType)
+ if adaptor == nil {
+ return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
+ }
+
request := buildTestRequest()
request.Model = testModel
meta.UpstreamModelName = testModel
diff --git a/dto/text_request.go b/dto/text_request.go
index e12c9b4..ed36988 100644
--- a/dto/text_request.go
+++ b/dto/text_request.go
@@ -48,8 +48,8 @@ type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
-func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
- return int64(r.MaxTokens)
+func (r GeneralOpenAIRequest) GetMaxTokens() int {
+ return int(r.MaxTokens)
}
func (r GeneralOpenAIRequest) ParseInput() []string {
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 61361e6..9f75207 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -198,11 +198,11 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
- //case common.ChannelTypeAIProxyLibrary:
- // c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
+ case common.ChannelCloudflare:
+ c.Set("api_version", channel.Other)
}
}
diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go
new file mode 100644
index 0000000..571f50c
--- /dev/null
+++ b/relay/channel/cloudflare/adaptor.go
@@ -0,0 +1,76 @@
+package cloudflare
+
+import (
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/constant"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ switch info.RelayMode {
+ case constant.RelayModeChatCompletions:
+ return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil
+ case constant.RelayModeEmbeddings:
+ return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil
+ default:
+ return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil
+ }
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ switch relayMode {
+ case constant.RelayModeCompletions:
+ return convertCf2CompletionsRequest(*request), nil
+ default:
+ return request, nil
+ }
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return request, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+ if info.IsStream {
+ err, usage = cfStreamHandler(c, resp, info)
+ } else {
+ err, usage = cfHandler(c, resp, info)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/cloudflare/constant.go b/relay/channel/cloudflare/constant.go
new file mode 100644
index 0000000..a874685
--- /dev/null
+++ b/relay/channel/cloudflare/constant.go
@@ -0,0 +1,38 @@
+package cloudflare
+
+var ModelList = []string{
+ "@cf/meta/llama-2-7b-chat-fp16",
+ "@cf/meta/llama-2-7b-chat-int8",
+ "@cf/mistral/mistral-7b-instruct-v0.1",
+ "@hf/thebloke/deepseek-coder-6.7b-base-awq",
+ "@hf/thebloke/deepseek-coder-6.7b-instruct-awq",
+ "@cf/deepseek-ai/deepseek-math-7b-base",
+ "@cf/deepseek-ai/deepseek-math-7b-instruct",
+ "@cf/thebloke/discolm-german-7b-v1-awq",
+ "@cf/tiiuae/falcon-7b-instruct",
+ "@cf/google/gemma-2b-it-lora",
+ "@hf/google/gemma-7b-it",
+ "@cf/google/gemma-7b-it-lora",
+ "@hf/nousresearch/hermes-2-pro-mistral-7b",
+ "@hf/thebloke/llama-2-13b-chat-awq",
+ "@cf/meta-llama/llama-2-7b-chat-hf-lora",
+ "@cf/meta/llama-3-8b-instruct",
+ "@hf/thebloke/llamaguard-7b-awq",
+ "@hf/thebloke/mistral-7b-instruct-v0.1-awq",
+ "@hf/mistralai/mistral-7b-instruct-v0.2",
+ "@cf/mistral/mistral-7b-instruct-v0.2-lora",
+ "@hf/thebloke/neural-chat-7b-v3-1-awq",
+ "@cf/openchat/openchat-3.5-0106",
+ "@hf/thebloke/openhermes-2.5-mistral-7b-awq",
+ "@cf/microsoft/phi-2",
+ "@cf/qwen/qwen1.5-0.5b-chat",
+ "@cf/qwen/qwen1.5-1.8b-chat",
+ "@cf/qwen/qwen1.5-14b-chat-awq",
+ "@cf/qwen/qwen1.5-7b-chat-awq",
+ "@cf/defog/sqlcoder-7b-2",
+ "@hf/nexusflow/starling-lm-7b-beta",
+ "@cf/tinyllama/tinyllama-1.1b-chat-v1.0",
+ "@hf/thebloke/zephyr-7b-beta-awq",
+}
+
+var ChannelName = "cloudflare"
diff --git a/relay/channel/cloudflare/model.go b/relay/channel/cloudflare/model.go
new file mode 100644
index 0000000..c870813
--- /dev/null
+++ b/relay/channel/cloudflare/model.go
@@ -0,0 +1,13 @@
+package cloudflare
+
+import "one-api/dto"
+
+type CfRequest struct {
+ Messages []dto.Message `json:"messages,omitempty"`
+ Lora string `json:"lora,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ Prompt string `json:"prompt,omitempty"`
+ Raw bool `json:"raw,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+}
diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go
new file mode 100644
index 0000000..94a7ea0
--- /dev/null
+++ b/relay/channel/cloudflare/relay_cloudflare.go
@@ -0,0 +1,115 @@
+package cloudflare
+
+import (
+ "bufio"
+ "encoding/json"
+ "github.com/gin-gonic/gin"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/service"
+ "strings"
+)
+
+func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
+ p, _ := textRequest.Prompt.(string)
+ return &CfRequest{
+ Prompt: p,
+ MaxTokens: textRequest.GetMaxTokens(),
+ Stream: textRequest.Stream,
+ Temperature: textRequest.Temperature,
+ }
+}
+
+func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+ scanner := bufio.NewScanner(resp.Body)
+ scanner.Split(bufio.ScanLines)
+
+ service.SetEventStreamHeaders(c)
+ id := service.GetResponseID(c)
+ var responseText string
+
+ for scanner.Scan() {
+ data := scanner.Text()
+ if len(data) < len("data: ") {
+ continue
+ }
+ data = strings.TrimPrefix(data, "data: ")
+ data = strings.TrimSuffix(data, "\r")
+
+ if data == "[DONE]" {
+ break
+ }
+
+ var response dto.ChatCompletionsStreamResponse
+ err := json.Unmarshal([]byte(data), &response)
+ if err != nil {
+ common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
+ continue
+ }
+ for _, choice := range response.Choices {
+ choice.Delta.Role = "assistant"
+ responseText += choice.Delta.GetContentString()
+ }
+ response.Id = id
+ response.Model = info.UpstreamModelName
+ err = service.ObjectData(c, response)
+ if err != nil {
+ common.LogError(c, "error_rendering_stream_response: "+err.Error())
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ common.LogError(c, "error_scanning_stream_response: "+err.Error())
+ }
+ usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ if info.ShouldIncludeUsage {
+ response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
+ err := service.ObjectData(c, response)
+ if err != nil {
+ common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
+ }
+ }
+ service.Done(c)
+
+ err := resp.Body.Close()
+ if err != nil {
+ common.LogError(c, "close_response_body_failed: "+err.Error())
+ }
+
+ return nil, usage
+}
+
+func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ }
+ var response dto.TextResponse
+ err = json.Unmarshal(responseBody, &response)
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ }
+ response.Model = info.UpstreamModelName
+ var responseText string
+ for _, choice := range response.Choices {
+ responseText += choice.Message.StringContent()
+ }
+ usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ response.Usage = *usage
+ response.Id = service.GetResponseID(c)
+ jsonResponse, err := json.Marshal(response)
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, _ = c.Writer.Write(jsonResponse)
+ return nil, usage
+}
diff --git a/relay/channel/cohere/dto.go b/relay/channel/cohere/dto.go
index fc6c445..b2c2739 100644
--- a/relay/channel/cohere/dto.go
+++ b/relay/channel/cohere/dto.go
@@ -7,7 +7,7 @@ type CohereRequest struct {
ChatHistory []ChatHistory `json:"chat_history"`
Message string `json:"message"`
Stream bool `json:"stream"`
- MaxTokens int64 `json:"max_tokens"`
+ MaxTokens int `json:"max_tokens"`
}
type ChatHistory struct {
diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go
index 42c8381..e07434a 100644
--- a/relay/common/relay_info.go
+++ b/relay/common/relay_info.go
@@ -68,7 +68,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
info.ApiVersion = GetAPIVersion(c)
}
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
- info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini {
+ info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
+ info.ChannelType == common.ChannelCloudflare {
info.SupportStreamOptions = true
}
return info
diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go
index 0ce2657..6bd93c4 100644
--- a/relay/constant/api_type.go
+++ b/relay/constant/api_type.go
@@ -22,6 +22,7 @@ const (
APITypeCohere
APITypeDify
APITypeJina
+ APITypeCloudflare
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -63,6 +64,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeDify
case common.ChannelTypeJina:
apiType = APITypeJina
+ case common.ChannelCloudflare:
+ apiType = APITypeCloudflare
}
if apiType == -1 {
return APITypeOpenAI, false
diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go
index 8998540..4c0aef1 100644
--- a/relay/relay_adaptor.go
+++ b/relay/relay_adaptor.go
@@ -7,6 +7,7 @@ import (
"one-api/relay/channel/aws"
"one-api/relay/channel/baidu"
"one-api/relay/channel/claude"
+ "one-api/relay/channel/cloudflare"
"one-api/relay/channel/cohere"
"one-api/relay/channel/dify"
"one-api/relay/channel/gemini"
@@ -59,6 +60,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &dify.Adaptor{}
case constant.APITypeJina:
return &jina.Adaptor{}
+ case constant.APITypeCloudflare:
+ return &cloudflare.Adaptor{}
}
return nil
}
diff --git a/service/sse.go b/service/relay.go
similarity index 87%
rename from service/sse.go
rename to service/relay.go
index 2d531a4..22f9ce3 100644
--- a/service/sse.go
+++ b/service/relay.go
@@ -35,3 +35,8 @@ func ObjectData(c *gin.Context, object interface{}) error {
func Done(c *gin.Context) {
StringData(c, "[DONE]")
}
+
+func GetResponseID(c *gin.Context) string {
+ logID := c.GetString("X-Oneapi-Request-Id")
+ return fmt.Sprintf("chatcmpl-%s", logID)
+}
diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js
index ff1d281..88614b0 100644
--- a/web/src/constants/channel.constants.js
+++ b/web/src/constants/channel.constants.js
@@ -99,6 +99,7 @@ export const CHANNEL_OPTIONS = [
color: 'orange',
label: 'Google PaLM2',
},
+ { key: 39, text: 'Cloudflare', value: 39, color: 'grey', label: 'Cloudflare' },
{ key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
{ key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js
index aec3768..900fdf3 100644
--- a/web/src/pages/Channel/EditChannel.js
+++ b/web/src/pages/Channel/EditChannel.js
@@ -605,6 +605,24 @@ const EditChannel = (props) => {
/>
>
)}
+ {inputs.type === 39 && (
+ <>
+
+ Account ID:
+
+ {
+ handleInputChange('other', value);
+ }}
+ value={inputs.other}
+ autoComplete='new-password'
+ />
+ >
+ )}
模型:
From e67aa370bc14c8c4b89419ce11630533f27f0b25 Mon Sep 17 00:00:00 2001
From: FENG
Date: Sun, 14 Jul 2024 00:14:07 +0800
Subject: [PATCH 03/34] fix: channel timeout auto-ban and auto-enable
---
controller/channel-test.go | 32 +++++++++++++++++++-------------
1 file changed, 19 insertions(+), 13 deletions(-)
diff --git a/controller/channel-test.go b/controller/channel-test.go
index 268dac2..6f82cd7 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -231,27 +231,33 @@ func testAllChannels(notify bool) error {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
ban = true
}
+
+ // request error disables the channel
if openaiErr != nil {
err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message))
- ban = true
- }
- // parse *int to bool
- if channel.AutoBan != nil && *channel.AutoBan == 0 {
- ban = false
- }
- if openaiErr != nil {
openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{
StatusCode: -1,
Error: *openaiErr,
LocalError: false,
}
- if isChannelEnabled && service.ShouldDisableChannel(channel.Type, &openAiErrWithStatus) && ban {
- service.DisableChannel(channel.Id, channel.Name, err.Error())
- }
- if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
- service.EnableChannel(channel.Id, channel.Name)
- }
+ ban = service.ShouldDisableChannel(channel.Type, &openAiErrWithStatus)
}
+
+ // parse *int to bool
+ if channel.AutoBan != nil && *channel.AutoBan == 0 {
+ ban = false
+ }
+
+ // disable channel
+ if ban && isChannelEnabled {
+ service.DisableChannel(channel.Id, channel.Name, err.Error())
+ }
+
+ // enable channel
+ if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
+ service.EnableChannel(channel.Id, channel.Name)
+ }
+
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
}
From d55cb35c1c45b3b7f89239f6ea43e0ef39088a56 Mon Sep 17 00:00:00 2001
From: FENG
Date: Sun, 14 Jul 2024 01:21:05 +0800
Subject: [PATCH 04/34] fix: http code is not properly disabled
---
controller/channel-test.go | 25 ++++++++++---------------
service/channel.go | 4 ++--
2 files changed, 12 insertions(+), 17 deletions(-)
diff --git a/controller/channel-test.go b/controller/channel-test.go
index 6f82cd7..2174ff1 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -25,7 +25,7 @@ import (
"github.com/gin-gonic/gin"
)
-func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
+func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
tik := time.Now()
if channel.Type == common.ChannelTypeMidjourney {
return errors.New("midjourney channel test is not supported"), nil
@@ -58,8 +58,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
- openaiErr := service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError).Error
- return err, &openaiErr
+ return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[testModel] != "" {
testModel = modelMap[testModel]
@@ -104,11 +103,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
}
if resp != nil && resp.StatusCode != http.StatusOK {
err := relaycommon.RelayErrorHandler(resp)
- return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
+ return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
- return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
+ return fmt.Errorf("%s", respErr.Error.Message), respErr
}
if usage == nil {
return errors.New("usage is nil"), nil
@@ -222,7 +221,7 @@ func testAllChannels(notify bool) error {
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
- err, openaiErr := testChannel(channel, "")
+ err, openaiWithStatusErr := testChannel(channel, "")
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
@@ -233,14 +232,10 @@ func testAllChannels(notify bool) error {
}
// request error disables the channel
- if openaiErr != nil {
- err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message))
- openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{
- StatusCode: -1,
- Error: *openaiErr,
- LocalError: false,
- }
- ban = service.ShouldDisableChannel(channel.Type, &openAiErrWithStatus)
+ if openaiWithStatusErr != nil {
+ oaiErr := openaiWithStatusErr.Error
+ err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
+ ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
}
// parse *int to bool
@@ -254,7 +249,7 @@ func testAllChannels(notify bool) error {
}
// enable channel
- if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
+ if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
service.EnableChannel(channel.Id, channel.Name)
}
diff --git a/service/channel.go b/service/channel.go
index 76be271..5716a6d 100644
--- a/service/channel.go
+++ b/service/channel.go
@@ -74,14 +74,14 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
return false
}
-func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError, status int) bool {
+func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool {
if !common.AutomaticEnableChannelEnabled {
return false
}
if err != nil {
return false
}
- if openAIErr != nil {
+ if openaiWithStatusErr != nil {
return false
}
if status != common.ChannelStatusAutoDisabled {
From 0f687aab9a39687d80b40aa61cd116458028582e Mon Sep 17 00:00:00 2001
From: CalciumIon <1808837298@qq.com>
Date: Mon, 15 Jul 2024 16:05:30 +0800
Subject: [PATCH 05/34] fix: azure stream options
---
controller/channel-test.go | 2 +-
relay/channel/adapter.go | 2 +-
relay/channel/ali/adaptor.go | 4 ++--
relay/channel/aws/adaptor.go | 2 +-
relay/channel/baidu/adaptor.go | 4 ++--
relay/channel/claude/adaptor.go | 2 +-
relay/channel/cloudflare/adaptor.go | 4 ++--
relay/channel/cloudflare/relay_cloudflare.go | 6 ++++++
relay/channel/cohere/adaptor.go | 2 +-
relay/channel/dify/adaptor.go | 2 +-
relay/channel/gemini/adaptor.go | 2 +-
relay/channel/jina/adaptor.go | 2 +-
relay/channel/ollama/adaptor.go | 4 ++--
relay/channel/openai/adaptor.go | 5 ++++-
relay/channel/palm/adaptor.go | 2 +-
relay/channel/perplexity/adaptor.go | 2 +-
relay/channel/tencent/adaptor.go | 2 +-
relay/channel/xunfei/adaptor.go | 2 +-
relay/channel/zhipu/adaptor.go | 2 +-
relay/channel/zhipu_4v/adaptor.go | 2 +-
relay/relay-text.go | 2 +-
21 files changed, 33 insertions(+), 24 deletions(-)
diff --git a/controller/channel-test.go b/controller/channel-test.go
index 2174ff1..4ad7457 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -87,7 +87,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
adaptor.Init(meta, *request)
- convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
+ convertedRequest, err := adaptor.ConvertRequest(c, meta, request)
if err != nil {
return err, nil
}
diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go
index e222a70..7064b88 100644
--- a/relay/channel/adapter.go
+++ b/relay/channel/adapter.go
@@ -14,7 +14,7 @@ type Adaptor interface {
InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest)
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
- ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
+ ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go
index fbaf546..e03d29f 100644
--- a/relay/channel/ali/adaptor.go
+++ b/relay/channel/ali/adaptor.go
@@ -42,11 +42,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
- switch relayMode {
+ switch info.RelayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
return baiduEmbeddingRequest, nil
diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go
index 6452392..8214777 100644
--- a/relay/channel/aws/adaptor.go
+++ b/relay/channel/aws/adaptor.go
@@ -41,7 +41,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go
index 17f5384..40a0696 100644
--- a/relay/channel/baidu/adaptor.go
+++ b/relay/channel/baidu/adaptor.go
@@ -99,11 +99,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
- switch relayMode {
+ switch info.RelayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
return baiduEmbeddingRequest, nil
diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go
index 4623318..8e4c75d 100644
--- a/relay/channel/claude/adaptor.go
+++ b/relay/channel/claude/adaptor.go
@@ -53,7 +53,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go
index 571f50c..53b5a91 100644
--- a/relay/channel/cloudflare/adaptor.go
+++ b/relay/channel/cloudflare/adaptor.go
@@ -38,11 +38,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
- switch relayMode {
+ switch info.RelayMode {
case constant.RelayModeCompletions:
return convertCf2CompletionsRequest(*request), nil
default:
diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go
index 94a7ea0..d9319ef 100644
--- a/relay/channel/cloudflare/relay_cloudflare.go
+++ b/relay/channel/cloudflare/relay_cloudflare.go
@@ -11,6 +11,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
+ "time"
)
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
@@ -30,6 +31,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
service.SetEventStreamHeaders(c)
id := service.GetResponseID(c)
var responseText string
+ isFirst := true
for scanner.Scan() {
data := scanner.Text()
@@ -56,6 +58,10 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
response.Id = id
response.Model = info.UpstreamModelName
err = service.ObjectData(c, response)
+ if isFirst {
+ isFirst = false
+ info.FirstResponseTime = time.Now()
+ }
if err != nil {
common.LogError(c, "error_rendering_stream_response: "+err.Error())
}
diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go
index b5f3521..84243aa 100644
--- a/relay/channel/cohere/adaptor.go
+++ b/relay/channel/cohere/adaptor.go
@@ -34,7 +34,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return requestOpenAI2Cohere(*request), nil
}
diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go
index a54b95b..8dbe8b8 100644
--- a/relay/channel/dify/adaptor.go
+++ b/relay/channel/dify/adaptor.go
@@ -32,7 +32,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go
index 9755163..f223fbf 100644
--- a/relay/channel/gemini/adaptor.go
+++ b/relay/channel/gemini/adaptor.go
@@ -51,7 +51,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go
index 48616b6..d0a379a 100644
--- a/relay/channel/jina/adaptor.go
+++ b/relay/channel/jina/adaptor.go
@@ -36,7 +36,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}
diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go
index 76de148..b0550ca 100644
--- a/relay/channel/ollama/adaptor.go
+++ b/relay/channel/ollama/adaptor.go
@@ -36,11 +36,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
- switch relayMode {
+ switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
return requestOpenAI2Embeddings(*request), nil
default:
diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go
index 00f01fd..e327027 100644
--- a/relay/channel/openai/adaptor.go
+++ b/relay/channel/openai/adaptor.go
@@ -74,10 +74,13 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
+ if info.ChannelType != common.ChannelTypeOpenAI {
+ request.StreamOptions = nil
+ }
return request, nil
}
diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go
index 8f6dd0a..51d1399 100644
--- a/relay/channel/palm/adaptor.go
+++ b/relay/channel/palm/adaptor.go
@@ -33,7 +33,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go
index 3c65b2d..a220076 100644
--- a/relay/channel/perplexity/adaptor.go
+++ b/relay/channel/perplexity/adaptor.go
@@ -34,7 +34,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go
index d79330e..3dd9115 100644
--- a/relay/channel/tencent/adaptor.go
+++ b/relay/channel/tencent/adaptor.go
@@ -47,7 +47,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go
index 9852aa1..adb054e 100644
--- a/relay/channel/xunfei/adaptor.go
+++ b/relay/channel/xunfei/adaptor.go
@@ -33,7 +33,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go
index 0893a83..09345ca 100644
--- a/relay/channel/zhipu/adaptor.go
+++ b/relay/channel/zhipu/adaptor.go
@@ -37,7 +37,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go
index 508861f..9b8bd49 100644
--- a/relay/channel/zhipu_4v/adaptor.go
+++ b/relay/channel/zhipu_4v/adaptor.go
@@ -35,7 +35,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
diff --git a/relay/relay-text.go b/relay/relay-text.go
index 6e74fbb..ef169fa 100644
--- a/relay/relay-text.go
+++ b/relay/relay-text.go
@@ -153,7 +153,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
adaptor.Init(relayInfo, *textRequest)
var requestBody io.Reader
- convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
+ convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}
From 7029065892142e8d240477dd888f297ba018ac94 Mon Sep 17 00:00:00 2001
From: CalciumIon <1808837298@qq.com>
Date: Mon, 15 Jul 2024 18:04:05 +0800
Subject: [PATCH 06/34] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E6=B5=81?=
=?UTF-8?q?=E6=A8=A1=E5=BC=8F=E9=80=BB=E8=BE=91?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
dto/text_response.go | 15 +-
relay/channel/openai/adaptor.go | 9 +-
relay/channel/openai/relay-openai.go | 209 +++++++++++++--------------
service/usage_helpr.go | 4 +
4 files changed, 114 insertions(+), 123 deletions(-)
diff --git a/dto/text_response.go b/dto/text_response.go
index 3310d02..e1f0cc0 100644
--- a/dto/text_response.go
+++ b/dto/text_response.go
@@ -66,10 +66,6 @@ type ChatCompletionsStreamResponseChoiceDelta struct {
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
-func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool {
- return c.Content == nil && len(c.ToolCalls) == 0
-}
-
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
c.Content = &s
}
@@ -105,6 +101,17 @@ type ChatCompletionsStreamResponse struct {
Usage *Usage `json:"usage"`
}
+func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
+ if c.SystemFingerprint == nil {
+ return ""
+ }
+ return *c.SystemFingerprint
+}
+
+func (c *ChatCompletionsStreamResponse) SetSystemFingerprint(s string) {
+ c.SystemFingerprint = &s
+}
+
type ChatCompletionsStreamResponseSimple struct {
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Usage *Usage `json:"usage"`
diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go
index e327027..688dedc 100644
--- a/relay/channel/openai/adaptor.go
+++ b/relay/channel/openai/adaptor.go
@@ -14,7 +14,6 @@ import (
"one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
- "one-api/service"
"strings"
)
@@ -90,13 +89,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
- var responseText string
- var toolCount int
- err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
- if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
- usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
- usage.CompletionTokens += toolCount * 7
- }
+ err, usage, _, _ = OpenaiStreamHandler(c, resp, info)
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go
index dace39c..3fd7f03 100644
--- a/relay/channel/openai/relay-openai.go
+++ b/relay/channel/openai/relay-openai.go
@@ -14,38 +14,33 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"strings"
- "sync"
"time"
)
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) {
- //checkSensitive := constant.ShouldCheckCompletionSensitive()
+ hasStreamUsage := false
+ responseId := ""
+ var createAt int64 = 0
+ var systemFingerprint string
+
var responseTextBuilder strings.Builder
- var usage dto.Usage
+ var usage = &dto.Usage{}
toolCount := 0
scanner := bufio.NewScanner(resp.Body)
- scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
- if atEOF && len(data) == 0 {
- return 0, nil, nil
- }
- if i := strings.Index(string(data), "\n"); i >= 0 {
- return i + 1, data[0:i], nil
- }
- if atEOF {
- return len(data), data, nil
- }
- return 0, nil, nil
- })
- dataChan := make(chan string, 5)
+ scanner.Split(bufio.ScanLines)
+ var streamItems []string // store stream items
+
+ service.SetEventStreamHeaders(c)
+
+ ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
+ defer ticker.Stop()
+
stopChan := make(chan bool, 2)
defer close(stopChan)
- defer close(dataChan)
- var wg sync.WaitGroup
+
go func() {
- wg.Add(1)
- defer wg.Done()
- var streamItems []string // store stream items
for scanner.Scan() {
+ ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
@@ -53,54 +48,42 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
- if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
- // send data timeout, stop the stream
- common.LogError(c, "send data timeout, stop the stream")
- break
- }
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
+ service.StringData(c, data)
streamItems = append(streamItems, data)
}
}
- // 计算token
- streamResp := "[" + strings.Join(streamItems, ",") + "]"
- switch info.RelayMode {
- case relayconstant.RelayModeChatCompletions:
- var streamResponses []dto.ChatCompletionsStreamResponseSimple
- err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
- if err != nil {
- // 一次性解析失败,逐个解析
- common.SysError("error unmarshalling stream response: " + err.Error())
- for _, item := range streamItems {
- var streamResponse dto.ChatCompletionsStreamResponseSimple
- err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
- if err == nil {
- if streamResponse.Usage != nil {
- if streamResponse.Usage.TotalTokens != 0 {
- usage = *streamResponse.Usage
- }
- }
- for _, choice := range streamResponse.Choices {
- responseTextBuilder.WriteString(choice.Delta.GetContentString())
- if choice.Delta.ToolCalls != nil {
- if len(choice.Delta.ToolCalls) > toolCount {
- toolCount = len(choice.Delta.ToolCalls)
- }
- for _, tool := range choice.Delta.ToolCalls {
- responseTextBuilder.WriteString(tool.Function.Name)
- responseTextBuilder.WriteString(tool.Function.Arguments)
- }
- }
- }
- }
- }
- } else {
- for _, streamResponse := range streamResponses {
- if streamResponse.Usage != nil {
- if streamResponse.Usage.TotalTokens != 0 {
- usage = *streamResponse.Usage
- }
+ stopChan <- true
+ }()
+
+ select {
+ case <-ticker.C:
+ // 超时处理逻辑
+ common.LogError(c, "streaming timeout")
+ case <-stopChan:
+ // 正常结束
+ }
+
+ // 计算token
+ streamResp := "[" + strings.Join(streamItems, ",") + "]"
+ switch info.RelayMode {
+ case relayconstant.RelayModeChatCompletions:
+ var streamResponses []dto.ChatCompletionsStreamResponse
+ err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
+ if err != nil {
+ // 一次性解析失败,逐个解析
+ common.SysError("error unmarshalling stream response: " + err.Error())
+ for _, item := range streamItems {
+ var streamResponse dto.ChatCompletionsStreamResponse
+ err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
+ if err == nil {
+ responseId = streamResponse.Id
+ createAt = streamResponse.Created
+ systemFingerprint = streamResponse.GetSystemFingerprint()
+ if service.ValidUsage(streamResponse.Usage) {
+ usage = streamResponse.Usage
+ hasStreamUsage = true
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
@@ -116,67 +99,71 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
}
}
- case relayconstant.RelayModeCompletions:
- var streamResponses []dto.CompletionsStreamResponse
- err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
- if err != nil {
- // 一次性解析失败,逐个解析
- common.SysError("error unmarshalling stream response: " + err.Error())
- for _, item := range streamItems {
- var streamResponse dto.CompletionsStreamResponse
- err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
- if err == nil {
- for _, choice := range streamResponse.Choices {
- responseTextBuilder.WriteString(choice.Text)
+ } else {
+ for _, streamResponse := range streamResponses {
+ responseId = streamResponse.Id
+ createAt = streamResponse.Created
+ systemFingerprint = streamResponse.GetSystemFingerprint()
+ if service.ValidUsage(streamResponse.Usage) {
+ usage = streamResponse.Usage
+ hasStreamUsage = true
+ }
+ for _, choice := range streamResponse.Choices {
+ responseTextBuilder.WriteString(choice.Delta.GetContentString())
+ if choice.Delta.ToolCalls != nil {
+ if len(choice.Delta.ToolCalls) > toolCount {
+ toolCount = len(choice.Delta.ToolCalls)
+ }
+ for _, tool := range choice.Delta.ToolCalls {
+ responseTextBuilder.WriteString(tool.Function.Name)
+ responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
- } else {
- for _, streamResponse := range streamResponses {
+ }
+ }
+ case relayconstant.RelayModeCompletions:
+ var streamResponses []dto.CompletionsStreamResponse
+ err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
+ if err != nil {
+ // 一次性解析失败,逐个解析
+ common.SysError("error unmarshalling stream response: " + err.Error())
+ for _, item := range streamItems {
+ var streamResponse dto.CompletionsStreamResponse
+ err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
+ if err == nil {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
}
- }
- if len(dataChan) > 0 {
- // wait data out
- time.Sleep(2 * time.Second)
- }
- common.SafeSendBool(stopChan, true)
- }()
- service.SetEventStreamHeaders(c)
- isFirst := true
- ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
- defer ticker.Stop()
- c.Stream(func(w io.Writer) bool {
- select {
- case <-ticker.C:
- common.LogError(c, "reading data from upstream timeout")
- return false
- case data := <-dataChan:
- if isFirst {
- isFirst = false
- info.FirstResponseTime = time.Now()
+ } else {
+ for _, streamResponse := range streamResponses {
+ for _, choice := range streamResponse.Choices {
+ responseTextBuilder.WriteString(choice.Text)
+ }
}
- ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
- if strings.HasPrefix(data, "data: [DONE]") {
- data = data[:12]
- }
- // some implementations may add \r at the end of data
- data = strings.TrimSuffix(data, "\r")
- c.Render(-1, common.CustomEvent{Data: data})
- return true
- case <-stopChan:
- return false
}
- })
+ }
+
+ if !hasStreamUsage {
+ usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
+ usage.CompletionTokens += toolCount * 7
+ }
+
+ if info.ShouldIncludeUsage && !hasStreamUsage {
+ response := service.GenerateFinalUsageResponse(responseId, createAt, info.UpstreamModelName, *usage)
+ response.SetSystemFingerprint(systemFingerprint)
+ service.ObjectData(c, response)
+ }
+
+ service.Done(c)
+
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
}
- wg.Wait()
- return nil, &usage, responseTextBuilder.String(), toolCount
+ return nil, usage, responseTextBuilder.String(), toolCount
}
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
diff --git a/service/usage_helpr.go b/service/usage_helpr.go
index 528f3d4..adec566 100644
--- a/service/usage_helpr.go
+++ b/service/usage_helpr.go
@@ -36,3 +36,7 @@ func GenerateFinalUsageResponse(id string, createAt int64, model string, usage d
Usage: &usage,
}
}
+
+func ValidUsage(usage *dto.Usage) bool {
+ return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0)
+}
From 220ab412e26aafc246d11ca8cf42b0019759c1c9 Mon Sep 17 00:00:00 2001
From: CalciumIon <1808837298@qq.com>
Date: Mon, 15 Jul 2024 18:14:07 +0800
Subject: [PATCH 07/34] fix: openai response time
---
relay/channel/openai/relay-openai.go | 1 +
relay/common/relay_info.go | 8 ++++++++
2 files changed, 9 insertions(+)
diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go
index 3fd7f03..16cbb0c 100644
--- a/relay/channel/openai/relay-openai.go
+++ b/relay/channel/openai/relay-openai.go
@@ -40,6 +40,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
go func() {
for scanner.Scan() {
+ info.SetFirstResponseTime()
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go
index e07434a..564a7ad 100644
--- a/relay/common/relay_info.go
+++ b/relay/common/relay_info.go
@@ -17,6 +17,7 @@ type RelayInfo struct {
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
+ setFirstResponse bool
ApiType int
IsStream bool
RelayMode int
@@ -83,6 +84,13 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
info.IsStream = isStream
}
+func (info *RelayInfo) SetFirstResponseTime() {
+ if !info.setFirstResponse {
+ info.FirstResponseTime = time.Now()
+ info.setFirstResponse = true
+ }
+}
+
type TaskRelayInfo struct {
ChannelType int
ChannelId int
From e2b906165086ded46c21bc8033aa51375d0c532e Mon Sep 17 00:00:00 2001
From: CalciumIon <1808837298@qq.com>
Date: Mon, 15 Jul 2024 19:06:13 +0800
Subject: [PATCH 08/34] fix: openai stream response
---
relay/channel/openai/relay-openai.go | 14 +++++++++-----
1 file changed, 9 insertions(+), 5 deletions(-)
diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go
index 16cbb0c..8fc4f6f 100644
--- a/relay/channel/openai/relay-openai.go
+++ b/relay/channel/openai/relay-openai.go
@@ -22,20 +22,22 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
responseId := ""
var createAt int64 = 0
var systemFingerprint string
+ model := info.UpstreamModelName
var responseTextBuilder strings.Builder
var usage = &dto.Usage{}
+ var streamItems []string // store stream items
+
toolCount := 0
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
- var streamItems []string // store stream items
service.SetEventStreamHeaders(c)
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
defer ticker.Stop()
- stopChan := make(chan bool, 2)
+ stopChan := make(chan bool)
defer close(stopChan)
go func() {
@@ -55,7 +57,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
streamItems = append(streamItems, data)
}
}
- stopChan <- true
+ common.SafeSendBool(stopChan, true)
}()
select {
@@ -82,6 +84,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
responseId = streamResponse.Id
createAt = streamResponse.Created
systemFingerprint = streamResponse.GetSystemFingerprint()
+ model = streamResponse.Model
if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage
hasStreamUsage = true
@@ -105,6 +108,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
responseId = streamResponse.Id
createAt = streamResponse.Created
systemFingerprint = streamResponse.GetSystemFingerprint()
+ model = streamResponse.Model
if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage
hasStreamUsage = true
@@ -153,7 +157,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
if info.ShouldIncludeUsage && !hasStreamUsage {
- response := service.GenerateFinalUsageResponse(responseId, createAt, info.UpstreamModelName, *usage)
+ response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
service.ObjectData(c, response)
}
@@ -162,7 +166,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
err := resp.Body.Close()
if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
+ common.LogError(c, "close_response_body_failed: "+err.Error())
}
return nil, usage, responseTextBuilder.String(), toolCount
}
From 9bbe8e7d1ba584295f322c46ce9cde3180acc342 Mon Sep 17 00:00:00 2001
From: CalciumIon <1808837298@qq.com>
Date: Mon, 15 Jul 2024 20:23:19 +0800
Subject: [PATCH 09/34] =?UTF-8?q?fix:=20=E6=97=A5=E5=BF=97=E8=AF=A6?=
=?UTF-8?q?=E6=83=85=E9=9D=9E=E6=B6=88=E8=B4=B9=E7=B1=BB=E5=9E=8B=E6=98=BE?=
=?UTF-8?q?=E7=A4=BA=E9=94=99=E8=AF=AF?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
web/src/components/LogsTable.js | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js
index 4bbacf0..55106f2 100644
--- a/web/src/components/LogsTable.js
+++ b/web/src/components/LogsTable.js
@@ -367,7 +367,7 @@ const LogsTable = () => {
dataIndex: 'content',
render: (text, record, index) => {
let other = getLogOther(record.other);
- if (other == null) {
+ if (other == null || record.type !== 2) {
return (
Date: Mon, 15 Jul 2024 22:07:50 +0800
Subject: [PATCH 10/34] chore: openai stream
---
common/model-ratio.go | 13 +++++++------
relay/channel/ollama/adaptor.go | 7 +------
relay/channel/openai/adaptor.go | 2 +-
relay/channel/openai/relay-openai.go | 4 ++--
relay/channel/perplexity/adaptor.go | 7 +------
relay/channel/zhipu_4v/adaptor.go | 9 +--------
relay/channel/zhipu_4v/constants.go | 2 +-
7 files changed, 14 insertions(+), 30 deletions(-)
diff --git a/common/model-ratio.go b/common/model-ratio.go
index c554036..294a0cc 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -105,12 +105,13 @@ var defaultModelRatio = map[string]float64{
"gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1,
- "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
- "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
- "chatglm_std": 0.3572, // ¥0.005 / 1k tokens
- "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
- "glm-4": 7.143, // ¥0.1 / 1k tokens
- "glm-4v": 7.143, // ¥0.1 / 1k tokens
+ "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
+ "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
+ "chatglm_std": 0.3572, // ¥0.005 / 1k tokens
+ "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
+ "glm-4": 7.143, // ¥0.1 / 1k tokens
+ "glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
+ "glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go
index b0550ca..15ced27 100644
--- a/relay/channel/ollama/adaptor.go
+++ b/relay/channel/ollama/adaptor.go
@@ -10,7 +10,6 @@ import (
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
- "one-api/service"
)
type Adaptor struct {
@@ -58,11 +57,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
- var responseText string
- err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
- if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
- usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
- }
+ err, usage = openai.OpenaiStreamHandler(c, resp, info)
} else {
if info.RelayMode == relayconstant.RelayModeEmbeddings {
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go
index 688dedc..6dc56d0 100644
--- a/relay/channel/openai/adaptor.go
+++ b/relay/channel/openai/adaptor.go
@@ -89,7 +89,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
- err, usage, _, _ = OpenaiStreamHandler(c, resp, info)
+ err, usage = OpenaiStreamHandler(c, resp, info)
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go
index 8fc4f6f..b71fcce 100644
--- a/relay/channel/openai/relay-openai.go
+++ b/relay/channel/openai/relay-openai.go
@@ -17,7 +17,7 @@ import (
"time"
)
-func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) {
+func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
hasStreamUsage := false
responseId := ""
var createAt int64 = 0
@@ -168,7 +168,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
if err != nil {
common.LogError(c, "close_response_body_failed: "+err.Error())
}
- return nil, usage, responseTextBuilder.String(), toolCount
+ return nil, usage
}
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go
index a220076..40aa0f4 100644
--- a/relay/channel/perplexity/adaptor.go
+++ b/relay/channel/perplexity/adaptor.go
@@ -10,7 +10,6 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
- "one-api/service"
)
type Adaptor struct {
@@ -54,11 +53,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
- var responseText string
- err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
- if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
- usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
- }
+ err, usage = openai.OpenaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go
index 9b8bd49..bdce639 100644
--- a/relay/channel/zhipu_4v/adaptor.go
+++ b/relay/channel/zhipu_4v/adaptor.go
@@ -10,7 +10,6 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
- "one-api/service"
)
type Adaptor struct {
@@ -55,13 +54,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
- var responseText string
- var toolCount int
- err, usage, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
- if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
- usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
- usage.CompletionTokens += toolCount * 7
- }
+ err, usage = openai.OpenaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
diff --git a/relay/channel/zhipu_4v/constants.go b/relay/channel/zhipu_4v/constants.go
index 1b0b0cc..3383eb3 100644
--- a/relay/channel/zhipu_4v/constants.go
+++ b/relay/channel/zhipu_4v/constants.go
@@ -1,7 +1,7 @@
package zhipu_4v
var ModelList = []string{
- "glm-4", "glm-4v", "glm-3-turbo",
+ "glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools",
}
var ChannelName = "zhipu_4v"
From ba27da9e2cee5850851f62e4228544220aa25e50 Mon Sep 17 00:00:00 2001
From: CalciumIon <1808837298@qq.com>
Date: Mon, 15 Jul 2024 22:09:11 +0800
Subject: [PATCH 11/34] fix: try to fix mj
---
controller/midjourney.go | 21 ++++++++++++---------
1 file changed, 12 insertions(+), 9 deletions(-)
diff --git a/controller/midjourney.go b/controller/midjourney.go
index 508c5dd..1a8cd36 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -146,7 +146,7 @@ func UpdateMidjourneyTaskBulk() {
buttonStr, _ := json.Marshal(responseItem.Buttons)
task.Buttons = string(buttonStr)
}
-
+ shouldReturnQuota := false
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
@@ -154,20 +154,23 @@ func UpdateMidjourneyTaskBulk() {
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
- quota := task.Quota
- if quota != 0 {
- err = model.IncreaseUserQuota(task.UserId, quota)
- if err != nil {
- common.LogError(ctx, "fail to increase user quota: "+err.Error())
- }
- logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
- model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+ if task.Quota != 0 {
+ shouldReturnQuota = true
}
}
}
err = task.Update()
if err != nil {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
+ } else {
+ if shouldReturnQuota {
+ err = model.IncreaseUserQuota(task.UserId, task.Quota)
+ if err != nil {
+ common.LogError(ctx, "fail to increase user quota: "+err.Error())
+ }
+ logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota))
+ model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+ }
}
}
}
From a3880d558ad0aaa5913809348fac829255c62fea Mon Sep 17 00:00:00 2001
From: CalciumIon <1808837298@qq.com>
Date: Mon, 15 Jul 2024 22:14:30 +0800
Subject: [PATCH 12/34] chore: mj
---
controller/midjourney.go | 9 ++-------
1 file changed, 2 insertions(+), 7 deletions(-)
diff --git a/controller/midjourney.go b/controller/midjourney.go
index 1a8cd36..01ddb2f 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -150,13 +150,8 @@ func UpdateMidjourneyTaskBulk() {
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
- err = model.CacheUpdateUserQuota(task.UserId)
- if err != nil {
- common.LogError(ctx, "error update user quota cache: "+err.Error())
- } else {
- if task.Quota != 0 {
- shouldReturnQuota = true
- }
+ if task.Quota != 0 {
+ shouldReturnQuota = true
}
}
err = task.Update()
From 963985e76c05124aea21430f9bb1d2af4cfcbbe0 Mon Sep 17 00:00:00 2001
From: CalciumIon <1808837298@qq.com>
Date: Tue, 16 Jul 2024 14:54:03 +0800
Subject: [PATCH 13/34] chore: update model radio
---
common/model-ratio.go | 2 ++
1 file changed, 2 insertions(+)
diff --git a/common/model-ratio.go b/common/model-ratio.go
index 294a0cc..1200310 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -159,6 +159,8 @@ var defaultModelRatio = map[string]float64{
}
var defaultModelPrice = map[string]float64{
+ "suno_music": 0.1,
+ "suno_lyrics": 0.01,
"dall-e-3": 0.04,
"gpt-4-gizmo-*": 0.1,
"mj_imagine": 0.1,
From eb9b4b07ad94cc2378756e8eaba8a984b45ac61c Mon Sep 17 00:00:00 2001
From: CalciumIon <1808837298@qq.com>
Date: Tue, 16 Jul 2024 15:48:56 +0800
Subject: [PATCH 14/34] feat: update register page
---
web/src/components/RegisterForm.js | 217 +++++++++++++++--------------
1 file changed, 113 insertions(+), 104 deletions(-)
diff --git a/web/src/components/RegisterForm.js b/web/src/components/RegisterForm.js
index fcd2638..5ff2588 100644
--- a/web/src/components/RegisterForm.js
+++ b/web/src/components/RegisterForm.js
@@ -1,16 +1,10 @@
import React, { useEffect, useState } from 'react';
-import {
- Button,
- Form,
- Grid,
- Header,
- Image,
- Message,
- Segment,
-} from 'semantic-ui-react';
import { Link, useNavigate } from 'react-router-dom';
import { API, getLogo, showError, showInfo, showSuccess } from '../helpers';
import Turnstile from 'react-turnstile';
+import { Button, Card, Form, Layout } from '@douyinfe/semi-ui';
+import Title from '@douyinfe/semi-ui/lib/es/typography/title';
+import Text from '@douyinfe/semi-ui/lib/es/typography/text';
const RegisterForm = () => {
const [inputs, setInputs] = useState({
@@ -18,7 +12,7 @@ const RegisterForm = () => {
password: '',
password2: '',
email: '',
- verification_code: '',
+ verification_code: ''
});
const { username, password, password2 } = inputs;
const [showEmailVerification, setShowEmailVerification] = useState(false);
@@ -46,9 +40,7 @@ const RegisterForm = () => {
let navigate = useNavigate();
- function handleChange(e) {
- const { name, value } = e.target;
- console.log(name, value);
+ function handleChange(name, value) {
setInputs((inputs) => ({ ...inputs, [name]: value }));
}
@@ -73,7 +65,7 @@ const RegisterForm = () => {
inputs.aff_code = affCode;
const res = await API.post(
`/api/user/register?turnstile=${turnstileToken}`,
- inputs,
+ inputs
);
const { success, message } = res.data;
if (success) {
@@ -94,7 +86,7 @@ const RegisterForm = () => {
}
setLoading(true);
const res = await API.get(
- `/api/verification?email=${inputs.email}&turnstile=${turnstileToken}`,
+ `/api/verification?email=${inputs.email}&turnstile=${turnstileToken}`
);
const { success, message } = res.data;
if (success) {
@@ -106,96 +98,113 @@ const RegisterForm = () => {
};
return (
-
-
-
-