Compare commits

...

10 Commits

Author SHA1 Message Date
Calcium-Ion
a9b978528e Merge pull request #335 from HynoR/fix/v1
fix testAllChannels nil pointer panic
2024-06-27 16:23:47 +08:00
CalciumIon
d1778bb20a feat: support Spark4.0 Ultra 2024-06-27 16:22:31 +08:00
HynoR
37a0930db4 fix testAllChannels nil pointer panic 2024-06-27 11:41:52 +08:00
CalciumIon
1117112225 feat: first response time support aws 2024-06-27 00:19:58 +08:00
CalciumIon
f2654692e8 feat: first response time support gemini and claude 2024-06-27 00:16:39 +08:00
CalciumIon
c834289f2c Update README.md 2024-06-27 00:16:04 +08:00
Calcium-Ion
bc649ddaa7 Merge pull request #331 from mageia/master
chore: Add Anthropic claude-3-5-sonnet-20240620 to model list
2024-06-26 22:12:23 +08:00
Calcium-Ion
c838beba3d Delete fly.toml 2024-06-26 22:11:58 +08:00
CalciumIon
1e9d64fd19 fix: sqlite too many SQL variables 2024-06-26 19:51:23 +08:00
Mageia
6b07e6fb97 chore: Add Anthropic claude-3-5-sonnet-20240620 to model list 2024-06-25 10:04:17 +08:00
13 changed files with 93 additions and 58 deletions

View File

@@ -5,8 +5,6 @@
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。
> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> [!WARNING]
> 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
@@ -87,6 +85,9 @@
## 部署
### 部署要求
- 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机)
- 远程数据库MySQL 版本 >= 5.7.8PgSQL 版本 >= 9.6
### 基于 Docker 进行部署
```shell
# 使用 SQLite 的部署命令:

View File

@@ -72,11 +72,12 @@ var defaultModelRatio = map[string]float64{
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5,
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens //renamed to ERNIE-3.5-8K
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens //renamed to ERNIE-Lite-8K
@@ -114,6 +115,7 @@ var defaultModelRatio = map[string]float64{
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858,
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens

View File

@@ -222,16 +222,18 @@ func testAllChannels(notify bool) error {
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{
StatusCode: -1,
Error: *openaiErr,
LocalError: false,
}
if isChannelEnabled && service.ShouldDisableChannel(&openAiErrWithStatus) && ban {
service.DisableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
service.EnableChannel(channel.Id, channel.Name)
if openaiErr != nil {
openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{
StatusCode: -1,
Error: *openaiErr,
LocalError: false,
}
if isChannelEnabled && service.ShouldDisableChannel(&openAiErrWithStatus) && ban {
service.DisableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
service.EnableChannel(channel.Id, channel.Name)
}
}
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)

View File

@@ -56,6 +56,11 @@ func getPriority(group string, model string, retry int) (int, error) {
return 0, err
}
if len(priorities) == 0 {
// 如果没有查询到优先级,则返回错误
return 0, errors.New("数据库一致性被破坏")
}
// 确定要使用的优先级
var priorityToUse int
if retry >= len(priorities) {
@@ -199,7 +204,7 @@ func FixAbility() (int, error) {
// Use channelIds to find channel not in abilities table
var abilityChannelIds []int
err = DB.Model(&Ability{}).Pluck("channel_id", &abilityChannelIds).Error
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
return 0, err

View File

@@ -14,6 +14,7 @@ import (
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
@@ -156,6 +157,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
var usage relaymodel.Usage
var id string
var model string
isFirst := true
createdTime := common.GetTimestamp()
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
@@ -166,6 +168,10 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
switch v := event.(type) {
case *types.ResponseStreamMemberChunk:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
claudeResp := new(claude.ClaudeResponse)
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
if err != nil {

View File

@@ -65,7 +65,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
err, usage = claudeStreamHandler(c, resp, info, a.RequestMode)
} else {
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -8,6 +8,7 @@ var ModelList = []string{
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-haiku-20240307",
"claude-3-5-sonnet-20240620",
}
var ChannelName = "claude"

View File

@@ -9,8 +9,10 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
)
func stopReasonClaude2OpenAI(reason string) string {
@@ -246,7 +248,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
return &fullTextResponse
}
func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
var usage *dto.Usage
usage = &dto.Usage{}
@@ -278,10 +280,15 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
}
stopChan <- true
}()
isFirst := true
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse ClaudeResponse
@@ -302,7 +309,7 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
responseId = claudeResponse.Message.Id
modelName = claudeResponse.Message.Model
info.UpstreamModelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text
@@ -316,7 +323,7 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
//response.Id = responseId
response.Id = responseId
response.Created = createdTime
response.Model = modelName
response.Model = info.UpstreamModelName
jsonStr, err := json.Marshal(response)
if err != nil {
@@ -335,13 +342,13 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
if usage.PromptTokens == 0 {
usage.PromptTokens = promptTokens
usage.PromptTokens = info.PromptTokens
}
if usage.CompletionTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, modelName, usage.PromptTokens)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
}
}
return nil, usage

View File

@@ -20,27 +20,27 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIReq
// 定义一个映射,存储模型名称和对应的版本
var modelVersionMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-flash-latest": "v1beta",
"gemini-ultra": "v1beta",
"gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-flash-latest": "v1beta",
"gemini-ultra": "v1beta",
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
version, beta := modelVersionMap[info.UpstreamModelName]
if !beta {
if info.ApiVersion != "" {
version = info.ApiVersion
} else {
version = "v1"
}
}
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
version, beta := modelVersionMap[info.UpstreamModelName]
if !beta {
if info.ApiVersion != "" {
version = info.ApiVersion
} else {
version = "v1"
}
}
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent"
}
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent"
}
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
@@ -63,7 +63,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = geminiChatStreamHandler(c, resp)
err, responseText = geminiChatStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)

View File

@@ -11,6 +11,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
"github.com/gin-gonic/gin"
)
@@ -160,7 +161,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
return &response
}
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) {
responseText := ""
dataChan := make(chan string)
stopChan := make(chan bool)
@@ -190,10 +191,15 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIEr
}
stopChan <- true
}()
isFirst := true
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
// this is used to prevent annoying \ related format bug
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
type dummyStruct struct {

View File

@@ -6,6 +6,7 @@ var ModelList = []string{
"SparkDesk-v2.1",
"SparkDesk-v3.1",
"SparkDesk-v3.5",
"SparkDesk-v4.0",
}
var ChannelName = "xunfei"

View File

@@ -252,6 +252,8 @@ func apiVersion2domain(apiVersion string) string {
return "generalv3"
case "v3.5":
return "generalv3.5"
case "v4.0":
return "4.0Ultra"
}
return "general" + apiVersion
}

View File

@@ -38,24 +38,26 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
group := c.GetString("group")
tokenUnlimited := c.GetBool("token_unlimited_quota")
startTime := time.Now()
// firstResponseTime = time.Now() - 1 second
apiType, _ := constant.ChannelType2APIType(channelType)
info := &RelayInfo{
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,
ChannelId: channelId,
TokenId: tokenId,
UserId: userId,
Group: group,
TokenUnlimited: tokenUnlimited,
StartTime: startTime,
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,
ChannelId: channelId,
TokenId: tokenId,
UserId: userId,
Group: group,
TokenUnlimited: tokenUnlimited,
StartTime: startTime,
FirstResponseTime: startTime.Add(-time.Second),
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
}
if info.BaseUrl == "" {
info.BaseUrl = common.ChannelBaseURLs[channelType]