mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	Compare commits
	
		
			33 Commits
		
	
	
		
			v0.6.9-alp
			...
			v0.6.10-al
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					49ffb1c60d | ||
| 
						 | 
					2f16649896 | ||
| 
						 | 
					af3aa57bd6 | ||
| 
						 | 
					e9f117ff72 | ||
| 
						 | 
					6bb5247bd6 | ||
| 
						 | 
					305ce14fe3 | ||
| 
						 | 
					36c8f4f15c | ||
| 
						 | 
					45b51ea0ee | ||
| 
						 | 
					7c8628bd95 | ||
| 
						 | 
					6ab87f8a08 | ||
| 
						 | 
					833fa7ad6f | ||
| 
						 | 
					6eb0770a89 | ||
| 
						 | 
					92cd46d64f | ||
| 
						 | 
					2b2dc2c733 | ||
| 
						 | 
					a3d7df7f89 | ||
| 
						 | 
					c368232f50 | ||
| 
						 | 
					cbfc983dc3 | ||
| 
						 | 
					8ec092ba44 | ||
| 
						 | 
					b0b88a79ff | ||
| 
						 | 
					7e51b04221 | ||
| 
						 | 
					f75a17f8eb | ||
| 
						 | 
					6f13a3bb3c | ||
| 
						 | 
					f092eed1db | ||
| 
						 | 
					629378691b | ||
| 
						 | 
					3716e1b0e6 | ||
| 
						 | 
					a4d6e7a886 | ||
| 
						 | 
					cb772e5d06 | ||
| 
						 | 
					e32cb0b844 | ||
| 
						 | 
					fdd7bf41c0 | ||
| 
						 | 
					29389ed44f | ||
| 
						 | 
					88acc5a614 | ||
| 
						 | 
					a21681096a | ||
| 
						 | 
					32f90a79a8 | 
							
								
								
									
										8
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							@@ -1,13 +1,13 @@
 | 
			
		||||
name: CI
 | 
			
		||||
 | 
			
		||||
# This setup assumes that you run the unit tests with code coverage in the same
 | 
			
		||||
# workflow that will also print the coverage report as comment to the pull request. 
 | 
			
		||||
# workflow that will also print the coverage report as comment to the pull request.
 | 
			
		||||
# Therefore, you need to trigger this workflow when a pull request is (re)opened or
 | 
			
		||||
# when new code is pushed to the branch of the pull request. In addition, you also
 | 
			
		||||
# need to trigger this workflow when new code is pushed to the main branch because 
 | 
			
		||||
# need to trigger this workflow when new code is pushed to the main branch because
 | 
			
		||||
# we need to upload the code coverage results as artifact for the main branch as
 | 
			
		||||
# well since it will be the baseline code coverage.
 | 
			
		||||
# 
 | 
			
		||||
#
 | 
			
		||||
# We do not want to trigger the workflow for pushes to *any* branch because this
 | 
			
		||||
# would trigger our jobs twice on pull requests (once from "push" event and once
 | 
			
		||||
# from "pull_request->synchronize")
 | 
			
		||||
@@ -31,7 +31,7 @@ jobs:
 | 
			
		||||
        with:
 | 
			
		||||
          go-version: ^1.22
 | 
			
		||||
 | 
			
		||||
      # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a 
 | 
			
		||||
      # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a
 | 
			
		||||
      # coverage profile to a file. You will need the name of the file (e.g. "coverage.txt")
 | 
			
		||||
      # in the next step as well as the next job.
 | 
			
		||||
      - name: Test
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -9,4 +9,5 @@ logs
 | 
			
		||||
data
 | 
			
		||||
/web/node_modules
 | 
			
		||||
cmd.md
 | 
			
		||||
.env
 | 
			
		||||
.env
 | 
			
		||||
/one-api
 | 
			
		||||
 
 | 
			
		||||
@@ -90,6 +90,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
   + [x] [together.ai](https://www.together.ai/)
 | 
			
		||||
   + [x] [novita.ai](https://www.novita.ai/)
 | 
			
		||||
   + [x] [硅基流动 SiliconCloud](https://siliconflow.cn/siliconcloud)
 | 
			
		||||
   + [x] [xAI](https://x.ai/)
 | 
			
		||||
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
 | 
			
		||||
3. 支持通过**负载均衡**的方式访问多个渠道。
 | 
			
		||||
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
 | 
			
		||||
@@ -114,8 +115,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
21. 支持 Cloudflare Turnstile 用户校验。
 | 
			
		||||
22. 支持用户管理,支持**多种用户登录注册方式**:
 | 
			
		||||
    + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
 | 
			
		||||
    + 支持使用飞书进行授权登录。
 | 
			
		||||
    + [GitHub 开放授权](https://github.com/settings/applications/new)。
 | 
			
		||||
    + 支持[飞书授权登录](https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/authen-v1/authorize/get)([这里有 One API 的实现细节阐述供参考](https://iamazing.cn/page/feishu-oauth-login))。
 | 
			
		||||
    + 支持 [GitHub 授权登录](https://github.com/settings/applications/new)。
 | 
			
		||||
    + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
 | 
			
		||||
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
 | 
			
		||||
24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。
 | 
			
		||||
@@ -399,6 +400,7 @@ graph LR
 | 
			
		||||
26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
 | 
			
		||||
27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
 | 
			
		||||
28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。
 | 
			
		||||
29. `ENFORCE_INCLUDE_USAGE`:是否强制在 stream 模型下返回 usage,默认不开启,可选值为 `true` 和 `false`。
 | 
			
		||||
 | 
			
		||||
### 命令行参数
 | 
			
		||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
 | 
			
		||||
 
 | 
			
		||||
@@ -160,3 +160,5 @@ var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)
 | 
			
		||||
var RelayProxy = env.String("RELAY_PROXY", "")
 | 
			
		||||
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
 | 
			
		||||
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)
 | 
			
		||||
 | 
			
		||||
var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false)
 | 
			
		||||
 
 | 
			
		||||
@@ -20,4 +20,5 @@ const (
 | 
			
		||||
	BaseURL           = "base_url"
 | 
			
		||||
	AvailableModels   = "available_models"
 | 
			
		||||
	KeyRequestBody    = "key_request_body"
 | 
			
		||||
	SystemPrompt      = "system_prompt"
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -31,15 +31,15 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
 | 
			
		||||
	contentType := c.Request.Header.Get("Content-Type")
 | 
			
		||||
	if strings.HasPrefix(contentType, "application/json") {
 | 
			
		||||
		err = json.Unmarshal(requestBody, &v)
 | 
			
		||||
		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
	} else {
 | 
			
		||||
		// skip for now
 | 
			
		||||
		// TODO: someday non json request have variant model, we will need to implementation this
 | 
			
		||||
		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
		err = c.ShouldBind(&v)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	// Reset request body
 | 
			
		||||
	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -137,3 +137,23 @@ func String2Int(str string) int {
 | 
			
		||||
	}
 | 
			
		||||
	return num
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Float64PtrMax(p *float64, maxValue float64) *float64 {
 | 
			
		||||
	if p == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	if *p > maxValue {
 | 
			
		||||
		return &maxValue
 | 
			
		||||
	}
 | 
			
		||||
	return p
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Float64PtrMin(p *float64, minValue float64) *float64 {
 | 
			
		||||
	if p == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	if *p < minValue {
 | 
			
		||||
		return &minValue
 | 
			
		||||
	}
 | 
			
		||||
	return p
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,9 +3,10 @@ package render
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func StringData(c *gin.Context, str string) {
 | 
			
		||||
 
 | 
			
		||||
@@ -40,7 +40,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData))
 | 
			
		||||
	req, err := http.NewRequest("POST", "https://open.feishu.cn/open-apis/authen/v2/oauth/token", bytes.NewBuffer(jsonData))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -81,6 +81,26 @@ type APGC2DGPTUsageResponse struct {
 | 
			
		||||
	TotalUsed      float64 `json:"total_used"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SiliconFlowUsageResponse struct {
 | 
			
		||||
	Code    int    `json:"code"`
 | 
			
		||||
	Message string `json:"message"`
 | 
			
		||||
	Status  bool   `json:"status"`
 | 
			
		||||
	Data    struct {
 | 
			
		||||
		ID            string `json:"id"`
 | 
			
		||||
		Name          string `json:"name"`
 | 
			
		||||
		Image         string `json:"image"`
 | 
			
		||||
		Email         string `json:"email"`
 | 
			
		||||
		IsAdmin       bool   `json:"isAdmin"`
 | 
			
		||||
		Balance       string `json:"balance"`
 | 
			
		||||
		Status        string `json:"status"`
 | 
			
		||||
		Introduction  string `json:"introduction"`
 | 
			
		||||
		Role          string `json:"role"`
 | 
			
		||||
		ChargeBalance string `json:"chargeBalance"`
 | 
			
		||||
		TotalBalance  string `json:"totalBalance"`
 | 
			
		||||
		Category      string `json:"category"`
 | 
			
		||||
	} `json:"data"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetAuthHeader get auth header
 | 
			
		||||
func GetAuthHeader(token string) http.Header {
 | 
			
		||||
	h := http.Header{}
 | 
			
		||||
@@ -203,6 +223,28 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	return response.TotalAvailable, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	url := "https://api.siliconflow.cn/v1/user/info"
 | 
			
		||||
	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	response := SiliconFlowUsageResponse{}
 | 
			
		||||
	err = json.Unmarshal(body, &response)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	if response.Code != 20000 {
 | 
			
		||||
		return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
 | 
			
		||||
	}
 | 
			
		||||
	balance, err := strconv.ParseFloat(response.Data.Balance, 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	channel.UpdateBalance(balance)
 | 
			
		||||
	return balance, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	baseURL := channeltype.ChannelBaseURLs[channel.Type]
 | 
			
		||||
	if channel.GetBaseURL() == "" {
 | 
			
		||||
@@ -227,6 +269,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
		return updateChannelAPI2GPTBalance(channel)
 | 
			
		||||
	case channeltype.AIGC2D:
 | 
			
		||||
		return updateChannelAIGC2DBalance(channel)
 | 
			
		||||
	case channeltype.SiliconFlow:
 | 
			
		||||
		return updateChannelSiliconFlowBalance(channel)
 | 
			
		||||
	default:
 | 
			
		||||
		return 0, errors.New("尚未实现")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -76,9 +76,9 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques
 | 
			
		||||
		if len(modelNames) > 0 {
 | 
			
		||||
			modelName = modelNames[0]
 | 
			
		||||
		}
 | 
			
		||||
		if modelMap != nil && modelMap[modelName] != "" {
 | 
			
		||||
			modelName = modelMap[modelName]
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if modelMap != nil && modelMap[modelName] != "" {
 | 
			
		||||
		modelName = modelMap[modelName]
 | 
			
		||||
	}
 | 
			
		||||
	meta.OriginModelName, meta.ActualModelName = request.Model, modelName
 | 
			
		||||
	request.Model = modelName
 | 
			
		||||
 
 | 
			
		||||
@@ -60,7 +60,7 @@ func Relay(c *gin.Context) {
 | 
			
		||||
	channelName := c.GetString(ctxkey.ChannelName)
 | 
			
		||||
	group := c.GetString(ctxkey.Group)
 | 
			
		||||
	originalModel := c.GetString(ctxkey.OriginalModel)
 | 
			
		||||
	go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
 | 
			
		||||
	go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)
 | 
			
		||||
	requestId := c.GetString(helper.RequestIdKey)
 | 
			
		||||
	retryTimes := config.RetryTimes
 | 
			
		||||
	if !shouldRetry(c, bizErr.StatusCode) {
 | 
			
		||||
@@ -87,8 +87,7 @@ func Relay(c *gin.Context) {
 | 
			
		||||
		channelId := c.GetInt(ctxkey.ChannelId)
 | 
			
		||||
		lastFailedChannelId = channelId
 | 
			
		||||
		channelName := c.GetString(ctxkey.ChannelName)
 | 
			
		||||
		// BUG: bizErr is in race condition
 | 
			
		||||
		go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
 | 
			
		||||
		go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)
 | 
			
		||||
	}
 | 
			
		||||
	if bizErr != nil {
 | 
			
		||||
		if bizErr.StatusCode == http.StatusTooManyRequests {
 | 
			
		||||
@@ -122,7 +121,7 @@ func shouldRetry(c *gin.Context, statusCode int) bool {
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) {
 | 
			
		||||
func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err model.ErrorWithStatusCode) {
 | 
			
		||||
	logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
 | 
			
		||||
	// https://platform.openai.com/docs/guides/error-codes/api-errors
 | 
			
		||||
	if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										8
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								go.mod
									
									
									
									
									
								
							@@ -25,7 +25,7 @@ require (
 | 
			
		||||
	github.com/pkoukk/tiktoken-go v0.1.7
 | 
			
		||||
	github.com/smartystreets/goconvey v1.8.1
 | 
			
		||||
	github.com/stretchr/testify v1.9.0
 | 
			
		||||
	golang.org/x/crypto v0.24.0
 | 
			
		||||
	golang.org/x/crypto v0.31.0
 | 
			
		||||
	golang.org/x/image v0.18.0
 | 
			
		||||
	google.golang.org/api v0.187.0
 | 
			
		||||
	gorm.io/driver/mysql v1.5.6
 | 
			
		||||
@@ -99,9 +99,9 @@ require (
 | 
			
		||||
	golang.org/x/arch v0.8.0 // indirect
 | 
			
		||||
	golang.org/x/net v0.26.0 // indirect
 | 
			
		||||
	golang.org/x/oauth2 v0.21.0 // indirect
 | 
			
		||||
	golang.org/x/sync v0.7.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.21.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.16.0 // indirect
 | 
			
		||||
	golang.org/x/sync v0.10.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.28.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.21.0 // indirect
 | 
			
		||||
	golang.org/x/time v0.5.0 // indirect
 | 
			
		||||
	google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect
 | 
			
		||||
	google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										16
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								go.sum
									
									
									
									
									
								
							@@ -222,8 +222,8 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
 | 
			
		||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
 | 
			
		||||
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
 | 
			
		||||
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
 | 
			
		||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
 | 
			
		||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
 | 
			
		||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
 | 
			
		||||
golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
 | 
			
		||||
golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
 | 
			
		||||
@@ -244,20 +244,20 @@ golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht
 | 
			
		||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 | 
			
		||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 | 
			
		||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 | 
			
		||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
 | 
			
		||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
 | 
			
		||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
 | 
			
		||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
 | 
			
		||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 | 
			
		||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 | 
			
		||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 | 
			
		||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 | 
			
		||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
 | 
			
		||||
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 | 
			
		||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
 | 
			
		||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 | 
			
		||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 | 
			
		||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 | 
			
		||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
 | 
			
		||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
 | 
			
		||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
 | 
			
		||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
 | 
			
		||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
 | 
			
		||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
 | 
			
		||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 | 
			
		||||
 
 | 
			
		||||
@@ -12,7 +12,7 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ModelRequest struct {
 | 
			
		||||
	Model string `json:"model"`
 | 
			
		||||
	Model string `json:"model" form:"model"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Distribute() func(c *gin.Context) {
 | 
			
		||||
@@ -61,6 +61,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
 | 
			
		||||
	c.Set(ctxkey.Channel, channel.Type)
 | 
			
		||||
	c.Set(ctxkey.ChannelId, channel.Id)
 | 
			
		||||
	c.Set(ctxkey.ChannelName, channel.Name)
 | 
			
		||||
	if channel.SystemPrompt != nil && *channel.SystemPrompt != "" {
 | 
			
		||||
		c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt)
 | 
			
		||||
	}
 | 
			
		||||
	c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
 | 
			
		||||
	c.Set(ctxkey.OriginalModel, modelName) // for retry
 | 
			
		||||
	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										27
									
								
								middleware/gzip.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								middleware/gzip.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
			
		||||
package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"compress/gzip"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GzipDecodeMiddleware() gin.HandlerFunc {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		if c.GetHeader("Content-Encoding") == "gzip" {
 | 
			
		||||
			gzipReader, err := gzip.NewReader(c.Request.Body)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.AbortWithStatus(http.StatusBadRequest)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			defer gzipReader.Close()
 | 
			
		||||
 | 
			
		||||
			// Replace the request body with the decompressed data
 | 
			
		||||
			c.Request.Body = io.NopCloser(gzipReader)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Continue processing the request
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -37,6 +37,7 @@ type Channel struct {
 | 
			
		||||
	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
 | 
			
		||||
	Priority           *int64  `json:"priority" gorm:"bigint;default:0"`
 | 
			
		||||
	Config             string  `json:"config"`
 | 
			
		||||
	SystemPrompt       *string `json:"system_prompt" gorm:"type:text"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChannelConfig struct {
 | 
			
		||||
 
 | 
			
		||||
@@ -30,7 +30,7 @@ type Token struct {
 | 
			
		||||
	RemainQuota    int64   `json:"remain_quota" gorm:"bigint;default:0"`
 | 
			
		||||
	UnlimitedQuota bool    `json:"unlimited_quota" gorm:"default:false"`
 | 
			
		||||
	UsedQuota      int64   `json:"used_quota" gorm:"bigint;default:0"` // used quota
 | 
			
		||||
	Models         *string `json:"models" gorm:"default:''"`           // allowed models
 | 
			
		||||
	Models         *string `json:"models" gorm:"type:text"`            // allowed models
 | 
			
		||||
	Subnet         *string `json:"subnet" gorm:"default:''"`           // allowed subnet
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -121,30 +121,40 @@ func GetTokenById(id int) (*Token, error) {
 | 
			
		||||
	return &token, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (token *Token) Insert() error {
 | 
			
		||||
func (t *Token) Insert() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Create(token).Error
 | 
			
		||||
	err = DB.Create(t).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Update Make sure your token's fields is completed, because this will update non-zero values
 | 
			
		||||
func (token *Token) Update() error {
 | 
			
		||||
func (t *Token) Update() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error
 | 
			
		||||
	err = DB.Model(t).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(t).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (token *Token) SelectUpdate() error {
 | 
			
		||||
func (t *Token) SelectUpdate() error {
 | 
			
		||||
	// This can update zero values
 | 
			
		||||
	return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
 | 
			
		||||
	return DB.Model(t).Select("accessed_time", "status").Updates(t).Error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (token *Token) Delete() error {
 | 
			
		||||
func (t *Token) Delete() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Delete(token).Error
 | 
			
		||||
	err = DB.Delete(t).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *Token) GetModels() string {
 | 
			
		||||
	if t == nil {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	if t.Models == nil {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	return *t.Models
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DeleteTokenById(id int, userId int) (err error) {
 | 
			
		||||
	// Why we need userId here? In case user want to delete other's token.
 | 
			
		||||
	if id == 0 || userId == 0 {
 | 
			
		||||
 
 | 
			
		||||
@@ -34,7 +34,7 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool {
 | 
			
		||||
		strings.Contains(lowerMessage, "credit") ||
 | 
			
		||||
		strings.Contains(lowerMessage, "balance") ||
 | 
			
		||||
		strings.Contains(lowerMessage, "permission denied") ||
 | 
			
		||||
  	strings.Contains(lowerMessage, "organization has been restricted") || // groq
 | 
			
		||||
		strings.Contains(lowerMessage, "organization has been restricted") || // groq
 | 
			
		||||
		strings.Contains(lowerMessage, "已欠费") {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -16,6 +16,7 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/palm"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/proxy"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/replicate"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/tencent"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
 | 
			
		||||
@@ -61,6 +62,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
 | 
			
		||||
		return &vertexai.Adaptor{}
 | 
			
		||||
	case apitype.Proxy:
 | 
			
		||||
		return &proxy.Adaptor{}
 | 
			
		||||
	case apitype.Replicate:
 | 
			
		||||
		return &replicate.Adaptor{}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,23 @@
 | 
			
		||||
package ali
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
 | 
			
		||||
	"text-embedding-v1",
 | 
			
		||||
	"qwen-turbo", "qwen-turbo-latest",
 | 
			
		||||
	"qwen-plus", "qwen-plus-latest",
 | 
			
		||||
	"qwen-max", "qwen-max-latest",
 | 
			
		||||
	"qwen-max-longcontext",
 | 
			
		||||
	"qwen-vl-max", "qwen-vl-max-latest", "qwen-vl-plus", "qwen-vl-plus-latest",
 | 
			
		||||
	"qwen-vl-ocr", "qwen-vl-ocr-latest",
 | 
			
		||||
	"qwen-audio-turbo",
 | 
			
		||||
	"qwen-math-plus", "qwen-math-plus-latest", "qwen-math-turbo", "qwen-math-turbo-latest",
 | 
			
		||||
	"qwen-coder-plus", "qwen-coder-plus-latest", "qwen-coder-turbo", "qwen-coder-turbo-latest",
 | 
			
		||||
	"qwq-32b-preview", "qwen2.5-72b-instruct", "qwen2.5-32b-instruct", "qwen2.5-14b-instruct", "qwen2.5-7b-instruct", "qwen2.5-3b-instruct", "qwen2.5-1.5b-instruct", "qwen2.5-0.5b-instruct",
 | 
			
		||||
	"qwen2-72b-instruct", "qwen2-57b-a14b-instruct", "qwen2-7b-instruct", "qwen2-1.5b-instruct", "qwen2-0.5b-instruct",
 | 
			
		||||
	"qwen1.5-110b-chat", "qwen1.5-72b-chat", "qwen1.5-32b-chat", "qwen1.5-14b-chat", "qwen1.5-7b-chat", "qwen1.5-1.8b-chat", "qwen1.5-0.5b-chat",
 | 
			
		||||
	"qwen-72b-chat", "qwen-14b-chat", "qwen-7b-chat", "qwen-1.8b-chat", "qwen-1.8b-longcontext-chat",
 | 
			
		||||
	"qwen2-vl-7b-instruct", "qwen2-vl-2b-instruct", "qwen-vl-v1", "qwen-vl-chat-v1",
 | 
			
		||||
	"qwen2-audio-instruct", "qwen-audio-chat",
 | 
			
		||||
	"qwen2.5-math-72b-instruct", "qwen2.5-math-7b-instruct", "qwen2.5-math-1.5b-instruct", "qwen2-math-72b-instruct", "qwen2-math-7b-instruct", "qwen2-math-1.5b-instruct",
 | 
			
		||||
	"qwen2.5-coder-32b-instruct", "qwen2.5-coder-14b-instruct", "qwen2.5-coder-7b-instruct", "qwen2.5-coder-3b-instruct", "qwen2.5-coder-1.5b-instruct", "qwen2.5-coder-0.5b-instruct",
 | 
			
		||||
	"text-embedding-v1", "text-embedding-v3", "text-embedding-v2", "text-embedding-async-v2", "text-embedding-async-v1",
 | 
			
		||||
	"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -36,9 +36,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
		enableSearch = true
 | 
			
		||||
		aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
 | 
			
		||||
	}
 | 
			
		||||
	if request.TopP >= 1 {
 | 
			
		||||
		request.TopP = 0.9999
 | 
			
		||||
	}
 | 
			
		||||
	request.TopP = helper.Float64PtrMax(request.TopP, 0.9999)
 | 
			
		||||
	return &ChatRequest{
 | 
			
		||||
		Model: aliModel,
 | 
			
		||||
		Input: Input{
 | 
			
		||||
 
 | 
			
		||||
@@ -16,13 +16,13 @@ type Input struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Parameters struct {
 | 
			
		||||
	TopP              float64      `json:"top_p,omitempty"`
 | 
			
		||||
	TopP              *float64     `json:"top_p,omitempty"`
 | 
			
		||||
	TopK              int          `json:"top_k,omitempty"`
 | 
			
		||||
	Seed              uint64       `json:"seed,omitempty"`
 | 
			
		||||
	EnableSearch      bool         `json:"enable_search,omitempty"`
 | 
			
		||||
	IncrementalOutput bool         `json:"incremental_output,omitempty"`
 | 
			
		||||
	MaxTokens         int          `json:"max_tokens,omitempty"`
 | 
			
		||||
	Temperature       float64      `json:"temperature,omitempty"`
 | 
			
		||||
	Temperature       *float64     `json:"temperature,omitempty"`
 | 
			
		||||
	ResultFormat      string       `json:"result_format,omitempty"`
 | 
			
		||||
	Tools             []model.Tool `json:"tools,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,7 +3,11 @@ package anthropic
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"claude-instant-1.2", "claude-2.0", "claude-2.1",
 | 
			
		||||
	"claude-3-haiku-20240307",
 | 
			
		||||
	"claude-3-5-haiku-20241022",
 | 
			
		||||
	"claude-3-sonnet-20240229",
 | 
			
		||||
	"claude-3-opus-20240229",
 | 
			
		||||
	"claude-3-5-sonnet-20240620",
 | 
			
		||||
	"claude-3-5-sonnet-20241022",
 | 
			
		||||
	"claude-3-5-sonnet-latest",
 | 
			
		||||
	"claude-3-5-haiku-20241022",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -48,8 +48,8 @@ type Request struct {
 | 
			
		||||
	MaxTokens     int       `json:"max_tokens,omitempty"`
 | 
			
		||||
	StopSequences []string  `json:"stop_sequences,omitempty"`
 | 
			
		||||
	Stream        bool      `json:"stream,omitempty"`
 | 
			
		||||
	Temperature   float64   `json:"temperature,omitempty"`
 | 
			
		||||
	TopP          float64   `json:"top_p,omitempty"`
 | 
			
		||||
	Temperature   *float64  `json:"temperature,omitempty"`
 | 
			
		||||
	TopP          *float64  `json:"top_p,omitempty"`
 | 
			
		||||
	TopK          int       `json:"top_k,omitempty"`
 | 
			
		||||
	Tools         []Tool    `json:"tools,omitempty"`
 | 
			
		||||
	ToolChoice    any       `json:"tool_choice,omitempty"`
 | 
			
		||||
 
 | 
			
		||||
@@ -29,10 +29,13 @@ var AwsModelIDMap = map[string]string{
 | 
			
		||||
	"claude-instant-1.2":         "anthropic.claude-instant-v1",
 | 
			
		||||
	"claude-2.0":                 "anthropic.claude-v2",
 | 
			
		||||
	"claude-2.1":                 "anthropic.claude-v2:1",
 | 
			
		||||
	"claude-3-sonnet-20240229":   "anthropic.claude-3-sonnet-20240229-v1:0",
 | 
			
		||||
	"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
 | 
			
		||||
	"claude-3-opus-20240229":     "anthropic.claude-3-opus-20240229-v1:0",
 | 
			
		||||
	"claude-3-haiku-20240307":    "anthropic.claude-3-haiku-20240307-v1:0",
 | 
			
		||||
	"claude-3-sonnet-20240229":   "anthropic.claude-3-sonnet-20240229-v1:0",
 | 
			
		||||
	"claude-3-opus-20240229":     "anthropic.claude-3-opus-20240229-v1:0",
 | 
			
		||||
	"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
 | 
			
		||||
	"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
 | 
			
		||||
	"claude-3-5-sonnet-latest":   "anthropic.claude-3-5-sonnet-20241022-v2:0",
 | 
			
		||||
	"claude-3-5-haiku-20241022":  "anthropic.claude-3-5-haiku-20241022-v1:0",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func awsModelID(requestModel string) (string, error) {
 | 
			
		||||
 
 | 
			
		||||
@@ -11,8 +11,8 @@ type Request struct {
 | 
			
		||||
	Messages         []anthropic.Message `json:"messages"`
 | 
			
		||||
	System           string              `json:"system,omitempty"`
 | 
			
		||||
	MaxTokens        int                 `json:"max_tokens,omitempty"`
 | 
			
		||||
	Temperature      float64             `json:"temperature,omitempty"`
 | 
			
		||||
	TopP             float64             `json:"top_p,omitempty"`
 | 
			
		||||
	Temperature      *float64            `json:"temperature,omitempty"`
 | 
			
		||||
	TopP             *float64            `json:"top_p,omitempty"`
 | 
			
		||||
	TopK             int                 `json:"top_k,omitempty"`
 | 
			
		||||
	StopSequences    []string            `json:"stop_sequences,omitempty"`
 | 
			
		||||
	Tools            []anthropic.Tool    `json:"tools,omitempty"`
 | 
			
		||||
 
 | 
			
		||||
@@ -4,10 +4,10 @@ package aws
 | 
			
		||||
//
 | 
			
		||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
 | 
			
		||||
type Request struct {
 | 
			
		||||
	Prompt      string  `json:"prompt"`
 | 
			
		||||
	MaxGenLen   int     `json:"max_gen_len,omitempty"`
 | 
			
		||||
	Temperature float64 `json:"temperature,omitempty"`
 | 
			
		||||
	TopP        float64 `json:"top_p,omitempty"`
 | 
			
		||||
	Prompt      string   `json:"prompt"`
 | 
			
		||||
	MaxGenLen   int      `json:"max_gen_len,omitempty"`
 | 
			
		||||
	Temperature *float64 `json:"temperature,omitempty"`
 | 
			
		||||
	TopP        *float64 `json:"top_p,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Response is the response from AWS Llama3
 | 
			
		||||
 
 | 
			
		||||
@@ -35,9 +35,9 @@ type Message struct {
 | 
			
		||||
 | 
			
		||||
type ChatRequest struct {
 | 
			
		||||
	Messages        []Message `json:"messages"`
 | 
			
		||||
	Temperature     float64   `json:"temperature,omitempty"`
 | 
			
		||||
	TopP            float64   `json:"top_p,omitempty"`
 | 
			
		||||
	PenaltyScore    float64   `json:"penalty_score,omitempty"`
 | 
			
		||||
	Temperature     *float64  `json:"temperature,omitempty"`
 | 
			
		||||
	TopP            *float64  `json:"top_p,omitempty"`
 | 
			
		||||
	PenaltyScore    *float64  `json:"penalty_score,omitempty"`
 | 
			
		||||
	Stream          bool      `json:"stream,omitempty"`
 | 
			
		||||
	System          string    `json:"system,omitempty"`
 | 
			
		||||
	DisableSearch   bool      `json:"disable_search,omitempty"`
 | 
			
		||||
 
 | 
			
		||||
@@ -9,5 +9,5 @@ type Request struct {
 | 
			
		||||
	Prompt      string          `json:"prompt,omitempty"`
 | 
			
		||||
	Raw         bool            `json:"raw,omitempty"`
 | 
			
		||||
	Stream      bool            `json:"stream,omitempty"`
 | 
			
		||||
	Temperature float64         `json:"temperature,omitempty"`
 | 
			
		||||
	Temperature *float64        `json:"temperature,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -43,7 +43,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
 | 
			
		||||
		K:                textRequest.TopK,
 | 
			
		||||
		Stream:           textRequest.Stream,
 | 
			
		||||
		FrequencyPenalty: textRequest.FrequencyPenalty,
 | 
			
		||||
		PresencePenalty:  textRequest.FrequencyPenalty,
 | 
			
		||||
		PresencePenalty:  textRequest.PresencePenalty,
 | 
			
		||||
		Seed:             int(textRequest.Seed),
 | 
			
		||||
	}
 | 
			
		||||
	if cohereRequest.Model == "" {
 | 
			
		||||
 
 | 
			
		||||
@@ -10,15 +10,15 @@ type Request struct {
 | 
			
		||||
	PromptTruncation string        `json:"prompt_truncation,omitempty"` // 默认值为"AUTO"
 | 
			
		||||
	Connectors       []Connector   `json:"connectors,omitempty"`
 | 
			
		||||
	Documents        []Document    `json:"documents,omitempty"`
 | 
			
		||||
	Temperature      float64       `json:"temperature,omitempty"` // 默认值为0.3
 | 
			
		||||
	Temperature      *float64      `json:"temperature,omitempty"` // 默认值为0.3
 | 
			
		||||
	MaxTokens        int           `json:"max_tokens,omitempty"`
 | 
			
		||||
	MaxInputTokens   int           `json:"max_input_tokens,omitempty"`
 | 
			
		||||
	K                int           `json:"k,omitempty"` // 默认值为0
 | 
			
		||||
	P                float64       `json:"p,omitempty"` // 默认值为0.75
 | 
			
		||||
	P                *float64      `json:"p,omitempty"` // 默认值为0.75
 | 
			
		||||
	Seed             int           `json:"seed,omitempty"`
 | 
			
		||||
	StopSequences    []string      `json:"stop_sequences,omitempty"`
 | 
			
		||||
	FrequencyPenalty float64       `json:"frequency_penalty,omitempty"` // 默认值为0.0
 | 
			
		||||
	PresencePenalty  float64       `json:"presence_penalty,omitempty"`  // 默认值为0.0
 | 
			
		||||
	FrequencyPenalty *float64      `json:"frequency_penalty,omitempty"` // 默认值为0.0
 | 
			
		||||
	PresencePenalty  *float64      `json:"presence_penalty,omitempty"`  // 默认值为0.0
 | 
			
		||||
	Tools            []Tool        `json:"tools,omitempty"`
 | 
			
		||||
	ToolResults      []ToolResult  `json:"tool_results,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -24,7 +24,12 @@ func (a *Adaptor) Init(meta *meta.Meta) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
 | 
			
		||||
	version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion)
 | 
			
		||||
	defaultVersion := config.GeminiVersion
 | 
			
		||||
	if meta.ActualModelName == "gemini-2.0-flash-exp" {
 | 
			
		||||
		defaultVersion = "v1beta"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	version := helper.AssignOrDefault(meta.Config.APIVersion, defaultVersion)
 | 
			
		||||
	action := ""
 | 
			
		||||
	switch meta.Mode {
 | 
			
		||||
	case relaymode.Embeddings:
 | 
			
		||||
@@ -36,6 +41,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
 | 
			
		||||
	if meta.IsStream {
 | 
			
		||||
		action = "streamGenerateContent?alt=sse"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,5 +3,9 @@ package gemini
 | 
			
		||||
// https://ai.google.dev/models/gemini
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"gemini-pro", "gemini-1.0-pro", "gemini-1.5-flash", "gemini-1.5-pro", "text-embedding-004", "aqa",
 | 
			
		||||
	"gemini-pro", "gemini-1.0-pro",
 | 
			
		||||
	"gemini-1.5-flash", "gemini-1.5-pro",
 | 
			
		||||
	"text-embedding-004", "aqa",
 | 
			
		||||
	"gemini-2.0-flash-exp",
 | 
			
		||||
	"gemini-2.0-flash-thinking-exp",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,11 +4,12 @@ import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
@@ -28,6 +29,11 @@ const (
 | 
			
		||||
	VisionMaxImageNum = 16
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var mimeTypeMap = map[string]string{
 | 
			
		||||
	"json_object": "application/json",
 | 
			
		||||
	"text":        "text/plain",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
 | 
			
		||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
	geminiRequest := ChatRequest{
 | 
			
		||||
@@ -49,6 +55,10 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
				Category:  "HARM_CATEGORY_DANGEROUS_CONTENT",
 | 
			
		||||
				Threshold: config.GeminiSafetySetting,
 | 
			
		||||
			},
 | 
			
		||||
			{
 | 
			
		||||
				Category:  "HARM_CATEGORY_CIVIC_INTEGRITY",
 | 
			
		||||
				Threshold: config.GeminiSafetySetting,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		GenerationConfig: ChatGenerationConfig{
 | 
			
		||||
			Temperature:     textRequest.Temperature,
 | 
			
		||||
@@ -56,6 +66,15 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
			MaxOutputTokens: textRequest.MaxTokens,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	if textRequest.ResponseFormat != nil {
 | 
			
		||||
		if mimeType, ok := mimeTypeMap[textRequest.ResponseFormat.Type]; ok {
 | 
			
		||||
			geminiRequest.GenerationConfig.ResponseMimeType = mimeType
 | 
			
		||||
		}
 | 
			
		||||
		if textRequest.ResponseFormat.JsonSchema != nil {
 | 
			
		||||
			geminiRequest.GenerationConfig.ResponseSchema = textRequest.ResponseFormat.JsonSchema.Schema
 | 
			
		||||
			geminiRequest.GenerationConfig.ResponseMimeType = mimeTypeMap["json_object"]
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if textRequest.Tools != nil {
 | 
			
		||||
		functions := make([]model.Function, 0, len(textRequest.Tools))
 | 
			
		||||
		for _, tool := range textRequest.Tools {
 | 
			
		||||
@@ -232,7 +251,14 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
 | 
			
		||||
			if candidate.Content.Parts[0].FunctionCall != nil {
 | 
			
		||||
				choice.Message.ToolCalls = getToolCalls(&candidate)
 | 
			
		||||
			} else {
 | 
			
		||||
				choice.Message.Content = candidate.Content.Parts[0].Text
 | 
			
		||||
				var builder strings.Builder
 | 
			
		||||
				for _, part := range candidate.Content.Parts {
 | 
			
		||||
					if i > 0 {
 | 
			
		||||
						builder.WriteString("\n")
 | 
			
		||||
					}
 | 
			
		||||
					builder.WriteString(part.Text)
 | 
			
		||||
				}
 | 
			
		||||
				choice.Message.Content = builder.String()
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			choice.Message.Content = ""
 | 
			
		||||
 
 | 
			
		||||
@@ -65,10 +65,12 @@ type ChatTools struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatGenerationConfig struct {
 | 
			
		||||
	Temperature     float64  `json:"temperature,omitempty"`
 | 
			
		||||
	TopP            float64  `json:"topP,omitempty"`
 | 
			
		||||
	TopK            float64  `json:"topK,omitempty"`
 | 
			
		||||
	MaxOutputTokens int      `json:"maxOutputTokens,omitempty"`
 | 
			
		||||
	CandidateCount  int      `json:"candidateCount,omitempty"`
 | 
			
		||||
	StopSequences   []string `json:"stopSequences,omitempty"`
 | 
			
		||||
	ResponseMimeType string   `json:"responseMimeType,omitempty"`
 | 
			
		||||
	ResponseSchema   any      `json:"responseSchema,omitempty"`
 | 
			
		||||
	Temperature      *float64 `json:"temperature,omitempty"`
 | 
			
		||||
	TopP             *float64 `json:"topP,omitempty"`
 | 
			
		||||
	TopK             float64  `json:"topK,omitempty"`
 | 
			
		||||
	MaxOutputTokens  int      `json:"maxOutputTokens,omitempty"`
 | 
			
		||||
	CandidateCount   int      `json:"candidateCount,omitempty"`
 | 
			
		||||
	StopSequences    []string `json:"stopSequences,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,14 +4,24 @@ package groq
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"gemma-7b-it",
 | 
			
		||||
	"mixtral-8x7b-32768",
 | 
			
		||||
	"llama3-8b-8192",
 | 
			
		||||
	"llama3-70b-8192",
 | 
			
		||||
	"gemma2-9b-it",
 | 
			
		||||
	"llama-3.1-405b-reasoning",
 | 
			
		||||
	"llama-3.1-70b-versatile",
 | 
			
		||||
	"llama-3.1-8b-instant",
 | 
			
		||||
	"llama-3.2-11b-text-preview",
 | 
			
		||||
	"llama-3.2-11b-vision-preview",
 | 
			
		||||
	"llama-3.2-1b-preview",
 | 
			
		||||
	"llama-3.2-3b-preview",
 | 
			
		||||
	"llama-3.2-11b-vision-preview",
 | 
			
		||||
	"llama-3.2-90b-text-preview",
 | 
			
		||||
	"llama-3.2-90b-vision-preview",
 | 
			
		||||
	"llama-guard-3-8b",
 | 
			
		||||
	"llama3-70b-8192",
 | 
			
		||||
	"llama3-8b-8192",
 | 
			
		||||
	"llama3-groq-70b-8192-tool-use-preview",
 | 
			
		||||
	"llama3-groq-8b-8192-tool-use-preview",
 | 
			
		||||
	"llava-v1.5-7b-4096-preview",
 | 
			
		||||
	"mixtral-8x7b-32768",
 | 
			
		||||
	"distil-whisper-large-v3-en",
 | 
			
		||||
	"whisper-large-v3",
 | 
			
		||||
	"whisper-large-v3-turbo",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -31,8 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
			TopP:             request.TopP,
 | 
			
		||||
			FrequencyPenalty: request.FrequencyPenalty,
 | 
			
		||||
			PresencePenalty:  request.PresencePenalty,
 | 
			
		||||
			NumPredict:  	  request.MaxTokens,
 | 
			
		||||
			NumCtx:  	  request.NumCtx,
 | 
			
		||||
			NumPredict:       request.MaxTokens,
 | 
			
		||||
			NumCtx:           request.NumCtx,
 | 
			
		||||
		},
 | 
			
		||||
		Stream: request.Stream,
 | 
			
		||||
	}
 | 
			
		||||
@@ -122,7 +122,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if strings.HasPrefix(data, "}") {
 | 
			
		||||
		    data = strings.TrimPrefix(data, "}") + "}"
 | 
			
		||||
			data = strings.TrimPrefix(data, "}") + "}"
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var ollamaResponse ChatResponse
 | 
			
		||||
 
 | 
			
		||||
@@ -1,14 +1,14 @@
 | 
			
		||||
package ollama
 | 
			
		||||
 | 
			
		||||
type Options struct {
 | 
			
		||||
	Seed             int     `json:"seed,omitempty"`
 | 
			
		||||
	Temperature      float64 `json:"temperature,omitempty"`
 | 
			
		||||
	TopK             int     `json:"top_k,omitempty"`
 | 
			
		||||
	TopP             float64 `json:"top_p,omitempty"`
 | 
			
		||||
	FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
 | 
			
		||||
	PresencePenalty  float64 `json:"presence_penalty,omitempty"`
 | 
			
		||||
	NumPredict  	 int 	 `json:"num_predict,omitempty"`
 | 
			
		||||
	NumCtx  	 int 	 `json:"num_ctx,omitempty"`
 | 
			
		||||
	Seed             int      `json:"seed,omitempty"`
 | 
			
		||||
	Temperature      *float64 `json:"temperature,omitempty"`
 | 
			
		||||
	TopK             int      `json:"top_k,omitempty"`
 | 
			
		||||
	TopP             *float64 `json:"top_p,omitempty"`
 | 
			
		||||
	FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
 | 
			
		||||
	PresencePenalty  *float64 `json:"presence_penalty,omitempty"`
 | 
			
		||||
	NumPredict       int      `json:"num_predict,omitempty"`
 | 
			
		||||
	NumCtx           int      `json:"num_ctx,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Message struct {
 | 
			
		||||
 
 | 
			
		||||
@@ -75,6 +75,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
	if request.Stream {
 | 
			
		||||
		// always return usage in stream mode
 | 
			
		||||
		if request.StreamOptions == nil {
 | 
			
		||||
			request.StreamOptions = &model.StreamOptions{}
 | 
			
		||||
		}
 | 
			
		||||
		request.StreamOptions.IncludeUsage = true
 | 
			
		||||
	}
 | 
			
		||||
	return request, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -11,9 +11,10 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/mistral"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/moonshot"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/novita"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/siliconflow"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/stepfun"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/togetherai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/siliconflow"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/xai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channeltype"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -32,6 +33,7 @@ var CompatibleChannels = []int{
 | 
			
		||||
	channeltype.TogetherAI,
 | 
			
		||||
	channeltype.Novita,
 | 
			
		||||
	channeltype.SiliconFlow,
 | 
			
		||||
	channeltype.XAI,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetCompatibleChannelMeta(channelType int) (string, []string) {
 | 
			
		||||
@@ -64,6 +66,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
 | 
			
		||||
		return "novita", novita.ModelList
 | 
			
		||||
	case channeltype.SiliconFlow:
 | 
			
		||||
		return "siliconflow", siliconflow.ModelList
 | 
			
		||||
	case channeltype.XAI:
 | 
			
		||||
		return "xai", xai.ModelList
 | 
			
		||||
	default:
 | 
			
		||||
		return "openai", ModelList
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -20,4 +20,7 @@ var ModelList = []string{
 | 
			
		||||
	"dall-e-2", "dall-e-3",
 | 
			
		||||
	"whisper-1",
 | 
			
		||||
	"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
 | 
			
		||||
	"o1", "o1-2024-12-17",
 | 
			
		||||
	"o1-preview", "o1-preview-2024-09-12",
 | 
			
		||||
	"o1-mini", "o1-mini-2024-09-12",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,15 +2,16 @@ package openai
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channeltype"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
 | 
			
		||||
func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage {
 | 
			
		||||
	usage := &model.Usage{}
 | 
			
		||||
	usage.PromptTokens = promptTokens
 | 
			
		||||
	usage.CompletionTokens = CountTokenText(responseText, modeName)
 | 
			
		||||
	usage.CompletionTokens = CountTokenText(responseText, modelName)
 | 
			
		||||
	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 | 
			
		||||
	return usage
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -55,8 +55,8 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
 | 
			
		||||
				render.StringData(c, data) // if error happened, pass the data to client
 | 
			
		||||
				continue                   // just ignore the error
 | 
			
		||||
			}
 | 
			
		||||
			if len(streamResponse.Choices) == 0 {
 | 
			
		||||
				// but for empty choice, we should not pass it to client, this is for azure
 | 
			
		||||
			if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
 | 
			
		||||
				// but for empty choice and no usage, we should not pass it to client, this is for azure
 | 
			
		||||
				continue // just ignore empty choice
 | 
			
		||||
			}
 | 
			
		||||
			render.StringData(c, data)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,8 +1,16 @@
 | 
			
		||||
package openai
 | 
			
		||||
 | 
			
		||||
import "github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
 | 
			
		||||
	logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err))
 | 
			
		||||
 | 
			
		||||
	Error := model.Error{
 | 
			
		||||
		Message: err.Error(),
 | 
			
		||||
		Type:    "one_api_error",
 | 
			
		||||
 
 | 
			
		||||
@@ -19,11 +19,11 @@ type Prompt struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatRequest struct {
 | 
			
		||||
	Prompt         Prompt  `json:"prompt"`
 | 
			
		||||
	Temperature    float64 `json:"temperature,omitempty"`
 | 
			
		||||
	CandidateCount int     `json:"candidateCount,omitempty"`
 | 
			
		||||
	TopP           float64 `json:"topP,omitempty"`
 | 
			
		||||
	TopK           int     `json:"topK,omitempty"`
 | 
			
		||||
	Prompt         Prompt   `json:"prompt"`
 | 
			
		||||
	Temperature    *float64 `json:"temperature,omitempty"`
 | 
			
		||||
	CandidateCount int      `json:"candidateCount,omitempty"`
 | 
			
		||||
	TopP           *float64 `json:"topP,omitempty"`
 | 
			
		||||
	TopK           int      `json:"topK,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Error struct {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										136
									
								
								relay/adaptor/replicate/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								relay/adaptor/replicate/adaptor.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,136 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/relaymode"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
	meta *meta.Meta
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConvertImageRequest implements adaptor.Adaptor.
 | 
			
		||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
	return DrawImageRequest{
 | 
			
		||||
		Input: ImageInput{
 | 
			
		||||
			Steps:           25,
 | 
			
		||||
			Prompt:          request.Prompt,
 | 
			
		||||
			Guidance:        3,
 | 
			
		||||
			Seed:            int(time.Now().UnixNano()),
 | 
			
		||||
			SafetyTolerance: 5,
 | 
			
		||||
			NImages:         1, // replicate will always return 1 image
 | 
			
		||||
			Width:           1440,
 | 
			
		||||
			Height:          1440,
 | 
			
		||||
			AspectRatio:     "1:1",
 | 
			
		||||
		},
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
 | 
			
		||||
	if !request.Stream {
 | 
			
		||||
		// TODO: support non-stream mode
 | 
			
		||||
		return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Build the prompt from OpenAI messages
 | 
			
		||||
	var promptBuilder strings.Builder
 | 
			
		||||
	for _, message := range request.Messages {
 | 
			
		||||
		switch msgCnt := message.Content.(type) {
 | 
			
		||||
		case string:
 | 
			
		||||
			promptBuilder.WriteString(message.Role)
 | 
			
		||||
			promptBuilder.WriteString(": ")
 | 
			
		||||
			promptBuilder.WriteString(msgCnt)
 | 
			
		||||
			promptBuilder.WriteString("\n")
 | 
			
		||||
		default:
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	replicateRequest := ReplicateChatRequest{
 | 
			
		||||
		Input: ChatInput{
 | 
			
		||||
			Prompt:           promptBuilder.String(),
 | 
			
		||||
			MaxTokens:        request.MaxTokens,
 | 
			
		||||
			Temperature:      1.0,
 | 
			
		||||
			TopP:             1.0,
 | 
			
		||||
			PresencePenalty:  0.0,
 | 
			
		||||
			FrequencyPenalty: 0.0,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Map optional fields
 | 
			
		||||
	if request.Temperature != nil {
 | 
			
		||||
		replicateRequest.Input.Temperature = *request.Temperature
 | 
			
		||||
	}
 | 
			
		||||
	if request.TopP != nil {
 | 
			
		||||
		replicateRequest.Input.TopP = *request.TopP
 | 
			
		||||
	}
 | 
			
		||||
	if request.PresencePenalty != nil {
 | 
			
		||||
		replicateRequest.Input.PresencePenalty = *request.PresencePenalty
 | 
			
		||||
	}
 | 
			
		||||
	if request.FrequencyPenalty != nil {
 | 
			
		||||
		replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty
 | 
			
		||||
	}
 | 
			
		||||
	if request.MaxTokens > 0 {
 | 
			
		||||
		replicateRequest.Input.MaxTokens = request.MaxTokens
 | 
			
		||||
	} else if request.MaxTokens == 0 {
 | 
			
		||||
		replicateRequest.Input.MaxTokens = 500
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return replicateRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(meta *meta.Meta) {
 | 
			
		||||
	a.meta = meta
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
 | 
			
		||||
	if !slices.Contains(ModelList, meta.OriginModelName) {
 | 
			
		||||
		return "", errors.Errorf("model %s not supported", meta.OriginModelName)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
 | 
			
		||||
	adaptor.SetupCommonRequestHeader(c, req, meta)
 | 
			
		||||
	req.Header.Set("Authorization", "Bearer "+meta.APIKey)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	logger.Info(c, "send request to replicate")
 | 
			
		||||
	return adaptor.DoRequestHelper(a, c, meta, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
 | 
			
		||||
	switch meta.Mode {
 | 
			
		||||
	case relaymode.ImagesGenerations:
 | 
			
		||||
		err, usage = ImageHandler(c, resp)
 | 
			
		||||
	case relaymode.ChatCompletions:
 | 
			
		||||
		err, usage = ChatHandler(c, resp)
 | 
			
		||||
	default:
 | 
			
		||||
		err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetModelList() []string {
 | 
			
		||||
	return ModelList
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetChannelName() string {
 | 
			
		||||
	return "replicate"
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										191
									
								
								relay/adaptor/replicate/chat.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								relay/adaptor/replicate/chat.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,191 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ChatHandler(c *gin.Context, resp *http.Response) (
 | 
			
		||||
	srvErr *model.ErrorWithStatusCode, usage *model.Usage) {
 | 
			
		||||
	if resp.StatusCode != http.StatusCreated {
 | 
			
		||||
		payload, _ := io.ReadAll(resp.Body)
 | 
			
		||||
		return openai.ErrorWrapper(
 | 
			
		||||
				errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
 | 
			
		||||
				"bad_status_code", http.StatusInternalServerError),
 | 
			
		||||
			nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respData := new(ChatResponse)
 | 
			
		||||
	if err = json.Unmarshal(respBody, respData); err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		err = func() error {
 | 
			
		||||
			// get task
 | 
			
		||||
			taskReq, err := http.NewRequestWithContext(c.Request.Context(),
 | 
			
		||||
				http.MethodGet, respData.URLs.Get, nil)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "new request")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
 | 
			
		||||
			taskResp, err := http.DefaultClient.Do(taskReq)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "get task")
 | 
			
		||||
			}
 | 
			
		||||
			defer taskResp.Body.Close()
 | 
			
		||||
 | 
			
		||||
			if taskResp.StatusCode != http.StatusOK {
 | 
			
		||||
				payload, _ := io.ReadAll(taskResp.Body)
 | 
			
		||||
				return errors.Errorf("bad status code [%d]%s",
 | 
			
		||||
					taskResp.StatusCode, string(payload))
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskBody, err := io.ReadAll(taskResp.Body)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "read task response")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskData := new(ChatResponse)
 | 
			
		||||
			if err = json.Unmarshal(taskBody, taskData); err != nil {
 | 
			
		||||
				return errors.Wrap(err, "decode task response")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			switch taskData.Status {
 | 
			
		||||
			case "succeeded":
 | 
			
		||||
			case "failed", "canceled":
 | 
			
		||||
				return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error)
 | 
			
		||||
			default:
 | 
			
		||||
				time.Sleep(time.Second * 3)
 | 
			
		||||
				return errNextLoop
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if taskData.URLs.Stream == "" {
 | 
			
		||||
				return errors.New("stream url is empty")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// request stream url
 | 
			
		||||
			responseText, err := chatStreamHandler(c, taskData.URLs.Stream)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "chat stream handler")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			ctxMeta := meta.GetByContext(c)
 | 
			
		||||
			usage = openai.ResponseText2Usage(responseText,
 | 
			
		||||
				ctxMeta.ActualModelName, ctxMeta.PromptTokens)
 | 
			
		||||
			return nil
 | 
			
		||||
		}()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if errors.Is(err, errNextLoop) {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		break
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	eventPrefix = "event: "
 | 
			
		||||
	dataPrefix  = "data: "
 | 
			
		||||
	done        = "[DONE]"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) {
 | 
			
		||||
	// request stream endpoint
 | 
			
		||||
	streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", errors.Wrap(err, "new request to stream")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
 | 
			
		||||
	streamReq.Header.Set("Accept", "text/event-stream")
 | 
			
		||||
	streamReq.Header.Set("Cache-Control", "no-store")
 | 
			
		||||
 | 
			
		||||
	resp, err := http.DefaultClient.Do(streamReq)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", errors.Wrap(err, "do request to stream")
 | 
			
		||||
	}
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		payload, _ := io.ReadAll(resp.Body)
 | 
			
		||||
		return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	doneRendered := false
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		line := strings.TrimSpace(scanner.Text())
 | 
			
		||||
		if line == "" {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Handle comments starting with ':'
 | 
			
		||||
		if strings.HasPrefix(line, ":") {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Parse SSE fields
 | 
			
		||||
		if strings.HasPrefix(line, eventPrefix) {
 | 
			
		||||
			event := strings.TrimSpace(line[len(eventPrefix):])
 | 
			
		||||
			var data string
 | 
			
		||||
			// Read the following lines to get data and id
 | 
			
		||||
			for scanner.Scan() {
 | 
			
		||||
				nextLine := scanner.Text()
 | 
			
		||||
				if nextLine == "" {
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
				if strings.HasPrefix(nextLine, dataPrefix) {
 | 
			
		||||
					data = nextLine[len(dataPrefix):]
 | 
			
		||||
				} else if strings.HasPrefix(nextLine, "id:") {
 | 
			
		||||
					// id = strings.TrimSpace(nextLine[len("id:"):])
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if event == "output" {
 | 
			
		||||
				render.StringData(c, data)
 | 
			
		||||
				responseText += data
 | 
			
		||||
			} else if event == "done" {
 | 
			
		||||
				render.Done(c)
 | 
			
		||||
				doneRendered = true
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		return "", errors.Wrap(err, "scan stream")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !doneRendered {
 | 
			
		||||
		render.Done(c)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return responseText, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										58
									
								
								relay/adaptor/replicate/constant.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								relay/adaptor/replicate/constant.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,58 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
// ModelList is a list of models that can be used with Replicate.
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/pricing
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// image model
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	"black-forest-labs/flux-1.1-pro",
 | 
			
		||||
	"black-forest-labs/flux-1.1-pro-ultra",
 | 
			
		||||
	"black-forest-labs/flux-canny-dev",
 | 
			
		||||
	"black-forest-labs/flux-canny-pro",
 | 
			
		||||
	"black-forest-labs/flux-depth-dev",
 | 
			
		||||
	"black-forest-labs/flux-depth-pro",
 | 
			
		||||
	"black-forest-labs/flux-dev",
 | 
			
		||||
	"black-forest-labs/flux-dev-lora",
 | 
			
		||||
	"black-forest-labs/flux-fill-dev",
 | 
			
		||||
	"black-forest-labs/flux-fill-pro",
 | 
			
		||||
	"black-forest-labs/flux-pro",
 | 
			
		||||
	"black-forest-labs/flux-redux-dev",
 | 
			
		||||
	"black-forest-labs/flux-redux-schnell",
 | 
			
		||||
	"black-forest-labs/flux-schnell",
 | 
			
		||||
	"black-forest-labs/flux-schnell-lora",
 | 
			
		||||
	"ideogram-ai/ideogram-v2",
 | 
			
		||||
	"ideogram-ai/ideogram-v2-turbo",
 | 
			
		||||
	"recraft-ai/recraft-v3",
 | 
			
		||||
	"recraft-ai/recraft-v3-svg",
 | 
			
		||||
	"stability-ai/stable-diffusion-3",
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-large",
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-large-turbo",
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-medium",
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// language model
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	"ibm-granite/granite-20b-code-instruct-8k",
 | 
			
		||||
	"ibm-granite/granite-3.0-2b-instruct",
 | 
			
		||||
	"ibm-granite/granite-3.0-8b-instruct",
 | 
			
		||||
	"ibm-granite/granite-8b-code-instruct-128k",
 | 
			
		||||
	"meta/llama-2-13b",
 | 
			
		||||
	"meta/llama-2-13b-chat",
 | 
			
		||||
	"meta/llama-2-70b",
 | 
			
		||||
	"meta/llama-2-70b-chat",
 | 
			
		||||
	"meta/llama-2-7b",
 | 
			
		||||
	"meta/llama-2-7b-chat",
 | 
			
		||||
	"meta/meta-llama-3.1-405b-instruct",
 | 
			
		||||
	"meta/meta-llama-3-70b",
 | 
			
		||||
	"meta/meta-llama-3-70b-instruct",
 | 
			
		||||
	"meta/meta-llama-3-8b",
 | 
			
		||||
	"meta/meta-llama-3-8b-instruct",
 | 
			
		||||
	"mistralai/mistral-7b-instruct-v0.2",
 | 
			
		||||
	"mistralai/mistral-7b-v0.1",
 | 
			
		||||
	"mistralai/mixtral-8x7b-instruct-v0.1",
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// video model
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// "minimax/video-01",  // TODO: implement the adaptor
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										222
									
								
								relay/adaptor/replicate/image.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										222
									
								
								relay/adaptor/replicate/image.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,222 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"image"
 | 
			
		||||
	"image/png"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"golang.org/x/image/webp"
 | 
			
		||||
	"golang.org/x/sync/errgroup"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ImagesEditsHandler just copy response body to client
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-fill-pro
 | 
			
		||||
// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
// 	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
// 	for k, v := range resp.Header {
 | 
			
		||||
// 		c.Writer.Header().Set(k, v[0])
 | 
			
		||||
// 	}
 | 
			
		||||
 | 
			
		||||
// 	if _, err := io.Copy(c.Writer, resp.Body); err != nil {
 | 
			
		||||
// 		return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
// 	}
 | 
			
		||||
// 	defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
// 	return nil, nil
 | 
			
		||||
// }
 | 
			
		||||
 | 
			
		||||
var errNextLoop = errors.New("next_loop")
 | 
			
		||||
 | 
			
		||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	if resp.StatusCode != http.StatusCreated {
 | 
			
		||||
		payload, _ := io.ReadAll(resp.Body)
 | 
			
		||||
		return openai.ErrorWrapper(
 | 
			
		||||
				errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
 | 
			
		||||
				"bad_status_code", http.StatusInternalServerError),
 | 
			
		||||
			nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respData := new(ImageResponse)
 | 
			
		||||
	if err = json.Unmarshal(respBody, respData); err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		err = func() error {
 | 
			
		||||
			// get task
 | 
			
		||||
			taskReq, err := http.NewRequestWithContext(c.Request.Context(),
 | 
			
		||||
				http.MethodGet, respData.URLs.Get, nil)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "new request")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
 | 
			
		||||
			taskResp, err := http.DefaultClient.Do(taskReq)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "get task")
 | 
			
		||||
			}
 | 
			
		||||
			defer taskResp.Body.Close()
 | 
			
		||||
 | 
			
		||||
			if taskResp.StatusCode != http.StatusOK {
 | 
			
		||||
				payload, _ := io.ReadAll(taskResp.Body)
 | 
			
		||||
				return errors.Errorf("bad status code [%d]%s",
 | 
			
		||||
					taskResp.StatusCode, string(payload))
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskBody, err := io.ReadAll(taskResp.Body)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "read task response")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskData := new(ImageResponse)
 | 
			
		||||
			if err = json.Unmarshal(taskBody, taskData); err != nil {
 | 
			
		||||
				return errors.Wrap(err, "decode task response")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			switch taskData.Status {
 | 
			
		||||
			case "succeeded":
 | 
			
		||||
			case "failed", "canceled":
 | 
			
		||||
				return errors.Errorf("task failed: %s", taskData.Status)
 | 
			
		||||
			default:
 | 
			
		||||
				time.Sleep(time.Second * 3)
 | 
			
		||||
				return errNextLoop
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			output, err := taskData.GetOutput()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "get output")
 | 
			
		||||
			}
 | 
			
		||||
			if len(output) == 0 {
 | 
			
		||||
				return errors.New("response output is empty")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var mu sync.Mutex
 | 
			
		||||
			var pool errgroup.Group
 | 
			
		||||
			respBody := &openai.ImageResponse{
 | 
			
		||||
				Created: taskData.CompletedAt.Unix(),
 | 
			
		||||
				Data:    []openai.ImageData{},
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, imgOut := range output {
 | 
			
		||||
				imgOut := imgOut
 | 
			
		||||
				pool.Go(func() error {
 | 
			
		||||
					// download image
 | 
			
		||||
					downloadReq, err := http.NewRequestWithContext(c.Request.Context(),
 | 
			
		||||
						http.MethodGet, imgOut, nil)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.Wrap(err, "new request")
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					imgResp, err := http.DefaultClient.Do(downloadReq)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.Wrap(err, "download image")
 | 
			
		||||
					}
 | 
			
		||||
					defer imgResp.Body.Close()
 | 
			
		||||
 | 
			
		||||
					if imgResp.StatusCode != http.StatusOK {
 | 
			
		||||
						payload, _ := io.ReadAll(imgResp.Body)
 | 
			
		||||
						return errors.Errorf("bad status code [%d]%s",
 | 
			
		||||
							imgResp.StatusCode, string(payload))
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					imgData, err := io.ReadAll(imgResp.Body)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.Wrap(err, "read image")
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					imgData, err = ConvertImageToPNG(imgData)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.Wrap(err, "convert image")
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					mu.Lock()
 | 
			
		||||
					respBody.Data = append(respBody.Data, openai.ImageData{
 | 
			
		||||
						B64Json: fmt.Sprintf("data:image/png;base64,%s",
 | 
			
		||||
							base64.StdEncoding.EncodeToString(imgData)),
 | 
			
		||||
					})
 | 
			
		||||
					mu.Unlock()
 | 
			
		||||
 | 
			
		||||
					return nil
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err := pool.Wait(); err != nil {
 | 
			
		||||
				if len(respBody.Data) == 0 {
 | 
			
		||||
					return errors.WithStack(err)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err))
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			c.JSON(http.StatusOK, respBody)
 | 
			
		||||
			return nil
 | 
			
		||||
		}()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if errors.Is(err, errNextLoop) {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		break
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConvertImageToPNG converts a WebP image to PNG format
 | 
			
		||||
func ConvertImageToPNG(webpData []byte) ([]byte, error) {
 | 
			
		||||
	// bypass if it's already a PNG image
 | 
			
		||||
	if bytes.HasPrefix(webpData, []byte("\x89PNG")) {
 | 
			
		||||
		return webpData, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// check if is jpeg, convert to png
 | 
			
		||||
	if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) {
 | 
			
		||||
		img, _, err := image.Decode(bytes.NewReader(webpData))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, errors.Wrap(err, "decode jpeg")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var pngBuffer bytes.Buffer
 | 
			
		||||
		if err := png.Encode(&pngBuffer, img); err != nil {
 | 
			
		||||
			return nil, errors.Wrap(err, "encode png")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return pngBuffer.Bytes(), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Decode the WebP image
 | 
			
		||||
	img, err := webp.Decode(bytes.NewReader(webpData))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "decode webp")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Encode the image as PNG
 | 
			
		||||
	var pngBuffer bytes.Buffer
 | 
			
		||||
	if err := png.Encode(&pngBuffer, img); err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "encode png")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return pngBuffer.Bytes(), nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										159
									
								
								relay/adaptor/replicate/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								relay/adaptor/replicate/model.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,159 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// DrawImageRequest draw image by fluxpro
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
 | 
			
		||||
type DrawImageRequest struct {
 | 
			
		||||
	Input ImageInput `json:"input"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ImageInput is input of DrawImageByFluxProRequest
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema
 | 
			
		||||
type ImageInput struct {
 | 
			
		||||
	Steps           int    `json:"steps" binding:"required,min=1"`
 | 
			
		||||
	Prompt          string `json:"prompt" binding:"required,min=5"`
 | 
			
		||||
	ImagePrompt     string `json:"image_prompt"`
 | 
			
		||||
	Guidance        int    `json:"guidance" binding:"required,min=2,max=5"`
 | 
			
		||||
	Interval        int    `json:"interval" binding:"required,min=1,max=4"`
 | 
			
		||||
	AspectRatio     string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"`
 | 
			
		||||
	SafetyTolerance int    `json:"safety_tolerance" binding:"required,min=1,max=5"`
 | 
			
		||||
	Seed            int    `json:"seed"`
 | 
			
		||||
	NImages         int    `json:"n_images" binding:"required,min=1,max=8"`
 | 
			
		||||
	Width           int    `json:"width" binding:"required,min=256,max=1440"`
 | 
			
		||||
	Height          int    `json:"height" binding:"required,min=256,max=1440"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
 | 
			
		||||
type InpaintingImageByFlusReplicateRequest struct {
 | 
			
		||||
	Input FluxInpaintingInput `json:"input"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FluxInpaintingInput is input of DrawImageByFluxProRequest
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
 | 
			
		||||
type FluxInpaintingInput struct {
 | 
			
		||||
	Mask             string `json:"mask" binding:"required"`
 | 
			
		||||
	Image            string `json:"image" binding:"required"`
 | 
			
		||||
	Seed             int    `json:"seed"`
 | 
			
		||||
	Steps            int    `json:"steps" binding:"required,min=1"`
 | 
			
		||||
	Prompt           string `json:"prompt" binding:"required,min=5"`
 | 
			
		||||
	Guidance         int    `json:"guidance" binding:"required,min=2,max=5"`
 | 
			
		||||
	OutputFormat     string `json:"output_format"`
 | 
			
		||||
	SafetyTolerance  int    `json:"safety_tolerance" binding:"required,min=1,max=5"`
 | 
			
		||||
	PromptUnsampling bool   `json:"prompt_unsampling"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ImageResponse is response of DrawImageByFluxProRequest
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
 | 
			
		||||
type ImageResponse struct {
 | 
			
		||||
	CompletedAt time.Time        `json:"completed_at"`
 | 
			
		||||
	CreatedAt   time.Time        `json:"created_at"`
 | 
			
		||||
	DataRemoved bool             `json:"data_removed"`
 | 
			
		||||
	Error       string           `json:"error"`
 | 
			
		||||
	ID          string           `json:"id"`
 | 
			
		||||
	Input       DrawImageRequest `json:"input"`
 | 
			
		||||
	Logs        string           `json:"logs"`
 | 
			
		||||
	Metrics     FluxMetrics      `json:"metrics"`
 | 
			
		||||
	// Output could be `string` or `[]string`
 | 
			
		||||
	Output    any       `json:"output"`
 | 
			
		||||
	StartedAt time.Time `json:"started_at"`
 | 
			
		||||
	Status    string    `json:"status"`
 | 
			
		||||
	URLs      FluxURLs  `json:"urls"`
 | 
			
		||||
	Version   string    `json:"version"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *ImageResponse) GetOutput() ([]string, error) {
 | 
			
		||||
	switch v := r.Output.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		return []string{v}, nil
 | 
			
		||||
	case []string:
 | 
			
		||||
		return v, nil
 | 
			
		||||
	case nil:
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	case []interface{}:
 | 
			
		||||
		// convert []interface{} to []string
 | 
			
		||||
		ret := make([]string, len(v))
 | 
			
		||||
		for idx, vv := range v {
 | 
			
		||||
			if vvv, ok := vv.(string); ok {
 | 
			
		||||
				ret[idx] = vvv
 | 
			
		||||
			} else {
 | 
			
		||||
				return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return ret, nil
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FluxMetrics is metrics of ImageResponse
 | 
			
		||||
type FluxMetrics struct {
 | 
			
		||||
	ImageCount  int     `json:"image_count"`
 | 
			
		||||
	PredictTime float64 `json:"predict_time"`
 | 
			
		||||
	TotalTime   float64 `json:"total_time"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FluxURLs is urls of ImageResponse
 | 
			
		||||
type FluxURLs struct {
 | 
			
		||||
	Get    string `json:"get"`
 | 
			
		||||
	Cancel string `json:"cancel"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ReplicateChatRequest struct {
 | 
			
		||||
	Input ChatInput `json:"input" form:"input" binding:"required"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ChatInput is input of ChatByReplicateRequest
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema
 | 
			
		||||
type ChatInput struct {
 | 
			
		||||
	TopK             int     `json:"top_k"`
 | 
			
		||||
	TopP             float64 `json:"top_p"`
 | 
			
		||||
	Prompt           string  `json:"prompt"`
 | 
			
		||||
	MaxTokens        int     `json:"max_tokens"`
 | 
			
		||||
	MinTokens        int     `json:"min_tokens"`
 | 
			
		||||
	Temperature      float64 `json:"temperature"`
 | 
			
		||||
	SystemPrompt     string  `json:"system_prompt"`
 | 
			
		||||
	StopSequences    string  `json:"stop_sequences"`
 | 
			
		||||
	PromptTemplate   string  `json:"prompt_template"`
 | 
			
		||||
	PresencePenalty  float64 `json:"presence_penalty"`
 | 
			
		||||
	FrequencyPenalty float64 `json:"frequency_penalty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ChatResponse is response of ChatByReplicateRequest
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json
 | 
			
		||||
type ChatResponse struct {
 | 
			
		||||
	CompletedAt time.Time   `json:"completed_at"`
 | 
			
		||||
	CreatedAt   time.Time   `json:"created_at"`
 | 
			
		||||
	DataRemoved bool        `json:"data_removed"`
 | 
			
		||||
	Error       string      `json:"error"`
 | 
			
		||||
	ID          string      `json:"id"`
 | 
			
		||||
	Input       ChatInput   `json:"input"`
 | 
			
		||||
	Logs        string      `json:"logs"`
 | 
			
		||||
	Metrics     FluxMetrics `json:"metrics"`
 | 
			
		||||
	// Output could be `string` or `[]string`
 | 
			
		||||
	Output    []string        `json:"output"`
 | 
			
		||||
	StartedAt time.Time       `json:"started_at"`
 | 
			
		||||
	Status    string          `json:"status"`
 | 
			
		||||
	URLs      ChatResponseUrl `json:"urls"`
 | 
			
		||||
	Version   string          `json:"version"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ChatResponseUrl is task urls of ChatResponse
 | 
			
		||||
type ChatResponseUrl struct {
 | 
			
		||||
	Stream string `json:"stream"`
 | 
			
		||||
	Get    string `json:"get"`
 | 
			
		||||
	Cancel string `json:"cancel"`
 | 
			
		||||
}
 | 
			
		||||
@@ -5,4 +5,5 @@ var ModelList = []string{
 | 
			
		||||
	"hunyuan-standard",
 | 
			
		||||
	"hunyuan-standard-256K",
 | 
			
		||||
	"hunyuan-pro",
 | 
			
		||||
	"hunyuan-vision",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -39,8 +39,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
		Model:       &request.Model,
 | 
			
		||||
		Stream:      &request.Stream,
 | 
			
		||||
		Messages:    messages,
 | 
			
		||||
		TopP:        &request.TopP,
 | 
			
		||||
		Temperature: &request.Temperature,
 | 
			
		||||
		TopP:        request.TopP,
 | 
			
		||||
		Temperature: request.Temperature,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -13,7 +13,12 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"claude-3-haiku@20240307", "claude-3-opus@20240229", "claude-3-5-sonnet@20240620", "claude-3-sonnet@20240229",
 | 
			
		||||
	"claude-3-haiku@20240307",
 | 
			
		||||
	"claude-3-sonnet@20240229",
 | 
			
		||||
	"claude-3-opus@20240229",
 | 
			
		||||
	"claude-3-5-sonnet@20240620",
 | 
			
		||||
	"claude-3-5-sonnet-v2@20241022",
 | 
			
		||||
	"claude-3-5-haiku@20241022",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const anthropicVersion = "vertex-2023-10-16"
 | 
			
		||||
 
 | 
			
		||||
@@ -11,8 +11,8 @@ type Request struct {
 | 
			
		||||
	MaxTokens     int                 `json:"max_tokens,omitempty"`
 | 
			
		||||
	StopSequences []string            `json:"stop_sequences,omitempty"`
 | 
			
		||||
	Stream        bool                `json:"stream,omitempty"`
 | 
			
		||||
	Temperature   float64             `json:"temperature,omitempty"`
 | 
			
		||||
	TopP          float64             `json:"top_p,omitempty"`
 | 
			
		||||
	Temperature   *float64            `json:"temperature,omitempty"`
 | 
			
		||||
	TopP          *float64            `json:"top_p,omitempty"`
 | 
			
		||||
	TopK          int                 `json:"top_k,omitempty"`
 | 
			
		||||
	Tools         []anthropic.Tool    `json:"tools,omitempty"`
 | 
			
		||||
	ToolChoice    any                 `json:"tool_choice,omitempty"`
 | 
			
		||||
 
 | 
			
		||||
@@ -15,7 +15,10 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision",
 | 
			
		||||
	"gemini-pro", "gemini-pro-vision",
 | 
			
		||||
	"gemini-1.5-pro-001", "gemini-1.5-flash-001",
 | 
			
		||||
	"gemini-1.5-pro-002", "gemini-1.5-flash-002",
 | 
			
		||||
	"gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										5
									
								
								relay/adaptor/xai/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								relay/adaptor/xai/constants.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
			
		||||
package xai
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"grok-beta",
 | 
			
		||||
}
 | 
			
		||||
@@ -5,6 +5,8 @@ var ModelList = []string{
 | 
			
		||||
	"SparkDesk-v1.1",
 | 
			
		||||
	"SparkDesk-v2.1",
 | 
			
		||||
	"SparkDesk-v3.1",
 | 
			
		||||
	"SparkDesk-v3.1-128K",
 | 
			
		||||
	"SparkDesk-v3.5",
 | 
			
		||||
	"SparkDesk-v3.5-32K",
 | 
			
		||||
	"SparkDesk-v4.0",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -272,9 +272,9 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parseAPIVersionByModelName(modelName string) string {
 | 
			
		||||
	parts := strings.Split(modelName, "-")
 | 
			
		||||
	if len(parts) == 2 {
 | 
			
		||||
		return parts[1]
 | 
			
		||||
	index := strings.IndexAny(modelName, "-")
 | 
			
		||||
	if index != -1 {
 | 
			
		||||
		return modelName[index+1:]
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
@@ -283,13 +283,17 @@ func parseAPIVersionByModelName(modelName string) string {
 | 
			
		||||
func apiVersion2domain(apiVersion string) string {
 | 
			
		||||
	switch apiVersion {
 | 
			
		||||
	case "v1.1":
 | 
			
		||||
		return "general"
 | 
			
		||||
		return "lite"
 | 
			
		||||
	case "v2.1":
 | 
			
		||||
		return "generalv2"
 | 
			
		||||
	case "v3.1":
 | 
			
		||||
		return "generalv3"
 | 
			
		||||
	case "v3.1-128K":
 | 
			
		||||
		return "pro-128k"
 | 
			
		||||
	case "v3.5":
 | 
			
		||||
		return "generalv3.5"
 | 
			
		||||
	case "v3.5-32K":
 | 
			
		||||
		return "max-32k"
 | 
			
		||||
	case "v4.0":
 | 
			
		||||
		return "4.0Ultra"
 | 
			
		||||
	}
 | 
			
		||||
@@ -297,7 +301,17 @@ func apiVersion2domain(apiVersion string) string {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) {
 | 
			
		||||
	var authUrl string
 | 
			
		||||
	domain := apiVersion2domain(apiVersion)
 | 
			
		||||
	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
 | 
			
		||||
	switch apiVersion {
 | 
			
		||||
	case "v3.1-128K":
 | 
			
		||||
		authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/pro-128k"), apiKey, apiSecret)
 | 
			
		||||
		break
 | 
			
		||||
	case "v3.5-32K":
 | 
			
		||||
		authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/max-32k"), apiKey, apiSecret)
 | 
			
		||||
		break
 | 
			
		||||
	default:
 | 
			
		||||
		authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
 | 
			
		||||
	}
 | 
			
		||||
	return domain, authUrl
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -19,11 +19,11 @@ type ChatRequest struct {
 | 
			
		||||
	} `json:"header"`
 | 
			
		||||
	Parameter struct {
 | 
			
		||||
		Chat struct {
 | 
			
		||||
			Domain      string  `json:"domain,omitempty"`
 | 
			
		||||
			Temperature float64 `json:"temperature,omitempty"`
 | 
			
		||||
			TopK        int     `json:"top_k,omitempty"`
 | 
			
		||||
			MaxTokens   int     `json:"max_tokens,omitempty"`
 | 
			
		||||
			Auditing    bool    `json:"auditing,omitempty"`
 | 
			
		||||
			Domain      string   `json:"domain,omitempty"`
 | 
			
		||||
			Temperature *float64 `json:"temperature,omitempty"`
 | 
			
		||||
			TopK        int      `json:"top_k,omitempty"`
 | 
			
		||||
			MaxTokens   int      `json:"max_tokens,omitempty"`
 | 
			
		||||
			Auditing    bool     `json:"auditing,omitempty"`
 | 
			
		||||
		} `json:"chat"`
 | 
			
		||||
	} `json:"parameter"`
 | 
			
		||||
	Payload struct {
 | 
			
		||||
 
 | 
			
		||||
@@ -4,13 +4,13 @@ import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/relaymode"
 | 
			
		||||
	"io"
 | 
			
		||||
	"math"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
@@ -65,13 +65,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
		baiduEmbeddingRequest, err := ConvertEmbeddingRequest(*request)
 | 
			
		||||
		return baiduEmbeddingRequest, err
 | 
			
		||||
	default:
 | 
			
		||||
		// TopP (0.0, 1.0)
 | 
			
		||||
		request.TopP = math.Min(0.99, request.TopP)
 | 
			
		||||
		request.TopP = math.Max(0.01, request.TopP)
 | 
			
		||||
		// TopP [0.0, 1.0]
 | 
			
		||||
		request.TopP = helper.Float64PtrMax(request.TopP, 1)
 | 
			
		||||
		request.TopP = helper.Float64PtrMin(request.TopP, 0)
 | 
			
		||||
 | 
			
		||||
		// Temperature (0.0, 1.0)
 | 
			
		||||
		request.Temperature = math.Min(0.99, request.Temperature)
 | 
			
		||||
		request.Temperature = math.Max(0.01, request.Temperature)
 | 
			
		||||
		// Temperature [0.0, 1.0]
 | 
			
		||||
		request.Temperature = helper.Float64PtrMax(request.Temperature, 1)
 | 
			
		||||
		request.Temperature = helper.Float64PtrMin(request.Temperature, 0)
 | 
			
		||||
		a.SetVersionByModeName(request.Model)
 | 
			
		||||
		if a.APIVersion == "v4" {
 | 
			
		||||
			return request, nil
 | 
			
		||||
 
 | 
			
		||||
@@ -12,8 +12,8 @@ type Message struct {
 | 
			
		||||
 | 
			
		||||
type Request struct {
 | 
			
		||||
	Prompt      []Message `json:"prompt"`
 | 
			
		||||
	Temperature float64   `json:"temperature,omitempty"`
 | 
			
		||||
	TopP        float64   `json:"top_p,omitempty"`
 | 
			
		||||
	Temperature *float64  `json:"temperature,omitempty"`
 | 
			
		||||
	TopP        *float64  `json:"top_p,omitempty"`
 | 
			
		||||
	RequestId   string    `json:"request_id,omitempty"`
 | 
			
		||||
	Incremental bool      `json:"incremental,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -19,6 +19,7 @@ const (
 | 
			
		||||
	DeepL
 | 
			
		||||
	VertexAI
 | 
			
		||||
	Proxy
 | 
			
		||||
	Replicate
 | 
			
		||||
 | 
			
		||||
	Dummy // this one is only for count, do not add any channel after this
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -48,8 +48,14 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"gpt-3.5-turbo-instruct":  0.75, // $0.0015 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo-1106":      0.5,  // $0.001 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo-0125":      0.25, // $0.0005 / 1K tokens
 | 
			
		||||
	"davinci-002":             1,    // $0.002 / 1K tokens
 | 
			
		||||
	"babbage-002":             0.2,  // $0.0004 / 1K tokens
 | 
			
		||||
	"o1":                      7.5,  // $15.00 / 1M input tokens
 | 
			
		||||
	"o1-2024-12-17":           7.5,
 | 
			
		||||
	"o1-preview":              7.5, // $15.00 / 1M input tokens
 | 
			
		||||
	"o1-preview-2024-09-12":   7.5,
 | 
			
		||||
	"o1-mini":                 1.5, // $3.00 / 1M input tokens
 | 
			
		||||
	"o1-mini-2024-09-12":      1.5,
 | 
			
		||||
	"davinci-002":             1,   // $0.002 / 1K tokens
 | 
			
		||||
	"babbage-002":             0.2, // $0.0004 / 1K tokens
 | 
			
		||||
	"text-ada-001":            0.2,
 | 
			
		||||
	"text-babbage-001":        0.25,
 | 
			
		||||
	"text-curie-001":          1,
 | 
			
		||||
@@ -79,8 +85,10 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"claude-2.0":                 8.0 / 1000 * USD,
 | 
			
		||||
	"claude-2.1":                 8.0 / 1000 * USD,
 | 
			
		||||
	"claude-3-haiku-20240307":    0.25 / 1000 * USD,
 | 
			
		||||
	"claude-3-5-haiku-20241022":  1.0 / 1000 * USD,
 | 
			
		||||
	"claude-3-sonnet-20240229":   3.0 / 1000 * USD,
 | 
			
		||||
	"claude-3-5-sonnet-20240620": 3.0 / 1000 * USD,
 | 
			
		||||
	"claude-3-5-sonnet-20241022": 3.0 / 1000 * USD,
 | 
			
		||||
	"claude-3-opus-20240229":     15.0 / 1000 * USD,
 | 
			
		||||
	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
 | 
			
		||||
	"ERNIE-4.0-8K":       0.120 * RMB,
 | 
			
		||||
@@ -100,11 +108,15 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"bge-large-en":       0.002 * RMB,
 | 
			
		||||
	"tao-8k":             0.002 * RMB,
 | 
			
		||||
	// https://ai.google.dev/pricing
 | 
			
		||||
	"gemini-pro":       1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
 | 
			
		||||
	"gemini-1.0-pro":   1,
 | 
			
		||||
	"gemini-1.5-flash": 1,
 | 
			
		||||
	"gemini-1.5-pro":   1,
 | 
			
		||||
	"aqa":              1,
 | 
			
		||||
	"gemini-pro":                    1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
 | 
			
		||||
	"gemini-1.0-pro":                1,
 | 
			
		||||
	"gemini-1.5-pro":                1,
 | 
			
		||||
	"gemini-1.5-pro-001":            1,
 | 
			
		||||
	"gemini-1.5-flash":              1,
 | 
			
		||||
	"gemini-1.5-flash-001":          1,
 | 
			
		||||
	"gemini-2.0-flash-exp":          1,
 | 
			
		||||
	"gemini-2.0-flash-thinking-exp": 1,
 | 
			
		||||
	"aqa":                           1,
 | 
			
		||||
	// https://open.bigmodel.cn/pricing
 | 
			
		||||
	"glm-4":         0.1 * RMB,
 | 
			
		||||
	"glm-4v":        0.1 * RMB,
 | 
			
		||||
@@ -116,27 +128,94 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"chatglm_lite":  0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"cogview-3":     0.25 * RMB,
 | 
			
		||||
	// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
 | 
			
		||||
	"qwen-turbo":                0.5715, // ¥0.008 / 1k tokens
 | 
			
		||||
	"qwen-plus":                 1.4286, // ¥0.02 / 1k tokens
 | 
			
		||||
	"qwen-max":                  1.4286, // ¥0.02 / 1k tokens
 | 
			
		||||
	"qwen-max-longcontext":      1.4286, // ¥0.02 / 1k tokens
 | 
			
		||||
	"text-embedding-v1":         0.05,   // ¥0.0007 / 1k tokens
 | 
			
		||||
	"ali-stable-diffusion-xl":   8,
 | 
			
		||||
	"ali-stable-diffusion-v1.5": 8,
 | 
			
		||||
	"wanx-v1":                   8,
 | 
			
		||||
	"SparkDesk":                 1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v1.1":            1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v2.1":            1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v3.1":            1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v3.5":            1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v4.0":            1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"hunyuan":                   7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
 | 
			
		||||
	"ChatStd":                   0.01 * RMB,
 | 
			
		||||
	"ChatPro":                   0.1 * RMB,
 | 
			
		||||
	"qwen-turbo":                  1.4286, // ¥0.02 / 1k tokens
 | 
			
		||||
	"qwen-turbo-latest":           1.4286,
 | 
			
		||||
	"qwen-plus":                   1.4286,
 | 
			
		||||
	"qwen-plus-latest":            1.4286,
 | 
			
		||||
	"qwen-max":                    1.4286,
 | 
			
		||||
	"qwen-max-latest":             1.4286,
 | 
			
		||||
	"qwen-max-longcontext":        1.4286,
 | 
			
		||||
	"qwen-vl-max":                 1.4286,
 | 
			
		||||
	"qwen-vl-max-latest":          1.4286,
 | 
			
		||||
	"qwen-vl-plus":                1.4286,
 | 
			
		||||
	"qwen-vl-plus-latest":         1.4286,
 | 
			
		||||
	"qwen-vl-ocr":                 1.4286,
 | 
			
		||||
	"qwen-vl-ocr-latest":          1.4286,
 | 
			
		||||
	"qwen-audio-turbo":            1.4286,
 | 
			
		||||
	"qwen-math-plus":              1.4286,
 | 
			
		||||
	"qwen-math-plus-latest":       1.4286,
 | 
			
		||||
	"qwen-math-turbo":             1.4286,
 | 
			
		||||
	"qwen-math-turbo-latest":      1.4286,
 | 
			
		||||
	"qwen-coder-plus":             1.4286,
 | 
			
		||||
	"qwen-coder-plus-latest":      1.4286,
 | 
			
		||||
	"qwen-coder-turbo":            1.4286,
 | 
			
		||||
	"qwen-coder-turbo-latest":     1.4286,
 | 
			
		||||
	"qwq-32b-preview":             1.4286,
 | 
			
		||||
	"qwen2.5-72b-instruct":        1.4286,
 | 
			
		||||
	"qwen2.5-32b-instruct":        1.4286,
 | 
			
		||||
	"qwen2.5-14b-instruct":        1.4286,
 | 
			
		||||
	"qwen2.5-7b-instruct":         1.4286,
 | 
			
		||||
	"qwen2.5-3b-instruct":         1.4286,
 | 
			
		||||
	"qwen2.5-1.5b-instruct":       1.4286,
 | 
			
		||||
	"qwen2.5-0.5b-instruct":       1.4286,
 | 
			
		||||
	"qwen2-72b-instruct":          1.4286,
 | 
			
		||||
	"qwen2-57b-a14b-instruct":     1.4286,
 | 
			
		||||
	"qwen2-7b-instruct":           1.4286,
 | 
			
		||||
	"qwen2-1.5b-instruct":         1.4286,
 | 
			
		||||
	"qwen2-0.5b-instruct":         1.4286,
 | 
			
		||||
	"qwen1.5-110b-chat":           1.4286,
 | 
			
		||||
	"qwen1.5-72b-chat":            1.4286,
 | 
			
		||||
	"qwen1.5-32b-chat":            1.4286,
 | 
			
		||||
	"qwen1.5-14b-chat":            1.4286,
 | 
			
		||||
	"qwen1.5-7b-chat":             1.4286,
 | 
			
		||||
	"qwen1.5-1.8b-chat":           1.4286,
 | 
			
		||||
	"qwen1.5-0.5b-chat":           1.4286,
 | 
			
		||||
	"qwen-72b-chat":               1.4286,
 | 
			
		||||
	"qwen-14b-chat":               1.4286,
 | 
			
		||||
	"qwen-7b-chat":                1.4286,
 | 
			
		||||
	"qwen-1.8b-chat":              1.4286,
 | 
			
		||||
	"qwen-1.8b-longcontext-chat":  1.4286,
 | 
			
		||||
	"qwen2-vl-7b-instruct":        1.4286,
 | 
			
		||||
	"qwen2-vl-2b-instruct":        1.4286,
 | 
			
		||||
	"qwen-vl-v1":                  1.4286,
 | 
			
		||||
	"qwen-vl-chat-v1":             1.4286,
 | 
			
		||||
	"qwen2-audio-instruct":        1.4286,
 | 
			
		||||
	"qwen-audio-chat":             1.4286,
 | 
			
		||||
	"qwen2.5-math-72b-instruct":   1.4286,
 | 
			
		||||
	"qwen2.5-math-7b-instruct":    1.4286,
 | 
			
		||||
	"qwen2.5-math-1.5b-instruct":  1.4286,
 | 
			
		||||
	"qwen2-math-72b-instruct":     1.4286,
 | 
			
		||||
	"qwen2-math-7b-instruct":      1.4286,
 | 
			
		||||
	"qwen2-math-1.5b-instruct":    1.4286,
 | 
			
		||||
	"qwen2.5-coder-32b-instruct":  1.4286,
 | 
			
		||||
	"qwen2.5-coder-14b-instruct":  1.4286,
 | 
			
		||||
	"qwen2.5-coder-7b-instruct":   1.4286,
 | 
			
		||||
	"qwen2.5-coder-3b-instruct":   1.4286,
 | 
			
		||||
	"qwen2.5-coder-1.5b-instruct": 1.4286,
 | 
			
		||||
	"qwen2.5-coder-0.5b-instruct": 1.4286,
 | 
			
		||||
	"text-embedding-v1":           0.05, // ¥0.0007 / 1k tokens
 | 
			
		||||
	"text-embedding-v3":           0.05,
 | 
			
		||||
	"text-embedding-v2":           0.05,
 | 
			
		||||
	"text-embedding-async-v2":     0.05,
 | 
			
		||||
	"text-embedding-async-v1":     0.05,
 | 
			
		||||
	"ali-stable-diffusion-xl":     8.00,
 | 
			
		||||
	"ali-stable-diffusion-v1.5":   8.00,
 | 
			
		||||
	"wanx-v1":                     8.00,
 | 
			
		||||
	"SparkDesk":                   1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v1.1":              1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v2.1":              1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v3.1":              1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v3.1-128K":         1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v3.5":              1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v3.5-32K":          1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v4.0":              1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"360GPT_S2_V9":                0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"embedding-bert-512-v1":       0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"embedding_s1_v1":             0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"semantic_similarity_s1_v1":   0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"hunyuan":                     7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
 | 
			
		||||
	"ChatStd":                     0.01 * RMB,
 | 
			
		||||
	"ChatPro":                     0.1 * RMB,
 | 
			
		||||
	// https://platform.moonshot.cn/pricing
 | 
			
		||||
	"moonshot-v1-8k":   0.012 * RMB,
 | 
			
		||||
	"moonshot-v1-32k":  0.024 * RMB,
 | 
			
		||||
@@ -160,15 +239,21 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"mistral-embed":         0.1 / 1000 * USD,
 | 
			
		||||
	// https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed
 | 
			
		||||
	"gemma-7b-it":                           0.07 / 1000000 * USD,
 | 
			
		||||
	"mixtral-8x7b-32768":                    0.24 / 1000000 * USD,
 | 
			
		||||
	"llama3-8b-8192":                        0.05 / 1000000 * USD,
 | 
			
		||||
	"llama3-70b-8192":                       0.59 / 1000000 * USD,
 | 
			
		||||
	"gemma2-9b-it":                          0.20 / 1000000 * USD,
 | 
			
		||||
	"llama-3.1-405b-reasoning":              0.89 / 1000000 * USD,
 | 
			
		||||
	"llama-3.1-70b-versatile":               0.59 / 1000000 * USD,
 | 
			
		||||
	"llama-3.1-8b-instant":                  0.05 / 1000000 * USD,
 | 
			
		||||
	"llama-3.2-11b-text-preview":            0.05 / 1000000 * USD,
 | 
			
		||||
	"llama-3.2-11b-vision-preview":          0.05 / 1000000 * USD,
 | 
			
		||||
	"llama-3.2-1b-preview":                  0.05 / 1000000 * USD,
 | 
			
		||||
	"llama-3.2-3b-preview":                  0.05 / 1000000 * USD,
 | 
			
		||||
	"llama-3.2-90b-text-preview":            0.59 / 1000000 * USD,
 | 
			
		||||
	"llama-guard-3-8b":                      0.05 / 1000000 * USD,
 | 
			
		||||
	"llama3-70b-8192":                       0.59 / 1000000 * USD,
 | 
			
		||||
	"llama3-8b-8192":                        0.05 / 1000000 * USD,
 | 
			
		||||
	"llama3-groq-70b-8192-tool-use-preview": 0.89 / 1000000 * USD,
 | 
			
		||||
	"llama3-groq-8b-8192-tool-use-preview":  0.19 / 1000000 * USD,
 | 
			
		||||
	"mixtral-8x7b-32768":                    0.24 / 1000000 * USD,
 | 
			
		||||
 | 
			
		||||
	// https://platform.lingyiwanwu.com/docs#-计费单元
 | 
			
		||||
	"yi-34b-chat-0205": 2.5 / 1000 * RMB,
 | 
			
		||||
	"yi-34b-chat-200k": 12.0 / 1000 * RMB,
 | 
			
		||||
@@ -199,6 +284,52 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"deepl-zh": 25.0 / 1000 * USD,
 | 
			
		||||
	"deepl-en": 25.0 / 1000 * USD,
 | 
			
		||||
	"deepl-ja": 25.0 / 1000 * USD,
 | 
			
		||||
	// https://console.x.ai/
 | 
			
		||||
	"grok-beta": 5.0 / 1000 * USD,
 | 
			
		||||
	// replicate charges based on the number of generated images
 | 
			
		||||
	// https://replicate.com/pricing
 | 
			
		||||
	"black-forest-labs/flux-1.1-pro":                0.04 * USD,
 | 
			
		||||
	"black-forest-labs/flux-1.1-pro-ultra":          0.06 * USD,
 | 
			
		||||
	"black-forest-labs/flux-canny-dev":              0.025 * USD,
 | 
			
		||||
	"black-forest-labs/flux-canny-pro":              0.05 * USD,
 | 
			
		||||
	"black-forest-labs/flux-depth-dev":              0.025 * USD,
 | 
			
		||||
	"black-forest-labs/flux-depth-pro":              0.05 * USD,
 | 
			
		||||
	"black-forest-labs/flux-dev":                    0.025 * USD,
 | 
			
		||||
	"black-forest-labs/flux-dev-lora":               0.032 * USD,
 | 
			
		||||
	"black-forest-labs/flux-fill-dev":               0.04 * USD,
 | 
			
		||||
	"black-forest-labs/flux-fill-pro":               0.05 * USD,
 | 
			
		||||
	"black-forest-labs/flux-pro":                    0.055 * USD,
 | 
			
		||||
	"black-forest-labs/flux-redux-dev":              0.025 * USD,
 | 
			
		||||
	"black-forest-labs/flux-redux-schnell":          0.003 * USD,
 | 
			
		||||
	"black-forest-labs/flux-schnell":                0.003 * USD,
 | 
			
		||||
	"black-forest-labs/flux-schnell-lora":           0.02 * USD,
 | 
			
		||||
	"ideogram-ai/ideogram-v2":                       0.08 * USD,
 | 
			
		||||
	"ideogram-ai/ideogram-v2-turbo":                 0.05 * USD,
 | 
			
		||||
	"recraft-ai/recraft-v3":                         0.04 * USD,
 | 
			
		||||
	"recraft-ai/recraft-v3-svg":                     0.08 * USD,
 | 
			
		||||
	"stability-ai/stable-diffusion-3":               0.035 * USD,
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-large":       0.065 * USD,
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD,
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-medium":      0.035 * USD,
 | 
			
		||||
	// replicate chat models
 | 
			
		||||
	"ibm-granite/granite-20b-code-instruct-8k":  0.100 * USD,
 | 
			
		||||
	"ibm-granite/granite-3.0-2b-instruct":       0.030 * USD,
 | 
			
		||||
	"ibm-granite/granite-3.0-8b-instruct":       0.050 * USD,
 | 
			
		||||
	"ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD,
 | 
			
		||||
	"meta/llama-2-13b":                          0.100 * USD,
 | 
			
		||||
	"meta/llama-2-13b-chat":                     0.100 * USD,
 | 
			
		||||
	"meta/llama-2-70b":                          0.650 * USD,
 | 
			
		||||
	"meta/llama-2-70b-chat":                     0.650 * USD,
 | 
			
		||||
	"meta/llama-2-7b":                           0.050 * USD,
 | 
			
		||||
	"meta/llama-2-7b-chat":                      0.050 * USD,
 | 
			
		||||
	"meta/meta-llama-3.1-405b-instruct":         9.500 * USD,
 | 
			
		||||
	"meta/meta-llama-3-70b":                     0.650 * USD,
 | 
			
		||||
	"meta/meta-llama-3-70b-instruct":            0.650 * USD,
 | 
			
		||||
	"meta/meta-llama-3-8b":                      0.050 * USD,
 | 
			
		||||
	"meta/meta-llama-3-8b-instruct":             0.050 * USD,
 | 
			
		||||
	"mistralai/mistral-7b-instruct-v0.2":        0.050 * USD,
 | 
			
		||||
	"mistralai/mistral-7b-v0.1":                 0.050 * USD,
 | 
			
		||||
	"mistralai/mixtral-8x7b-instruct-v0.1":      0.300 * USD,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var CompletionRatio = map[string]float64{
 | 
			
		||||
@@ -332,6 +463,10 @@ func GetCompletionRatio(name string, channelType int) float64 {
 | 
			
		||||
		}
 | 
			
		||||
		return 2
 | 
			
		||||
	}
 | 
			
		||||
	// including o1, o1-preview, o1-mini
 | 
			
		||||
	if strings.HasPrefix(name, "o1") {
 | 
			
		||||
		return 4
 | 
			
		||||
	}
 | 
			
		||||
	if name == "chatgpt-4o-latest" {
 | 
			
		||||
		return 3
 | 
			
		||||
	}
 | 
			
		||||
@@ -350,6 +485,7 @@ func GetCompletionRatio(name string, channelType int) float64 {
 | 
			
		||||
	if strings.HasPrefix(name, "deepseek-") {
 | 
			
		||||
		return 2
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch name {
 | 
			
		||||
	case "llama2-70b-4096":
 | 
			
		||||
		return 0.8 / 0.64
 | 
			
		||||
@@ -363,6 +499,37 @@ func GetCompletionRatio(name string, channelType int) float64 {
 | 
			
		||||
		return 3
 | 
			
		||||
	case "command-r-plus":
 | 
			
		||||
		return 5
 | 
			
		||||
	case "grok-beta":
 | 
			
		||||
		return 3
 | 
			
		||||
	// Replicate Models
 | 
			
		||||
	// https://replicate.com/pricing
 | 
			
		||||
	case "ibm-granite/granite-20b-code-instruct-8k":
 | 
			
		||||
		return 5
 | 
			
		||||
	case "ibm-granite/granite-3.0-2b-instruct":
 | 
			
		||||
		return 8.333333333333334
 | 
			
		||||
	case "ibm-granite/granite-3.0-8b-instruct",
 | 
			
		||||
		"ibm-granite/granite-8b-code-instruct-128k":
 | 
			
		||||
		return 5
 | 
			
		||||
	case "meta/llama-2-13b",
 | 
			
		||||
		"meta/llama-2-13b-chat",
 | 
			
		||||
		"meta/llama-2-7b",
 | 
			
		||||
		"meta/llama-2-7b-chat",
 | 
			
		||||
		"meta/meta-llama-3-8b",
 | 
			
		||||
		"meta/meta-llama-3-8b-instruct":
 | 
			
		||||
		return 5
 | 
			
		||||
	case "meta/llama-2-70b",
 | 
			
		||||
		"meta/llama-2-70b-chat",
 | 
			
		||||
		"meta/meta-llama-3-70b",
 | 
			
		||||
		"meta/meta-llama-3-70b-instruct":
 | 
			
		||||
		return 2.750 / 0.650 // ≈4.230769
 | 
			
		||||
	case "meta/meta-llama-3.1-405b-instruct":
 | 
			
		||||
		return 1
 | 
			
		||||
	case "mistralai/mistral-7b-instruct-v0.2",
 | 
			
		||||
		"mistralai/mistral-7b-v0.1":
 | 
			
		||||
		return 5
 | 
			
		||||
	case "mistralai/mixtral-8x7b-instruct-v0.1":
 | 
			
		||||
		return 1.000 / 0.300 // ≈3.333333
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return 1
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -46,5 +46,7 @@ const (
 | 
			
		||||
	VertextAI
 | 
			
		||||
	Proxy
 | 
			
		||||
	SiliconFlow
 | 
			
		||||
	XAI
 | 
			
		||||
	Replicate
 | 
			
		||||
	Dummy
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -37,6 +37,8 @@ func ToAPIType(channelType int) int {
 | 
			
		||||
		apiType = apitype.DeepL
 | 
			
		||||
	case VertextAI:
 | 
			
		||||
		apiType = apitype.VertexAI
 | 
			
		||||
	case Replicate:
 | 
			
		||||
		apiType = apitype.Replicate
 | 
			
		||||
	case Proxy:
 | 
			
		||||
		apiType = apitype.Proxy
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -45,7 +45,9 @@ var ChannelBaseURLs = []string{
 | 
			
		||||
	"https://api.novita.ai/v3/openai",           // 41
 | 
			
		||||
	"",                                          // 42
 | 
			
		||||
	"",                                          // 43
 | 
			
		||||
	"https://api.siliconflow.cn",                 // 44
 | 
			
		||||
	"https://api.siliconflow.cn",                // 44
 | 
			
		||||
	"https://api.x.ai",                          // 45
 | 
			
		||||
	"https://api.replicate.com/v1/models/",      // 46
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,6 @@
 | 
			
		||||
package role
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	System    = "system"
 | 
			
		||||
	Assistant = "assistant"
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/constant/role"
 | 
			
		||||
	"math"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -90,7 +91,7 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
 | 
			
		||||
	return preConsumedQuota, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) {
 | 
			
		||||
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) {
 | 
			
		||||
	if usage == nil {
 | 
			
		||||
		logger.Error(ctx, "usage is nil, which is unexpected")
 | 
			
		||||
		return
 | 
			
		||||
@@ -118,7 +119,11 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(ctx, "error update user quota cache: "+err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
 | 
			
		||||
	var extraLog string
 | 
			
		||||
	if systemPromptReset {
 | 
			
		||||
		extraLog = " (注意系统提示词已被重置)"
 | 
			
		||||
	}
 | 
			
		||||
	logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f%s", modelRatio, groupRatio, completionRatio, extraLog)
 | 
			
		||||
	model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
 | 
			
		||||
	model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
 | 
			
		||||
	model.UpdateChannelUsedQuota(meta.ChannelId, quota)
 | 
			
		||||
@@ -142,15 +147,41 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
 | 
			
		||||
		}
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
	if resp.StatusCode != http.StatusOK &&
 | 
			
		||||
		// replicate return 201 to create a task
 | 
			
		||||
		resp.StatusCode != http.StatusCreated {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if meta.ChannelType == channeltype.DeepL {
 | 
			
		||||
		// skip stream check for deepl
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
 | 
			
		||||
 | 
			
		||||
	if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") &&
 | 
			
		||||
		// Even if stream mode is enabled, replicate will first return a task info in JSON format,
 | 
			
		||||
		// requiring the client to request the stream endpoint in the task info
 | 
			
		||||
		meta.ChannelType != channeltype.Replicate {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIRequest, prompt string) (reset bool) {
 | 
			
		||||
	if prompt == "" {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if len(request.Messages) == 0 {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if request.Messages[0].Role == role.System {
 | 
			
		||||
		request.Messages[0].Content = prompt
 | 
			
		||||
		logger.Infof(ctx, "rewrite system prompt")
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	request.Messages = append([]relaymodel.Message{{
 | 
			
		||||
		Role:    role.System,
 | 
			
		||||
		Content: prompt,
 | 
			
		||||
	}}, request.Messages...)
 | 
			
		||||
	logger.Infof(ctx, "add system prompt")
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -22,7 +22,7 @@ import (
 | 
			
		||||
	relaymodel "github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
 | 
			
		||||
func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) {
 | 
			
		||||
	imageRequest := &relaymodel.ImageRequest{}
 | 
			
		||||
	err := common.UnmarshalBodyReusable(c, imageRequest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -65,7 +65,7 @@ func getImageSizeRatio(model string, size string) float64 {
 | 
			
		||||
	return 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
 | 
			
		||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode {
 | 
			
		||||
	// check prompt length
 | 
			
		||||
	if imageRequest.Prompt == "" {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
 | 
			
		||||
@@ -150,12 +150,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
	}
 | 
			
		||||
	adaptor.Init(meta)
 | 
			
		||||
 | 
			
		||||
	// these adaptors need to convert the request
 | 
			
		||||
	switch meta.ChannelType {
 | 
			
		||||
	case channeltype.Ali:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case channeltype.Baidu:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case channeltype.Zhipu:
 | 
			
		||||
	case channeltype.Zhipu,
 | 
			
		||||
		channeltype.Ali,
 | 
			
		||||
		channeltype.Replicate,
 | 
			
		||||
		channeltype.Baidu:
 | 
			
		||||
		finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
 | 
			
		||||
@@ -172,7 +172,14 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
	ratio := modelRatio * groupRatio
 | 
			
		||||
	userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
 | 
			
		||||
 | 
			
		||||
	quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
 | 
			
		||||
	var quota int64
 | 
			
		||||
	switch meta.ChannelType {
 | 
			
		||||
	case channeltype.Replicate:
 | 
			
		||||
		// replicate always return 1 image
 | 
			
		||||
		quota = int64(ratio * imageCostRatio * 1000)
 | 
			
		||||
	default:
 | 
			
		||||
		quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if userQuota-quota < 0 {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 | 
			
		||||
@@ -186,7 +193,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer func(ctx context.Context) {
 | 
			
		||||
		if resp != nil && resp.StatusCode != http.StatusOK {
 | 
			
		||||
		if resp != nil &&
 | 
			
		||||
			resp.StatusCode != http.StatusCreated && // replicate returns 201
 | 
			
		||||
			resp.StatusCode != http.StatusOK {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
@@ -35,6 +36,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
 | 
			
		||||
	meta.OriginModelName = textRequest.Model
 | 
			
		||||
	textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
 | 
			
		||||
	meta.ActualModelName = textRequest.Model
 | 
			
		||||
	// set system prompt if not empty
 | 
			
		||||
	systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt)
 | 
			
		||||
	// get model ratio & group ratio
 | 
			
		||||
	modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
 | 
			
		||||
	groupRatio := billingratio.GetGroupRatio(meta.Group)
 | 
			
		||||
@@ -79,12 +82,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
 | 
			
		||||
		return respErr
 | 
			
		||||
	}
 | 
			
		||||
	// post-consume quota
 | 
			
		||||
	go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
 | 
			
		||||
	go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
 | 
			
		||||
	if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {
 | 
			
		||||
	if !config.EnforceIncludeUsage && meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {
 | 
			
		||||
		// no need to convert request for openai
 | 
			
		||||
		return c.Request.Body, nil
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -30,6 +30,7 @@ type Meta struct {
 | 
			
		||||
	ActualModelName string
 | 
			
		||||
	RequestURLPath  string
 | 
			
		||||
	PromptTokens    int // only for DoResponse
 | 
			
		||||
	SystemPrompt    string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetByContext(c *gin.Context) *Meta {
 | 
			
		||||
@@ -46,6 +47,7 @@ func GetByContext(c *gin.Context) *Meta {
 | 
			
		||||
		BaseURL:         c.GetString(ctxkey.BaseURL),
 | 
			
		||||
		APIKey:          strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
 | 
			
		||||
		RequestURLPath:  c.Request.URL.String(),
 | 
			
		||||
		SystemPrompt:    c.GetString(ctxkey.SystemPrompt),
 | 
			
		||||
	}
 | 
			
		||||
	cfg, ok := c.Get(ctxkey.Config)
 | 
			
		||||
	if ok {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	ContentTypeText     = "text"
 | 
			
		||||
	ContentTypeImageURL = "image_url"
 | 
			
		||||
	ContentTypeText       = "text"
 | 
			
		||||
	ContentTypeImageURL   = "image_url"
 | 
			
		||||
	ContentTypeInputAudio = "input_audio"
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -12,32 +12,59 @@ type JSONSchema struct {
 | 
			
		||||
	Strict      *bool                  `json:"strict,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Audio struct {
 | 
			
		||||
	Voice  string `json:"voice,omitempty"`
 | 
			
		||||
	Format string `json:"format,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type StreamOptions struct {
 | 
			
		||||
	IncludeUsage bool `json:"include_usage,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type GeneralOpenAIRequest struct {
 | 
			
		||||
	Messages         []Message       `json:"messages,omitempty"`
 | 
			
		||||
	Model            string          `json:"model,omitempty"`
 | 
			
		||||
	FrequencyPenalty float64         `json:"frequency_penalty,omitempty"`
 | 
			
		||||
	MaxTokens        int             `json:"max_tokens,omitempty"`
 | 
			
		||||
	N                int             `json:"n,omitempty"`
 | 
			
		||||
	PresencePenalty  float64         `json:"presence_penalty,omitempty"`
 | 
			
		||||
	ResponseFormat   *ResponseFormat `json:"response_format,omitempty"`
 | 
			
		||||
	Seed             float64         `json:"seed,omitempty"`
 | 
			
		||||
	Stop             any             `json:"stop,omitempty"`
 | 
			
		||||
	Stream           bool            `json:"stream,omitempty"`
 | 
			
		||||
	Temperature      float64         `json:"temperature,omitempty"`
 | 
			
		||||
	TopP             float64         `json:"top_p,omitempty"`
 | 
			
		||||
	TopK             int             `json:"top_k,omitempty"`
 | 
			
		||||
	Tools            []Tool          `json:"tools,omitempty"`
 | 
			
		||||
	ToolChoice       any             `json:"tool_choice,omitempty"`
 | 
			
		||||
	FunctionCall     any             `json:"function_call,omitempty"`
 | 
			
		||||
	Functions        any             `json:"functions,omitempty"`
 | 
			
		||||
	User             string          `json:"user,omitempty"`
 | 
			
		||||
	Prompt           any             `json:"prompt,omitempty"`
 | 
			
		||||
	Input            any             `json:"input,omitempty"`
 | 
			
		||||
	EncodingFormat   string          `json:"encoding_format,omitempty"`
 | 
			
		||||
	Dimensions       int             `json:"dimensions,omitempty"`
 | 
			
		||||
	Instruction      string          `json:"instruction,omitempty"`
 | 
			
		||||
	Size             string          `json:"size,omitempty"`
 | 
			
		||||
	NumCtx           int         	 `json:"num_ctx,omitempty"`
 | 
			
		||||
	// https://platform.openai.com/docs/api-reference/chat/create
 | 
			
		||||
	Messages            []Message       `json:"messages,omitempty"`
 | 
			
		||||
	Model               string          `json:"model,omitempty"`
 | 
			
		||||
	Store               *bool           `json:"store,omitempty"`
 | 
			
		||||
	Metadata            any             `json:"metadata,omitempty"`
 | 
			
		||||
	FrequencyPenalty    *float64        `json:"frequency_penalty,omitempty"`
 | 
			
		||||
	LogitBias           any             `json:"logit_bias,omitempty"`
 | 
			
		||||
	Logprobs            *bool           `json:"logprobs,omitempty"`
 | 
			
		||||
	TopLogprobs         *int            `json:"top_logprobs,omitempty"`
 | 
			
		||||
	MaxTokens           int             `json:"max_tokens,omitempty"`
 | 
			
		||||
	MaxCompletionTokens *int            `json:"max_completion_tokens,omitempty"`
 | 
			
		||||
	N                   int             `json:"n,omitempty"`
 | 
			
		||||
	Modalities          []string        `json:"modalities,omitempty"`
 | 
			
		||||
	Prediction          any             `json:"prediction,omitempty"`
 | 
			
		||||
	Audio               *Audio          `json:"audio,omitempty"`
 | 
			
		||||
	PresencePenalty     *float64        `json:"presence_penalty,omitempty"`
 | 
			
		||||
	ResponseFormat      *ResponseFormat `json:"response_format,omitempty"`
 | 
			
		||||
	Seed                float64         `json:"seed,omitempty"`
 | 
			
		||||
	ServiceTier         *string         `json:"service_tier,omitempty"`
 | 
			
		||||
	Stop                any             `json:"stop,omitempty"`
 | 
			
		||||
	Stream              bool            `json:"stream,omitempty"`
 | 
			
		||||
	StreamOptions       *StreamOptions  `json:"stream_options,omitempty"`
 | 
			
		||||
	Temperature         *float64        `json:"temperature,omitempty"`
 | 
			
		||||
	TopP                *float64        `json:"top_p,omitempty"`
 | 
			
		||||
	TopK                int             `json:"top_k,omitempty"`
 | 
			
		||||
	Tools               []Tool          `json:"tools,omitempty"`
 | 
			
		||||
	ToolChoice          any             `json:"tool_choice,omitempty"`
 | 
			
		||||
	ParallelTooCalls    *bool           `json:"parallel_tool_calls,omitempty"`
 | 
			
		||||
	User                string          `json:"user,omitempty"`
 | 
			
		||||
	FunctionCall        any             `json:"function_call,omitempty"`
 | 
			
		||||
	Functions           any             `json:"functions,omitempty"`
 | 
			
		||||
	// https://platform.openai.com/docs/api-reference/embeddings/create
 | 
			
		||||
	Input          any    `json:"input,omitempty"`
 | 
			
		||||
	EncodingFormat string `json:"encoding_format,omitempty"`
 | 
			
		||||
	Dimensions     int    `json:"dimensions,omitempty"`
 | 
			
		||||
	// https://platform.openai.com/docs/api-reference/images/create
 | 
			
		||||
	Prompt  any     `json:"prompt,omitempty"`
 | 
			
		||||
	Quality *string `json:"quality,omitempty"`
 | 
			
		||||
	Size    string  `json:"size,omitempty"`
 | 
			
		||||
	Style   *string `json:"style,omitempty"`
 | 
			
		||||
	// Others
 | 
			
		||||
	Instruction string `json:"instruction,omitempty"`
 | 
			
		||||
	NumCtx      int    `json:"num_ctx,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r GeneralOpenAIRequest) ParseInput() []string {
 | 
			
		||||
 
 | 
			
		||||
@@ -9,6 +9,7 @@ import (
 | 
			
		||||
 | 
			
		||||
func SetRelayRouter(router *gin.Engine) {
 | 
			
		||||
	router.Use(middleware.CORS())
 | 
			
		||||
	router.Use(middleware.GzipDecodeMiddleware())
 | 
			
		||||
	// https://platform.openai.com/docs/api-reference/introduction
 | 
			
		||||
	modelsRouter := router.Group("/v1/models")
 | 
			
		||||
	modelsRouter.Use(middleware.TokenAuth())
 | 
			
		||||
 
 | 
			
		||||
@@ -395,7 +395,7 @@ const TokensTable = () => {
 | 
			
		||||
        url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
 | 
			
		||||
        break;
 | 
			
		||||
      case 'lobechat':
 | 
			
		||||
        url = chatLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}"/v1"}}}`;
 | 
			
		||||
        url = chatLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}/v1"}}}`;
 | 
			
		||||
        break;
 | 
			
		||||
      default:
 | 
			
		||||
        if (!chatLink) {
 | 
			
		||||
 
 | 
			
		||||
@@ -30,6 +30,8 @@ export const CHANNEL_OPTIONS = [
 | 
			
		||||
  { key: 42, text: 'VertexAI', value: 42, color: 'blue' },
 | 
			
		||||
  { key: 43, text: 'Proxy', value: 43, color: 'blue' },
 | 
			
		||||
  { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
 | 
			
		||||
  { key: 45, text: 'xAI', value: 45, color: 'blue' },
 | 
			
		||||
  { key: 46, text: 'Replicate', value: 46, color: 'blue' },
 | 
			
		||||
  { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
 | 
			
		||||
  { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
 | 
			
		||||
  { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
 | 
			
		||||
 
 | 
			
		||||
@@ -43,6 +43,7 @@ const EditChannel = (props) => {
 | 
			
		||||
        base_url: '',
 | 
			
		||||
        other: '',
 | 
			
		||||
        model_mapping: '',
 | 
			
		||||
        system_prompt: '',
 | 
			
		||||
        models: [],
 | 
			
		||||
        auto_ban: 1,
 | 
			
		||||
        groups: ['default']
 | 
			
		||||
@@ -63,7 +64,7 @@ const EditChannel = (props) => {
 | 
			
		||||
            let localModels = [];
 | 
			
		||||
            switch (value) {
 | 
			
		||||
                case 14:
 | 
			
		||||
                    localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20240620"];
 | 
			
		||||
                    localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-haiku-20241022", "claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20241022"];
 | 
			
		||||
                    break;
 | 
			
		||||
                case 11:
 | 
			
		||||
                    localModels = ['PaLM-2'];
 | 
			
		||||
@@ -78,7 +79,7 @@ const EditChannel = (props) => {
 | 
			
		||||
                    localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
 | 
			
		||||
                    break;
 | 
			
		||||
                case 18:
 | 
			
		||||
                    localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0'];
 | 
			
		||||
                    localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.1-128K', 'SparkDesk-v3.5', 'SparkDesk-v3.5-32K', 'SparkDesk-v4.0'];
 | 
			
		||||
                    break;
 | 
			
		||||
                case 19:
 | 
			
		||||
                    localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];
 | 
			
		||||
@@ -304,163 +305,163 @@ const EditChannel = (props) => {
 | 
			
		||||
                width={isMobile() ? '100%' : 600}
 | 
			
		||||
            >
 | 
			
		||||
                <Spin spinning={loading}>
 | 
			
		||||
                    <div style={{marginTop: 10}}>
 | 
			
		||||
                    <div style={{ marginTop: 10 }}>
 | 
			
		||||
                        <Typography.Text strong>类型:</Typography.Text>
 | 
			
		||||
                    </div>
 | 
			
		||||
                    <Select
 | 
			
		||||
                        name='type'
 | 
			
		||||
                        required
 | 
			
		||||
                        optionList={CHANNEL_OPTIONS}
 | 
			
		||||
                        value={inputs.type}
 | 
			
		||||
                        onChange={value => handleInputChange('type', value)}
 | 
			
		||||
                        style={{width: '50%'}}
 | 
			
		||||
                      name='type'
 | 
			
		||||
                      required
 | 
			
		||||
                      optionList={CHANNEL_OPTIONS}
 | 
			
		||||
                      value={inputs.type}
 | 
			
		||||
                      onChange={value => handleInputChange('type', value)}
 | 
			
		||||
                      style={{ width: '50%' }}
 | 
			
		||||
                    />
 | 
			
		||||
                    {
 | 
			
		||||
                        inputs.type === 3 && (
 | 
			
		||||
                            <>
 | 
			
		||||
                                <div style={{marginTop: 10}}>
 | 
			
		||||
                                    <Banner type={"warning"} description={
 | 
			
		||||
                                        <>
 | 
			
		||||
                                            注意,<strong>模型部署名称必须和模型名称保持一致</strong>,因为 One API 会把请求体中的
 | 
			
		||||
                                            model
 | 
			
		||||
                                            参数替换为你的部署名称(模型名称中的点会被剔除),<a target='_blank'
 | 
			
		||||
                                                                                              href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示</a>。
 | 
			
		||||
                                        </>
 | 
			
		||||
                                    }>
 | 
			
		||||
                                    </Banner>
 | 
			
		||||
                                </div>
 | 
			
		||||
                                <div style={{marginTop: 10}}>
 | 
			
		||||
                                    <Typography.Text strong>AZURE_OPENAI_ENDPOINT:</Typography.Text>
 | 
			
		||||
                                </div>
 | 
			
		||||
                                <Input
 | 
			
		||||
                                    label='AZURE_OPENAI_ENDPOINT'
 | 
			
		||||
                                    name='azure_base_url'
 | 
			
		||||
                                    placeholder={'请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com'}
 | 
			
		||||
                                    onChange={value => {
 | 
			
		||||
                                        handleInputChange('base_url', value)
 | 
			
		||||
                                    }}
 | 
			
		||||
                                    value={inputs.base_url}
 | 
			
		||||
                                    autoComplete='new-password'
 | 
			
		||||
                                />
 | 
			
		||||
                                <div style={{marginTop: 10}}>
 | 
			
		||||
                                    <Typography.Text strong>默认 API 版本:</Typography.Text>
 | 
			
		||||
                                </div>
 | 
			
		||||
                                <Input
 | 
			
		||||
                                    label='默认 API 版本'
 | 
			
		||||
                                    name='azure_other'
 | 
			
		||||
                                    placeholder={'请输入默认 API 版本,例如:2024-03-01-preview,该配置可以被实际的请求查询参数所覆盖'}
 | 
			
		||||
                                    onChange={value => {
 | 
			
		||||
                                        handleInputChange('other', value)
 | 
			
		||||
                                    }}
 | 
			
		||||
                                    value={inputs.other}
 | 
			
		||||
                                    autoComplete='new-password'
 | 
			
		||||
                                />
 | 
			
		||||
                            </>
 | 
			
		||||
                        )
 | 
			
		||||
                      inputs.type === 3 && (
 | 
			
		||||
                        <>
 | 
			
		||||
                            <div style={{ marginTop: 10 }}>
 | 
			
		||||
                                <Banner type={"warning"} description={
 | 
			
		||||
                                    <>
 | 
			
		||||
                                        注意,<strong>模型部署名称必须和模型名称保持一致</strong>,因为 One API 会把请求体中的
 | 
			
		||||
                                        model
 | 
			
		||||
                                        参数替换为你的部署名称(模型名称中的点会被剔除),<a target='_blank'
 | 
			
		||||
                                                                                          href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示</a>。
 | 
			
		||||
                                    </>
 | 
			
		||||
                                }>
 | 
			
		||||
                                </Banner>
 | 
			
		||||
                            </div>
 | 
			
		||||
                            <div style={{ marginTop: 10 }}>
 | 
			
		||||
                                <Typography.Text strong>AZURE_OPENAI_ENDPOINT:</Typography.Text>
 | 
			
		||||
                            </div>
 | 
			
		||||
                            <Input
 | 
			
		||||
                              label='AZURE_OPENAI_ENDPOINT'
 | 
			
		||||
                              name='azure_base_url'
 | 
			
		||||
                              placeholder={'请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com'}
 | 
			
		||||
                              onChange={value => {
 | 
			
		||||
                                  handleInputChange('base_url', value)
 | 
			
		||||
                              }}
 | 
			
		||||
                              value={inputs.base_url}
 | 
			
		||||
                              autoComplete='new-password'
 | 
			
		||||
                            />
 | 
			
		||||
                            <div style={{ marginTop: 10 }}>
 | 
			
		||||
                                <Typography.Text strong>默认 API 版本:</Typography.Text>
 | 
			
		||||
                            </div>
 | 
			
		||||
                            <Input
 | 
			
		||||
                              label='默认 API 版本'
 | 
			
		||||
                              name='azure_other'
 | 
			
		||||
                              placeholder={'请输入默认 API 版本,例如:2024-03-01-preview,该配置可以被实际的请求查询参数所覆盖'}
 | 
			
		||||
                              onChange={value => {
 | 
			
		||||
                                  handleInputChange('other', value)
 | 
			
		||||
                              }}
 | 
			
		||||
                              value={inputs.other}
 | 
			
		||||
                              autoComplete='new-password'
 | 
			
		||||
                            />
 | 
			
		||||
                        </>
 | 
			
		||||
                      )
 | 
			
		||||
                    }
 | 
			
		||||
                    {
 | 
			
		||||
                        inputs.type === 8 && (
 | 
			
		||||
                            <>
 | 
			
		||||
                                <div style={{marginTop: 10}}>
 | 
			
		||||
                                    <Typography.Text strong>Base URL:</Typography.Text>
 | 
			
		||||
                                </div>
 | 
			
		||||
                                <Input
 | 
			
		||||
                                    name='base_url'
 | 
			
		||||
                                    placeholder={'请输入自定义渠道的 Base URL'}
 | 
			
		||||
                                    onChange={value => {
 | 
			
		||||
                                        handleInputChange('base_url', value)
 | 
			
		||||
                                    }}
 | 
			
		||||
                                    value={inputs.base_url}
 | 
			
		||||
                                    autoComplete='new-password'
 | 
			
		||||
                                />
 | 
			
		||||
                            </>
 | 
			
		||||
                        )
 | 
			
		||||
                      inputs.type === 8 && (
 | 
			
		||||
                        <>
 | 
			
		||||
                            <div style={{ marginTop: 10 }}>
 | 
			
		||||
                                <Typography.Text strong>Base URL:</Typography.Text>
 | 
			
		||||
                            </div>
 | 
			
		||||
                            <Input
 | 
			
		||||
                              name='base_url'
 | 
			
		||||
                              placeholder={'请输入自定义渠道的 Base URL'}
 | 
			
		||||
                              onChange={value => {
 | 
			
		||||
                                  handleInputChange('base_url', value)
 | 
			
		||||
                              }}
 | 
			
		||||
                              value={inputs.base_url}
 | 
			
		||||
                              autoComplete='new-password'
 | 
			
		||||
                            />
 | 
			
		||||
                        </>
 | 
			
		||||
                      )
 | 
			
		||||
                    }
 | 
			
		||||
                    <div style={{marginTop: 10}}>
 | 
			
		||||
                    <div style={{ marginTop: 10 }}>
 | 
			
		||||
                        <Typography.Text strong>名称:</Typography.Text>
 | 
			
		||||
                    </div>
 | 
			
		||||
                    <Input
 | 
			
		||||
                        required
 | 
			
		||||
                        name='name'
 | 
			
		||||
                        placeholder={'请为渠道命名'}
 | 
			
		||||
                        onChange={value => {
 | 
			
		||||
                            handleInputChange('name', value)
 | 
			
		||||
                        }}
 | 
			
		||||
                        value={inputs.name}
 | 
			
		||||
                        autoComplete='new-password'
 | 
			
		||||
                      required
 | 
			
		||||
                      name='name'
 | 
			
		||||
                      placeholder={'请为渠道命名'}
 | 
			
		||||
                      onChange={value => {
 | 
			
		||||
                          handleInputChange('name', value)
 | 
			
		||||
                      }}
 | 
			
		||||
                      value={inputs.name}
 | 
			
		||||
                      autoComplete='new-password'
 | 
			
		||||
                    />
 | 
			
		||||
                    <div style={{marginTop: 10}}>
 | 
			
		||||
                    <div style={{ marginTop: 10 }}>
 | 
			
		||||
                        <Typography.Text strong>分组:</Typography.Text>
 | 
			
		||||
                    </div>
 | 
			
		||||
                    <Select
 | 
			
		||||
                        placeholder={'请选择可以使用该渠道的分组'}
 | 
			
		||||
                        name='groups'
 | 
			
		||||
                        required
 | 
			
		||||
                        multiple
 | 
			
		||||
                        selection
 | 
			
		||||
                        allowAdditions
 | 
			
		||||
                        additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'}
 | 
			
		||||
                        onChange={value => {
 | 
			
		||||
                            handleInputChange('groups', value)
 | 
			
		||||
                        }}
 | 
			
		||||
                        value={inputs.groups}
 | 
			
		||||
                        autoComplete='new-password'
 | 
			
		||||
                        optionList={groupOptions}
 | 
			
		||||
                      placeholder={'请选择可以使用该渠道的分组'}
 | 
			
		||||
                      name='groups'
 | 
			
		||||
                      required
 | 
			
		||||
                      multiple
 | 
			
		||||
                      selection
 | 
			
		||||
                      allowAdditions
 | 
			
		||||
                      additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'}
 | 
			
		||||
                      onChange={value => {
 | 
			
		||||
                          handleInputChange('groups', value)
 | 
			
		||||
                      }}
 | 
			
		||||
                      value={inputs.groups}
 | 
			
		||||
                      autoComplete='new-password'
 | 
			
		||||
                      optionList={groupOptions}
 | 
			
		||||
                    />
 | 
			
		||||
                    {
 | 
			
		||||
                        inputs.type === 18 && (
 | 
			
		||||
                            <>
 | 
			
		||||
                                <div style={{marginTop: 10}}>
 | 
			
		||||
                                    <Typography.Text strong>模型版本:</Typography.Text>
 | 
			
		||||
                                </div>
 | 
			
		||||
                                <Input
 | 
			
		||||
                                    name='other'
 | 
			
		||||
                                    placeholder={'请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1'}
 | 
			
		||||
                                    onChange={value => {
 | 
			
		||||
                                        handleInputChange('other', value)
 | 
			
		||||
                                    }}
 | 
			
		||||
                                    value={inputs.other}
 | 
			
		||||
                                    autoComplete='new-password'
 | 
			
		||||
                                />
 | 
			
		||||
                            </>
 | 
			
		||||
                        )
 | 
			
		||||
                      inputs.type === 18 && (
 | 
			
		||||
                        <>
 | 
			
		||||
                            <div style={{ marginTop: 10 }}>
 | 
			
		||||
                                <Typography.Text strong>模型版本:</Typography.Text>
 | 
			
		||||
                            </div>
 | 
			
		||||
                            <Input
 | 
			
		||||
                              name='other'
 | 
			
		||||
                              placeholder={'请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1'}
 | 
			
		||||
                              onChange={value => {
 | 
			
		||||
                                  handleInputChange('other', value)
 | 
			
		||||
                              }}
 | 
			
		||||
                              value={inputs.other}
 | 
			
		||||
                              autoComplete='new-password'
 | 
			
		||||
                            />
 | 
			
		||||
                        </>
 | 
			
		||||
                      )
 | 
			
		||||
                    }
 | 
			
		||||
                    {
 | 
			
		||||
                        inputs.type === 21 && (
 | 
			
		||||
                            <>
 | 
			
		||||
                                <div style={{marginTop: 10}}>
 | 
			
		||||
                                    <Typography.Text strong>知识库 ID:</Typography.Text>
 | 
			
		||||
                                </div>
 | 
			
		||||
                                <Input
 | 
			
		||||
                                    label='知识库 ID'
 | 
			
		||||
                                    name='other'
 | 
			
		||||
                                    placeholder={'请输入知识库 ID,例如:123456'}
 | 
			
		||||
                                    onChange={value => {
 | 
			
		||||
                                        handleInputChange('other', value)
 | 
			
		||||
                                    }}
 | 
			
		||||
                                    value={inputs.other}
 | 
			
		||||
                                    autoComplete='new-password'
 | 
			
		||||
                                />
 | 
			
		||||
                            </>
 | 
			
		||||
                        )
 | 
			
		||||
                      inputs.type === 21 && (
 | 
			
		||||
                        <>
 | 
			
		||||
                            <div style={{ marginTop: 10 }}>
 | 
			
		||||
                                <Typography.Text strong>知识库 ID:</Typography.Text>
 | 
			
		||||
                            </div>
 | 
			
		||||
                            <Input
 | 
			
		||||
                              label='知识库 ID'
 | 
			
		||||
                              name='other'
 | 
			
		||||
                              placeholder={'请输入知识库 ID,例如:123456'}
 | 
			
		||||
                              onChange={value => {
 | 
			
		||||
                                  handleInputChange('other', value)
 | 
			
		||||
                              }}
 | 
			
		||||
                              value={inputs.other}
 | 
			
		||||
                              autoComplete='new-password'
 | 
			
		||||
                            />
 | 
			
		||||
                        </>
 | 
			
		||||
                      )
 | 
			
		||||
                    }
 | 
			
		||||
                    <div style={{marginTop: 10}}>
 | 
			
		||||
                    <div style={{ marginTop: 10 }}>
 | 
			
		||||
                        <Typography.Text strong>模型:</Typography.Text>
 | 
			
		||||
                    </div>
 | 
			
		||||
                    <Select
 | 
			
		||||
                        placeholder={'请选择该渠道所支持的模型'}
 | 
			
		||||
                        name='models'
 | 
			
		||||
                        required
 | 
			
		||||
                        multiple
 | 
			
		||||
                        selection
 | 
			
		||||
                        onChange={value => {
 | 
			
		||||
                            handleInputChange('models', value)
 | 
			
		||||
                        }}
 | 
			
		||||
                        value={inputs.models}
 | 
			
		||||
                        autoComplete='new-password'
 | 
			
		||||
                        optionList={modelOptions}
 | 
			
		||||
                      placeholder={'请选择该渠道所支持的模型'}
 | 
			
		||||
                      name='models'
 | 
			
		||||
                      required
 | 
			
		||||
                      multiple
 | 
			
		||||
                      selection
 | 
			
		||||
                      onChange={value => {
 | 
			
		||||
                          handleInputChange('models', value)
 | 
			
		||||
                      }}
 | 
			
		||||
                      value={inputs.models}
 | 
			
		||||
                      autoComplete='new-password'
 | 
			
		||||
                      optionList={modelOptions}
 | 
			
		||||
                    />
 | 
			
		||||
                    <div style={{lineHeight: '40px', marginBottom: '12px'}}>
 | 
			
		||||
                    <div style={{ lineHeight: '40px', marginBottom: '12px' }}>
 | 
			
		||||
                        <Space>
 | 
			
		||||
                            <Button type='primary' onClick={() => {
 | 
			
		||||
                                handleInputChange('models', basicModels);
 | 
			
		||||
@@ -473,28 +474,41 @@ const EditChannel = (props) => {
 | 
			
		||||
                            }}>清除所有模型</Button>
 | 
			
		||||
                        </Space>
 | 
			
		||||
                        <Input
 | 
			
		||||
                            addonAfter={
 | 
			
		||||
                                <Button type='primary' onClick={addCustomModel}>填入</Button>
 | 
			
		||||
                            }
 | 
			
		||||
                            placeholder='输入自定义模型名称'
 | 
			
		||||
                            value={customModel}
 | 
			
		||||
                            onChange={(value) => {
 | 
			
		||||
                                setCustomModel(value.trim());
 | 
			
		||||
                            }}
 | 
			
		||||
                          addonAfter={
 | 
			
		||||
                              <Button type='primary' onClick={addCustomModel}>填入</Button>
 | 
			
		||||
                          }
 | 
			
		||||
                          placeholder='输入自定义模型名称'
 | 
			
		||||
                          value={customModel}
 | 
			
		||||
                          onChange={(value) => {
 | 
			
		||||
                              setCustomModel(value.trim());
 | 
			
		||||
                          }}
 | 
			
		||||
                        />
 | 
			
		||||
                    </div>
 | 
			
		||||
                    <div style={{marginTop: 10}}>
 | 
			
		||||
                    <div style={{ marginTop: 10 }}>
 | 
			
		||||
                        <Typography.Text strong>模型重定向:</Typography.Text>
 | 
			
		||||
                    </div>
 | 
			
		||||
                    <TextArea
 | 
			
		||||
                        placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
 | 
			
		||||
                        name='model_mapping'
 | 
			
		||||
                        onChange={value => {
 | 
			
		||||
                            handleInputChange('model_mapping', value)
 | 
			
		||||
                        }}
 | 
			
		||||
                        autosize
 | 
			
		||||
                        value={inputs.model_mapping}
 | 
			
		||||
                        autoComplete='new-password'
 | 
			
		||||
                      placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
 | 
			
		||||
                      name='model_mapping'
 | 
			
		||||
                      onChange={value => {
 | 
			
		||||
                          handleInputChange('model_mapping', value)
 | 
			
		||||
                      }}
 | 
			
		||||
                      autosize
 | 
			
		||||
                      value={inputs.model_mapping}
 | 
			
		||||
                      autoComplete='new-password'
 | 
			
		||||
                    />
 | 
			
		||||
                    <div style={{ marginTop: 10 }}>
 | 
			
		||||
                        <Typography.Text strong>系统提示词:</Typography.Text>
 | 
			
		||||
                    </div>
 | 
			
		||||
                    <TextArea
 | 
			
		||||
                      placeholder={`此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型`}
 | 
			
		||||
                      name='system_prompt'
 | 
			
		||||
                      onChange={value => {
 | 
			
		||||
                          handleInputChange('system_prompt', value)
 | 
			
		||||
                      }}
 | 
			
		||||
                      autosize
 | 
			
		||||
                      value={inputs.system_prompt}
 | 
			
		||||
                      autoComplete='new-password'
 | 
			
		||||
                    />
 | 
			
		||||
                    <Typography.Text style={{
 | 
			
		||||
                        color: 'rgba(var(--semi-blue-5), 1)',
 | 
			
		||||
@@ -507,116 +521,116 @@ const EditChannel = (props) => {
 | 
			
		||||
                    }>
 | 
			
		||||
                        填入模板
 | 
			
		||||
                    </Typography.Text>
 | 
			
		||||
                    <div style={{marginTop: 10}}>
 | 
			
		||||
                    <div style={{ marginTop: 10 }}>
 | 
			
		||||
                        <Typography.Text strong>密钥:</Typography.Text>
 | 
			
		||||
                    </div>
 | 
			
		||||
                    {
 | 
			
		||||
                        batch ?
 | 
			
		||||
                            <TextArea
 | 
			
		||||
                                label='密钥'
 | 
			
		||||
                                name='key'
 | 
			
		||||
                                required
 | 
			
		||||
                                placeholder={'请输入密钥,一行一个'}
 | 
			
		||||
                                onChange={value => {
 | 
			
		||||
                                    handleInputChange('key', value)
 | 
			
		||||
                                }}
 | 
			
		||||
                                value={inputs.key}
 | 
			
		||||
                                style={{minHeight: 150, fontFamily: 'JetBrains Mono, Consolas'}}
 | 
			
		||||
                                autoComplete='new-password'
 | 
			
		||||
                            />
 | 
			
		||||
                            :
 | 
			
		||||
                            <Input
 | 
			
		||||
                                label='密钥'
 | 
			
		||||
                                name='key'
 | 
			
		||||
                                required
 | 
			
		||||
                                placeholder={type2secretPrompt(inputs.type)}
 | 
			
		||||
                                onChange={value => {
 | 
			
		||||
                                    handleInputChange('key', value)
 | 
			
		||||
                                }}
 | 
			
		||||
                                value={inputs.key}
 | 
			
		||||
                                autoComplete='new-password'
 | 
			
		||||
                            />
 | 
			
		||||
                          <TextArea
 | 
			
		||||
                            label='密钥'
 | 
			
		||||
                            name='key'
 | 
			
		||||
                            required
 | 
			
		||||
                            placeholder={'请输入密钥,一行一个'}
 | 
			
		||||
                            onChange={value => {
 | 
			
		||||
                                handleInputChange('key', value)
 | 
			
		||||
                            }}
 | 
			
		||||
                            value={inputs.key}
 | 
			
		||||
                            style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
 | 
			
		||||
                            autoComplete='new-password'
 | 
			
		||||
                          />
 | 
			
		||||
                          :
 | 
			
		||||
                          <Input
 | 
			
		||||
                            label='密钥'
 | 
			
		||||
                            name='key'
 | 
			
		||||
                            required
 | 
			
		||||
                            placeholder={type2secretPrompt(inputs.type)}
 | 
			
		||||
                            onChange={value => {
 | 
			
		||||
                                handleInputChange('key', value)
 | 
			
		||||
                            }}
 | 
			
		||||
                            value={inputs.key}
 | 
			
		||||
                            autoComplete='new-password'
 | 
			
		||||
                          />
 | 
			
		||||
                    }
 | 
			
		||||
                    <div style={{marginTop: 10}}>
 | 
			
		||||
                    <div style={{ marginTop: 10 }}>
 | 
			
		||||
                        <Typography.Text strong>组织:</Typography.Text>
 | 
			
		||||
                    </div>
 | 
			
		||||
                    <Input
 | 
			
		||||
                        label='组织,可选,不填则为默认组织'
 | 
			
		||||
                        name='openai_organization'
 | 
			
		||||
                        placeholder='请输入组织org-xxx'
 | 
			
		||||
                        onChange={value => {
 | 
			
		||||
                            handleInputChange('openai_organization', value)
 | 
			
		||||
                        }}
 | 
			
		||||
                        value={inputs.openai_organization}
 | 
			
		||||
                      label='组织,可选,不填则为默认组织'
 | 
			
		||||
                      name='openai_organization'
 | 
			
		||||
                      placeholder='请输入组织org-xxx'
 | 
			
		||||
                      onChange={value => {
 | 
			
		||||
                          handleInputChange('openai_organization', value)
 | 
			
		||||
                      }}
 | 
			
		||||
                      value={inputs.openai_organization}
 | 
			
		||||
                    />
 | 
			
		||||
                    <div style={{marginTop: 10, display: 'flex'}}>
 | 
			
		||||
                    <div style={{ marginTop: 10, display: 'flex' }}>
 | 
			
		||||
                        <Space>
 | 
			
		||||
                            <Checkbox
 | 
			
		||||
                                name='auto_ban'
 | 
			
		||||
                                checked={autoBan}
 | 
			
		||||
                                onChange={
 | 
			
		||||
                                    () => {
 | 
			
		||||
                                        setAutoBan(!autoBan);
 | 
			
		||||
                                    }
 | 
			
		||||
                                }
 | 
			
		||||
                                // onChange={handleInputChange}
 | 
			
		||||
                              name='auto_ban'
 | 
			
		||||
                              checked={autoBan}
 | 
			
		||||
                              onChange={
 | 
			
		||||
                                  () => {
 | 
			
		||||
                                      setAutoBan(!autoBan);
 | 
			
		||||
                                  }
 | 
			
		||||
                              }
 | 
			
		||||
                              // onChange={handleInputChange}
 | 
			
		||||
                            />
 | 
			
		||||
                            <Typography.Text
 | 
			
		||||
                                strong>是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道:</Typography.Text>
 | 
			
		||||
                              strong>是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道:</Typography.Text>
 | 
			
		||||
                        </Space>
 | 
			
		||||
                    </div>
 | 
			
		||||
 | 
			
		||||
                    {
 | 
			
		||||
                        !isEdit && (
 | 
			
		||||
                            <div style={{marginTop: 10, display: 'flex'}}>
 | 
			
		||||
                                <Space>
 | 
			
		||||
                                    <Checkbox
 | 
			
		||||
                                        checked={batch}
 | 
			
		||||
                                        label='批量创建'
 | 
			
		||||
                                        name='batch'
 | 
			
		||||
                                        onChange={() => setBatch(!batch)}
 | 
			
		||||
                                    />
 | 
			
		||||
                                    <Typography.Text strong>批量创建</Typography.Text>
 | 
			
		||||
                                </Space>
 | 
			
		||||
                      !isEdit && (
 | 
			
		||||
                        <div style={{ marginTop: 10, display: 'flex' }}>
 | 
			
		||||
                            <Space>
 | 
			
		||||
                                <Checkbox
 | 
			
		||||
                                  checked={batch}
 | 
			
		||||
                                  label='批量创建'
 | 
			
		||||
                                  name='batch'
 | 
			
		||||
                                  onChange={() => setBatch(!batch)}
 | 
			
		||||
                                />
 | 
			
		||||
                                <Typography.Text strong>批量创建</Typography.Text>
 | 
			
		||||
                            </Space>
 | 
			
		||||
                        </div>
 | 
			
		||||
                      )
 | 
			
		||||
                    }
 | 
			
		||||
                    {
 | 
			
		||||
                      inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
 | 
			
		||||
                        <>
 | 
			
		||||
                            <div style={{ marginTop: 10 }}>
 | 
			
		||||
                                <Typography.Text strong>代理:</Typography.Text>
 | 
			
		||||
                            </div>
 | 
			
		||||
                        )
 | 
			
		||||
                            <Input
 | 
			
		||||
                              label='代理'
 | 
			
		||||
                              name='base_url'
 | 
			
		||||
                              placeholder={'此项可选,用于通过代理站来进行 API 调用'}
 | 
			
		||||
                              onChange={value => {
 | 
			
		||||
                                  handleInputChange('base_url', value)
 | 
			
		||||
                              }}
 | 
			
		||||
                              value={inputs.base_url}
 | 
			
		||||
                              autoComplete='new-password'
 | 
			
		||||
                            />
 | 
			
		||||
                        </>
 | 
			
		||||
                      )
 | 
			
		||||
                    }
 | 
			
		||||
                    {
 | 
			
		||||
                        inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
 | 
			
		||||
                            <>
 | 
			
		||||
                                <div style={{marginTop: 10}}>
 | 
			
		||||
                                    <Typography.Text strong>代理:</Typography.Text>
 | 
			
		||||
                                </div>
 | 
			
		||||
                                <Input
 | 
			
		||||
                                    label='代理'
 | 
			
		||||
                                    name='base_url'
 | 
			
		||||
                                    placeholder={'此项可选,用于通过代理站来进行 API 调用'}
 | 
			
		||||
                                    onChange={value => {
 | 
			
		||||
                                        handleInputChange('base_url', value)
 | 
			
		||||
                                    }}
 | 
			
		||||
                                    value={inputs.base_url}
 | 
			
		||||
                                    autoComplete='new-password'
 | 
			
		||||
                                />
 | 
			
		||||
                            </>
 | 
			
		||||
                        )
 | 
			
		||||
                    }
 | 
			
		||||
                    {
 | 
			
		||||
                        inputs.type === 22 && (
 | 
			
		||||
                            <>
 | 
			
		||||
                                <div style={{marginTop: 10}}>
 | 
			
		||||
                                    <Typography.Text strong>私有部署地址:</Typography.Text>
 | 
			
		||||
                                </div>
 | 
			
		||||
                                <Input
 | 
			
		||||
                                    name='base_url'
 | 
			
		||||
                                    placeholder={'请输入私有部署地址,格式为:https://fastgpt.run/api/openapi'}
 | 
			
		||||
                                    onChange={value => {
 | 
			
		||||
                                        handleInputChange('base_url', value)
 | 
			
		||||
                                    }}
 | 
			
		||||
                                    value={inputs.base_url}
 | 
			
		||||
                                    autoComplete='new-password'
 | 
			
		||||
                                />
 | 
			
		||||
                            </>
 | 
			
		||||
                        )
 | 
			
		||||
                      inputs.type === 22 && (
 | 
			
		||||
                        <>
 | 
			
		||||
                            <div style={{ marginTop: 10 }}>
 | 
			
		||||
                                <Typography.Text strong>私有部署地址:</Typography.Text>
 | 
			
		||||
                            </div>
 | 
			
		||||
                            <Input
 | 
			
		||||
                              name='base_url'
 | 
			
		||||
                              placeholder={'请输入私有部署地址,格式为:https://fastgpt.run/api/openapi'}
 | 
			
		||||
                              onChange={value => {
 | 
			
		||||
                                  handleInputChange('base_url', value)
 | 
			
		||||
                              }}
 | 
			
		||||
                              value={inputs.base_url}
 | 
			
		||||
                              autoComplete='new-password'
 | 
			
		||||
                            />
 | 
			
		||||
                        </>
 | 
			
		||||
                      )
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                </Spin>
 | 
			
		||||
 
 | 
			
		||||
@@ -179,6 +179,18 @@ export const CHANNEL_OPTIONS = {
 | 
			
		||||
    value: 44,
 | 
			
		||||
    color: 'primary'
 | 
			
		||||
  },
 | 
			
		||||
  45: {
 | 
			
		||||
    key: 45,
 | 
			
		||||
    text: 'xAI',
 | 
			
		||||
    value: 45,
 | 
			
		||||
    color: 'primary'
 | 
			
		||||
  },
 | 
			
		||||
  45: {
 | 
			
		||||
    key: 46,
 | 
			
		||||
    text: 'Replicate',
 | 
			
		||||
    value: 46,
 | 
			
		||||
    color: 'primary'
 | 
			
		||||
  },
 | 
			
		||||
  41: {
 | 
			
		||||
    key: 41,
 | 
			
		||||
    text: 'Novita',
 | 
			
		||||
 
 | 
			
		||||
@@ -95,7 +95,7 @@ export async function onLarkOAuthClicked(lark_client_id) {
 | 
			
		||||
  const state = await getOAuthState();
 | 
			
		||||
  if (!state) return;
 | 
			
		||||
  let redirect_uri = `${window.location.origin}/oauth/lark`;
 | 
			
		||||
  window.open(`https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}`);
 | 
			
		||||
  window.open(`https://accounts.feishu.cn/open-apis/authen/v1/authorize?redirect_uri=${redirect_uri}&client_id=${lark_client_id}&state=${state}`);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function onOidcClicked(auth_url, client_id, openInNewTab = false) {
 | 
			
		||||
 
 | 
			
		||||
@@ -595,6 +595,28 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
 | 
			
		||||
                  <FormHelperText id="helper-tex-channel-model_mapping-label"> {inputPrompt.model_mapping} </FormHelperText>
 | 
			
		||||
                )}
 | 
			
		||||
              </FormControl>
 | 
			
		||||
              <FormControl fullWidth error={Boolean(touched.system_prompt && errors.system_prompt)} sx={{ ...theme.typography.otherInput }}>
 | 
			
		||||
                {/* <InputLabel htmlFor="channel-model_mapping-label">{inputLabel.model_mapping}</InputLabel> */}
 | 
			
		||||
                <TextField
 | 
			
		||||
                  multiline
 | 
			
		||||
                  id="channel-system_prompt-label"
 | 
			
		||||
                  label={inputLabel.system_prompt}
 | 
			
		||||
                  value={values.system_prompt}
 | 
			
		||||
                  name="system_prompt"
 | 
			
		||||
                  onBlur={handleBlur}
 | 
			
		||||
                  onChange={handleChange}
 | 
			
		||||
                  aria-describedby="helper-text-channel-system_prompt-label"
 | 
			
		||||
                  minRows={5}
 | 
			
		||||
                  placeholder={inputPrompt.system_prompt}
 | 
			
		||||
                />
 | 
			
		||||
                {touched.system_prompt && errors.system_prompt ? (
 | 
			
		||||
                  <FormHelperText error id="helper-tex-channel-system_prompt-label">
 | 
			
		||||
                    {errors.system_prompt}
 | 
			
		||||
                  </FormHelperText>
 | 
			
		||||
                ) : (
 | 
			
		||||
                  <FormHelperText id="helper-tex-channel-system_prompt-label"> {inputPrompt.system_prompt} </FormHelperText>
 | 
			
		||||
                )}
 | 
			
		||||
              </FormControl>
 | 
			
		||||
              <DialogActions>
 | 
			
		||||
                <Button onClick={onCancel}>取消</Button>
 | 
			
		||||
                <Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">
 | 
			
		||||
 
 | 
			
		||||
@@ -268,6 +268,8 @@ function renderBalance(type, balance) {
 | 
			
		||||
      return <span>¥{balance.toFixed(2)}</span>;
 | 
			
		||||
    case 13: // AIGC2D
 | 
			
		||||
      return <span>{renderNumber(balance)}</span>;
 | 
			
		||||
    case 44: // SiliconFlow
 | 
			
		||||
      return <span>¥{balance.toFixed(2)}</span>;
 | 
			
		||||
    default:
 | 
			
		||||
      return <span>不支持</span>;
 | 
			
		||||
  }
 | 
			
		||||
 
 | 
			
		||||
@@ -18,6 +18,7 @@ const defaultConfig = {
 | 
			
		||||
    other: '其他参数',
 | 
			
		||||
    models: '模型',
 | 
			
		||||
    model_mapping: '模型映射关系',
 | 
			
		||||
    system_prompt: '系统提示词',
 | 
			
		||||
    groups: '用户组',
 | 
			
		||||
    config: null
 | 
			
		||||
  },
 | 
			
		||||
@@ -30,6 +31,7 @@ const defaultConfig = {
 | 
			
		||||
    models: '请选择该渠道所支持的模型',
 | 
			
		||||
    model_mapping:
 | 
			
		||||
      '请输入要修改的模型映射关系,格式为:api请求模型ID:实际转发给渠道的模型ID,使用JSON数组表示,例如:{"gpt-3.5": "gpt-35"}',
 | 
			
		||||
    system_prompt:"此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型",
 | 
			
		||||
    groups: '请选择该渠道所支持的用户组',
 | 
			
		||||
    config: null
 | 
			
		||||
  },
 | 
			
		||||
@@ -91,7 +93,7 @@ const typeConfig = {
 | 
			
		||||
      other: '版本号'
 | 
			
		||||
    },
 | 
			
		||||
    input: {
 | 
			
		||||
      models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0']
 | 
			
		||||
      models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.1-128K', 'SparkDesk-v3.5', 'SparkDesk-v3.5-32K', 'SparkDesk-v4.0']
 | 
			
		||||
    },
 | 
			
		||||
    prompt: {
 | 
			
		||||
      key: '按照如下格式输入:APPID|APISecret|APIKey',
 | 
			
		||||
@@ -223,6 +225,9 @@ const typeConfig = {
 | 
			
		||||
    },
 | 
			
		||||
    modelGroup: 'anthropic'
 | 
			
		||||
  },
 | 
			
		||||
  45: {
 | 
			
		||||
    modelGroup: 'xai'
 | 
			
		||||
  },
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
export { defaultConfig, typeConfig };
 | 
			
		||||
 
 | 
			
		||||
@@ -33,7 +33,7 @@ const COPY_OPTIONS = [
 | 
			
		||||
  },
 | 
			
		||||
  { key: 'ama', text: 'BotGem', url: 'ama://set-api-key?server={serverAddress}&key=sk-{key}', encode: true },
 | 
			
		||||
  { key: 'opencat', text: 'OpenCat', url: 'opencat://team/join?domain={serverAddress}&token=sk-{key}', encode: true },
 | 
			
		||||
  { key: 'lobechat', text: 'LobeChat', url: 'https://lobehub.com/?settings={"keyVaults":{"openai":{"apiKey":"user-key","baseURL":"https://your-proxy.com/v1"}}}', encode: true }
 | 
			
		||||
  { key: 'lobechat', text: 'LobeChat', url: 'https://lobehub.com/?settings={"keyVaults":{"openai":{"apiKey":"sk-{key}","baseURL":"{serverAddress}"}}}', encode: true }
 | 
			
		||||
];
 | 
			
		||||
 | 
			
		||||
function replacePlaceholders(text, key, serverAddress) {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										0
									
								
								web/build.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							
							
						
						
									
										0
									
								
								web/build.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							@@ -52,11 +52,19 @@ function renderBalance(type, balance) {
 | 
			
		||||
      return <span>¥{balance.toFixed(2)}</span>;
 | 
			
		||||
    case 13: // AIGC2D
 | 
			
		||||
      return <span>{renderNumber(balance)}</span>;
 | 
			
		||||
    case 44: // SiliconFlow
 | 
			
		||||
      return <span>¥{balance.toFixed(2)}</span>;
 | 
			
		||||
    default:
 | 
			
		||||
      return <span>不支持</span>;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function isShowDetail() {
 | 
			
		||||
  return localStorage.getItem("show_detail") === "true";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const promptID = "detail"
 | 
			
		||||
 | 
			
		||||
const ChannelsTable = () => {
 | 
			
		||||
  const [channels, setChannels] = useState([]);
 | 
			
		||||
  const [loading, setLoading] = useState(true);
 | 
			
		||||
@@ -64,7 +72,8 @@ const ChannelsTable = () => {
 | 
			
		||||
  const [searchKeyword, setSearchKeyword] = useState('');
 | 
			
		||||
  const [searching, setSearching] = useState(false);
 | 
			
		||||
  const [updatingBalance, setUpdatingBalance] = useState(false);
 | 
			
		||||
  const [showPrompt, setShowPrompt] = useState(shouldShowPrompt("channel-test"));
 | 
			
		||||
  const [showPrompt, setShowPrompt] = useState(shouldShowPrompt(promptID));
 | 
			
		||||
  const [showDetail, setShowDetail] = useState(isShowDetail());
 | 
			
		||||
 | 
			
		||||
  const loadChannels = async (startIdx) => {
 | 
			
		||||
    const res = await API.get(`/api/channel/?p=${startIdx}`);
 | 
			
		||||
@@ -118,6 +127,11 @@ const ChannelsTable = () => {
 | 
			
		||||
    await loadChannels(activePage - 1);
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const toggleShowDetail = () => {
 | 
			
		||||
    setShowDetail(!showDetail);
 | 
			
		||||
    localStorage.setItem("show_detail", (!showDetail).toString());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  useEffect(() => {
 | 
			
		||||
    loadChannels(0)
 | 
			
		||||
      .then()
 | 
			
		||||
@@ -362,11 +376,13 @@ const ChannelsTable = () => {
 | 
			
		||||
        showPrompt && (
 | 
			
		||||
          <Message onDismiss={() => {
 | 
			
		||||
            setShowPrompt(false);
 | 
			
		||||
            setPromptShown("channel-test");
 | 
			
		||||
            setPromptShown(promptID);
 | 
			
		||||
          }}>
 | 
			
		||||
            OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。
 | 
			
		||||
            <br/>
 | 
			
		||||
            渠道测试仅支持 chat 模型,优先使用 gpt-3.5-turbo,如果该模型不可用则使用你所配置的模型列表中的第一个模型。
 | 
			
		||||
            <br/>
 | 
			
		||||
            点击下方详情按钮可以显示余额以及设置额外的测试模型。
 | 
			
		||||
          </Message>
 | 
			
		||||
        )
 | 
			
		||||
      }
 | 
			
		||||
@@ -426,6 +442,7 @@ const ChannelsTable = () => {
 | 
			
		||||
              onClick={() => {
 | 
			
		||||
                sortChannel('balance');
 | 
			
		||||
              }}
 | 
			
		||||
              hidden={!showDetail}
 | 
			
		||||
            >
 | 
			
		||||
              余额
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
@@ -437,7 +454,7 @@ const ChannelsTable = () => {
 | 
			
		||||
            >
 | 
			
		||||
              优先级
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell>测试模型</Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell hidden={!showDetail}>测试模型</Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell>操作</Table.HeaderCell>
 | 
			
		||||
          </Table.Row>
 | 
			
		||||
        </Table.Header>
 | 
			
		||||
@@ -465,7 +482,7 @@ const ChannelsTable = () => {
 | 
			
		||||
                      basic
 | 
			
		||||
                    />
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
                  <Table.Cell hidden={!showDetail}>
 | 
			
		||||
                    <Popup
 | 
			
		||||
                      trigger={<span onClick={() => {
 | 
			
		||||
                        updateChannelBalance(channel.id, channel.name, idx);
 | 
			
		||||
@@ -492,7 +509,7 @@ const ChannelsTable = () => {
 | 
			
		||||
                      basic
 | 
			
		||||
                    />
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
                  <Table.Cell hidden={!showDetail}>
 | 
			
		||||
                    <Dropdown
 | 
			
		||||
                      placeholder='请选择测试模型'
 | 
			
		||||
                      selection
 | 
			
		||||
@@ -571,7 +588,7 @@ const ChannelsTable = () => {
 | 
			
		||||
 | 
			
		||||
        <Table.Footer>
 | 
			
		||||
          <Table.Row>
 | 
			
		||||
            <Table.HeaderCell colSpan='9'>
 | 
			
		||||
            <Table.HeaderCell colSpan={showDetail ? "10" : "8"}>
 | 
			
		||||
              <Button size='small' as={Link} to='/channel/add' loading={loading}>
 | 
			
		||||
                添加新的渠道
 | 
			
		||||
              </Button>
 | 
			
		||||
@@ -609,6 +626,7 @@ const ChannelsTable = () => {
 | 
			
		||||
                }
 | 
			
		||||
              />
 | 
			
		||||
              <Button size='small' onClick={refresh} loading={loading}>刷新</Button>
 | 
			
		||||
              <Button size='small' onClick={toggleShowDetail}>{showDetail ? "隐藏详情" : "详情"}</Button>
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
          </Table.Row>
 | 
			
		||||
        </Table.Footer>
 | 
			
		||||
 
 | 
			
		||||
@@ -117,7 +117,7 @@ const TokensTable = () => {
 | 
			
		||||
        url = nextUrl;
 | 
			
		||||
        break;
 | 
			
		||||
      case 'lobechat':
 | 
			
		||||
        url = nextLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}"/v1"}}}`;
 | 
			
		||||
        url = nextLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}/v1"}}}`;
 | 
			
		||||
        break;
 | 
			
		||||
      default:
 | 
			
		||||
        url = `sk-${key}`;
 | 
			
		||||
@@ -160,7 +160,7 @@ const TokensTable = () => {
 | 
			
		||||
        break;
 | 
			
		||||
 | 
			
		||||
      case 'lobechat':
 | 
			
		||||
        url = chatLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}"/v1"}}}`;
 | 
			
		||||
        url = chatLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}/v1"}}}`;
 | 
			
		||||
        break;
 | 
			
		||||
 | 
			
		||||
      default:
 | 
			
		||||
 
 | 
			
		||||
@@ -30,6 +30,8 @@ export const CHANNEL_OPTIONS = [
 | 
			
		||||
    { key: 42, text: 'VertexAI', value: 42, color: 'blue' },
 | 
			
		||||
    { key: 43, text: 'Proxy', value: 43, color: 'blue' },
 | 
			
		||||
    { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
 | 
			
		||||
    { key: 45, text: 'xAI', value: 45, color: 'blue' },
 | 
			
		||||
    { key: 46, text: 'Replicate', value: 46, color: 'blue' },
 | 
			
		||||
    { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
 | 
			
		||||
    { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
 | 
			
		||||
    { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
 | 
			
		||||
 
 | 
			
		||||
@@ -43,6 +43,7 @@ const EditChannel = () => {
 | 
			
		||||
    base_url: '',
 | 
			
		||||
    other: '',
 | 
			
		||||
    model_mapping: '',
 | 
			
		||||
    system_prompt: '',
 | 
			
		||||
    models: [],
 | 
			
		||||
    groups: ['default']
 | 
			
		||||
  };
 | 
			
		||||
@@ -425,7 +426,7 @@ const EditChannel = () => {
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
          {
 | 
			
		||||
          inputs.type !== 43 && (
 | 
			
		||||
          inputs.type !== 43 && (<>
 | 
			
		||||
              <Form.Field>
 | 
			
		||||
                <Form.TextArea
 | 
			
		||||
                  label='模型重定向'
 | 
			
		||||
@@ -437,6 +438,18 @@ const EditChannel = () => {
 | 
			
		||||
                  autoComplete='new-password'
 | 
			
		||||
                />
 | 
			
		||||
              </Form.Field>
 | 
			
		||||
            <Form.Field>
 | 
			
		||||
                <Form.TextArea
 | 
			
		||||
                  label='系统提示词'
 | 
			
		||||
                  placeholder={`此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型`}
 | 
			
		||||
                  name='system_prompt'
 | 
			
		||||
                  onChange={handleInputChange}
 | 
			
		||||
                  value={inputs.system_prompt}
 | 
			
		||||
                  style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
 | 
			
		||||
                  autoComplete='new-password'
 | 
			
		||||
                />
 | 
			
		||||
              </Form.Field>
 | 
			
		||||
              </>
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
          {
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,7 @@ import React from 'react';
 | 
			
		||||
import { Header, Segment } from 'semantic-ui-react';
 | 
			
		||||
import ChannelsTable from '../../components/ChannelsTable';
 | 
			
		||||
 | 
			
		||||
const File = () => (
 | 
			
		||||
const Channel = () => (
 | 
			
		||||
  <>
 | 
			
		||||
    <Segment>
 | 
			
		||||
      <Header as='h3'>管理渠道</Header>
 | 
			
		||||
@@ -11,4 +11,4 @@ const File = () => (
 | 
			
		||||
  </>
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
export default File;
 | 
			
		||||
export default Channel;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user