Compare commits

...

22 Commits

Author SHA1 Message Date
CalciumIon
7b36a2b885 feat: support cloudflare worker ai 2024-07-13 19:55:22 +08:00
CalciumIon
c88f3741e6 feat: support claude stop_sequences 2024-07-11 18:44:45 +08:00
CalciumIon
4e7e206290 fix: gemini usage (close #354) 2024-07-10 16:01:09 +08:00
CalciumIon
579fc8129e fix: dify (close #355) 2024-07-10 15:36:17 +08:00
CalciumIon
f55f63f412 fix: email login 2024-07-09 21:36:31 +08:00
CalciumIon
0526c85732 feat: update stream options 2024-07-09 21:11:01 +08:00
CalciumIon
b75134ece4 fix: hunyuan 2024-07-08 23:42:16 +08:00
CalciumIon
a075598757 fix: stream options 2024-07-08 21:54:32 +08:00
CalciumIon
a984daa503 feat: update FORCE_STREAM_OPTION default value 2024-07-08 21:41:52 +08:00
CalciumIon
90abe7f27d fix: baidu max_output_tokens (#353) 2024-07-08 19:50:12 +08:00
CalciumIon
bb313eb26f ci: update ci 2024-07-08 19:48:03 +08:00
CalciumIon
02545e4856 fix: baidu max_output_tokens (close #353) 2024-07-08 19:46:45 +08:00
CalciumIon
49cec50908 fix: channel default test model 2024-07-08 17:06:29 +08:00
CalciumIon
4f6710e50c fix: 修复渠道晒筛选后无法展开测试模型 (close #297 #302) 2024-07-08 17:00:10 +08:00
CalciumIon
03b130f2b5 feat: 允许设置是否检测mj任务已完成才可进行action操作 (close #349) 2024-07-08 16:48:10 +08:00
CalciumIon
45b9de9df9 feat: able to use email to login (close #343,#348) 2024-07-08 16:28:56 +08:00
CalciumIon
e062cf32e3 fix: 日志详情 2024-07-08 15:48:28 +08:00
CalciumIon
52debe7572 feat: 完善stream_options 2024-07-08 02:04:21 +08:00
CalciumIon
df6502733c feat: 完善stream_options 2024-07-08 02:00:39 +08:00
CalciumIon
9896ba0a64 feat: support aws stream_options 2024-07-08 01:52:40 +08:00
CalciumIon
e8b93ed6ec feat: support claude stream_options 2024-07-08 01:45:43 +08:00
CalciumIon
b0e234e8f5 feat: support stream_options 2024-07-08 01:27:57 +08:00
49 changed files with 610 additions and 173 deletions

View File

@@ -4,6 +4,7 @@ on:
push:
tags:
- '*'
- '!*-alpha*'
workflow_dispatch:
inputs:
name:

View File

@@ -73,7 +73,7 @@
## 比原版One API多出的配置
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true` 可选值为 `true``false`
- `FORCE_STREAM_OPTION`覆盖客户端stream_options参数请求上游返回流模式usage目前仅支持 `OpenAI` 渠道类型
## 部署
### 部署要求
- 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机)

View File

@@ -212,36 +212,37 @@ const (
ChannelTypeSunoAPI = 36
ChannelTypeDify = 37
ChannelTypeJina = 38
ChannelCloudflare = 39
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"http://localhost:11434", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", //23
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"http://localhost:11434", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.tencentcloudapi.com", //23
"https://generativelanguage.googleapis.com", //24
"https://api.moonshot.cn", //25
"https://open.bigmodel.cn", //26
@@ -257,4 +258,5 @@ var ChannelBaseURLs = []string{
"", //36
"", //37
"https://api.jina.ai", //38
"https://api.cloudflare.com", //39
}

View File

@@ -6,3 +6,6 @@ import (
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
// ForceStreamOption 覆盖请求参数强制返回usage信息
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)

View File

@@ -4,6 +4,7 @@ var MjNotifyEnabled = false
var MjAccountFilterEnabled = false
var MjModeClearEnabled = false
var MjForwardUrlEnabled = true
var MjActionCheckSuccessEnabled = true
const (
MjErrorUnknown = 5

View File

@@ -12,6 +12,7 @@ import (
"net/url"
"one-api/common"
"one-api/dto"
"one-api/middleware"
"one-api/model"
"one-api/relay"
relaycommon "one-api/relay/common"
@@ -40,35 +41,13 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
Body: nil,
Header: make(http.Header),
}
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
//case common.ChannelTypeAIProxyLibrary:
// c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
}
meta := relaycommon.GenRelayInfo(c)
apiType, _ := constant.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
if testModel == "" {
if channel.TestModel != nil && *channel.TestModel != "" {
testModel = *channel.TestModel
} else {
if len(adaptor.GetModelList()) > 0 {
testModel = adaptor.GetModelList()[0]
if len(channel.GetModels()) > 0 {
testModel = channel.GetModels()[0]
} else {
testModel = "gpt-3.5-turbo"
}
@@ -88,6 +67,20 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
}
}
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
middleware.SetupContextForSelectedChannel(c, channel, testModel)
meta := relaycommon.GenRelayInfo(c)
apiType, _ := constant.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
request := buildTestRequest()
request.Model = testModel
meta.UpstreamModelName = testModel
@@ -121,7 +114,6 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
return errors.New("usage is nil"), nil
}
result := w.Result()
// print result.Body
respBody, err := io.ReadAll(result.Body)
if err != nil {
return err, nil

View File

@@ -11,6 +11,7 @@ type GeneralOpenAIRequest struct {
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
@@ -43,8 +44,12 @@ type OpenAIFunction struct {
Parameters any `json:"parameters,omitempty"`
}
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
return int64(r.MaxTokens)
type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
func (r GeneralOpenAIRequest) GetMaxTokens() int {
return int(r.MaxTokens)
}
func (r GeneralOpenAIRequest) ParseInput() []string {

View File

@@ -102,10 +102,12 @@ type ChatCompletionsStreamResponse struct {
Model string `json:"model"`
SystemFingerprint *string `json:"system_fingerprint"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Usage *Usage `json:"usage"`
}
type ChatCompletionsStreamResponseSimple struct {
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Usage *Usage `json:"usage"`
}
type CompletionsStreamResponse struct {

View File

@@ -198,11 +198,11 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
//case common.ChannelTypeAIProxyLibrary:
// c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
case common.ChannelCloudflare:
c.Set("api_version", channel.Other)
}
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"gorm.io/gorm"
"one-api/common"
"strings"
)
type Channel struct {
@@ -33,6 +34,13 @@ type Channel struct {
OtherInfo string `json:"other_info"`
}
func (channel *Channel) GetModels() []string {
if channel.Models == "" {
return []string{}
}
return strings.Split(strings.Trim(channel.Models, ","), ",")
}
func (channel *Channel) GetOtherInfo() map[string]interface{} {
otherInfo := make(map[string]interface{})
if channel.OtherInfo != "" {

View File

@@ -99,6 +99,7 @@ func InitOptionMap() {
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(constant.MjActionCheckSuccessEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
@@ -210,6 +211,8 @@ func updateOptionMap(key string, value string) (err error) {
constant.MjModeClearEnabled = boolValue
case "MjForwardUrlEnabled":
constant.MjForwardUrlEnabled = boolValue
case "MjActionCheckSuccessEnabled":
constant.MjActionCheckSuccessEnabled = boolValue
case "CheckSensitiveEnabled":
constant.CheckSensitiveEnabled = boolValue
case "CheckSensitiveOnPromptEnabled":

View File

@@ -298,7 +298,8 @@ func (user *User) ValidateAndFill() (err error) {
if user.Username == "" || password == "" {
return errors.New("用户名或密码为空")
}
DB.Where(User{Username: user.Username}).First(user)
// find buy username or email
DB.Where("username = ? OR email = ?", user.Username, user.Username).First(user)
okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁")

View File

@@ -68,7 +68,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = awsStreamHandler(c, info, a.RequestMode)
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
} else {
err, usage = awsHandler(c, info, a.RequestMode)
}

View File

@@ -13,6 +13,7 @@ import (
relaymodel "one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
@@ -112,7 +113,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
return nil, &usage
}
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
@@ -162,7 +163,6 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
@@ -214,6 +214,17 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
return false
}
})
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
err := service.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}

View File

@@ -19,7 +19,7 @@ type BaiduChatRequest struct {
System string `json:"system,omitempty"`
DisableSearch bool `json:"disable_search,omitempty"`
EnableCitation bool `json:"enable_citation,omitempty"`
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
UserId string `json:"user_id,omitempty"`
}

View File

@@ -23,14 +23,20 @@ var baiduTokenStore sync.Map
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
baiduRequest := BaiduChatRequest{
Temperature: request.Temperature,
TopP: request.TopP,
PenaltyScore: request.FrequencyPenalty,
Stream: request.Stream,
DisableSearch: false,
EnableCitation: false,
MaxOutputTokens: int(request.MaxTokens),
UserId: request.User,
Temperature: request.Temperature,
TopP: request.TopP,
PenaltyScore: request.FrequencyPenalty,
Stream: request.Stream,
DisableSearch: false,
EnableCitation: false,
UserId: request.User,
}
if request.MaxTokens != 0 {
maxTokens := int(request.MaxTokens)
if request.MaxTokens == 1 {
maxTokens = 2
}
baiduRequest.MaxOutputTokens = &maxTokens
}
for _, message := range request.Messages {
if message.Role == "system" {

View File

@@ -72,6 +72,19 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
}
if textRequest.Stop != nil {
// stop maybe string/array string, convert to array string
switch textRequest.Stop.(type) {
case string:
claudeRequest.StopSequences = []string{textRequest.Stop.(string)}
case []interface{}:
stopSequences := make([]string, 0)
for _, stop := range textRequest.Stop.([]interface{}) {
stopSequences = append(stopSequences, stop.(string))
}
claudeRequest.StopSequences = stopSequences
}
}
formatMessages := make([]dto.Message, 0)
var lastMessage *dto.Message
for i, message := range textRequest.Messages {
@@ -330,22 +343,15 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
response.Created = createdTime
response.Model = info.UpstreamModelName
jsonStr, err := json.Marshal(response)
err = service.ObjectData(c, response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
common.SysError(err.Error())
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
@@ -356,6 +362,18 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
}
}
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, usage
}

View File

@@ -0,0 +1,76 @@
package cloudflare
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil
default:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeCompletions:
return convertCf2CompletionsRequest(*request), nil
default:
return request, nil
}
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = cfStreamHandler(c, resp, info)
} else {
err, usage = cfHandler(c, resp, info)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,38 @@
package cloudflare
var ModelList = []string{
"@cf/meta/llama-2-7b-chat-fp16",
"@cf/meta/llama-2-7b-chat-int8",
"@cf/mistral/mistral-7b-instruct-v0.1",
"@hf/thebloke/deepseek-coder-6.7b-base-awq",
"@hf/thebloke/deepseek-coder-6.7b-instruct-awq",
"@cf/deepseek-ai/deepseek-math-7b-base",
"@cf/deepseek-ai/deepseek-math-7b-instruct",
"@cf/thebloke/discolm-german-7b-v1-awq",
"@cf/tiiuae/falcon-7b-instruct",
"@cf/google/gemma-2b-it-lora",
"@hf/google/gemma-7b-it",
"@cf/google/gemma-7b-it-lora",
"@hf/nousresearch/hermes-2-pro-mistral-7b",
"@hf/thebloke/llama-2-13b-chat-awq",
"@cf/meta-llama/llama-2-7b-chat-hf-lora",
"@cf/meta/llama-3-8b-instruct",
"@hf/thebloke/llamaguard-7b-awq",
"@hf/thebloke/mistral-7b-instruct-v0.1-awq",
"@hf/mistralai/mistral-7b-instruct-v0.2",
"@cf/mistral/mistral-7b-instruct-v0.2-lora",
"@hf/thebloke/neural-chat-7b-v3-1-awq",
"@cf/openchat/openchat-3.5-0106",
"@hf/thebloke/openhermes-2.5-mistral-7b-awq",
"@cf/microsoft/phi-2",
"@cf/qwen/qwen1.5-0.5b-chat",
"@cf/qwen/qwen1.5-1.8b-chat",
"@cf/qwen/qwen1.5-14b-chat-awq",
"@cf/qwen/qwen1.5-7b-chat-awq",
"@cf/defog/sqlcoder-7b-2",
"@hf/nexusflow/starling-lm-7b-beta",
"@cf/tinyllama/tinyllama-1.1b-chat-v1.0",
"@hf/thebloke/zephyr-7b-beta-awq",
}
var ChannelName = "cloudflare"

View File

@@ -0,0 +1,13 @@
package cloudflare
import "one-api/dto"
type CfRequest struct {
Messages []dto.Message `json:"messages,omitempty"`
Lora string `json:"lora,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
}

View File

@@ -0,0 +1,115 @@
package cloudflare
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
)
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
p, _ := textRequest.Prompt.(string)
return &CfRequest{
Prompt: p,
MaxTokens: textRequest.GetMaxTokens(),
Stream: textRequest.Stream,
Temperature: textRequest.Temperature,
}
}
func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
id := service.GetResponseID(c)
var responseText string
for scanner.Scan() {
data := scanner.Text()
if len(data) < len("data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\r")
if data == "[DONE]" {
break
}
var response dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &response)
if err != nil {
common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
continue
}
for _, choice := range response.Choices {
choice.Delta.Role = "assistant"
responseText += choice.Delta.GetContentString()
}
response.Id = id
response.Model = info.UpstreamModelName
err = service.ObjectData(c, response)
if err != nil {
common.LogError(c, "error_rendering_stream_response: "+err.Error())
}
}
if err := scanner.Err(); err != nil {
common.LogError(c, "error_scanning_stream_response: "+err.Error())
}
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
if err != nil {
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
}
}
service.Done(c)
err := resp.Body.Close()
if err != nil {
common.LogError(c, "close_response_body_failed: "+err.Error())
}
return nil, usage
}
func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var response dto.TextResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
response.Model = info.UpstreamModelName
var responseText string
for _, choice := range response.Choices {
responseText += choice.Message.StringContent()
}
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
response.Usage = *usage
response.Id = service.GetResponseID(c)
jsonResponse, err := json.Marshal(response)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
return nil, usage
}

View File

@@ -7,7 +7,7 @@ type CohereRequest struct {
ChatHistory []ChatHistory `json:"chat_history"`
Message string `json:"message"`
Stream bool `json:"stream"`
MaxTokens int64 `json:"max_tokens"`
MaxTokens int `json:"max_tokens"`
}
type ChatHistory struct {

View File

@@ -21,7 +21,7 @@ type DifyData struct {
type DifyChatCompletionResponse struct {
ConversationId string `json:"conversation_id"`
Answers string `json:"answers"`
Answer string `json:"answer"`
CreateAt int64 `json:"create_at"`
MetaData DifyMetaData `json:"metadata"`
}

View File

@@ -117,6 +117,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var difyResponse DifyChatCompletionResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
@@ -134,7 +135,7 @@ func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInf
Created: common.GetTimestamp(),
Usage: difyResponse.MetaData.Usage,
}
content, _ := json.Marshal(difyResponse.Answers)
content, _ := json.Marshal(difyResponse.Answer)
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{

View File

@@ -9,7 +9,6 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
)
type Adaptor struct {
@@ -69,9 +68,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = geminiChatStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
err, usage = geminiChatStreamHandler(c, resp, info)
} else {
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -59,4 +59,11 @@ type GeminiChatPromptFeedback struct {
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
}
type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"one-api/common"
"one-api/constant"
@@ -162,8 +163,12 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
return &response
}
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) {
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseText := ""
responseJson := ""
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp()
var usage = &dto.Usage{}
dataChan := make(chan string, 5)
stopChan := make(chan bool, 2)
scanner := bufio.NewScanner(resp.Body)
@@ -182,6 +187,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
go func() {
for scanner.Scan() {
data := scanner.Text()
responseJson += data
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "\"text\": \"") {
continue
@@ -216,10 +222,10 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(dummy.Content)
response := dto.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Id: id,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "gemini-pro",
Created: createAt,
Model: info.UpstreamModelName,
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
@@ -230,15 +236,34 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
var geminiChatResponses []GeminiChatResponse
err := json.Unmarshal([]byte(responseJson), &geminiChatResponses)
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
log.Printf("cannot get gemini usage: %s", err.Error())
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
for _, response := range geminiChatResponses {
usage.PromptTokens = response.UsageMetadata.PromptTokenCount
usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount
}
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
return nil, responseText
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage
}
return nil, usage
}
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -267,11 +292,10 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model)
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)

View File

@@ -59,8 +59,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
} else {
if info.RelayMode == relayconstant.RelayModeEmbeddings {
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)

View File

@@ -89,9 +89,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
var toolCount int
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -18,9 +18,10 @@ import (
"time"
)
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string, int) {
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) {
//checkSensitive := constant.ShouldCheckCompletionSensitive()
var responseTextBuilder strings.Builder
var usage dto.Usage
toolCount := 0
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -62,17 +63,24 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
streamItems = append(streamItems, data)
}
}
// 计算token
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
if streamResponse.Usage != nil {
if streamResponse.Usage.TotalTokens != 0 {
usage = *streamResponse.Usage
}
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
@@ -89,6 +97,11 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
} else {
for _, streamResponse := range streamResponses {
if streamResponse.Usage != nil {
if streamResponse.Usage.TotalTokens != 0 {
usage = *streamResponse.Usage
}
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
@@ -107,6 +120,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
var streamResponses []dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
@@ -159,10 +173,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
}
wg.Wait()
return nil, responseTextBuilder.String(), toolCount
return nil, &usage, responseTextBuilder.String(), toolCount
}
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {

View File

@@ -55,8 +55,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -17,6 +17,7 @@ import (
type Adaptor struct {
Sign string
AppID int64
Action string
Version string
Timestamp int64
@@ -34,7 +35,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIReq
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/hyllm/v1/chat/completions", info.BaseUrl), nil
return fmt.Sprintf("%s/", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
@@ -52,11 +53,12 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
_, secretId, secretKey, err := parseTencentConfig(apiKey)
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
a.AppID = appId
if err != nil {
return nil, err
}
tencentRequest := requestOpenAI2Tencent(*request)
tencentRequest := requestOpenAI2Tencent(a, *request)
// we have to calculate the sign here
a.Sign = getTencentSign(*tencentRequest, a, secretId, secretKey)
return tencentRequest, nil

View File

@@ -30,17 +30,17 @@ type TencentChatRequest struct {
//
// 注意:
// 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。
Stream *bool `json:"Stream"`
Stream *bool `json:"Stream,omitempty"`
// 说明:
// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
// 3. 非必要不建议使用,不合理的取值会影响效果。
TopP *float64 `json:"TopP"`
TopP *float64 `json:"TopP,omitempty"`
// 说明:
// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
// 3. 非必要不建议使用,不合理的取值会影响效果。
Temperature *float64 `json:"Temperature"`
Temperature *float64 `json:"Temperature,omitempty"`
}
type TencentError struct {
@@ -69,3 +69,7 @@ type TencentChatResponse struct {
Note string `json:"Note,omitempty"` // 注释
ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}
type TencentChatResponseSB struct {
Response TencentChatResponse `json:"Response,omitempty"`
}

View File

@@ -22,7 +22,7 @@ import (
// https://cloud.tencent.com/document/product/1729/97732
func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest {
func requestOpenAI2Tencent(a *Adaptor, request dto.GeneralOpenAIRequest) *TencentChatRequest {
messages := make([]*TencentMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
@@ -31,17 +31,23 @@ func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest
Role: message.Role,
})
}
return &TencentChatRequest{
Temperature: &request.Temperature,
TopP: &request.TopP,
Stream: &request.Stream,
Messages: messages,
Model: &request.Model,
var req = TencentChatRequest{
Stream: &request.Stream,
Messages: messages,
Model: &request.Model,
}
if request.TopP != 0 {
req.TopP = &request.TopP
}
if request.Temperature != 0 {
req.Temperature = &request.Temperature
}
return &req
}
func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: response.Id,
Object: "chat.completion",
Created: common.GetTimestamp(),
Usage: dto.Usage{
@@ -129,7 +135,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
}
func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var TencentResponse TencentChatResponse
var tencentSb TencentChatResponseSB
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -138,20 +144,20 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &TencentResponse)
err = json.Unmarshal(responseBody, &tencentSb)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if TencentResponse.Error.Code != 0 {
if tencentSb.Response.Error.Code != 0 {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code,
Message: tencentSb.Response.Error.Message,
Code: tencentSb.Response.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -57,9 +57,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
var toolCount int
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
err, usage, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -9,24 +9,26 @@ import (
)
type RelayInfo struct {
ChannelType int
ChannelId int
TokenId int
UserId int
Group string
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
ApiType int
IsStream bool
RelayMode int
UpstreamModelName string
RequestURLPath string
ApiVersion string
PromptTokens int
ApiKey string
Organization string
BaseUrl string
ChannelType int
ChannelId int
TokenId int
UserId int
Group string
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
ApiType int
IsStream bool
RelayMode int
UpstreamModelName string
RequestURLPath string
ApiVersion string
PromptTokens int
ApiKey string
Organization string
BaseUrl string
SupportStreamOptions bool
ShouldIncludeUsage bool
}
func GenRelayInfo(c *gin.Context) *RelayInfo {
@@ -65,6 +67,11 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if info.ChannelType == common.ChannelTypeAzure {
info.ApiVersion = GetAPIVersion(c)
}
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
info.ChannelType == common.ChannelCloudflare {
info.SupportStreamOptions = true
}
return info
}

View File

@@ -38,6 +38,7 @@ func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.Open
var textResponse dto.TextResponseWithError
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
OpenAIErrorWithStatusCode.Error.Message = fmt.Sprintf("error unmarshalling response body: %s", responseBody)
return
}
OpenAIErrorWithStatusCode.Error = textResponse.Error

View File

@@ -22,6 +22,7 @@ const (
APITypeCohere
APITypeDify
APITypeJina
APITypeCloudflare
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -63,6 +64,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeDify
case common.ChannelTypeJina:
apiType = APITypeJina
case common.ChannelCloudflare:
apiType = APITypeCloudflare
}
if apiType == -1 {
return APITypeOpenAI, false

View File

@@ -415,9 +415,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
originTask := model.GetByMJId(userId, mjId)
if originTask == nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
} else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理
if constant.MjActionCheckSuccessEnabled {
if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
}
}
channel, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")

View File

@@ -77,7 +77,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
@@ -87,7 +87,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if modelMap[textRequest.Model] != "" {
textRequest.Model = modelMap[textRequest.Model]
// set upstream model name
isModelMapped = true
//isModelMapped = true
}
}
relayInfo.UpstreamModelName = textRequest.Model
@@ -130,33 +130,38 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return openaiErr
}
// 如果不支持StreamOptions将StreamOptions设置为nil
if !relayInfo.SupportStreamOptions || !textRequest.Stream {
textRequest.StreamOptions = nil
} else {
// 如果支持StreamOptions且请求中没有设置StreamOptions根据配置文件设置StreamOptions
if constant.ForceStreamOption {
textRequest.StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
}
}
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
relayInfo.ShouldIncludeUsage = textRequest.StreamOptions.IncludeUsage
}
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(relayInfo, *textRequest)
var requestBody io.Reader
if relayInfo.ApiType == relayconstant.APITypeOpenAI {
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
} else {
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)

View File

@@ -7,6 +7,7 @@ import (
"one-api/relay/channel/aws"
"one-api/relay/channel/baidu"
"one-api/relay/channel/claude"
"one-api/relay/channel/cloudflare"
"one-api/relay/channel/cohere"
"one-api/relay/channel/dify"
"one-api/relay/channel/gemini"
@@ -59,6 +60,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &dify.Adaptor{}
case constant.APITypeJina:
return &jina.Adaptor{}
case constant.APITypeCloudflare:
return &cloudflare.Adaptor{}
}
return nil
}

View File

@@ -35,3 +35,8 @@ func ObjectData(c *gin.Context, object interface{}) error {
func Done(c *gin.Context) {
StringData(c, "[DONE]")
}
func GetResponseID(c *gin.Context) string {
logID := c.GetString("X-Oneapi-Request-Id")
return fmt.Sprintf("chatcmpl-%s", logID)
}

View File

@@ -24,3 +24,15 @@ func ResponseText2Usage(responseText string, modeName string, promptTokens int)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage, err
}
func GenerateFinalUsageResponse(id string, createAt int64, model string, usage dto.Usage) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: createAt,
Model: model,
SystemFingerprint: nil,
Choices: make([]dto.ChatCompletionsStreamResponseChoice, 0),
Usage: &usage,
}
}

View File

@@ -550,7 +550,7 @@ const ChannelsTable = () => {
);
const { success, message, data } = res.data;
if (success) {
setChannels(data);
setChannelFormat(data);
setActivePage(1);
} else {
showError(message);

View File

@@ -42,6 +42,7 @@ const OperationSetting = () => {
MjAccountFilterEnabled: false,
MjModeClearEnabled: false,
MjForwardUrlEnabled: false,
MjActionCheckSuccessEnabled: false,
DrawingEnabled: false,
DataExportEnabled: false,
DataExportDefaultTime: 'hour',

View File

@@ -99,6 +99,7 @@ export const CHANNEL_OPTIONS = [
color: 'orange',
label: 'Google PaLM2',
},
{ key: 39, text: 'Cloudflare', value: 39, color: 'grey', label: 'Cloudflare' },
{ key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
{ key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },

View File

@@ -153,8 +153,8 @@ export function renderModelPrice(
let inputRatioPrice = modelRatio * 2.0;
let completionRatioPrice = modelRatio * 2.0 * completionRatio;
let price =
(inputTokens / 1000000) * inputRatioPrice +
(completionTokens / 1000000) * completionRatioPrice;
(inputTokens / 1000000) * inputRatioPrice * groupRatio +
(completionTokens / 1000000) * completionRatioPrice * groupRatio;
return (
<>
<article>

View File

@@ -605,6 +605,24 @@ const EditChannel = (props) => {
/>
</>
)}
{inputs.type === 39 && (
<>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>Account ID</Typography.Text>
</div>
<Input
name='other'
placeholder={
'请输入Account ID例如d6b5da8hk1awo8nap34ube6gh'
}
onChange={(value) => {
handleInputChange('other', value);
}}
value={inputs.other}
autoComplete='new-password'
/>
</>
)}
<div style={{ marginTop: 10 }}>
<Typography.Text strong>模型</Typography.Text>
</div>

View File

@@ -16,6 +16,7 @@ export default function SettingsDrawing(props) {
MjAccountFilterEnabled: false,
MjForwardUrlEnabled: false,
MjModeClearEnabled: false,
MjActionCheckSuccessEnabled: false,
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -156,6 +157,25 @@ export default function SettingsDrawing(props) {
}
/>
</Col>
<Col span={8}>
<Form.Switch
field={'MjActionCheckSuccessEnabled'}
label={
<>
检测必须等待绘图成功才能进行放大等操作
</>
}
size='large'
checkedText=''
uncheckedText=''
onChange={(value) =>
setInputs({
...inputs,
MjActionCheckSuccessEnabled: value,
})
}
/>
</Col>
</Row>
<Row>
<Button size='large' onClick={onSubmit}>