mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-23 01:43:42 +08:00
Compare commits
10 Commits
v0.6.7-alp
...
v0.6.7-alp
Author | SHA1 | Date | |
---|---|---|---|
|
ed717211aa | ||
|
6ccf3f3cfc | ||
|
f74577141c | ||
|
6aafb7a99e | ||
|
c1971870fa | ||
|
f83894c83f | ||
|
e9981fff36 | ||
|
98669d5d48 | ||
|
9321427c6e | ||
|
ceea4c6d4a |
21
README.md
21
README.md
@@ -68,6 +68,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
||||
+ [x] [Anthropic Claude 系列模型](https://anthropic.com) (支持 AWS Claude)
|
||||
+ [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google)
|
||||
+ [x] [Mistral 系列模型](https://mistral.ai/)
|
||||
+ [x] [字节跳动豆包大模型](https://console.volcengine.com/ark/region:ark+cn-beijing/model)
|
||||
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
|
||||
+ [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
|
||||
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
|
||||
@@ -76,7 +77,6 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
||||
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
|
||||
+ [x] [Moonshot AI](https://platform.moonshot.cn/)
|
||||
+ [x] [百川大模型](https://platform.baichuan-ai.com)
|
||||
+ [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP)
|
||||
+ [x] [MINIMAX](https://api.minimax.chat/)
|
||||
+ [x] [Groq](https://wow.groq.com/)
|
||||
+ [x] [Ollama](https://github.com/ollama/ollama)
|
||||
@@ -384,14 +384,17 @@ graph LR
|
||||
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
|
||||
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
|
||||
17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
|
||||
18. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
|
||||
19. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。
|
||||
20. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。
|
||||
21. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
|
||||
22. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。
|
||||
23. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
|
||||
24. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
|
||||
25. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
|
||||
18. `RELAY_PROXY`:设置后使用该代理来请求 API。
|
||||
19. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。
|
||||
20. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。
|
||||
21. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
|
||||
22. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。
|
||||
23. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。
|
||||
24. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
|
||||
25. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。
|
||||
26. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
|
||||
27. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
|
||||
28. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
|
||||
|
||||
### 命令行参数
|
||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
||||
|
60
common/client/init.go
Normal file
60
common/client/init.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
var HTTPClient *http.Client
|
||||
var ImpatientHTTPClient *http.Client
|
||||
var UserContentRequestHTTPClient *http.Client
|
||||
|
||||
func Init() {
|
||||
if config.UserContentRequestProxy != "" {
|
||||
logger.SysLog(fmt.Sprintf("using %s as proxy to fetch user content", config.UserContentRequestProxy))
|
||||
proxyURL, err := url.Parse(config.UserContentRequestProxy)
|
||||
if err != nil {
|
||||
logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy))
|
||||
}
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
}
|
||||
UserContentRequestHTTPClient = &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: time.Second * time.Duration(config.UserContentRequestTimeout),
|
||||
}
|
||||
} else {
|
||||
UserContentRequestHTTPClient = &http.Client{}
|
||||
}
|
||||
var transport http.RoundTripper
|
||||
if config.RelayProxy != "" {
|
||||
logger.SysLog(fmt.Sprintf("using %s as api relay proxy", config.RelayProxy))
|
||||
proxyURL, err := url.Parse(config.RelayProxy)
|
||||
if err != nil {
|
||||
logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy))
|
||||
}
|
||||
transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
}
|
||||
}
|
||||
|
||||
if config.RelayTimeout == 0 {
|
||||
HTTPClient = &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
} else {
|
||||
HTTPClient = &http.Client{
|
||||
Timeout: time.Duration(config.RelayTimeout) * time.Second,
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
ImpatientHTTPClient = &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
@@ -117,10 +117,10 @@ var ValidThemes = map[string]bool{
|
||||
// All duration's unit is seconds
|
||||
// Shouldn't larger then RateLimitKeyExpirationDuration
|
||||
var (
|
||||
GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 180)
|
||||
GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 240)
|
||||
GlobalApiRateLimitDuration int64 = 3 * 60
|
||||
|
||||
GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||
GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 120)
|
||||
GlobalWebRateLimitDuration int64 = 3 * 60
|
||||
|
||||
UploadRateLimitNum = 10
|
||||
@@ -144,3 +144,7 @@ var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
|
||||
var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN")
|
||||
|
||||
var GeminiVersion = env.String("GEMINI_VERSION", "v1")
|
||||
|
||||
var RelayProxy = env.String("RELAY_PROXY", "")
|
||||
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
|
||||
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)
|
||||
|
@@ -3,6 +3,7 @@ package image
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"github.com/songquanpeng/one-api/common/client"
|
||||
"image"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
@@ -19,7 +20,7 @@ import (
|
||||
var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
|
||||
|
||||
func IsImageUrl(url string) (bool, error) {
|
||||
resp, err := http.Head(url)
|
||||
resp, err := client.UserContentRequestHTTPClient.Head(url)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -34,7 +35,7 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) {
|
||||
if !isImage {
|
||||
return
|
||||
}
|
||||
resp, err := http.Get(url)
|
||||
resp, err := client.UserContentRequestHTTPClient.Get(url)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@@ -43,11 +43,19 @@ func SysLog(s string) {
|
||||
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||
}
|
||||
|
||||
func SysLogf(format string, a ...any) {
|
||||
SysLog(fmt.Sprintf(format, a...))
|
||||
}
|
||||
|
||||
func SysError(s string) {
|
||||
t := time.Now()
|
||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||
}
|
||||
|
||||
func SysErrorf(format string, a ...any) {
|
||||
SysError(fmt.Sprintf(format, a...))
|
||||
}
|
||||
|
||||
func Debug(ctx context.Context, msg string) {
|
||||
if config.DebugEnabled {
|
||||
logHelper(ctx, loggerDEBUG, msg)
|
||||
|
@@ -4,12 +4,12 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/client"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/monitor"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/client"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
@@ -6,8 +6,6 @@ import (
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
@@ -111,7 +109,6 @@ func Logout(c *gin.Context) {
|
||||
}
|
||||
|
||||
func Register(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
if !config.RegisterEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "管理员关闭了新用户注册",
|
||||
@@ -176,28 +173,7 @@ func Register(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
err := user.ValidateAndFill()
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "user.ValidateAndFill failed: %w", err)
|
||||
return
|
||||
}
|
||||
cleanToken := model.Token{
|
||||
UserId: user.Id,
|
||||
Name: "default",
|
||||
Key: random.GenerateKey(),
|
||||
CreatedTime: helper.GetTimestamp(),
|
||||
AccessedTime: helper.GetTimestamp(),
|
||||
ExpiredTime: -1,
|
||||
RemainQuota: -1,
|
||||
UnlimitedQuota: true,
|
||||
}
|
||||
err = cleanToken.Insert()
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "cleanToken.Insert failed: %w", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
|
4
main.go
4
main.go
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/gin-contrib/sessions/cookie"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/client"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/controller"
|
||||
@@ -23,7 +24,7 @@ var buildFS embed.FS
|
||||
|
||||
func main() {
|
||||
logger.SetupLogger()
|
||||
logger.SysLog(fmt.Sprintf("One API %s started", common.Version))
|
||||
logger.SysLogf("One API %s started", common.Version)
|
||||
if os.Getenv("GIN_MODE") != "debug" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
@@ -94,6 +95,7 @@ func main() {
|
||||
logger.SysLog("metric enabled, will disable channel if too much request failed")
|
||||
}
|
||||
openai.InitTokenEncoders()
|
||||
client.Init()
|
||||
|
||||
// Initialize HTTP server
|
||||
server := gin.New()
|
||||
|
@@ -67,26 +67,28 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
c.Set(ctxkey.BaseURL, channel.GetBaseURL())
|
||||
cfg, _ := channel.LoadConfig()
|
||||
// this is for backward compatibility
|
||||
switch channel.Type {
|
||||
case channeltype.Azure:
|
||||
if cfg.APIVersion == "" {
|
||||
cfg.APIVersion = channel.Other
|
||||
}
|
||||
case channeltype.Xunfei:
|
||||
if cfg.APIVersion == "" {
|
||||
cfg.APIVersion = channel.Other
|
||||
}
|
||||
case channeltype.Gemini:
|
||||
if cfg.APIVersion == "" {
|
||||
cfg.APIVersion = channel.Other
|
||||
}
|
||||
case channeltype.AIProxyLibrary:
|
||||
if cfg.LibraryID == "" {
|
||||
cfg.LibraryID = channel.Other
|
||||
}
|
||||
case channeltype.Ali:
|
||||
if cfg.Plugin == "" {
|
||||
cfg.Plugin = channel.Other
|
||||
if channel.Other != nil {
|
||||
switch channel.Type {
|
||||
case channeltype.Azure:
|
||||
if cfg.APIVersion == "" {
|
||||
cfg.APIVersion = *channel.Other
|
||||
}
|
||||
case channeltype.Xunfei:
|
||||
if cfg.APIVersion == "" {
|
||||
cfg.APIVersion = *channel.Other
|
||||
}
|
||||
case channeltype.Gemini:
|
||||
if cfg.APIVersion == "" {
|
||||
cfg.APIVersion = *channel.Other
|
||||
}
|
||||
case channeltype.AIProxyLibrary:
|
||||
if cfg.LibraryID == "" {
|
||||
cfg.LibraryID = *channel.Other
|
||||
}
|
||||
case channeltype.Ali:
|
||||
if cfg.Plugin == "" {
|
||||
cfg.Plugin = *channel.Other
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Set(ctxkey.Config, cfg)
|
||||
|
@@ -27,7 +27,7 @@ type Channel struct {
|
||||
TestTime int64 `json:"test_time" gorm:"bigint"`
|
||||
ResponseTime int `json:"response_time"` // in milliseconds
|
||||
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
|
||||
Other string `json:"other"` // DEPRECATED: please save config to field Config
|
||||
Other *string `json:"other"` // DEPRECATED: please save config to field Config
|
||||
Balance float64 `json:"balance"` // in USD
|
||||
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
|
||||
Models string `json:"models"`
|
||||
|
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/blacklist"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"gorm.io/gorm"
|
||||
@@ -140,6 +141,22 @@ func (user *User) Insert(inviterId int) error {
|
||||
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
|
||||
}
|
||||
}
|
||||
// create default token
|
||||
cleanToken := Token{
|
||||
UserId: user.Id,
|
||||
Name: "default",
|
||||
Key: random.GenerateKey(),
|
||||
CreatedTime: helper.GetTimestamp(),
|
||||
AccessedTime: helper.GetTimestamp(),
|
||||
ExpiredTime: -1,
|
||||
RemainQuota: -1,
|
||||
UnlimitedQuota: true,
|
||||
}
|
||||
result.Error = cleanToken.Insert()
|
||||
if result.Error != nil {
|
||||
// do not block
|
||||
logger.SysError(fmt.Sprintf("create default token for user %d failed: %s", user.Id, result.Error.Error()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@@ -7,9 +7,9 @@ import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/client"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/client"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
|
@@ -17,15 +17,21 @@ import (
|
||||
)
|
||||
|
||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
||||
lastMessage := textRequest.Messages[len(textRequest.Messages)-1]
|
||||
return &Request{
|
||||
MaxTokens: textRequest.MaxTokens,
|
||||
Prompt: lastMessage.StringContent(),
|
||||
Stream: textRequest.Stream,
|
||||
Temperature: textRequest.Temperature,
|
||||
}
|
||||
var promptBuilder strings.Builder
|
||||
for _, message := range textRequest.Messages {
|
||||
promptBuilder.WriteString(message.StringContent())
|
||||
promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息
|
||||
}
|
||||
|
||||
return &Request{
|
||||
MaxTokens: textRequest.MaxTokens,
|
||||
Prompt: promptBuilder.String(),
|
||||
Stream: textRequest.Stream,
|
||||
Temperature: textRequest.Temperature,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
|
@@ -4,7 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/client"
|
||||
"github.com/songquanpeng/one-api/common/client"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"io"
|
||||
"net/http"
|
||||
|
13
relay/adaptor/doubao/constants.go
Normal file
13
relay/adaptor/doubao/constants.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package doubao
|
||||
|
||||
// https://console.volcengine.com/ark/region:ark+cn-beijing/model
|
||||
|
||||
var ModelList = []string{
|
||||
"Doubao-pro-128k",
|
||||
"Doubao-pro-32k",
|
||||
"Doubao-pro-4k",
|
||||
"Doubao-lite-128k",
|
||||
"Doubao-lite-32k",
|
||||
"Doubao-lite-4k",
|
||||
"Doubao-embedding",
|
||||
}
|
14
relay/adaptor/doubao/main.go
Normal file
14
relay/adaptor/doubao/main.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package doubao
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
func GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
if meta.Mode == relaymode.ChatCompletions {
|
||||
return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil
|
||||
}
|
||||
return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode)
|
||||
}
|
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
@@ -24,7 +25,14 @@ func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion)
|
||||
action := "generateContent"
|
||||
action := ""
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
action = "batchEmbedContents"
|
||||
default:
|
||||
action = "generateContent"
|
||||
}
|
||||
|
||||
if meta.IsStream {
|
||||
action = "streamGenerateContent?alt=sse"
|
||||
}
|
||||
@@ -41,7 +49,14 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return ConvertRequest(*request), nil
|
||||
switch relayMode {
|
||||
case relaymode.Embeddings:
|
||||
geminiEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
return geminiEmbeddingRequest, nil
|
||||
default:
|
||||
geminiRequest := ConvertRequest(*request)
|
||||
return geminiRequest, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
@@ -61,7 +76,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
|
||||
err, responseText = StreamHandler(c, resp)
|
||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
} else {
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
err, usage = EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@@ -4,5 +4,5 @@ package gemini
|
||||
|
||||
var ModelList = []string{
|
||||
"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
|
||||
"gemini-pro-vision", "gemini-1.0-pro-vision-001",
|
||||
"gemini-pro-vision", "gemini-1.0-pro-vision-001", "embedding-001", "text-embedding-004",
|
||||
}
|
||||
|
@@ -134,6 +134,29 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
return &geminiRequest
|
||||
}
|
||||
|
||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbeddingRequest {
|
||||
inputs := request.ParseInput()
|
||||
requests := make([]EmbeddingRequest, len(inputs))
|
||||
model := fmt.Sprintf("models/%s", request.Model)
|
||||
|
||||
for i, input := range inputs {
|
||||
requests[i] = EmbeddingRequest{
|
||||
Model: model,
|
||||
Content: ChatContent{
|
||||
Parts: []Part{
|
||||
{
|
||||
Text: input,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return &BatchEmbeddingRequest{
|
||||
Requests: requests,
|
||||
}
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Candidates []ChatCandidate `json:"candidates"`
|
||||
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
|
||||
@@ -230,6 +253,23 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC
|
||||
return &response
|
||||
}
|
||||
|
||||
func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
|
||||
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),
|
||||
Model: "gemini-embedding",
|
||||
Usage: model.Usage{TotalTokens: 0},
|
||||
}
|
||||
for _, item := range response.Embeddings {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||
Object: `embedding`,
|
||||
Index: 0,
|
||||
Embedding: item.Values,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
@@ -337,3 +377,39 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var geminiEmbeddingResponse EmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &geminiEmbeddingResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if geminiEmbeddingResponse.Error != nil {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: geminiEmbeddingResponse.Error.Message,
|
||||
Type: "gemini_error",
|
||||
Param: "",
|
||||
Code: geminiEmbeddingResponse.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := embeddingResponseGemini2OpenAI(&geminiEmbeddingResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
@@ -7,6 +7,33 @@ type ChatRequest struct {
|
||||
Tools []ChatTools `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Content ChatContent `json:"content"`
|
||||
TaskType string `json:"taskType,omitempty"`
|
||||
Title string `json:"title,omitempty"`
|
||||
OutputDimensionality int `json:"outputDimensionality,omitempty"`
|
||||
}
|
||||
|
||||
type BatchEmbeddingRequest struct {
|
||||
Requests []EmbeddingRequest `json:"requests"`
|
||||
}
|
||||
|
||||
type EmbeddingData struct {
|
||||
Values []float64 `json:"values"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embeddings []EmbeddingData `json:"embeddings"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Code int `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
type InlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
|
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/doubao"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
@@ -45,6 +46,8 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
|
||||
case channeltype.Minimax:
|
||||
return minimax.GetRequestURL(meta)
|
||||
case channeltype.Doubao:
|
||||
return doubao.GetRequestURL(meta)
|
||||
default:
|
||||
return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
|
||||
}
|
||||
|
@@ -4,6 +4,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/ai360"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/baichuan"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/deepseek"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/doubao"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/groq"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/lingyiwanwu"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
|
||||
@@ -20,6 +21,7 @@ var CompatibleChannels = []int{
|
||||
channeltype.Moonshot,
|
||||
channeltype.Baichuan,
|
||||
channeltype.Minimax,
|
||||
channeltype.Doubao,
|
||||
channeltype.Mistral,
|
||||
channeltype.Groq,
|
||||
channeltype.LingYiWanWu,
|
||||
@@ -52,6 +54,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
|
||||
return "deepseek", deepseek.ModelList
|
||||
case channeltype.TogetherAI:
|
||||
return "together.ai", togetherai.ModelList
|
||||
case channeltype.Doubao:
|
||||
return "doubao", doubao.ModelList
|
||||
default:
|
||||
return "openai", ModelList
|
||||
}
|
||||
|
@@ -27,14 +27,6 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
version := parseAPIVersionByModelName(meta.ActualModelName)
|
||||
if version == "" {
|
||||
version = a.meta.Config.APIVersion
|
||||
}
|
||||
if version == "" {
|
||||
version = "v1.1"
|
||||
}
|
||||
a.meta.Config.APIVersion = version
|
||||
// check DoResponse for auth part
|
||||
return nil
|
||||
}
|
||||
@@ -69,6 +61,14 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
|
||||
if a.request == nil {
|
||||
return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
|
||||
}
|
||||
version := parseAPIVersionByModelName(meta.ActualModelName)
|
||||
if version == "" {
|
||||
version = a.meta.Config.APIVersion
|
||||
}
|
||||
if version == "" {
|
||||
version = "v1.1"
|
||||
}
|
||||
a.meta.Config.APIVersion = version
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, meta, *a.request, splits[0], splits[1], splits[2])
|
||||
} else {
|
||||
|
@@ -5,7 +5,14 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
@@ -16,11 +23,6 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// https://console.xfyun.cn/services/cbm
|
||||
@@ -28,11 +30,7 @@ import (
|
||||
|
||||
func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
var lastToolCalls []model.Tool
|
||||
for _, message := range request.Messages {
|
||||
if message.ToolCalls != nil {
|
||||
lastToolCalls = message.ToolCalls
|
||||
}
|
||||
messages = append(messages, Message{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
@@ -45,9 +43,10 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string
|
||||
xunfeiRequest.Parameter.Chat.TopK = request.N
|
||||
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
|
||||
xunfeiRequest.Payload.Message.Text = messages
|
||||
if len(lastToolCalls) != 0 {
|
||||
for _, toolCall := range lastToolCalls {
|
||||
xunfeiRequest.Payload.Functions.Text = append(xunfeiRequest.Payload.Functions.Text, toolCall.Function)
|
||||
|
||||
if strings.HasPrefix(domain, "generalv3") {
|
||||
xunfeiRequest.Payload.Functions = &Functions{
|
||||
Text: request.Tools,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,7 +202,7 @@ func Handler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIReq
|
||||
}
|
||||
}
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
return openai.ErrorWrapper(err, "xunfei_empty_response_detected", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(errors.New("xunfei empty response detected"), "xunfei_empty_response_detected", http.StatusInternalServerError), nil
|
||||
}
|
||||
xunfeiResponse.Payload.Choices.Text[0].Content = content
|
||||
|
||||
|
@@ -9,6 +9,10 @@ type Message struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type Functions struct {
|
||||
Text []model.Tool `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Header struct {
|
||||
AppId string `json:"app_id"`
|
||||
@@ -26,9 +30,7 @@ type ChatRequest struct {
|
||||
Message struct {
|
||||
Text []Message `json:"text"`
|
||||
} `json:"message"`
|
||||
Functions struct {
|
||||
Text []model.Function `json:"text,omitempty"`
|
||||
} `json:"functions,omitempty"`
|
||||
Functions *Functions `json:"functions,omitempty"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
|
||||
|
@@ -41,6 +41,6 @@ const (
|
||||
Cloudflare
|
||||
DeepL
|
||||
TogetherAI
|
||||
|
||||
Doubao
|
||||
Dummy
|
||||
)
|
||||
|
@@ -41,6 +41,7 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.cloudflare.com", // 37
|
||||
"https://api-free.deepl.com", // 38
|
||||
"https://api.together.xyz", // 39
|
||||
"https://ark.cn-beijing.volces.com", // 40
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@@ -1,24 +0,0 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
var HTTPClient *http.Client
|
||||
var ImpatientHTTPClient *http.Client
|
||||
|
||||
func init() {
|
||||
if config.RelayTimeout == 0 {
|
||||
HTTPClient = &http.Client{}
|
||||
} else {
|
||||
HTTPClient = &http.Client{
|
||||
Timeout: time.Duration(config.RelayTimeout) * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
ImpatientHTTPClient = &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
}
|
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/client"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
@@ -17,7 +18,6 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/client"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
|
@@ -47,6 +47,12 @@ export const CHANNEL_OPTIONS = {
|
||||
value: 28,
|
||||
color: 'warning'
|
||||
},
|
||||
40: {
|
||||
key: 40,
|
||||
text: '字节跳动豆包',
|
||||
value: 40,
|
||||
color: 'primary'
|
||||
},
|
||||
15: {
|
||||
key: 15,
|
||||
text: '百度文心千帆',
|
||||
|
@@ -163,7 +163,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
|
||||
values.other = 'v2.1';
|
||||
}
|
||||
if (values.key === '') {
|
||||
if (values.config.ak !== '' && values.config.sk !== '' && values.config.region !== '') {
|
||||
if (values.config.ak && values.config.sk && values.config.region) {
|
||||
values.key = `${values.config.ak}|${values.config.sk}|${values.config.region}`;
|
||||
}
|
||||
}
|
||||
|
@@ -6,6 +6,7 @@ export const CHANNEL_OPTIONS = [
|
||||
{key: 11, text: 'Google PaLM2', value: 11, color: 'orange'},
|
||||
{key: 24, text: 'Google Gemini', value: 24, color: 'orange'},
|
||||
{key: 28, text: 'Mistral AI', value: 28, color: 'orange'},
|
||||
{key: 40, text: '字节跳动豆包', value: 40, color: 'blue'},
|
||||
{key: 15, text: '百度文心千帆', value: 15, color: 'blue'},
|
||||
{key: 17, text: '阿里通义千问', value: 17, color: 'orange'},
|
||||
{key: 18, text: '讯飞星火认知', value: 18, color: 'blue'},
|
||||
|
@@ -181,9 +181,6 @@ const EditChannel = () => {
|
||||
if (localInputs.type === 3 && localInputs.other === '') {
|
||||
localInputs.other = '2024-03-01-preview';
|
||||
}
|
||||
if (localInputs.type === 18 && localInputs.other === '') {
|
||||
localInputs.other = 'v2.1';
|
||||
}
|
||||
let res;
|
||||
localInputs.models = localInputs.models.join(',');
|
||||
localInputs.group = localInputs.groups.join(',');
|
||||
@@ -362,6 +359,13 @@ const EditChannel = () => {
|
||||
</Message>
|
||||
)
|
||||
}
|
||||
{
|
||||
inputs.type === 40 && (
|
||||
<Message>
|
||||
对于豆包而言,需要手动去 <a target="_blank" href="https://console.volcengine.com/ark/region:ark+cn-beijing/endpoint">模型推理页面</a> 创建推理接入点,以接入点名称作为模型名称,例如:`ep-20240608051426-tkxvl`。
|
||||
</Message>
|
||||
)
|
||||
}
|
||||
<Form.Field>
|
||||
<Form.Dropdown
|
||||
label='模型'
|
||||
|
Reference in New Issue
Block a user