Compare commits

...

18 Commits

Author SHA1 Message Date
CalciumIon
ba27da9e2c fix: try to fix mj 2024-07-15 22:09:11 +08:00
CalciumIon
e262a9bd2c chore: openai stream 2024-07-15 22:07:50 +08:00
CalciumIon
9bbe8e7d1b fix: 日志详情非消费类型显示错误 2024-07-15 20:23:19 +08:00
CalciumIon
e2b9061650 fix: openai stream response 2024-07-15 19:06:13 +08:00
CalciumIon
220ab412e2 fix: openai response time 2024-07-15 18:14:07 +08:00
CalciumIon
7029065892 refactor: 重构流模式逻辑 2024-07-15 18:04:05 +08:00
CalciumIon
0f687aab9a fix: azure stream options 2024-07-15 16:05:30 +08:00
Calcium-Ion
5e936b3923 Merge pull request #363 from dalefengs/main
fix: http code is not properly disabled channel
2024-07-14 15:35:27 +08:00
FENG
d55cb35c1c fix: http code is not properly disabled 2024-07-14 01:21:05 +08:00
Calcium-Ion
5be4cbcaaf Merge pull request #362 from dalefengs/main
fix: channel timeout auto-ban and auto-enable
2024-07-14 00:30:17 +08:00
FENG
e67aa370bc fix: channel timeout auto-ban and auto-enable 2024-07-14 00:14:07 +08:00
CalciumIon
7b36a2b885 feat: support cloudflare worker ai 2024-07-13 19:55:22 +08:00
CalciumIon
c88f3741e6 feat: support claude stop_sequences 2024-07-11 18:44:45 +08:00
CalciumIon
4e7e206290 fix: gemini usage (close #354) 2024-07-10 16:01:09 +08:00
CalciumIon
579fc8129e fix: dify (close #355) 2024-07-10 15:36:17 +08:00
CalciumIon
f55f63f412 fix: email login 2024-07-09 21:36:31 +08:00
CalciumIon
0526c85732 feat: update stream options 2024-07-09 21:11:01 +08:00
CalciumIon
b75134ece4 fix: hunyuan 2024-07-08 23:42:16 +08:00
50 changed files with 624 additions and 308 deletions

View File

@@ -212,36 +212,37 @@ const (
ChannelTypeSunoAPI = 36
ChannelTypeDify = 37
ChannelTypeJina = 38
ChannelCloudflare = 39
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"http://localhost:11434", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", //23
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"http://localhost:11434", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.tencentcloudapi.com", //23
"https://generativelanguage.googleapis.com", //24
"https://api.moonshot.cn", //25
"https://open.bigmodel.cn", //26
@@ -257,4 +258,5 @@ var ChannelBaseURLs = []string{
"", //36
"", //37
"https://api.jina.ai", //38
"https://api.cloudflare.com", //39
}

View File

@@ -105,12 +105,13 @@ var defaultModelRatio = map[string]float64{
"gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-4v": 7.143, // ¥0.1 / 1k tokens
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens

View File

@@ -12,6 +12,7 @@ import (
"net/url"
"one-api/common"
"one-api/dto"
"one-api/middleware"
"one-api/model"
"one-api/relay"
relaycommon "one-api/relay/common"
@@ -24,7 +25,7 @@ import (
"github.com/gin-gonic/gin"
)
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
tik := time.Now()
if channel.Type == common.ChannelTypeMidjourney {
return errors.New("midjourney channel test is not supported"), nil
@@ -40,29 +41,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
Body: nil,
Header: make(http.Header),
}
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
//case common.ChannelTypeAIProxyLibrary:
// c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
}
meta := relaycommon.GenRelayInfo(c)
apiType, _ := constant.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
if testModel == "" {
if channel.TestModel != nil && *channel.TestModel != "" {
testModel = *channel.TestModel
@@ -79,8 +58,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
openaiErr := service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError).Error
return err, &openaiErr
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[testModel] != "" {
testModel = modelMap[testModel]
@@ -88,6 +66,20 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
}
}
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
middleware.SetupContextForSelectedChannel(c, channel, testModel)
meta := relaycommon.GenRelayInfo(c)
apiType, _ := constant.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
request := buildTestRequest()
request.Model = testModel
meta.UpstreamModelName = testModel
@@ -95,7 +87,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
adaptor.Init(meta, *request)
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
convertedRequest, err := adaptor.ConvertRequest(c, meta, request)
if err != nil {
return err, nil
}
@@ -111,17 +103,16 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
}
if resp != nil && resp.StatusCode != http.StatusOK {
err := relaycommon.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
return fmt.Errorf("%s", respErr.Error.Message), respErr
}
if usage == nil {
return errors.New("usage is nil"), nil
}
result := w.Result()
// print result.Body
respBody, err := io.ReadAll(result.Body)
if err != nil {
return err, nil
@@ -230,7 +221,7 @@ func testAllChannels(notify bool) error {
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
err, openaiErr := testChannel(channel, "")
err, openaiWithStatusErr := testChannel(channel, "")
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
@@ -239,27 +230,29 @@ func testAllChannels(notify bool) error {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
ban = true
}
if openaiErr != nil {
err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message))
ban = true
// request error disables the channel
if openaiWithStatusErr != nil {
oaiErr := openaiWithStatusErr.Error
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
}
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
if openaiErr != nil {
openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{
StatusCode: -1,
Error: *openaiErr,
LocalError: false,
}
if isChannelEnabled && service.ShouldDisableChannel(channel.Type, &openAiErrWithStatus) && ban {
service.DisableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
service.EnableChannel(channel.Id, channel.Name)
}
// disable channel
if ban && isChannelEnabled {
service.DisableChannel(channel.Id, channel.Name, err.Error())
}
// enable channel
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
service.EnableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
}

View File

@@ -146,7 +146,7 @@ func UpdateMidjourneyTaskBulk() {
buttonStr, _ := json.Marshal(responseItem.Buttons)
task.Buttons = string(buttonStr)
}
shouldReturnQuota := false
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
@@ -154,20 +154,23 @@ func UpdateMidjourneyTaskBulk() {
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
quota := task.Quota
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
if task.Quota != 0 {
shouldReturnQuota = true
}
}
}
err = task.Update()
if err != nil {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
} else {
if shouldReturnQuota {
err = model.IncreaseUserQuota(task.UserId, task.Quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(task.Quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
}

View File

@@ -48,8 +48,8 @@ type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
return int64(r.MaxTokens)
func (r GeneralOpenAIRequest) GetMaxTokens() int {
return int(r.MaxTokens)
}
func (r GeneralOpenAIRequest) ParseInput() []string {

View File

@@ -66,10 +66,6 @@ type ChatCompletionsStreamResponseChoiceDelta struct {
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool {
return c.Content == nil && len(c.ToolCalls) == 0
}
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
c.Content = &s
}
@@ -105,6 +101,17 @@ type ChatCompletionsStreamResponse struct {
Usage *Usage `json:"usage"`
}
func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
if c.SystemFingerprint == nil {
return ""
}
return *c.SystemFingerprint
}
func (c *ChatCompletionsStreamResponse) SetSystemFingerprint(s string) {
c.SystemFingerprint = &s
}
type ChatCompletionsStreamResponseSimple struct {
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Usage *Usage `json:"usage"`

View File

@@ -198,11 +198,11 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
//case common.ChannelTypeAIProxyLibrary:
// c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
case common.ChannelCloudflare:
c.Set("api_version", channel.Other)
}
}

View File

@@ -295,15 +295,11 @@ func (user *User) ValidateAndFill() (err error) {
// that means if your fields value is 0, '', false or other zero values,
// it wont be used to build query conditions
password := user.Password
if (user.Username == "" && user.Email == "") || password == "" {
if user.Username == "" || password == "" {
return errors.New("用户名或密码为空")
}
// find buy username or email
if user.Username != "" {
DB.Where(User{Username: user.Username}).First(user)
} else if user.Email != "" {
DB.Where(User{Email: user.Email}).First(user)
}
DB.Where("username = ? OR email = ?", user.Username, user.Username).First(user)
okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁")

View File

@@ -14,7 +14,7 @@ type Adaptor interface {
InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest)
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)

View File

@@ -42,11 +42,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
return baiduEmbeddingRequest, nil

View File

@@ -41,7 +41,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -99,11 +99,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
return baiduEmbeddingRequest, nil

View File

@@ -53,7 +53,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -72,6 +72,19 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
}
if textRequest.Stop != nil {
// stop maybe string/array string, convert to array string
switch textRequest.Stop.(type) {
case string:
claudeRequest.StopSequences = []string{textRequest.Stop.(string)}
case []interface{}:
stopSequences := make([]string, 0)
for _, stop := range textRequest.Stop.([]interface{}) {
stopSequences = append(stopSequences, stop.(string))
}
claudeRequest.StopSequences = stopSequences
}
}
formatMessages := make([]dto.Message, 0)
var lastMessage *dto.Message
for i, message := range textRequest.Messages {

View File

@@ -0,0 +1,76 @@
package cloudflare
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil
default:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch info.RelayMode {
case constant.RelayModeCompletions:
return convertCf2CompletionsRequest(*request), nil
default:
return request, nil
}
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = cfStreamHandler(c, resp, info)
} else {
err, usage = cfHandler(c, resp, info)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,38 @@
package cloudflare
var ModelList = []string{
"@cf/meta/llama-2-7b-chat-fp16",
"@cf/meta/llama-2-7b-chat-int8",
"@cf/mistral/mistral-7b-instruct-v0.1",
"@hf/thebloke/deepseek-coder-6.7b-base-awq",
"@hf/thebloke/deepseek-coder-6.7b-instruct-awq",
"@cf/deepseek-ai/deepseek-math-7b-base",
"@cf/deepseek-ai/deepseek-math-7b-instruct",
"@cf/thebloke/discolm-german-7b-v1-awq",
"@cf/tiiuae/falcon-7b-instruct",
"@cf/google/gemma-2b-it-lora",
"@hf/google/gemma-7b-it",
"@cf/google/gemma-7b-it-lora",
"@hf/nousresearch/hermes-2-pro-mistral-7b",
"@hf/thebloke/llama-2-13b-chat-awq",
"@cf/meta-llama/llama-2-7b-chat-hf-lora",
"@cf/meta/llama-3-8b-instruct",
"@hf/thebloke/llamaguard-7b-awq",
"@hf/thebloke/mistral-7b-instruct-v0.1-awq",
"@hf/mistralai/mistral-7b-instruct-v0.2",
"@cf/mistral/mistral-7b-instruct-v0.2-lora",
"@hf/thebloke/neural-chat-7b-v3-1-awq",
"@cf/openchat/openchat-3.5-0106",
"@hf/thebloke/openhermes-2.5-mistral-7b-awq",
"@cf/microsoft/phi-2",
"@cf/qwen/qwen1.5-0.5b-chat",
"@cf/qwen/qwen1.5-1.8b-chat",
"@cf/qwen/qwen1.5-14b-chat-awq",
"@cf/qwen/qwen1.5-7b-chat-awq",
"@cf/defog/sqlcoder-7b-2",
"@hf/nexusflow/starling-lm-7b-beta",
"@cf/tinyllama/tinyllama-1.1b-chat-v1.0",
"@hf/thebloke/zephyr-7b-beta-awq",
}
var ChannelName = "cloudflare"

View File

@@ -0,0 +1,13 @@
package cloudflare
import "one-api/dto"
type CfRequest struct {
Messages []dto.Message `json:"messages,omitempty"`
Lora string `json:"lora,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
}

View File

@@ -0,0 +1,121 @@
package cloudflare
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
)
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
p, _ := textRequest.Prompt.(string)
return &CfRequest{
Prompt: p,
MaxTokens: textRequest.GetMaxTokens(),
Stream: textRequest.Stream,
Temperature: textRequest.Temperature,
}
}
func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
id := service.GetResponseID(c)
var responseText string
isFirst := true
for scanner.Scan() {
data := scanner.Text()
if len(data) < len("data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\r")
if data == "[DONE]" {
break
}
var response dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &response)
if err != nil {
common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
continue
}
for _, choice := range response.Choices {
choice.Delta.Role = "assistant"
responseText += choice.Delta.GetContentString()
}
response.Id = id
response.Model = info.UpstreamModelName
err = service.ObjectData(c, response)
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
if err != nil {
common.LogError(c, "error_rendering_stream_response: "+err.Error())
}
}
if err := scanner.Err(); err != nil {
common.LogError(c, "error_scanning_stream_response: "+err.Error())
}
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
if err != nil {
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
}
}
service.Done(c)
err := resp.Body.Close()
if err != nil {
common.LogError(c, "close_response_body_failed: "+err.Error())
}
return nil, usage
}
func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var response dto.TextResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
response.Model = info.UpstreamModelName
var responseText string
for _, choice := range response.Choices {
responseText += choice.Message.StringContent()
}
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
response.Usage = *usage
response.Id = service.GetResponseID(c)
jsonResponse, err := json.Marshal(response)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
return nil, usage
}

View File

@@ -34,7 +34,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return requestOpenAI2Cohere(*request), nil
}

View File

@@ -7,7 +7,7 @@ type CohereRequest struct {
ChatHistory []ChatHistory `json:"chat_history"`
Message string `json:"message"`
Stream bool `json:"stream"`
MaxTokens int64 `json:"max_tokens"`
MaxTokens int `json:"max_tokens"`
}
type ChatHistory struct {

View File

@@ -32,7 +32,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -21,7 +21,7 @@ type DifyData struct {
type DifyChatCompletionResponse struct {
ConversationId string `json:"conversation_id"`
Answers string `json:"answers"`
Answer string `json:"answer"`
CreateAt int64 `json:"create_at"`
MetaData DifyMetaData `json:"metadata"`
}

View File

@@ -117,6 +117,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var difyResponse DifyChatCompletionResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
@@ -134,7 +135,7 @@ func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInf
Created: common.GetTimestamp(),
Usage: difyResponse.MetaData.Usage,
}
content, _ := json.Marshal(difyResponse.Answers)
content, _ := json.Marshal(difyResponse.Answer)
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{

View File

@@ -9,7 +9,6 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
)
type Adaptor struct {
@@ -52,7 +51,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -69,9 +68,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = geminiChatStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
err, usage = geminiChatStreamHandler(c, resp, info)
} else {
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -59,4 +59,11 @@ type GeminiChatPromptFeedback struct {
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
}
type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"one-api/common"
"one-api/constant"
@@ -162,8 +163,12 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
return &response
}
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) {
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseText := ""
responseJson := ""
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp()
var usage = &dto.Usage{}
dataChan := make(chan string, 5)
stopChan := make(chan bool, 2)
scanner := bufio.NewScanner(resp.Body)
@@ -182,6 +187,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
go func() {
for scanner.Scan() {
data := scanner.Text()
responseJson += data
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "\"text\": \"") {
continue
@@ -216,10 +222,10 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(dummy.Content)
response := dto.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Id: id,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "gemini-pro",
Created: createAt,
Model: info.UpstreamModelName,
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
@@ -230,15 +236,34 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
var geminiChatResponses []GeminiChatResponse
err := json.Unmarshal([]byte(responseJson), &geminiChatResponses)
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
log.Printf("cannot get gemini usage: %s", err.Error())
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
for _, response := range geminiChatResponses {
usage.PromptTokens = response.UsageMetadata.PromptTokenCount
usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount
}
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
return nil, responseText
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage
}
return nil, usage
}
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -267,11 +292,10 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model)
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)

View File

@@ -36,7 +36,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}

View File

@@ -10,7 +10,6 @@ import (
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
)
type Adaptor struct {
@@ -36,11 +35,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
return requestOpenAI2Embeddings(*request), nil
default:
@@ -58,11 +57,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
err, usage = openai.OpenaiStreamHandler(c, resp, info)
} else {
if info.RelayMode == relayconstant.RelayModeEmbeddings {
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)

View File

@@ -14,7 +14,6 @@ import (
"one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
)
@@ -74,10 +73,13 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if info.ChannelType != common.ChannelTypeOpenAI {
request.StreamOptions = nil
}
return request, nil
}
@@ -87,13 +89,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
var toolCount int
err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
err, usage = OpenaiStreamHandler(c, resp, info)
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -14,38 +14,36 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"strings"
"sync"
"time"
)
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) {
//checkSensitive := constant.ShouldCheckCompletionSensitive()
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
hasStreamUsage := false
responseId := ""
var createAt int64 = 0
var systemFingerprint string
model := info.UpstreamModelName
var responseTextBuilder strings.Builder
var usage dto.Usage
var usage = &dto.Usage{}
var streamItems []string // store stream items
toolCount := 0
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string, 5)
stopChan := make(chan bool, 2)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
defer ticker.Stop()
stopChan := make(chan bool)
defer close(stopChan)
defer close(dataChan)
var wg sync.WaitGroup
go func() {
wg.Add(1)
defer wg.Done()
var streamItems []string // store stream items
for scanner.Scan() {
info.SetFirstResponseTime()
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
@@ -53,58 +51,43 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
service.StringData(c, data)
streamItems = append(streamItems, data)
}
}
// 计算token
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
if streamResponse.Usage != nil {
if streamResponse.Usage.TotalTokens != 0 {
usage.PromptTokens += streamResponse.Usage.PromptTokens
usage.CompletionTokens += streamResponse.Usage.CompletionTokens
usage.TotalTokens += streamResponse.Usage.TotalTokens
}
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
}
} else {
for _, streamResponse := range streamResponses {
if streamResponse.Usage != nil {
if streamResponse.Usage.TotalTokens != 0 {
usage.PromptTokens += streamResponse.Usage.PromptTokens
usage.CompletionTokens += streamResponse.Usage.CompletionTokens
usage.TotalTokens += streamResponse.Usage.TotalTokens
}
common.SafeSendBool(stopChan, true)
}()
select {
case <-ticker.C:
// 超时处理逻辑
common.LogError(c, "streaming timeout")
case <-stopChan:
// 正常结束
}
// 计算token
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
responseId = streamResponse.Id
createAt = streamResponse.Created
systemFingerprint = streamResponse.GetSystemFingerprint()
model = streamResponse.Model
if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage
hasStreamUsage = true
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
@@ -120,67 +103,72 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
}
}
case relayconstant.RelayModeCompletions:
var streamResponses []dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
} else {
for _, streamResponse := range streamResponses {
responseId = streamResponse.Id
createAt = streamResponse.Created
systemFingerprint = streamResponse.GetSystemFingerprint()
model = streamResponse.Model
if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage
hasStreamUsage = true
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
} else {
for _, streamResponse := range streamResponses {
}
}
case relayconstant.RelayModeCompletions:
var streamResponses []dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
}
}
if len(dataChan) > 0 {
// wait data out
time.Sleep(2 * time.Second)
}
common.SafeSendBool(stopChan, true)
}()
service.SetEventStreamHeaders(c)
isFirst := true
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
defer ticker.Stop()
c.Stream(func(w io.Writer) bool {
select {
case <-ticker.C:
common.LogError(c, "reading data from upstream timeout")
return false
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
} else {
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
c.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
return false
}
})
}
if !hasStreamUsage {
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
if info.ShouldIncludeUsage && !hasStreamUsage {
response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
service.ObjectData(c, response)
}
service.Done(c)
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
common.LogError(c, "close_response_body_failed: "+err.Error())
}
wg.Wait()
return nil, &usage, responseTextBuilder.String(), toolCount
return nil, usage
}
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {

View File

@@ -33,7 +33,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -10,7 +10,6 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/service"
)
type Adaptor struct {
@@ -34,7 +33,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -54,11 +53,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
err, usage = openai.OpenaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -17,6 +17,7 @@ import (
type Adaptor struct {
Sign string
AppID int64
Action string
Version string
Timestamp int64
@@ -34,7 +35,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIReq
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/hyllm/v1/chat/completions", info.BaseUrl), nil
return fmt.Sprintf("%s/", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
@@ -46,17 +47,18 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
_, secretId, secretKey, err := parseTencentConfig(apiKey)
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
a.AppID = appId
if err != nil {
return nil, err
}
tencentRequest := requestOpenAI2Tencent(*request)
tencentRequest := requestOpenAI2Tencent(a, *request)
// we have to calculate the sign here
a.Sign = getTencentSign(*tencentRequest, a, secretId, secretKey)
return tencentRequest, nil

View File

@@ -30,17 +30,17 @@ type TencentChatRequest struct {
//
// 注意:
// 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。
Stream *bool `json:"Stream"`
Stream *bool `json:"Stream,omitempty"`
// 说明:
// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
// 3. 非必要不建议使用,不合理的取值会影响效果。
TopP *float64 `json:"TopP"`
TopP *float64 `json:"TopP,omitempty"`
// 说明:
// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
// 3. 非必要不建议使用,不合理的取值会影响效果。
Temperature *float64 `json:"Temperature"`
Temperature *float64 `json:"Temperature,omitempty"`
}
type TencentError struct {
@@ -69,3 +69,7 @@ type TencentChatResponse struct {
Note string `json:"Note,omitempty"` // 注释
ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}
type TencentChatResponseSB struct {
Response TencentChatResponse `json:"Response,omitempty"`
}

View File

@@ -22,7 +22,7 @@ import (
// https://cloud.tencent.com/document/product/1729/97732
func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest {
func requestOpenAI2Tencent(a *Adaptor, request dto.GeneralOpenAIRequest) *TencentChatRequest {
messages := make([]*TencentMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
@@ -31,17 +31,23 @@ func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest
Role: message.Role,
})
}
return &TencentChatRequest{
Temperature: &request.Temperature,
TopP: &request.TopP,
Stream: &request.Stream,
Messages: messages,
Model: &request.Model,
var req = TencentChatRequest{
Stream: &request.Stream,
Messages: messages,
Model: &request.Model,
}
if request.TopP != 0 {
req.TopP = &request.TopP
}
if request.Temperature != 0 {
req.Temperature = &request.Temperature
}
return &req
}
func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: response.Id,
Object: "chat.completion",
Created: common.GetTimestamp(),
Usage: dto.Usage{
@@ -129,7 +135,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
}
func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var TencentResponse TencentChatResponse
var tencentSb TencentChatResponseSB
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -138,20 +144,20 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &TencentResponse)
err = json.Unmarshal(responseBody, &tencentSb)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if TencentResponse.Error.Code != 0 {
if tencentSb.Response.Error.Code != 0 {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code,
Message: tencentSb.Response.Error.Message,
Code: tencentSb.Response.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -33,7 +33,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -37,7 +37,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -10,7 +10,6 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/service"
)
type Adaptor struct {
@@ -35,7 +34,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -55,13 +54,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
var toolCount int
err, usage, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
err, usage = openai.OpenaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -1,7 +1,7 @@
package zhipu_4v
var ModelList = []string{
"glm-4", "glm-4v", "glm-3-turbo",
"glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools",
}
var ChannelName = "zhipu_4v"

View File

@@ -17,6 +17,7 @@ type RelayInfo struct {
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
setFirstResponse bool
ApiType int
IsStream bool
RelayMode int
@@ -67,7 +68,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if info.ChannelType == common.ChannelTypeAzure {
info.ApiVersion = GetAPIVersion(c)
}
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || info.ChannelType == common.ChannelTypeAws {
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
info.ChannelType == common.ChannelCloudflare {
info.SupportStreamOptions = true
}
return info
@@ -81,6 +84,13 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
info.IsStream = isStream
}
func (info *RelayInfo) SetFirstResponseTime() {
if !info.setFirstResponse {
info.FirstResponseTime = time.Now()
info.setFirstResponse = true
}
}
type TaskRelayInfo struct {
ChannelType int
ChannelId int

View File

@@ -38,6 +38,7 @@ func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.Open
var textResponse dto.TextResponseWithError
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
OpenAIErrorWithStatusCode.Error.Message = fmt.Sprintf("error unmarshalling response body: %s", responseBody)
return
}
OpenAIErrorWithStatusCode.Error = textResponse.Error

View File

@@ -22,6 +22,7 @@ const (
APITypeCohere
APITypeDify
APITypeJina
APITypeCloudflare
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -63,6 +64,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeDify
case common.ChannelTypeJina:
apiType = APITypeJina
case common.ChannelCloudflare:
apiType = APITypeCloudflare
}
if apiType == -1 {
return APITypeOpenAI, false

View File

@@ -153,7 +153,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
adaptor.Init(relayInfo, *textRequest)
var requestBody io.Reader
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}

View File

@@ -7,6 +7,7 @@ import (
"one-api/relay/channel/aws"
"one-api/relay/channel/baidu"
"one-api/relay/channel/claude"
"one-api/relay/channel/cloudflare"
"one-api/relay/channel/cohere"
"one-api/relay/channel/dify"
"one-api/relay/channel/gemini"
@@ -59,6 +60,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &dify.Adaptor{}
case constant.APITypeJina:
return &jina.Adaptor{}
case constant.APITypeCloudflare:
return &cloudflare.Adaptor{}
}
return nil
}

View File

@@ -74,14 +74,14 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
return false
}
func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError, status int) bool {
func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool {
if !common.AutomaticEnableChannelEnabled {
return false
}
if err != nil {
return false
}
if openAIErr != nil {
if openaiWithStatusErr != nil {
return false
}
if status != common.ChannelStatusAutoDisabled {

View File

@@ -35,3 +35,8 @@ func ObjectData(c *gin.Context, object interface{}) error {
func Done(c *gin.Context) {
StringData(c, "[DONE]")
}
func GetResponseID(c *gin.Context) string {
logID := c.GetString("X-Oneapi-Request-Id")
return fmt.Sprintf("chatcmpl-%s", logID)
}

View File

@@ -36,3 +36,7 @@ func GenerateFinalUsageResponse(id string, createAt int64, model string, usage d
Usage: &usage,
}
}
func ValidUsage(usage *dto.Usage) bool {
return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0)
}

View File

@@ -367,7 +367,7 @@ const LogsTable = () => {
dataIndex: 'content',
render: (text, record, index) => {
let other = getLogOther(record.other);
if (other == null) {
if (other == null || record.type !== 2) {
return (
<Paragraph
ellipsis={{

View File

@@ -99,6 +99,7 @@ export const CHANNEL_OPTIONS = [
color: 'orange',
label: 'Google PaLM2',
},
{ key: 39, text: 'Cloudflare', value: 39, color: 'grey', label: 'Cloudflare' },
{ key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
{ key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },

View File

@@ -605,6 +605,24 @@ const EditChannel = (props) => {
/>
</>
)}
{inputs.type === 39 && (
<>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>Account ID</Typography.Text>
</div>
<Input
name='other'
placeholder={
'请输入Account ID例如d6b5da8hk1awo8nap34ube6gh'
}
onChange={(value) => {
handleInputChange('other', value);
}}
value={inputs.other}
autoComplete='new-password'
/>
</>
)}
<div style={{ marginTop: 10 }}>
<Typography.Text strong>模型</Typography.Text>
</div>