Compare commits

..

10 Commits

Author SHA1 Message Date
JustSong
ed717211aa chore: adjust default rate limit config 2024-06-13 00:35:37 +08:00
JustSong
6ccf3f3cfc chore: add logger.SysLogf function 2024-06-13 00:28:56 +08:00
jinjianming
f74577141c fix: fix default token not created in some cases (#1510)
* 修复git、微信等用户注册不会创建默认令牌问题

修复git、微信等用户注册不会创建默认令牌问题

* 修复git、微信等用户注册不会创建默认令牌问题

删除普通用户注册代码

* fix: do not block if error happened

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-06-13 00:20:48 +08:00
Buer
6aafb7a99e fix: channel edit settings key error (#1496) 2024-06-13 00:08:49 +08:00
Zhong Liu
c1971870fa fix: support for Spark Lite model (#1526)
* fix: Support for Spark Lite model

* fix: fix panic

* fix: fix xunfei version config

---------

Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-06-13 00:07:26 +08:00
wagxuebing
f83894c83f fix: xunfei interface call 4001 error (#1499)
Co-authored-by: lynnssb <lynntobing@gmail.com>
2024-06-12 23:12:58 +08:00
fxsome
e9981fff36 feat: post all messages for cloudflare (#1515) 2024-06-08 13:34:23 +08:00
取梦为饮
98669d5d48 feat: add support for bytedance's doubao (#1438)
* 增加豆包大模型支持

* chore: update channel options & add prompt

---------

Co-authored-by: 康龙彪 <longbiao.kang@i-tudou.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-06-08 13:26:26 +08:00
Wei Tingjiang
9321427c6e feat: support gemini embeddings (text-embedding-004,embedding-001) (#1475)
* Refactor Gemini Adaptor to Support Embeddings

* Add new models to ModelList
2024-05-29 01:17:32 +08:00
JustSong
ceea4c6d4a feat: support user content download proxy & relay proxy now 2024-05-29 01:14:00 +08:00
33 changed files with 353 additions and 128 deletions

View File

@@ -68,6 +68,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [Anthropic Claude 系列模型](https://anthropic.com) (支持 AWS Claude)
+ [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google)
+ [x] [Mistral 系列模型](https://mistral.ai/)
+ [x] [字节跳动豆包大模型](https://console.volcengine.com/ark/region:ark+cn-beijing/model)
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
+ [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
@@ -76,7 +77,6 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
+ [x] [Moonshot AI](https://platform.moonshot.cn/)
+ [x] [百川大模型](https://platform.baichuan-ai.com)
+ [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP)
+ [x] [MINIMAX](https://api.minimax.chat/)
+ [x] [Groq](https://wow.groq.com/)
+ [x] [Ollama](https://github.com/ollama/ollama)
@@ -384,14 +384,17 @@ graph LR
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
18. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`
19. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`
20. `GEMINI_VERSION`One API 所使用的 Gemini 版本,默认为 `v1`
21. `THEME`:系统的主题设置,默认 `default`,具体可选值参考[此处](./web/README.md)
22. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。
23. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
24. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`
25. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌
18. `RELAY_PROXY`:设置后使用该代理来请求 API
19. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒
20. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片
21. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`
22. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`。
23. `GEMINI_VERSION`One API 所使用的 Gemini 版本,默认为 `v1`。
24. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)
25. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`
26. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
27. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
28. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。

60
common/client/init.go Normal file
View File

@@ -0,0 +1,60 @@
package client
import (
"fmt"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"net/http"
"net/url"
"time"
)
var HTTPClient *http.Client
var ImpatientHTTPClient *http.Client
var UserContentRequestHTTPClient *http.Client
func Init() {
if config.UserContentRequestProxy != "" {
logger.SysLog(fmt.Sprintf("using %s as proxy to fetch user content", config.UserContentRequestProxy))
proxyURL, err := url.Parse(config.UserContentRequestProxy)
if err != nil {
logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy))
}
transport := &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
UserContentRequestHTTPClient = &http.Client{
Transport: transport,
Timeout: time.Second * time.Duration(config.UserContentRequestTimeout),
}
} else {
UserContentRequestHTTPClient = &http.Client{}
}
var transport http.RoundTripper
if config.RelayProxy != "" {
logger.SysLog(fmt.Sprintf("using %s as api relay proxy", config.RelayProxy))
proxyURL, err := url.Parse(config.RelayProxy)
if err != nil {
logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy))
}
transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
}
if config.RelayTimeout == 0 {
HTTPClient = &http.Client{
Transport: transport,
}
} else {
HTTPClient = &http.Client{
Timeout: time.Duration(config.RelayTimeout) * time.Second,
Transport: transport,
}
}
ImpatientHTTPClient = &http.Client{
Timeout: 5 * time.Second,
Transport: transport,
}
}

View File

@@ -117,10 +117,10 @@ var ValidThemes = map[string]bool{
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 240)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 120)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
@@ -144,3 +144,7 @@ var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN")
var GeminiVersion = env.String("GEMINI_VERSION", "v1")
var RelayProxy = env.String("RELAY_PROXY", "")
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)

View File

@@ -3,6 +3,7 @@ package image
import (
"bytes"
"encoding/base64"
"github.com/songquanpeng/one-api/common/client"
"image"
_ "image/gif"
_ "image/jpeg"
@@ -19,7 +20,7 @@ import (
var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
func IsImageUrl(url string) (bool, error) {
resp, err := http.Head(url)
resp, err := client.UserContentRequestHTTPClient.Head(url)
if err != nil {
return false, err
}
@@ -34,7 +35,7 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) {
if !isImage {
return
}
resp, err := http.Get(url)
resp, err := client.UserContentRequestHTTPClient.Get(url)
if err != nil {
return
}

View File

@@ -43,11 +43,19 @@ func SysLog(s string) {
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func SysLogf(format string, a ...any) {
SysLog(fmt.Sprintf(format, a...))
}
func SysError(s string) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func SysErrorf(format string, a ...any) {
SysError(fmt.Sprintf(format, a...))
}
func Debug(ctx context.Context, msg string) {
if config.DebugEnabled {
logHelper(ctx, loggerDEBUG, msg)

View File

@@ -4,12 +4,12 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/client"
"io"
"net/http"
"strconv"

View File

@@ -6,8 +6,6 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/model"
"net/http"
@@ -111,7 +109,6 @@ func Logout(c *gin.Context) {
}
func Register(c *gin.Context) {
ctx := c.Request.Context()
if !config.RegisterEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了新用户注册",
@@ -176,28 +173,7 @@ func Register(c *gin.Context) {
})
return
}
go func() {
err := user.ValidateAndFill()
if err != nil {
logger.Errorf(ctx, "user.ValidateAndFill failed: %w", err)
return
}
cleanToken := model.Token{
UserId: user.Id,
Name: "default",
Key: random.GenerateKey(),
CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(),
ExpiredTime: -1,
RemainQuota: -1,
UnlimitedQuota: true,
}
err = cleanToken.Insert()
if err != nil {
logger.Errorf(ctx, "cleanToken.Insert failed: %w", err)
return
}
}()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",

View File

@@ -7,6 +7,7 @@ import (
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller"
@@ -23,7 +24,7 @@ var buildFS embed.FS
func main() {
logger.SetupLogger()
logger.SysLog(fmt.Sprintf("One API %s started", common.Version))
logger.SysLogf("One API %s started", common.Version)
if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode)
}
@@ -94,6 +95,7 @@ func main() {
logger.SysLog("metric enabled, will disable channel if too much request failed")
}
openai.InitTokenEncoders()
client.Init()
// Initialize HTTP server
server := gin.New()

View File

@@ -67,26 +67,28 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set(ctxkey.BaseURL, channel.GetBaseURL())
cfg, _ := channel.LoadConfig()
// this is for backward compatibility
switch channel.Type {
case channeltype.Azure:
if cfg.APIVersion == "" {
cfg.APIVersion = channel.Other
}
case channeltype.Xunfei:
if cfg.APIVersion == "" {
cfg.APIVersion = channel.Other
}
case channeltype.Gemini:
if cfg.APIVersion == "" {
cfg.APIVersion = channel.Other
}
case channeltype.AIProxyLibrary:
if cfg.LibraryID == "" {
cfg.LibraryID = channel.Other
}
case channeltype.Ali:
if cfg.Plugin == "" {
cfg.Plugin = channel.Other
if channel.Other != nil {
switch channel.Type {
case channeltype.Azure:
if cfg.APIVersion == "" {
cfg.APIVersion = *channel.Other
}
case channeltype.Xunfei:
if cfg.APIVersion == "" {
cfg.APIVersion = *channel.Other
}
case channeltype.Gemini:
if cfg.APIVersion == "" {
cfg.APIVersion = *channel.Other
}
case channeltype.AIProxyLibrary:
if cfg.LibraryID == "" {
cfg.LibraryID = *channel.Other
}
case channeltype.Ali:
if cfg.Plugin == "" {
cfg.Plugin = *channel.Other
}
}
}
c.Set(ctxkey.Config, cfg)

View File

@@ -27,7 +27,7 @@ type Channel struct {
TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
Other string `json:"other"` // DEPRECATED: please save config to field Config
Other *string `json:"other"` // DEPRECATED: please save config to field Config
Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"`

View File

@@ -6,6 +6,7 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"gorm.io/gorm"
@@ -140,6 +141,22 @@ func (user *User) Insert(inviterId int) error {
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
}
}
// create default token
cleanToken := Token{
UserId: user.Id,
Name: "default",
Key: random.GenerateKey(),
CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(),
ExpiredTime: -1,
RemainQuota: -1,
UnlimitedQuota: true,
}
result.Error = cleanToken.Insert()
if result.Error != nil {
// do not block
logger.SysError(fmt.Sprintf("create default token for user %d failed: %s", user.Id, result.Error.Error()))
}
return nil
}

View File

@@ -7,9 +7,9 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/client"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"

View File

@@ -17,15 +17,21 @@ import (
)
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
lastMessage := textRequest.Messages[len(textRequest.Messages)-1]
return &Request{
MaxTokens: textRequest.MaxTokens,
Prompt: lastMessage.StringContent(),
Stream: textRequest.Stream,
Temperature: textRequest.Temperature,
}
var promptBuilder strings.Builder
for _, message := range textRequest.Messages {
promptBuilder.WriteString(message.StringContent())
promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息
}
return &Request{
MaxTokens: textRequest.MaxTokens,
Prompt: promptBuilder.String(),
Stream: textRequest.Stream,
Temperature: textRequest.Temperature,
}
}
func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,

View File

@@ -4,7 +4,7 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/client"
"github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/relay/meta"
"io"
"net/http"

View File

@@ -0,0 +1,13 @@
package doubao
// https://console.volcengine.com/ark/region:ark+cn-beijing/model
var ModelList = []string{
"Doubao-pro-128k",
"Doubao-pro-32k",
"Doubao-pro-4k",
"Doubao-lite-128k",
"Doubao-lite-32k",
"Doubao-lite-4k",
"Doubao-embedding",
}

View File

@@ -0,0 +1,14 @@
package doubao
import (
"fmt"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/relaymode"
)
func GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Mode == relaymode.ChatCompletions {
return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil
}
return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode)
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
)
type Adaptor struct {
@@ -24,7 +25,14 @@ func (a *Adaptor) Init(meta *meta.Meta) {
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion)
action := "generateContent"
action := ""
switch meta.Mode {
case relaymode.Embeddings:
action = "batchEmbedContents"
default:
action = "generateContent"
}
if meta.IsStream {
action = "streamGenerateContent?alt=sse"
}
@@ -41,7 +49,14 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
switch relayMode {
case relaymode.Embeddings:
geminiEmbeddingRequest := ConvertEmbeddingRequest(*request)
return geminiEmbeddingRequest, nil
default:
geminiRequest := ConvertRequest(*request)
return geminiRequest, nil
}
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
@@ -61,7 +76,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
err, responseText = StreamHandler(c, resp)
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
switch meta.Mode {
case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
}
return
}

View File

@@ -4,5 +4,5 @@ package gemini
var ModelList = []string{
"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
"gemini-pro-vision", "gemini-1.0-pro-vision-001",
"gemini-pro-vision", "gemini-1.0-pro-vision-001", "embedding-001", "text-embedding-004",
}

View File

@@ -134,6 +134,29 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
return &geminiRequest
}
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbeddingRequest {
inputs := request.ParseInput()
requests := make([]EmbeddingRequest, len(inputs))
model := fmt.Sprintf("models/%s", request.Model)
for i, input := range inputs {
requests[i] = EmbeddingRequest{
Model: model,
Content: ChatContent{
Parts: []Part{
{
Text: input,
},
},
},
}
}
return &BatchEmbeddingRequest{
Requests: requests,
}
}
type ChatResponse struct {
Candidates []ChatCandidate `json:"candidates"`
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
@@ -230,6 +253,23 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC
return &response
}
func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),
Model: "gemini-embedding",
Usage: model.Usage{TotalTokens: 0},
}
for _, item := range response.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: 0,
Embedding: item.Values,
})
}
return &openAIEmbeddingResponse
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
scanner := bufio.NewScanner(resp.Body)
@@ -337,3 +377,39 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var geminiEmbeddingResponse EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &geminiEmbeddingResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if geminiEmbeddingResponse.Error != nil {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: geminiEmbeddingResponse.Error.Message,
Type: "gemini_error",
Param: "",
Code: geminiEmbeddingResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseGemini2OpenAI(&geminiEmbeddingResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -7,6 +7,33 @@ type ChatRequest struct {
Tools []ChatTools `json:"tools,omitempty"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Content ChatContent `json:"content"`
TaskType string `json:"taskType,omitempty"`
Title string `json:"title,omitempty"`
OutputDimensionality int `json:"outputDimensionality,omitempty"`
}
type BatchEmbeddingRequest struct {
Requests []EmbeddingRequest `json:"requests"`
}
type EmbeddingData struct {
Values []float64 `json:"values"`
}
type EmbeddingResponse struct {
Embeddings []EmbeddingData `json:"embeddings"`
Error *Error `json:"error,omitempty"`
}
type Error struct {
Code int `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Status string `json:"status,omitempty"`
}
type InlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/doubao"
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
@@ -45,6 +46,8 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
case channeltype.Minimax:
return minimax.GetRequestURL(meta)
case channeltype.Doubao:
return doubao.GetRequestURL(meta)
default:
return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
}

View File

@@ -4,6 +4,7 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/ai360"
"github.com/songquanpeng/one-api/relay/adaptor/baichuan"
"github.com/songquanpeng/one-api/relay/adaptor/deepseek"
"github.com/songquanpeng/one-api/relay/adaptor/doubao"
"github.com/songquanpeng/one-api/relay/adaptor/groq"
"github.com/songquanpeng/one-api/relay/adaptor/lingyiwanwu"
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
@@ -20,6 +21,7 @@ var CompatibleChannels = []int{
channeltype.Moonshot,
channeltype.Baichuan,
channeltype.Minimax,
channeltype.Doubao,
channeltype.Mistral,
channeltype.Groq,
channeltype.LingYiWanWu,
@@ -52,6 +54,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
return "deepseek", deepseek.ModelList
case channeltype.TogetherAI:
return "together.ai", togetherai.ModelList
case channeltype.Doubao:
return "doubao", doubao.ModelList
default:
return "openai", ModelList
}

View File

@@ -27,14 +27,6 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
version := parseAPIVersionByModelName(meta.ActualModelName)
if version == "" {
version = a.meta.Config.APIVersion
}
if version == "" {
version = "v1.1"
}
a.meta.Config.APIVersion = version
// check DoResponse for auth part
return nil
}
@@ -69,6 +61,14 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
if a.request == nil {
return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
}
version := parseAPIVersionByModelName(meta.ActualModelName)
if version == "" {
version = a.meta.Config.APIVersion
}
if version == "" {
version = "v1.1"
}
a.meta.Config.APIVersion = version
if meta.IsStream {
err, usage = StreamHandler(c, meta, *a.request, splits[0], splits[1], splits[2])
} else {

View File

@@ -5,7 +5,14 @@ import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/songquanpeng/one-api/common"
@@ -16,11 +23,6 @@ import (
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// https://console.xfyun.cn/services/cbm
@@ -28,11 +30,7 @@ import (
func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
var lastToolCalls []model.Tool
for _, message := range request.Messages {
if message.ToolCalls != nil {
lastToolCalls = message.ToolCalls
}
messages = append(messages, Message{
Role: message.Role,
Content: message.StringContent(),
@@ -45,9 +43,10 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string
xunfeiRequest.Parameter.Chat.TopK = request.N
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
xunfeiRequest.Payload.Message.Text = messages
if len(lastToolCalls) != 0 {
for _, toolCall := range lastToolCalls {
xunfeiRequest.Payload.Functions.Text = append(xunfeiRequest.Payload.Functions.Text, toolCall.Function)
if strings.HasPrefix(domain, "generalv3") {
xunfeiRequest.Payload.Functions = &Functions{
Text: request.Tools,
}
}
@@ -203,7 +202,7 @@ func Handler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIReq
}
}
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
return openai.ErrorWrapper(err, "xunfei_empty_response_detected", http.StatusInternalServerError), nil
return openai.ErrorWrapper(errors.New("xunfei empty response detected"), "xunfei_empty_response_detected", http.StatusInternalServerError), nil
}
xunfeiResponse.Payload.Choices.Text[0].Content = content

View File

@@ -9,6 +9,10 @@ type Message struct {
Content string `json:"content"`
}
type Functions struct {
Text []model.Tool `json:"text,omitempty"`
}
type ChatRequest struct {
Header struct {
AppId string `json:"app_id"`
@@ -26,9 +30,7 @@ type ChatRequest struct {
Message struct {
Text []Message `json:"text"`
} `json:"message"`
Functions struct {
Text []model.Function `json:"text,omitempty"`
} `json:"functions,omitempty"`
Functions *Functions `json:"functions,omitempty"`
} `json:"payload"`
}

View File

@@ -41,6 +41,6 @@ const (
Cloudflare
DeepL
TogetherAI
Doubao
Dummy
)

View File

@@ -41,6 +41,7 @@ var ChannelBaseURLs = []string{
"https://api.cloudflare.com", // 37
"https://api-free.deepl.com", // 38
"https://api.together.xyz", // 39
"https://ark.cn-beijing.volces.com", // 40
}
func init() {

View File

@@ -1,24 +0,0 @@
package client
import (
"github.com/songquanpeng/one-api/common/config"
"net/http"
"time"
)
var HTTPClient *http.Client
var ImpatientHTTPClient *http.Client
func init() {
if config.RelayTimeout == 0 {
HTTPClient = &http.Client{}
} else {
HTTPClient = &http.Client{
Timeout: time.Duration(config.RelayTimeout) * time.Second,
}
}
ImpatientHTTPClient = &http.Client{
Timeout: 5 * time.Second,
}
}

View File

@@ -9,6 +9,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger"
@@ -17,7 +18,6 @@ import (
"github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/client"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"

View File

@@ -47,6 +47,12 @@ export const CHANNEL_OPTIONS = {
value: 28,
color: 'warning'
},
40: {
key: 40,
text: '字节跳动豆包',
value: 40,
color: 'primary'
},
15: {
key: 15,
text: '百度文心千帆',

View File

@@ -163,7 +163,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
values.other = 'v2.1';
}
if (values.key === '') {
if (values.config.ak !== '' && values.config.sk !== '' && values.config.region !== '') {
if (values.config.ak && values.config.sk && values.config.region) {
values.key = `${values.config.ak}|${values.config.sk}|${values.config.region}`;
}
}

View File

@@ -6,6 +6,7 @@ export const CHANNEL_OPTIONS = [
{key: 11, text: 'Google PaLM2', value: 11, color: 'orange'},
{key: 24, text: 'Google Gemini', value: 24, color: 'orange'},
{key: 28, text: 'Mistral AI', value: 28, color: 'orange'},
{key: 40, text: '字节跳动豆包', value: 40, color: 'blue'},
{key: 15, text: '百度文心千帆', value: 15, color: 'blue'},
{key: 17, text: '阿里通义千问', value: 17, color: 'orange'},
{key: 18, text: '讯飞星火认知', value: 18, color: 'blue'},

View File

@@ -181,9 +181,6 @@ const EditChannel = () => {
if (localInputs.type === 3 && localInputs.other === '') {
localInputs.other = '2024-03-01-preview';
}
if (localInputs.type === 18 && localInputs.other === '') {
localInputs.other = 'v2.1';
}
let res;
localInputs.models = localInputs.models.join(',');
localInputs.group = localInputs.groups.join(',');
@@ -362,6 +359,13 @@ const EditChannel = () => {
</Message>
)
}
{
inputs.type === 40 && (
<Message>
对于豆包而言需要手动去 <a target="_blank" href="https://console.volcengine.com/ark/region:ark+cn-beijing/endpoint">模型推理页面</a> `ep-20240608051426-tkxvl`
</Message>
)
}
<Form.Field>
<Form.Dropdown
label='模型'