mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-27 03:43:43 +08:00 
			
		
		
		
	Compare commits
	
		
			25 Commits
		
	
	
		
			v0.6.10-al
			...
			v0.6.10
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 3915ce9814 | ||
|  | 999defc88b | ||
|  | b51c47bc77 | ||
|  | 4f25cde132 | ||
|  | d89e9d7e44 | ||
|  | a858292b54 | ||
|  | ff589b5e4a | ||
|  | 95e8c16338 | ||
|  | 381172cb36 | ||
|  | 59eae186a3 | ||
|  | ce52f355bb | ||
|  | cb9d0a74c9 | ||
|  | 49ffb1c60d | ||
|  | 2f16649896 | ||
|  | af3aa57bd6 | ||
|  | e9f117ff72 | ||
|  | 6bb5247bd6 | ||
|  | 305ce14fe3 | ||
|  | 36c8f4f15c | ||
|  | 45b51ea0ee | ||
|  | 7c8628bd95 | ||
|  | 6ab87f8a08 | ||
|  | 833fa7ad6f | ||
|  | 6eb0770a89 | ||
|  | 92cd46d64f | 
							
								
								
									
										10
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										10
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							| @@ -1,19 +1,17 @@ | |||||||
| name: CI | name: CI | ||||||
|  |  | ||||||
| # This setup assumes that you run the unit tests with code coverage in the same | # This setup assumes that you run the unit tests with code coverage in the same | ||||||
| # workflow that will also print the coverage report as comment to the pull request.  | # workflow that will also print the coverage report as comment to the pull request. | ||||||
| # Therefore, you need to trigger this workflow when a pull request is (re)opened or | # Therefore, you need to trigger this workflow when a pull request is (re)opened or | ||||||
| # when new code is pushed to the branch of the pull request. In addition, you also | # when new code is pushed to the branch of the pull request. In addition, you also | ||||||
| # need to trigger this workflow when new code is pushed to the main branch because  | # need to trigger this workflow when new code is pushed to the main branch because | ||||||
| # we need to upload the code coverage results as artifact for the main branch as | # we need to upload the code coverage results as artifact for the main branch as | ||||||
| # well since it will be the baseline code coverage. | # well since it will be the baseline code coverage. | ||||||
| #  | # | ||||||
| # We do not want to trigger the workflow for pushes to *any* branch because this | # We do not want to trigger the workflow for pushes to *any* branch because this | ||||||
| # would trigger our jobs twice on pull requests (once from "push" event and once | # would trigger our jobs twice on pull requests (once from "push" event and once | ||||||
| # from "pull_request->synchronize") | # from "pull_request->synchronize") | ||||||
| on: | on: | ||||||
|   pull_request: |  | ||||||
|     types: [opened, reopened, synchronize] |  | ||||||
|   push: |   push: | ||||||
|     branches: |     branches: | ||||||
|       - 'main' |       - 'main' | ||||||
| @@ -31,7 +29,7 @@ jobs: | |||||||
|         with: |         with: | ||||||
|           go-version: ^1.22 |           go-version: ^1.22 | ||||||
|  |  | ||||||
|       # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a  |       # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a | ||||||
|       # coverage profile to a file. You will need the name of the file (e.g. "coverage.txt") |       # coverage profile to a file. You will need the name of the file (e.g. "coverage.txt") | ||||||
|       # in the next step as well as the next job. |       # in the next step as well as the next job. | ||||||
|       - name: Test |       - name: Test | ||||||
|   | |||||||
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -9,4 +9,5 @@ logs | |||||||
| data | data | ||||||
| /web/node_modules | /web/node_modules | ||||||
| cmd.md | cmd.md | ||||||
| .env | .env | ||||||
|  | /one-api | ||||||
|   | |||||||
							
								
								
									
										16
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								README.md
									
									
									
									
									
								
							| @@ -115,8 +115,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
| 21. 支持 Cloudflare Turnstile 用户校验。 | 21. 支持 Cloudflare Turnstile 用户校验。 | ||||||
| 22. 支持用户管理,支持**多种用户登录注册方式**: | 22. 支持用户管理,支持**多种用户登录注册方式**: | ||||||
|     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 |     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 | ||||||
|     + 支持使用飞书进行授权登录。 |     + 支持[飞书授权登录](https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/authen-v1/authorize/get)([这里有 One API 的实现细节阐述供参考](https://iamazing.cn/page/feishu-oauth-login))。 | ||||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 |     + 支持 [GitHub 授权登录](https://github.com/settings/applications/new)。 | ||||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 |     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||||
| 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 | 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 | ||||||
| 24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 | 24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 | ||||||
| @@ -175,6 +175,10 @@ sudo service nginx restart | |||||||
|  |  | ||||||
| 初始账号用户名为 `root`,密码为 `123456`。 | 初始账号用户名为 `root`,密码为 `123456`。 | ||||||
|  |  | ||||||
|  | ### 通过宝塔面板进行一键部署 | ||||||
|  | 1. 安装宝塔面板9.2.0及以上版本,前往 [宝塔面板](https://www.bt.cn/new/download.html?r=dk_oneapi) 官网,选择正式版的脚本下载安装; | ||||||
|  | 2. 安装后登录宝塔面板,在左侧菜单栏中点击 `Docker`,首次进入会提示安装 `Docker` 服务,点击立即安装,按提示完成安装; | ||||||
|  | 3. 安装完成后在应用商店中搜索 `One-API`,点击安装,配置域名等基本信息即可完成安装; | ||||||
|  |  | ||||||
| ### 基于 Docker Compose 进行部署 | ### 基于 Docker Compose 进行部署 | ||||||
|  |  | ||||||
| @@ -218,7 +222,7 @@ docker-compose ps | |||||||
| 3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。 | 3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。 | ||||||
| 4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置,在使用远程数据库的情况下,推荐设置该项并启用 Redis,无论主从。 | 4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置,在使用远程数据库的情况下,推荐设置该项并启用 Redis,无论主从。 | ||||||
| 5. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。 | 5. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。 | ||||||
| 6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟。 | 6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟(Redis 集群或者哨兵模式的支持请参考环境变量说明)。 | ||||||
| 7. 如果主服务器访问数据库延迟也比较高,则也需要启用 Redis,并设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。 | 7. 如果主服务器访问数据库延迟也比较高,则也需要启用 Redis,并设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。 | ||||||
|  |  | ||||||
| 环境变量的具体使用方法详见[此处](#环境变量)。 | 环境变量的具体使用方法详见[此处](#环境变量)。 | ||||||
| @@ -347,6 +351,11 @@ graph LR | |||||||
| 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 | 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 | ||||||
|    + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` |    + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||||
|    + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 |    + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 | ||||||
|  |    + 如果需要使用哨兵或者集群模式: | ||||||
|  |      + 则需要把该环境变量设置为节点列表,例如:`localhost:49153,localhost:49154,localhost:49155`。 | ||||||
|  |      + 除此之外还需要设置以下环境变量: | ||||||
|  |        + `REDIS_PASSWORD`:Redis 集群或者哨兵模式下的密码设置。 | ||||||
|  |        + `REDIS_MASTER_NAME`:Redis 哨兵模式下主节点的名称。 | ||||||
| 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 | 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 | ||||||
|    + 例子:`SESSION_SECRET=random_string` |    + 例子:`SESSION_SECRET=random_string` | ||||||
| 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 | 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 | ||||||
| @@ -400,6 +409,7 @@ graph LR | |||||||
| 26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 | 26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 | ||||||
| 27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 | 27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 | ||||||
| 28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。 | 28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。 | ||||||
|  | 29. `ENFORCE_INCLUDE_USAGE`:是否强制在 stream 模型下返回 usage,默认不开启,可选值为 `true` 和 `false`。 | ||||||
|  |  | ||||||
| ### 命令行参数 | ### 命令行参数 | ||||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||||
|   | |||||||
| @@ -160,3 +160,5 @@ var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false) | |||||||
| var RelayProxy = env.String("RELAY_PROXY", "") | var RelayProxy = env.String("RELAY_PROXY", "") | ||||||
| var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") | var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") | ||||||
| var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) | var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) | ||||||
|  |  | ||||||
|  | var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false) | ||||||
|   | |||||||
| @@ -20,4 +20,5 @@ const ( | |||||||
| 	BaseURL           = "base_url" | 	BaseURL           = "base_url" | ||||||
| 	AvailableModels   = "available_models" | 	AvailableModels   = "available_models" | ||||||
| 	KeyRequestBody    = "key_request_body" | 	KeyRequestBody    = "key_request_body" | ||||||
|  | 	SystemPrompt      = "system_prompt" | ||||||
| ) | ) | ||||||
|   | |||||||
| @@ -2,13 +2,15 @@ package common | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"os" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/go-redis/redis/v8" | 	"github.com/go-redis/redis/v8" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"os" |  | ||||||
| 	"time" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var RDB *redis.Client | var RDB redis.Cmdable | ||||||
| var RedisEnabled = true | var RedisEnabled = true | ||||||
|  |  | ||||||
| // InitRedisClient This function is called after init() | // InitRedisClient This function is called after init() | ||||||
| @@ -23,13 +25,23 @@ func InitRedisClient() (err error) { | |||||||
| 		logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") | 		logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	logger.SysLog("Redis is enabled") | 	redisConnString := os.Getenv("REDIS_CONN_STRING") | ||||||
| 	opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) | 	if os.Getenv("REDIS_MASTER_NAME") == "" { | ||||||
| 	if err != nil { | 		logger.SysLog("Redis is enabled") | ||||||
| 		logger.FatalLog("failed to parse Redis connection string: " + err.Error()) | 		opt, err := redis.ParseURL(redisConnString) | ||||||
|  | 		if err != nil { | ||||||
|  | 			logger.FatalLog("failed to parse Redis connection string: " + err.Error()) | ||||||
|  | 		} | ||||||
|  | 		RDB = redis.NewClient(opt) | ||||||
|  | 	} else { | ||||||
|  | 		// cluster mode | ||||||
|  | 		logger.SysLog("Redis cluster mode enabled") | ||||||
|  | 		RDB = redis.NewUniversalClient(&redis.UniversalOptions{ | ||||||
|  | 			Addrs:      strings.Split(redisConnString, ","), | ||||||
|  | 			Password:   os.Getenv("REDIS_PASSWORD"), | ||||||
|  | 			MasterName: os.Getenv("REDIS_MASTER_NAME"), | ||||||
|  | 		}) | ||||||
| 	} | 	} | ||||||
| 	RDB = redis.NewClient(opt) |  | ||||||
|  |  | ||||||
| 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||||
| 	defer cancel() | 	defer cancel() | ||||||
|  |  | ||||||
|   | |||||||
| @@ -3,9 +3,10 @@ package render | |||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"strings" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
| 	"strings" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func StringData(c *gin.Context, str string) { | func StringData(c *gin.Context, str string) { | ||||||
|   | |||||||
| @@ -40,7 +40,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData)) | 	req, err := http.NewRequest("POST", "https://open.feishu.cn/open-apis/authen/v2/oauth/token", bytes.NewBuffer(jsonData)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -4,16 +4,17 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"strconv" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/songquanpeng/one-api/common/client" | 	"github.com/songquanpeng/one-api/common/client" | ||||||
| 	"github.com/songquanpeng/one-api/common/config" | 	"github.com/songquanpeng/one-api/common/config" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"github.com/songquanpeng/one-api/model" | 	"github.com/songquanpeng/one-api/model" | ||||||
| 	"github.com/songquanpeng/one-api/monitor" | 	"github.com/songquanpeng/one-api/monitor" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||||
| 	"io" |  | ||||||
| 	"net/http" |  | ||||||
| 	"strconv" |  | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
| @@ -101,6 +102,16 @@ type SiliconFlowUsageResponse struct { | |||||||
| 	} `json:"data"` | 	} `json:"data"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type DeepSeekUsageResponse struct { | ||||||
|  | 	IsAvailable  bool `json:"is_available"` | ||||||
|  | 	BalanceInfos []struct { | ||||||
|  | 		Currency        string `json:"currency"` | ||||||
|  | 		TotalBalance    string `json:"total_balance"` | ||||||
|  | 		GrantedBalance  string `json:"granted_balance"` | ||||||
|  | 		ToppedUpBalance string `json:"topped_up_balance"` | ||||||
|  | 	} `json:"balance_infos"` | ||||||
|  | } | ||||||
|  |  | ||||||
| // GetAuthHeader get auth header | // GetAuthHeader get auth header | ||||||
| func GetAuthHeader(token string) http.Header { | func GetAuthHeader(token string) http.Header { | ||||||
| 	h := http.Header{} | 	h := http.Header{} | ||||||
| @@ -237,7 +248,36 @@ func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) { | |||||||
| 	if response.Code != 20000 { | 	if response.Code != 20000 { | ||||||
| 		return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message) | 		return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message) | ||||||
| 	} | 	} | ||||||
| 	balance, err := strconv.ParseFloat(response.Data.Balance, 64) | 	balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0, err | ||||||
|  | 	} | ||||||
|  | 	channel.UpdateBalance(balance) | ||||||
|  | 	return balance, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) { | ||||||
|  | 	url := "https://api.deepseek.com/user/balance" | ||||||
|  | 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0, err | ||||||
|  | 	} | ||||||
|  | 	response := DeepSeekUsageResponse{} | ||||||
|  | 	err = json.Unmarshal(body, &response) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0, err | ||||||
|  | 	} | ||||||
|  | 	index := -1 | ||||||
|  | 	for i, balanceInfo := range response.BalanceInfos { | ||||||
|  | 		if balanceInfo.Currency == "CNY" { | ||||||
|  | 			index = i | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if index == -1 { | ||||||
|  | 		return 0, errors.New("currency CNY not found") | ||||||
|  | 	} | ||||||
|  | 	balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return 0, err | 		return 0, err | ||||||
| 	} | 	} | ||||||
| @@ -271,6 +311,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { | |||||||
| 		return updateChannelAIGC2DBalance(channel) | 		return updateChannelAIGC2DBalance(channel) | ||||||
| 	case channeltype.SiliconFlow: | 	case channeltype.SiliconFlow: | ||||||
| 		return updateChannelSiliconFlowBalance(channel) | 		return updateChannelSiliconFlowBalance(channel) | ||||||
|  | 	case channeltype.DeepSeek: | ||||||
|  | 		return updateChannelDeepSeekBalance(channel) | ||||||
| 	default: | 	default: | ||||||
| 		return 0, errors.New("尚未实现") | 		return 0, errors.New("尚未实现") | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -60,7 +60,7 @@ func Relay(c *gin.Context) { | |||||||
| 	channelName := c.GetString(ctxkey.ChannelName) | 	channelName := c.GetString(ctxkey.ChannelName) | ||||||
| 	group := c.GetString(ctxkey.Group) | 	group := c.GetString(ctxkey.Group) | ||||||
| 	originalModel := c.GetString(ctxkey.OriginalModel) | 	originalModel := c.GetString(ctxkey.OriginalModel) | ||||||
| 	go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) | 	go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) | ||||||
| 	requestId := c.GetString(helper.RequestIdKey) | 	requestId := c.GetString(helper.RequestIdKey) | ||||||
| 	retryTimes := config.RetryTimes | 	retryTimes := config.RetryTimes | ||||||
| 	if !shouldRetry(c, bizErr.StatusCode) { | 	if !shouldRetry(c, bizErr.StatusCode) { | ||||||
| @@ -87,8 +87,7 @@ func Relay(c *gin.Context) { | |||||||
| 		channelId := c.GetInt(ctxkey.ChannelId) | 		channelId := c.GetInt(ctxkey.ChannelId) | ||||||
| 		lastFailedChannelId = channelId | 		lastFailedChannelId = channelId | ||||||
| 		channelName := c.GetString(ctxkey.ChannelName) | 		channelName := c.GetString(ctxkey.ChannelName) | ||||||
| 		// BUG: bizErr is in race condition | 		go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) | ||||||
| 		go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) |  | ||||||
| 	} | 	} | ||||||
| 	if bizErr != nil { | 	if bizErr != nil { | ||||||
| 		if bizErr.StatusCode == http.StatusTooManyRequests { | 		if bizErr.StatusCode == http.StatusTooManyRequests { | ||||||
| @@ -122,7 +121,7 @@ func shouldRetry(c *gin.Context, statusCode int) bool { | |||||||
| 	return true | 	return true | ||||||
| } | } | ||||||
|  |  | ||||||
| func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) { | func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err model.ErrorWithStatusCode) { | ||||||
| 	logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message) | 	logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message) | ||||||
| 	// https://platform.openai.com/docs/guides/error-codes/api-errors | 	// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||||
| 	if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { | 	if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { | ||||||
|   | |||||||
							
								
								
									
										8
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								go.mod
									
									
									
									
									
								
							| @@ -25,7 +25,7 @@ require ( | |||||||
| 	github.com/pkoukk/tiktoken-go v0.1.7 | 	github.com/pkoukk/tiktoken-go v0.1.7 | ||||||
| 	github.com/smartystreets/goconvey v1.8.1 | 	github.com/smartystreets/goconvey v1.8.1 | ||||||
| 	github.com/stretchr/testify v1.9.0 | 	github.com/stretchr/testify v1.9.0 | ||||||
| 	golang.org/x/crypto v0.24.0 | 	golang.org/x/crypto v0.31.0 | ||||||
| 	golang.org/x/image v0.18.0 | 	golang.org/x/image v0.18.0 | ||||||
| 	google.golang.org/api v0.187.0 | 	google.golang.org/api v0.187.0 | ||||||
| 	gorm.io/driver/mysql v1.5.6 | 	gorm.io/driver/mysql v1.5.6 | ||||||
| @@ -99,9 +99,9 @@ require ( | |||||||
| 	golang.org/x/arch v0.8.0 // indirect | 	golang.org/x/arch v0.8.0 // indirect | ||||||
| 	golang.org/x/net v0.26.0 // indirect | 	golang.org/x/net v0.26.0 // indirect | ||||||
| 	golang.org/x/oauth2 v0.21.0 // indirect | 	golang.org/x/oauth2 v0.21.0 // indirect | ||||||
| 	golang.org/x/sync v0.7.0 // indirect | 	golang.org/x/sync v0.10.0 // indirect | ||||||
| 	golang.org/x/sys v0.21.0 // indirect | 	golang.org/x/sys v0.28.0 // indirect | ||||||
| 	golang.org/x/text v0.16.0 // indirect | 	golang.org/x/text v0.21.0 // indirect | ||||||
| 	golang.org/x/time v0.5.0 // indirect | 	golang.org/x/time v0.5.0 // indirect | ||||||
| 	google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect | 	google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect | ||||||
| 	google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect | 	google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect | ||||||
|   | |||||||
							
								
								
									
										16
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								go.sum
									
									
									
									
									
								
							| @@ -222,8 +222,8 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= | |||||||
| golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= | golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= | ||||||
| golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= | ||||||
| golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= | golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= | ||||||
| golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= | golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= | ||||||
| golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= | golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= | ||||||
| golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= | golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= | ||||||
| golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= | golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= | ||||||
| golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= | golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= | ||||||
| @@ -244,20 +244,20 @@ golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht | |||||||
| golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||||
| golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||||
| golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||||
| golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= | golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= | ||||||
| golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= | golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= | ||||||
| golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||||
| golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||||
| golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||||
| golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||||
| golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||||
| golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||||
| golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= | golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= | ||||||
| golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||||||
| golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= | ||||||
| golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||||
| golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= | golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= | ||||||
| golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= | golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= | ||||||
| golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= | golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= | ||||||
| golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= | golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= | ||||||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||||||
|   | |||||||
| @@ -61,6 +61,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode | |||||||
| 	c.Set(ctxkey.Channel, channel.Type) | 	c.Set(ctxkey.Channel, channel.Type) | ||||||
| 	c.Set(ctxkey.ChannelId, channel.Id) | 	c.Set(ctxkey.ChannelId, channel.Id) | ||||||
| 	c.Set(ctxkey.ChannelName, channel.Name) | 	c.Set(ctxkey.ChannelName, channel.Name) | ||||||
|  | 	if channel.SystemPrompt != nil && *channel.SystemPrompt != "" { | ||||||
|  | 		c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt) | ||||||
|  | 	} | ||||||
| 	c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) | 	c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) | ||||||
| 	c.Set(ctxkey.OriginalModel, modelName) // for retry | 	c.Set(ctxkey.OriginalModel, modelName) // for retry | ||||||
| 	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | 	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | ||||||
|   | |||||||
							
								
								
									
										27
									
								
								middleware/gzip.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								middleware/gzip.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,27 @@ | |||||||
|  | package middleware | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"compress/gzip" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func GzipDecodeMiddleware() gin.HandlerFunc { | ||||||
|  | 	return func(c *gin.Context) { | ||||||
|  | 		if c.GetHeader("Content-Encoding") == "gzip" { | ||||||
|  | 			gzipReader, err := gzip.NewReader(c.Request.Body) | ||||||
|  | 			if err != nil { | ||||||
|  | 				c.AbortWithStatus(http.StatusBadRequest) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			defer gzipReader.Close() | ||||||
|  |  | ||||||
|  | 			// Replace the request body with the decompressed data | ||||||
|  | 			c.Request.Body = io.NopCloser(gzipReader) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Continue processing the request | ||||||
|  | 		c.Next() | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -37,6 +37,7 @@ type Channel struct { | |||||||
| 	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | 	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | ||||||
| 	Priority           *int64  `json:"priority" gorm:"bigint;default:0"` | 	Priority           *int64  `json:"priority" gorm:"bigint;default:0"` | ||||||
| 	Config             string  `json:"config"` | 	Config             string  `json:"config"` | ||||||
|  | 	SystemPrompt       *string `json:"system_prompt" gorm:"type:text"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type ChannelConfig struct { | type ChannelConfig struct { | ||||||
|   | |||||||
| @@ -34,7 +34,7 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool { | |||||||
| 		strings.Contains(lowerMessage, "credit") || | 		strings.Contains(lowerMessage, "credit") || | ||||||
| 		strings.Contains(lowerMessage, "balance") || | 		strings.Contains(lowerMessage, "balance") || | ||||||
| 		strings.Contains(lowerMessage, "permission denied") || | 		strings.Contains(lowerMessage, "permission denied") || | ||||||
|   	strings.Contains(lowerMessage, "organization has been restricted") || // groq | 		strings.Contains(lowerMessage, "organization has been restricted") || // groq | ||||||
| 		strings.Contains(lowerMessage, "已欠费") { | 		strings.Contains(lowerMessage, "已欠费") { | ||||||
| 		return true | 		return true | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -16,6 +16,7 @@ import ( | |||||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/palm" | 	"github.com/songquanpeng/one-api/relay/adaptor/palm" | ||||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/proxy" | 	"github.com/songquanpeng/one-api/relay/adaptor/proxy" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/adaptor/replicate" | ||||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/tencent" | 	"github.com/songquanpeng/one-api/relay/adaptor/tencent" | ||||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/vertexai" | 	"github.com/songquanpeng/one-api/relay/adaptor/vertexai" | ||||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/xunfei" | 	"github.com/songquanpeng/one-api/relay/adaptor/xunfei" | ||||||
| @@ -61,6 +62,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor { | |||||||
| 		return &vertexai.Adaptor{} | 		return &vertexai.Adaptor{} | ||||||
| 	case apitype.Proxy: | 	case apitype.Proxy: | ||||||
| 		return &proxy.Adaptor{} | 		return &proxy.Adaptor{} | ||||||
|  | 	case apitype.Replicate: | ||||||
|  | 		return &replicate.Adaptor{} | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,7 +1,23 @@ | |||||||
| package ali | package ali | ||||||
|  |  | ||||||
| var ModelList = []string{ | var ModelList = []string{ | ||||||
| 	"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", | 	"qwen-turbo", "qwen-turbo-latest", | ||||||
| 	"text-embedding-v1", | 	"qwen-plus", "qwen-plus-latest", | ||||||
|  | 	"qwen-max", "qwen-max-latest", | ||||||
|  | 	"qwen-max-longcontext", | ||||||
|  | 	"qwen-vl-max", "qwen-vl-max-latest", "qwen-vl-plus", "qwen-vl-plus-latest", | ||||||
|  | 	"qwen-vl-ocr", "qwen-vl-ocr-latest", | ||||||
|  | 	"qwen-audio-turbo", | ||||||
|  | 	"qwen-math-plus", "qwen-math-plus-latest", "qwen-math-turbo", "qwen-math-turbo-latest", | ||||||
|  | 	"qwen-coder-plus", "qwen-coder-plus-latest", "qwen-coder-turbo", "qwen-coder-turbo-latest", | ||||||
|  | 	"qwq-32b-preview", "qwen2.5-72b-instruct", "qwen2.5-32b-instruct", "qwen2.5-14b-instruct", "qwen2.5-7b-instruct", "qwen2.5-3b-instruct", "qwen2.5-1.5b-instruct", "qwen2.5-0.5b-instruct", | ||||||
|  | 	"qwen2-72b-instruct", "qwen2-57b-a14b-instruct", "qwen2-7b-instruct", "qwen2-1.5b-instruct", "qwen2-0.5b-instruct", | ||||||
|  | 	"qwen1.5-110b-chat", "qwen1.5-72b-chat", "qwen1.5-32b-chat", "qwen1.5-14b-chat", "qwen1.5-7b-chat", "qwen1.5-1.8b-chat", "qwen1.5-0.5b-chat", | ||||||
|  | 	"qwen-72b-chat", "qwen-14b-chat", "qwen-7b-chat", "qwen-1.8b-chat", "qwen-1.8b-longcontext-chat", | ||||||
|  | 	"qwen2-vl-7b-instruct", "qwen2-vl-2b-instruct", "qwen-vl-v1", "qwen-vl-chat-v1", | ||||||
|  | 	"qwen2-audio-instruct", "qwen-audio-chat", | ||||||
|  | 	"qwen2.5-math-72b-instruct", "qwen2.5-math-7b-instruct", "qwen2.5-math-1.5b-instruct", "qwen2-math-72b-instruct", "qwen2-math-7b-instruct", "qwen2-math-1.5b-instruct", | ||||||
|  | 	"qwen2.5-coder-32b-instruct", "qwen2.5-coder-14b-instruct", "qwen2.5-coder-7b-instruct", "qwen2.5-coder-3b-instruct", "qwen2.5-coder-1.5b-instruct", "qwen2.5-coder-0.5b-instruct", | ||||||
|  | 	"text-embedding-v1", "text-embedding-v3", "text-embedding-v2", "text-embedding-async-v2", "text-embedding-async-v1", | ||||||
| 	"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1", | 	"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1", | ||||||
| } | } | ||||||
|   | |||||||
| @@ -9,5 +9,4 @@ var ModelList = []string{ | |||||||
| 	"claude-3-5-sonnet-20240620", | 	"claude-3-5-sonnet-20240620", | ||||||
| 	"claude-3-5-sonnet-20241022", | 	"claude-3-5-sonnet-20241022", | ||||||
| 	"claude-3-5-sonnet-latest", | 	"claude-3-5-sonnet-latest", | ||||||
| 	"claude-3-5-haiku-20241022", |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -24,7 +24,12 @@ func (a *Adaptor) Init(meta *meta.Meta) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||||
| 	version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion) | 	defaultVersion := config.GeminiVersion | ||||||
|  | 	if meta.ActualModelName == "gemini-2.0-flash-exp" { | ||||||
|  | 		defaultVersion = "v1beta" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	version := helper.AssignOrDefault(meta.Config.APIVersion, defaultVersion) | ||||||
| 	action := "" | 	action := "" | ||||||
| 	switch meta.Mode { | 	switch meta.Mode { | ||||||
| 	case relaymode.Embeddings: | 	case relaymode.Embeddings: | ||||||
| @@ -36,6 +41,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | |||||||
| 	if meta.IsStream { | 	if meta.IsStream { | ||||||
| 		action = "streamGenerateContent?alt=sse" | 		action = "streamGenerateContent?alt=sse" | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil | 	return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -3,5 +3,9 @@ package gemini | |||||||
| // https://ai.google.dev/models/gemini | // https://ai.google.dev/models/gemini | ||||||
|  |  | ||||||
| var ModelList = []string{ | var ModelList = []string{ | ||||||
| 	"gemini-pro", "gemini-1.0-pro", "gemini-1.5-flash", "gemini-1.5-pro", "text-embedding-004", "aqa", | 	"gemini-pro", "gemini-1.0-pro", | ||||||
|  | 	"gemini-1.5-flash", "gemini-1.5-pro", | ||||||
|  | 	"text-embedding-004", "aqa", | ||||||
|  | 	"gemini-2.0-flash-exp", | ||||||
|  | 	"gemini-2.0-flash-thinking-exp", | ||||||
| } | } | ||||||
|   | |||||||
| @@ -55,6 +55,10 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { | |||||||
| 				Category:  "HARM_CATEGORY_DANGEROUS_CONTENT", | 				Category:  "HARM_CATEGORY_DANGEROUS_CONTENT", | ||||||
| 				Threshold: config.GeminiSafetySetting, | 				Threshold: config.GeminiSafetySetting, | ||||||
| 			}, | 			}, | ||||||
|  | 			{ | ||||||
|  | 				Category:  "HARM_CATEGORY_CIVIC_INTEGRITY", | ||||||
|  | 				Threshold: config.GeminiSafetySetting, | ||||||
|  | 			}, | ||||||
| 		}, | 		}, | ||||||
| 		GenerationConfig: ChatGenerationConfig{ | 		GenerationConfig: ChatGenerationConfig{ | ||||||
| 			Temperature:     textRequest.Temperature, | 			Temperature:     textRequest.Temperature, | ||||||
| @@ -247,7 +251,14 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { | |||||||
| 			if candidate.Content.Parts[0].FunctionCall != nil { | 			if candidate.Content.Parts[0].FunctionCall != nil { | ||||||
| 				choice.Message.ToolCalls = getToolCalls(&candidate) | 				choice.Message.ToolCalls = getToolCalls(&candidate) | ||||||
| 			} else { | 			} else { | ||||||
| 				choice.Message.Content = candidate.Content.Parts[0].Text | 				var builder strings.Builder | ||||||
|  | 				for _, part := range candidate.Content.Parts { | ||||||
|  | 					if i > 0 { | ||||||
|  | 						builder.WriteString("\n") | ||||||
|  | 					} | ||||||
|  | 					builder.WriteString(part.Text) | ||||||
|  | 				} | ||||||
|  | 				choice.Message.Content = builder.String() | ||||||
| 			} | 			} | ||||||
| 		} else { | 		} else { | ||||||
| 			choice.Message.Content = "" | 			choice.Message.Content = "" | ||||||
|   | |||||||
| @@ -31,8 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | |||||||
| 			TopP:             request.TopP, | 			TopP:             request.TopP, | ||||||
| 			FrequencyPenalty: request.FrequencyPenalty, | 			FrequencyPenalty: request.FrequencyPenalty, | ||||||
| 			PresencePenalty:  request.PresencePenalty, | 			PresencePenalty:  request.PresencePenalty, | ||||||
| 			NumPredict:  	  request.MaxTokens, | 			NumPredict:       request.MaxTokens, | ||||||
| 			NumCtx:  	  request.NumCtx, | 			NumCtx:           request.NumCtx, | ||||||
| 		}, | 		}, | ||||||
| 		Stream: request.Stream, | 		Stream: request.Stream, | ||||||
| 	} | 	} | ||||||
| @@ -122,7 +122,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | |||||||
| 	for scanner.Scan() { | 	for scanner.Scan() { | ||||||
| 		data := scanner.Text() | 		data := scanner.Text() | ||||||
| 		if strings.HasPrefix(data, "}") { | 		if strings.HasPrefix(data, "}") { | ||||||
| 		    data = strings.TrimPrefix(data, "}") + "}" | 			data = strings.TrimPrefix(data, "}") + "}" | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		var ollamaResponse ChatResponse | 		var ollamaResponse ChatResponse | ||||||
|   | |||||||
| @@ -9,6 +9,7 @@ var ModelList = []string{ | |||||||
| 	"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", | 	"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", | ||||||
| 	"gpt-4o", "gpt-4o-2024-05-13", | 	"gpt-4o", "gpt-4o-2024-05-13", | ||||||
| 	"gpt-4o-2024-08-06", | 	"gpt-4o-2024-08-06", | ||||||
|  | 	"gpt-4o-2024-11-20", | ||||||
| 	"chatgpt-4o-latest", | 	"chatgpt-4o-latest", | ||||||
| 	"gpt-4o-mini", "gpt-4o-mini-2024-07-18", | 	"gpt-4o-mini", "gpt-4o-mini-2024-07-18", | ||||||
| 	"gpt-4-vision-preview", | 	"gpt-4-vision-preview", | ||||||
| @@ -20,4 +21,7 @@ var ModelList = []string{ | |||||||
| 	"dall-e-2", "dall-e-3", | 	"dall-e-2", "dall-e-3", | ||||||
| 	"whisper-1", | 	"whisper-1", | ||||||
| 	"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106", | 	"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106", | ||||||
|  | 	"o1", "o1-2024-12-17", | ||||||
|  | 	"o1-preview", "o1-preview-2024-09-12", | ||||||
|  | 	"o1-mini", "o1-mini-2024-09-12", | ||||||
| } | } | ||||||
|   | |||||||
| @@ -2,15 +2,16 @@ package openai | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"strings" | ||||||
|  |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||||
| 	"github.com/songquanpeng/one-api/relay/model" | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
| 	"strings" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage { | func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage { | ||||||
| 	usage := &model.Usage{} | 	usage := &model.Usage{} | ||||||
| 	usage.PromptTokens = promptTokens | 	usage.PromptTokens = promptTokens | ||||||
| 	usage.CompletionTokens = CountTokenText(responseText, modeName) | 	usage.CompletionTokens = CountTokenText(responseText, modelName) | ||||||
| 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens | 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens | ||||||
| 	return usage | 	return usage | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,8 +1,16 @@ | |||||||
| package openai | package openai | ||||||
|  |  | ||||||
| import "github.com/songquanpeng/one-api/relay/model" | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"fmt" | ||||||
|  |  | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
|  | ) | ||||||
|  |  | ||||||
| func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode { | func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode { | ||||||
|  | 	logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err)) | ||||||
|  |  | ||||||
| 	Error := model.Error{ | 	Error := model.Error{ | ||||||
| 		Message: err.Error(), | 		Message: err.Error(), | ||||||
| 		Type:    "one_api_error", | 		Type:    "one_api_error", | ||||||
|   | |||||||
							
								
								
									
										136
									
								
								relay/adaptor/replicate/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								relay/adaptor/replicate/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,136 @@ | |||||||
|  | package replicate | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"slices" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/pkg/errors" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/meta" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Adaptor struct { | ||||||
|  | 	meta *meta.Meta | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ConvertImageRequest implements adaptor.Adaptor. | ||||||
|  | func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||||
|  | 	return DrawImageRequest{ | ||||||
|  | 		Input: ImageInput{ | ||||||
|  | 			Steps:           25, | ||||||
|  | 			Prompt:          request.Prompt, | ||||||
|  | 			Guidance:        3, | ||||||
|  | 			Seed:            int(time.Now().UnixNano()), | ||||||
|  | 			SafetyTolerance: 5, | ||||||
|  | 			NImages:         1, // replicate will always return 1 image | ||||||
|  | 			Width:           1440, | ||||||
|  | 			Height:          1440, | ||||||
|  | 			AspectRatio:     "1:1", | ||||||
|  | 		}, | ||||||
|  | 	}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||||
|  | 	if !request.Stream { | ||||||
|  | 		// TODO: support non-stream mode | ||||||
|  | 		return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Build the prompt from OpenAI messages | ||||||
|  | 	var promptBuilder strings.Builder | ||||||
|  | 	for _, message := range request.Messages { | ||||||
|  | 		switch msgCnt := message.Content.(type) { | ||||||
|  | 		case string: | ||||||
|  | 			promptBuilder.WriteString(message.Role) | ||||||
|  | 			promptBuilder.WriteString(": ") | ||||||
|  | 			promptBuilder.WriteString(msgCnt) | ||||||
|  | 			promptBuilder.WriteString("\n") | ||||||
|  | 		default: | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	replicateRequest := ReplicateChatRequest{ | ||||||
|  | 		Input: ChatInput{ | ||||||
|  | 			Prompt:           promptBuilder.String(), | ||||||
|  | 			MaxTokens:        request.MaxTokens, | ||||||
|  | 			Temperature:      1.0, | ||||||
|  | 			TopP:             1.0, | ||||||
|  | 			PresencePenalty:  0.0, | ||||||
|  | 			FrequencyPenalty: 0.0, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Map optional fields | ||||||
|  | 	if request.Temperature != nil { | ||||||
|  | 		replicateRequest.Input.Temperature = *request.Temperature | ||||||
|  | 	} | ||||||
|  | 	if request.TopP != nil { | ||||||
|  | 		replicateRequest.Input.TopP = *request.TopP | ||||||
|  | 	} | ||||||
|  | 	if request.PresencePenalty != nil { | ||||||
|  | 		replicateRequest.Input.PresencePenalty = *request.PresencePenalty | ||||||
|  | 	} | ||||||
|  | 	if request.FrequencyPenalty != nil { | ||||||
|  | 		replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty | ||||||
|  | 	} | ||||||
|  | 	if request.MaxTokens > 0 { | ||||||
|  | 		replicateRequest.Input.MaxTokens = request.MaxTokens | ||||||
|  | 	} else if request.MaxTokens == 0 { | ||||||
|  | 		replicateRequest.Input.MaxTokens = 500 | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return replicateRequest, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) Init(meta *meta.Meta) { | ||||||
|  | 	a.meta = meta | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||||
|  | 	if !slices.Contains(ModelList, meta.OriginModelName) { | ||||||
|  | 		return "", errors.Errorf("model %s not supported", meta.OriginModelName) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||||
|  | 	adaptor.SetupCommonRequestHeader(c, req, meta) | ||||||
|  | 	req.Header.Set("Authorization", "Bearer "+meta.APIKey) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||||
|  | 	logger.Info(c, "send request to replicate") | ||||||
|  | 	return adaptor.DoRequestHelper(a, c, meta, requestBody) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||||
|  | 	switch meta.Mode { | ||||||
|  | 	case relaymode.ImagesGenerations: | ||||||
|  | 		err, usage = ImageHandler(c, resp) | ||||||
|  | 	case relaymode.ChatCompletions: | ||||||
|  | 		err, usage = ChatHandler(c, resp) | ||||||
|  | 	default: | ||||||
|  | 		err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) GetModelList() []string { | ||||||
|  | 	return ModelList | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) GetChannelName() string { | ||||||
|  | 	return "replicate" | ||||||
|  | } | ||||||
							
								
								
									
										191
									
								
								relay/adaptor/replicate/chat.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								relay/adaptor/replicate/chat.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,191 @@ | |||||||
|  | package replicate | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/pkg/errors" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/render" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/meta" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func ChatHandler(c *gin.Context, resp *http.Response) ( | ||||||
|  | 	srvErr *model.ErrorWithStatusCode, usage *model.Usage) { | ||||||
|  | 	if resp.StatusCode != http.StatusCreated { | ||||||
|  | 		payload, _ := io.ReadAll(resp.Body) | ||||||
|  | 		return openai.ErrorWrapper( | ||||||
|  | 				errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)), | ||||||
|  | 				"bad_status_code", http.StatusInternalServerError), | ||||||
|  | 			nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	respBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	respData := new(ChatResponse) | ||||||
|  | 	if err = json.Unmarshal(respBody, respData); err != nil { | ||||||
|  | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for { | ||||||
|  | 		err = func() error { | ||||||
|  | 			// get task | ||||||
|  | 			taskReq, err := http.NewRequestWithContext(c.Request.Context(), | ||||||
|  | 				http.MethodGet, respData.URLs.Get, nil) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return errors.Wrap(err, "new request") | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) | ||||||
|  | 			taskResp, err := http.DefaultClient.Do(taskReq) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return errors.Wrap(err, "get task") | ||||||
|  | 			} | ||||||
|  | 			defer taskResp.Body.Close() | ||||||
|  |  | ||||||
|  | 			if taskResp.StatusCode != http.StatusOK { | ||||||
|  | 				payload, _ := io.ReadAll(taskResp.Body) | ||||||
|  | 				return errors.Errorf("bad status code [%d]%s", | ||||||
|  | 					taskResp.StatusCode, string(payload)) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			taskBody, err := io.ReadAll(taskResp.Body) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return errors.Wrap(err, "read task response") | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			taskData := new(ChatResponse) | ||||||
|  | 			if err = json.Unmarshal(taskBody, taskData); err != nil { | ||||||
|  | 				return errors.Wrap(err, "decode task response") | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			switch taskData.Status { | ||||||
|  | 			case "succeeded": | ||||||
|  | 			case "failed", "canceled": | ||||||
|  | 				return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error) | ||||||
|  | 			default: | ||||||
|  | 				time.Sleep(time.Second * 3) | ||||||
|  | 				return errNextLoop | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if taskData.URLs.Stream == "" { | ||||||
|  | 				return errors.New("stream url is empty") | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// request stream url | ||||||
|  | 			responseText, err := chatStreamHandler(c, taskData.URLs.Stream) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return errors.Wrap(err, "chat stream handler") | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			ctxMeta := meta.GetByContext(c) | ||||||
|  | 			usage = openai.ResponseText2Usage(responseText, | ||||||
|  | 				ctxMeta.ActualModelName, ctxMeta.PromptTokens) | ||||||
|  | 			return nil | ||||||
|  | 		}() | ||||||
|  | 		if err != nil { | ||||||
|  | 			if errors.Is(err, errNextLoop) { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		break | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil, usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	eventPrefix = "event: " | ||||||
|  | 	dataPrefix  = "data: " | ||||||
|  | 	done        = "[DONE]" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) { | ||||||
|  | 	// request stream endpoint | ||||||
|  | 	streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", errors.Wrap(err, "new request to stream") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) | ||||||
|  | 	streamReq.Header.Set("Accept", "text/event-stream") | ||||||
|  | 	streamReq.Header.Set("Cache-Control", "no-store") | ||||||
|  |  | ||||||
|  | 	resp, err := http.DefaultClient.Do(streamReq) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", errors.Wrap(err, "do request to stream") | ||||||
|  | 	} | ||||||
|  | 	defer resp.Body.Close() | ||||||
|  |  | ||||||
|  | 	if resp.StatusCode != http.StatusOK { | ||||||
|  | 		payload, _ := io.ReadAll(resp.Body) | ||||||
|  | 		return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	scanner := bufio.NewScanner(resp.Body) | ||||||
|  | 	scanner.Split(bufio.ScanLines) | ||||||
|  |  | ||||||
|  | 	common.SetEventStreamHeaders(c) | ||||||
|  | 	doneRendered := false | ||||||
|  | 	for scanner.Scan() { | ||||||
|  | 		line := strings.TrimSpace(scanner.Text()) | ||||||
|  | 		if line == "" { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Handle comments starting with ':' | ||||||
|  | 		if strings.HasPrefix(line, ":") { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Parse SSE fields | ||||||
|  | 		if strings.HasPrefix(line, eventPrefix) { | ||||||
|  | 			event := strings.TrimSpace(line[len(eventPrefix):]) | ||||||
|  | 			var data string | ||||||
|  | 			// Read the following lines to get data and id | ||||||
|  | 			for scanner.Scan() { | ||||||
|  | 				nextLine := scanner.Text() | ||||||
|  | 				if nextLine == "" { | ||||||
|  | 					break | ||||||
|  | 				} | ||||||
|  | 				if strings.HasPrefix(nextLine, dataPrefix) { | ||||||
|  | 					data = nextLine[len(dataPrefix):] | ||||||
|  | 				} else if strings.HasPrefix(nextLine, "id:") { | ||||||
|  | 					// id = strings.TrimSpace(nextLine[len("id:"):]) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if event == "output" { | ||||||
|  | 				render.StringData(c, data) | ||||||
|  | 				responseText += data | ||||||
|  | 			} else if event == "done" { | ||||||
|  | 				render.Done(c) | ||||||
|  | 				doneRendered = true | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := scanner.Err(); err != nil { | ||||||
|  | 		return "", errors.Wrap(err, "scan stream") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !doneRendered { | ||||||
|  | 		render.Done(c) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return responseText, nil | ||||||
|  | } | ||||||
							
								
								
									
										58
									
								
								relay/adaptor/replicate/constant.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								relay/adaptor/replicate/constant.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | |||||||
|  | package replicate | ||||||
|  |  | ||||||
|  | // ModelList is a list of models that can be used with Replicate. | ||||||
|  | // | ||||||
|  | // https://replicate.com/pricing | ||||||
|  | var ModelList = []string{ | ||||||
|  | 	// ------------------------------------- | ||||||
|  | 	// image model | ||||||
|  | 	// ------------------------------------- | ||||||
|  | 	"black-forest-labs/flux-1.1-pro", | ||||||
|  | 	"black-forest-labs/flux-1.1-pro-ultra", | ||||||
|  | 	"black-forest-labs/flux-canny-dev", | ||||||
|  | 	"black-forest-labs/flux-canny-pro", | ||||||
|  | 	"black-forest-labs/flux-depth-dev", | ||||||
|  | 	"black-forest-labs/flux-depth-pro", | ||||||
|  | 	"black-forest-labs/flux-dev", | ||||||
|  | 	"black-forest-labs/flux-dev-lora", | ||||||
|  | 	"black-forest-labs/flux-fill-dev", | ||||||
|  | 	"black-forest-labs/flux-fill-pro", | ||||||
|  | 	"black-forest-labs/flux-pro", | ||||||
|  | 	"black-forest-labs/flux-redux-dev", | ||||||
|  | 	"black-forest-labs/flux-redux-schnell", | ||||||
|  | 	"black-forest-labs/flux-schnell", | ||||||
|  | 	"black-forest-labs/flux-schnell-lora", | ||||||
|  | 	"ideogram-ai/ideogram-v2", | ||||||
|  | 	"ideogram-ai/ideogram-v2-turbo", | ||||||
|  | 	"recraft-ai/recraft-v3", | ||||||
|  | 	"recraft-ai/recraft-v3-svg", | ||||||
|  | 	"stability-ai/stable-diffusion-3", | ||||||
|  | 	"stability-ai/stable-diffusion-3.5-large", | ||||||
|  | 	"stability-ai/stable-diffusion-3.5-large-turbo", | ||||||
|  | 	"stability-ai/stable-diffusion-3.5-medium", | ||||||
|  | 	// ------------------------------------- | ||||||
|  | 	// language model | ||||||
|  | 	// ------------------------------------- | ||||||
|  | 	"ibm-granite/granite-20b-code-instruct-8k", | ||||||
|  | 	"ibm-granite/granite-3.0-2b-instruct", | ||||||
|  | 	"ibm-granite/granite-3.0-8b-instruct", | ||||||
|  | 	"ibm-granite/granite-8b-code-instruct-128k", | ||||||
|  | 	"meta/llama-2-13b", | ||||||
|  | 	"meta/llama-2-13b-chat", | ||||||
|  | 	"meta/llama-2-70b", | ||||||
|  | 	"meta/llama-2-70b-chat", | ||||||
|  | 	"meta/llama-2-7b", | ||||||
|  | 	"meta/llama-2-7b-chat", | ||||||
|  | 	"meta/meta-llama-3.1-405b-instruct", | ||||||
|  | 	"meta/meta-llama-3-70b", | ||||||
|  | 	"meta/meta-llama-3-70b-instruct", | ||||||
|  | 	"meta/meta-llama-3-8b", | ||||||
|  | 	"meta/meta-llama-3-8b-instruct", | ||||||
|  | 	"mistralai/mistral-7b-instruct-v0.2", | ||||||
|  | 	"mistralai/mistral-7b-v0.1", | ||||||
|  | 	"mistralai/mixtral-8x7b-instruct-v0.1", | ||||||
|  | 	// ------------------------------------- | ||||||
|  | 	// video model | ||||||
|  | 	// ------------------------------------- | ||||||
|  | 	// "minimax/video-01",  // TODO: implement the adaptor | ||||||
|  | } | ||||||
							
								
								
									
										222
									
								
								relay/adaptor/replicate/image.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										222
									
								
								relay/adaptor/replicate/image.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,222 @@ | |||||||
|  | package replicate | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"encoding/base64" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"image" | ||||||
|  | 	"image/png" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/pkg/errors" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/meta" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
|  | 	"golang.org/x/image/webp" | ||||||
|  | 	"golang.org/x/sync/errgroup" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // ImagesEditsHandler just copy response body to client | ||||||
|  | // | ||||||
|  | // https://replicate.com/black-forest-labs/flux-fill-pro | ||||||
|  | // func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||||
|  | // 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | // 	for k, v := range resp.Header { | ||||||
|  | // 		c.Writer.Header().Set(k, v[0]) | ||||||
|  | // 	} | ||||||
|  |  | ||||||
|  | // 	if _, err := io.Copy(c.Writer, resp.Body); err != nil { | ||||||
|  | // 		return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | // 	} | ||||||
|  | // 	defer resp.Body.Close() | ||||||
|  |  | ||||||
|  | // 	return nil, nil | ||||||
|  | // } | ||||||
|  |  | ||||||
|  | var errNextLoop = errors.New("next_loop") | ||||||
|  |  | ||||||
|  | func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||||
|  | 	if resp.StatusCode != http.StatusCreated { | ||||||
|  | 		payload, _ := io.ReadAll(resp.Body) | ||||||
|  | 		return openai.ErrorWrapper( | ||||||
|  | 				errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)), | ||||||
|  | 				"bad_status_code", http.StatusInternalServerError), | ||||||
|  | 			nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	respBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	respData := new(ImageResponse) | ||||||
|  | 	if err = json.Unmarshal(respBody, respData); err != nil { | ||||||
|  | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for { | ||||||
|  | 		err = func() error { | ||||||
|  | 			// get task | ||||||
|  | 			taskReq, err := http.NewRequestWithContext(c.Request.Context(), | ||||||
|  | 				http.MethodGet, respData.URLs.Get, nil) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return errors.Wrap(err, "new request") | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) | ||||||
|  | 			taskResp, err := http.DefaultClient.Do(taskReq) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return errors.Wrap(err, "get task") | ||||||
|  | 			} | ||||||
|  | 			defer taskResp.Body.Close() | ||||||
|  |  | ||||||
|  | 			if taskResp.StatusCode != http.StatusOK { | ||||||
|  | 				payload, _ := io.ReadAll(taskResp.Body) | ||||||
|  | 				return errors.Errorf("bad status code [%d]%s", | ||||||
|  | 					taskResp.StatusCode, string(payload)) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			taskBody, err := io.ReadAll(taskResp.Body) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return errors.Wrap(err, "read task response") | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			taskData := new(ImageResponse) | ||||||
|  | 			if err = json.Unmarshal(taskBody, taskData); err != nil { | ||||||
|  | 				return errors.Wrap(err, "decode task response") | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			switch taskData.Status { | ||||||
|  | 			case "succeeded": | ||||||
|  | 			case "failed", "canceled": | ||||||
|  | 				return errors.Errorf("task failed: %s", taskData.Status) | ||||||
|  | 			default: | ||||||
|  | 				time.Sleep(time.Second * 3) | ||||||
|  | 				return errNextLoop | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			output, err := taskData.GetOutput() | ||||||
|  | 			if err != nil { | ||||||
|  | 				return errors.Wrap(err, "get output") | ||||||
|  | 			} | ||||||
|  | 			if len(output) == 0 { | ||||||
|  | 				return errors.New("response output is empty") | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			var mu sync.Mutex | ||||||
|  | 			var pool errgroup.Group | ||||||
|  | 			respBody := &openai.ImageResponse{ | ||||||
|  | 				Created: taskData.CompletedAt.Unix(), | ||||||
|  | 				Data:    []openai.ImageData{}, | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			for _, imgOut := range output { | ||||||
|  | 				imgOut := imgOut | ||||||
|  | 				pool.Go(func() error { | ||||||
|  | 					// download image | ||||||
|  | 					downloadReq, err := http.NewRequestWithContext(c.Request.Context(), | ||||||
|  | 						http.MethodGet, imgOut, nil) | ||||||
|  | 					if err != nil { | ||||||
|  | 						return errors.Wrap(err, "new request") | ||||||
|  | 					} | ||||||
|  |  | ||||||
|  | 					imgResp, err := http.DefaultClient.Do(downloadReq) | ||||||
|  | 					if err != nil { | ||||||
|  | 						return errors.Wrap(err, "download image") | ||||||
|  | 					} | ||||||
|  | 					defer imgResp.Body.Close() | ||||||
|  |  | ||||||
|  | 					if imgResp.StatusCode != http.StatusOK { | ||||||
|  | 						payload, _ := io.ReadAll(imgResp.Body) | ||||||
|  | 						return errors.Errorf("bad status code [%d]%s", | ||||||
|  | 							imgResp.StatusCode, string(payload)) | ||||||
|  | 					} | ||||||
|  |  | ||||||
|  | 					imgData, err := io.ReadAll(imgResp.Body) | ||||||
|  | 					if err != nil { | ||||||
|  | 						return errors.Wrap(err, "read image") | ||||||
|  | 					} | ||||||
|  |  | ||||||
|  | 					imgData, err = ConvertImageToPNG(imgData) | ||||||
|  | 					if err != nil { | ||||||
|  | 						return errors.Wrap(err, "convert image") | ||||||
|  | 					} | ||||||
|  |  | ||||||
|  | 					mu.Lock() | ||||||
|  | 					respBody.Data = append(respBody.Data, openai.ImageData{ | ||||||
|  | 						B64Json: fmt.Sprintf("data:image/png;base64,%s", | ||||||
|  | 							base64.StdEncoding.EncodeToString(imgData)), | ||||||
|  | 					}) | ||||||
|  | 					mu.Unlock() | ||||||
|  |  | ||||||
|  | 					return nil | ||||||
|  | 				}) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if err := pool.Wait(); err != nil { | ||||||
|  | 				if len(respBody.Data) == 0 { | ||||||
|  | 					return errors.WithStack(err) | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err)) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			c.JSON(http.StatusOK, respBody) | ||||||
|  | 			return nil | ||||||
|  | 		}() | ||||||
|  | 		if err != nil { | ||||||
|  | 			if errors.Is(err, errNextLoop) { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		break | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ConvertImageToPNG converts a WebP image to PNG format | ||||||
|  | func ConvertImageToPNG(webpData []byte) ([]byte, error) { | ||||||
|  | 	// bypass if it's already a PNG image | ||||||
|  | 	if bytes.HasPrefix(webpData, []byte("\x89PNG")) { | ||||||
|  | 		return webpData, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// check if is jpeg, convert to png | ||||||
|  | 	if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) { | ||||||
|  | 		img, _, err := image.Decode(bytes.NewReader(webpData)) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, errors.Wrap(err, "decode jpeg") | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		var pngBuffer bytes.Buffer | ||||||
|  | 		if err := png.Encode(&pngBuffer, img); err != nil { | ||||||
|  | 			return nil, errors.Wrap(err, "encode png") | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		return pngBuffer.Bytes(), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Decode the WebP image | ||||||
|  | 	img, err := webp.Decode(bytes.NewReader(webpData)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, errors.Wrap(err, "decode webp") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Encode the image as PNG | ||||||
|  | 	var pngBuffer bytes.Buffer | ||||||
|  | 	if err := png.Encode(&pngBuffer, img); err != nil { | ||||||
|  | 		return nil, errors.Wrap(err, "encode png") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return pngBuffer.Bytes(), nil | ||||||
|  | } | ||||||
							
								
								
									
										159
									
								
								relay/adaptor/replicate/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								relay/adaptor/replicate/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,159 @@ | |||||||
|  | package replicate | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/pkg/errors" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // DrawImageRequest draw image by fluxpro | ||||||
|  | // | ||||||
|  | // https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json | ||||||
|  | type DrawImageRequest struct { | ||||||
|  | 	Input ImageInput `json:"input"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ImageInput is input of DrawImageByFluxProRequest | ||||||
|  | // | ||||||
|  | // https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema | ||||||
|  | type ImageInput struct { | ||||||
|  | 	Steps           int    `json:"steps" binding:"required,min=1"` | ||||||
|  | 	Prompt          string `json:"prompt" binding:"required,min=5"` | ||||||
|  | 	ImagePrompt     string `json:"image_prompt"` | ||||||
|  | 	Guidance        int    `json:"guidance" binding:"required,min=2,max=5"` | ||||||
|  | 	Interval        int    `json:"interval" binding:"required,min=1,max=4"` | ||||||
|  | 	AspectRatio     string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"` | ||||||
|  | 	SafetyTolerance int    `json:"safety_tolerance" binding:"required,min=1,max=5"` | ||||||
|  | 	Seed            int    `json:"seed"` | ||||||
|  | 	NImages         int    `json:"n_images" binding:"required,min=1,max=8"` | ||||||
|  | 	Width           int    `json:"width" binding:"required,min=256,max=1440"` | ||||||
|  | 	Height          int    `json:"height" binding:"required,min=256,max=1440"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro | ||||||
|  | // | ||||||
|  | // https://replicate.com/black-forest-labs/flux-fill-pro/api/schema | ||||||
|  | type InpaintingImageByFlusReplicateRequest struct { | ||||||
|  | 	Input FluxInpaintingInput `json:"input"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // FluxInpaintingInput is input of DrawImageByFluxProRequest | ||||||
|  | // | ||||||
|  | // https://replicate.com/black-forest-labs/flux-fill-pro/api/schema | ||||||
|  | type FluxInpaintingInput struct { | ||||||
|  | 	Mask             string `json:"mask" binding:"required"` | ||||||
|  | 	Image            string `json:"image" binding:"required"` | ||||||
|  | 	Seed             int    `json:"seed"` | ||||||
|  | 	Steps            int    `json:"steps" binding:"required,min=1"` | ||||||
|  | 	Prompt           string `json:"prompt" binding:"required,min=5"` | ||||||
|  | 	Guidance         int    `json:"guidance" binding:"required,min=2,max=5"` | ||||||
|  | 	OutputFormat     string `json:"output_format"` | ||||||
|  | 	SafetyTolerance  int    `json:"safety_tolerance" binding:"required,min=1,max=5"` | ||||||
|  | 	PromptUnsampling bool   `json:"prompt_unsampling"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ImageResponse is response of DrawImageByFluxProRequest | ||||||
|  | // | ||||||
|  | // https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json | ||||||
|  | type ImageResponse struct { | ||||||
|  | 	CompletedAt time.Time        `json:"completed_at"` | ||||||
|  | 	CreatedAt   time.Time        `json:"created_at"` | ||||||
|  | 	DataRemoved bool             `json:"data_removed"` | ||||||
|  | 	Error       string           `json:"error"` | ||||||
|  | 	ID          string           `json:"id"` | ||||||
|  | 	Input       DrawImageRequest `json:"input"` | ||||||
|  | 	Logs        string           `json:"logs"` | ||||||
|  | 	Metrics     FluxMetrics      `json:"metrics"` | ||||||
|  | 	// Output could be `string` or `[]string` | ||||||
|  | 	Output    any       `json:"output"` | ||||||
|  | 	StartedAt time.Time `json:"started_at"` | ||||||
|  | 	Status    string    `json:"status"` | ||||||
|  | 	URLs      FluxURLs  `json:"urls"` | ||||||
|  | 	Version   string    `json:"version"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *ImageResponse) GetOutput() ([]string, error) { | ||||||
|  | 	switch v := r.Output.(type) { | ||||||
|  | 	case string: | ||||||
|  | 		return []string{v}, nil | ||||||
|  | 	case []string: | ||||||
|  | 		return v, nil | ||||||
|  | 	case nil: | ||||||
|  | 		return nil, nil | ||||||
|  | 	case []interface{}: | ||||||
|  | 		// convert []interface{} to []string | ||||||
|  | 		ret := make([]string, len(v)) | ||||||
|  | 		for idx, vv := range v { | ||||||
|  | 			if vvv, ok := vv.(string); ok { | ||||||
|  | 				ret[idx] = vvv | ||||||
|  | 			} else { | ||||||
|  | 				return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		return ret, nil | ||||||
|  | 	default: | ||||||
|  | 		return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // FluxMetrics is metrics of ImageResponse | ||||||
|  | type FluxMetrics struct { | ||||||
|  | 	ImageCount  int     `json:"image_count"` | ||||||
|  | 	PredictTime float64 `json:"predict_time"` | ||||||
|  | 	TotalTime   float64 `json:"total_time"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // FluxURLs is urls of ImageResponse | ||||||
|  | type FluxURLs struct { | ||||||
|  | 	Get    string `json:"get"` | ||||||
|  | 	Cancel string `json:"cancel"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ReplicateChatRequest struct { | ||||||
|  | 	Input ChatInput `json:"input" form:"input" binding:"required"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ChatInput is input of ChatByReplicateRequest | ||||||
|  | // | ||||||
|  | // https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema | ||||||
|  | type ChatInput struct { | ||||||
|  | 	TopK             int     `json:"top_k"` | ||||||
|  | 	TopP             float64 `json:"top_p"` | ||||||
|  | 	Prompt           string  `json:"prompt"` | ||||||
|  | 	MaxTokens        int     `json:"max_tokens"` | ||||||
|  | 	MinTokens        int     `json:"min_tokens"` | ||||||
|  | 	Temperature      float64 `json:"temperature"` | ||||||
|  | 	SystemPrompt     string  `json:"system_prompt"` | ||||||
|  | 	StopSequences    string  `json:"stop_sequences"` | ||||||
|  | 	PromptTemplate   string  `json:"prompt_template"` | ||||||
|  | 	PresencePenalty  float64 `json:"presence_penalty"` | ||||||
|  | 	FrequencyPenalty float64 `json:"frequency_penalty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ChatResponse is response of ChatByReplicateRequest | ||||||
|  | // | ||||||
|  | // https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json | ||||||
|  | type ChatResponse struct { | ||||||
|  | 	CompletedAt time.Time   `json:"completed_at"` | ||||||
|  | 	CreatedAt   time.Time   `json:"created_at"` | ||||||
|  | 	DataRemoved bool        `json:"data_removed"` | ||||||
|  | 	Error       string      `json:"error"` | ||||||
|  | 	ID          string      `json:"id"` | ||||||
|  | 	Input       ChatInput   `json:"input"` | ||||||
|  | 	Logs        string      `json:"logs"` | ||||||
|  | 	Metrics     FluxMetrics `json:"metrics"` | ||||||
|  | 	// Output could be `string` or `[]string` | ||||||
|  | 	Output    []string        `json:"output"` | ||||||
|  | 	StartedAt time.Time       `json:"started_at"` | ||||||
|  | 	Status    string          `json:"status"` | ||||||
|  | 	URLs      ChatResponseUrl `json:"urls"` | ||||||
|  | 	Version   string          `json:"version"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ChatResponseUrl is task urls of ChatResponse | ||||||
|  | type ChatResponseUrl struct { | ||||||
|  | 	Stream string `json:"stream"` | ||||||
|  | 	Get    string `json:"get"` | ||||||
|  | 	Cancel string `json:"cancel"` | ||||||
|  | } | ||||||
| @@ -15,7 +15,10 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| var ModelList = []string{ | var ModelList = []string{ | ||||||
| 	"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", "gemini-1.5-pro-002", "gemini-1.5-flash-002",  | 	"gemini-pro", "gemini-pro-vision", | ||||||
|  | 	"gemini-1.5-pro-001", "gemini-1.5-flash-001", | ||||||
|  | 	"gemini-1.5-pro-002", "gemini-1.5-flash-002", | ||||||
|  | 	"gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", | ||||||
| } | } | ||||||
|  |  | ||||||
| type Adaptor struct { | type Adaptor struct { | ||||||
|   | |||||||
| @@ -19,6 +19,7 @@ const ( | |||||||
| 	DeepL | 	DeepL | ||||||
| 	VertexAI | 	VertexAI | ||||||
| 	Proxy | 	Proxy | ||||||
|  | 	Replicate | ||||||
|  |  | ||||||
| 	Dummy // this one is only for count, do not add any channel after this | 	Dummy // this one is only for count, do not add any channel after this | ||||||
| ) | ) | ||||||
|   | |||||||
| @@ -37,6 +37,7 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"chatgpt-4o-latest":       2.5,   // $0.005 / 1K tokens | 	"chatgpt-4o-latest":       2.5,   // $0.005 / 1K tokens | ||||||
| 	"gpt-4o-2024-05-13":       2.5,   // $0.005 / 1K tokens | 	"gpt-4o-2024-05-13":       2.5,   // $0.005 / 1K tokens | ||||||
| 	"gpt-4o-2024-08-06":       1.25,  // $0.0025 / 1K tokens | 	"gpt-4o-2024-08-06":       1.25,  // $0.0025 / 1K tokens | ||||||
|  | 	"gpt-4o-2024-11-20":       1.25,  // $0.0025 / 1K tokens | ||||||
| 	"gpt-4o-mini":             0.075, // $0.00015 / 1K tokens | 	"gpt-4o-mini":             0.075, // $0.00015 / 1K tokens | ||||||
| 	"gpt-4o-mini-2024-07-18":  0.075, // $0.00015 / 1K tokens | 	"gpt-4o-mini-2024-07-18":  0.075, // $0.00015 / 1K tokens | ||||||
| 	"gpt-4-vision-preview":    5,     // $0.01 / 1K tokens | 	"gpt-4-vision-preview":    5,     // $0.01 / 1K tokens | ||||||
| @@ -48,8 +49,14 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"gpt-3.5-turbo-instruct":  0.75, // $0.0015 / 1K tokens | 	"gpt-3.5-turbo-instruct":  0.75, // $0.0015 / 1K tokens | ||||||
| 	"gpt-3.5-turbo-1106":      0.5,  // $0.001 / 1K tokens | 	"gpt-3.5-turbo-1106":      0.5,  // $0.001 / 1K tokens | ||||||
| 	"gpt-3.5-turbo-0125":      0.25, // $0.0005 / 1K tokens | 	"gpt-3.5-turbo-0125":      0.25, // $0.0005 / 1K tokens | ||||||
| 	"davinci-002":             1,    // $0.002 / 1K tokens | 	"o1":                      7.5,  // $15.00 / 1M input tokens | ||||||
| 	"babbage-002":             0.2,  // $0.0004 / 1K tokens | 	"o1-2024-12-17":           7.5, | ||||||
|  | 	"o1-preview":              7.5, // $15.00 / 1M input tokens | ||||||
|  | 	"o1-preview-2024-09-12":   7.5, | ||||||
|  | 	"o1-mini":                 1.5, // $3.00 / 1M input tokens | ||||||
|  | 	"o1-mini-2024-09-12":      1.5, | ||||||
|  | 	"davinci-002":             1,   // $0.002 / 1K tokens | ||||||
|  | 	"babbage-002":             0.2, // $0.0004 / 1K tokens | ||||||
| 	"text-ada-001":            0.2, | 	"text-ada-001":            0.2, | ||||||
| 	"text-babbage-001":        0.25, | 	"text-babbage-001":        0.25, | ||||||
| 	"text-curie-001":          1, | 	"text-curie-001":          1, | ||||||
| @@ -102,11 +109,15 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"bge-large-en":       0.002 * RMB, | 	"bge-large-en":       0.002 * RMB, | ||||||
| 	"tao-8k":             0.002 * RMB, | 	"tao-8k":             0.002 * RMB, | ||||||
| 	// https://ai.google.dev/pricing | 	// https://ai.google.dev/pricing | ||||||
| 	"gemini-pro":       1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | 	"gemini-pro":                    1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||||
| 	"gemini-1.0-pro":   1, | 	"gemini-1.0-pro":                1, | ||||||
| 	"gemini-1.5-flash": 1, | 	"gemini-1.5-pro":                1, | ||||||
| 	"gemini-1.5-pro":   1, | 	"gemini-1.5-pro-001":            1, | ||||||
| 	"aqa":              1, | 	"gemini-1.5-flash":              1, | ||||||
|  | 	"gemini-1.5-flash-001":          1, | ||||||
|  | 	"gemini-2.0-flash-exp":          1, | ||||||
|  | 	"gemini-2.0-flash-thinking-exp": 1, | ||||||
|  | 	"aqa":                           1, | ||||||
| 	// https://open.bigmodel.cn/pricing | 	// https://open.bigmodel.cn/pricing | ||||||
| 	"glm-4":         0.1 * RMB, | 	"glm-4":         0.1 * RMB, | ||||||
| 	"glm-4v":        0.1 * RMB, | 	"glm-4v":        0.1 * RMB, | ||||||
| @@ -118,29 +129,94 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"chatglm_lite":  0.1429, // ¥0.002 / 1k tokens | 	"chatglm_lite":  0.1429, // ¥0.002 / 1k tokens | ||||||
| 	"cogview-3":     0.25 * RMB, | 	"cogview-3":     0.25 * RMB, | ||||||
| 	// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing | 	// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing | ||||||
| 	"qwen-turbo":                0.5715, // ¥0.008 / 1k tokens | 	"qwen-turbo":                  1.4286, // ¥0.02 / 1k tokens | ||||||
| 	"qwen-plus":                 1.4286, // ¥0.02 / 1k tokens | 	"qwen-turbo-latest":           1.4286, | ||||||
| 	"qwen-max":                  1.4286, // ¥0.02 / 1k tokens | 	"qwen-plus":                   1.4286, | ||||||
| 	"qwen-max-longcontext":      1.4286, // ¥0.02 / 1k tokens | 	"qwen-plus-latest":            1.4286, | ||||||
| 	"text-embedding-v1":         0.05,   // ¥0.0007 / 1k tokens | 	"qwen-max":                    1.4286, | ||||||
| 	"ali-stable-diffusion-xl":   8, | 	"qwen-max-latest":             1.4286, | ||||||
| 	"ali-stable-diffusion-v1.5": 8, | 	"qwen-max-longcontext":        1.4286, | ||||||
| 	"wanx-v1":                   8, | 	"qwen-vl-max":                 1.4286, | ||||||
| 	"SparkDesk":                 1.2858, // ¥0.018 / 1k tokens | 	"qwen-vl-max-latest":          1.4286, | ||||||
| 	"SparkDesk-v1.1":            1.2858, // ¥0.018 / 1k tokens | 	"qwen-vl-plus":                1.4286, | ||||||
| 	"SparkDesk-v2.1":            1.2858, // ¥0.018 / 1k tokens | 	"qwen-vl-plus-latest":         1.4286, | ||||||
| 	"SparkDesk-v3.1":            1.2858, // ¥0.018 / 1k tokens | 	"qwen-vl-ocr":                 1.4286, | ||||||
| 	"SparkDesk-v3.1-128K":       1.2858, // ¥0.018 / 1k tokens | 	"qwen-vl-ocr-latest":          1.4286, | ||||||
| 	"SparkDesk-v3.5":            1.2858, // ¥0.018 / 1k tokens | 	"qwen-audio-turbo":            1.4286, | ||||||
| 	"SparkDesk-v3.5-32K":        1.2858, // ¥0.018 / 1k tokens | 	"qwen-math-plus":              1.4286, | ||||||
| 	"SparkDesk-v4.0":            1.2858, // ¥0.018 / 1k tokens | 	"qwen-math-plus-latest":       1.4286, | ||||||
| 	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens | 	"qwen-math-turbo":             1.4286, | ||||||
| 	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens | 	"qwen-math-turbo-latest":      1.4286, | ||||||
| 	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens | 	"qwen-coder-plus":             1.4286, | ||||||
| 	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens | 	"qwen-coder-plus-latest":      1.4286, | ||||||
| 	"hunyuan":                   7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 | 	"qwen-coder-turbo":            1.4286, | ||||||
| 	"ChatStd":                   0.01 * RMB, | 	"qwen-coder-turbo-latest":     1.4286, | ||||||
| 	"ChatPro":                   0.1 * RMB, | 	"qwq-32b-preview":             1.4286, | ||||||
|  | 	"qwen2.5-72b-instruct":        1.4286, | ||||||
|  | 	"qwen2.5-32b-instruct":        1.4286, | ||||||
|  | 	"qwen2.5-14b-instruct":        1.4286, | ||||||
|  | 	"qwen2.5-7b-instruct":         1.4286, | ||||||
|  | 	"qwen2.5-3b-instruct":         1.4286, | ||||||
|  | 	"qwen2.5-1.5b-instruct":       1.4286, | ||||||
|  | 	"qwen2.5-0.5b-instruct":       1.4286, | ||||||
|  | 	"qwen2-72b-instruct":          1.4286, | ||||||
|  | 	"qwen2-57b-a14b-instruct":     1.4286, | ||||||
|  | 	"qwen2-7b-instruct":           1.4286, | ||||||
|  | 	"qwen2-1.5b-instruct":         1.4286, | ||||||
|  | 	"qwen2-0.5b-instruct":         1.4286, | ||||||
|  | 	"qwen1.5-110b-chat":           1.4286, | ||||||
|  | 	"qwen1.5-72b-chat":            1.4286, | ||||||
|  | 	"qwen1.5-32b-chat":            1.4286, | ||||||
|  | 	"qwen1.5-14b-chat":            1.4286, | ||||||
|  | 	"qwen1.5-7b-chat":             1.4286, | ||||||
|  | 	"qwen1.5-1.8b-chat":           1.4286, | ||||||
|  | 	"qwen1.5-0.5b-chat":           1.4286, | ||||||
|  | 	"qwen-72b-chat":               1.4286, | ||||||
|  | 	"qwen-14b-chat":               1.4286, | ||||||
|  | 	"qwen-7b-chat":                1.4286, | ||||||
|  | 	"qwen-1.8b-chat":              1.4286, | ||||||
|  | 	"qwen-1.8b-longcontext-chat":  1.4286, | ||||||
|  | 	"qwen2-vl-7b-instruct":        1.4286, | ||||||
|  | 	"qwen2-vl-2b-instruct":        1.4286, | ||||||
|  | 	"qwen-vl-v1":                  1.4286, | ||||||
|  | 	"qwen-vl-chat-v1":             1.4286, | ||||||
|  | 	"qwen2-audio-instruct":        1.4286, | ||||||
|  | 	"qwen-audio-chat":             1.4286, | ||||||
|  | 	"qwen2.5-math-72b-instruct":   1.4286, | ||||||
|  | 	"qwen2.5-math-7b-instruct":    1.4286, | ||||||
|  | 	"qwen2.5-math-1.5b-instruct":  1.4286, | ||||||
|  | 	"qwen2-math-72b-instruct":     1.4286, | ||||||
|  | 	"qwen2-math-7b-instruct":      1.4286, | ||||||
|  | 	"qwen2-math-1.5b-instruct":    1.4286, | ||||||
|  | 	"qwen2.5-coder-32b-instruct":  1.4286, | ||||||
|  | 	"qwen2.5-coder-14b-instruct":  1.4286, | ||||||
|  | 	"qwen2.5-coder-7b-instruct":   1.4286, | ||||||
|  | 	"qwen2.5-coder-3b-instruct":   1.4286, | ||||||
|  | 	"qwen2.5-coder-1.5b-instruct": 1.4286, | ||||||
|  | 	"qwen2.5-coder-0.5b-instruct": 1.4286, | ||||||
|  | 	"text-embedding-v1":           0.05, // ¥0.0007 / 1k tokens | ||||||
|  | 	"text-embedding-v3":           0.05, | ||||||
|  | 	"text-embedding-v2":           0.05, | ||||||
|  | 	"text-embedding-async-v2":     0.05, | ||||||
|  | 	"text-embedding-async-v1":     0.05, | ||||||
|  | 	"ali-stable-diffusion-xl":     8.00, | ||||||
|  | 	"ali-stable-diffusion-v1.5":   8.00, | ||||||
|  | 	"wanx-v1":                     8.00, | ||||||
|  | 	"SparkDesk":                   1.2858, // ¥0.018 / 1k tokens | ||||||
|  | 	"SparkDesk-v1.1":              1.2858, // ¥0.018 / 1k tokens | ||||||
|  | 	"SparkDesk-v2.1":              1.2858, // ¥0.018 / 1k tokens | ||||||
|  | 	"SparkDesk-v3.1":              1.2858, // ¥0.018 / 1k tokens | ||||||
|  | 	"SparkDesk-v3.1-128K":         1.2858, // ¥0.018 / 1k tokens | ||||||
|  | 	"SparkDesk-v3.5":              1.2858, // ¥0.018 / 1k tokens | ||||||
|  | 	"SparkDesk-v3.5-32K":          1.2858, // ¥0.018 / 1k tokens | ||||||
|  | 	"SparkDesk-v4.0":              1.2858, // ¥0.018 / 1k tokens | ||||||
|  | 	"360GPT_S2_V9":                0.8572, // ¥0.012 / 1k tokens | ||||||
|  | 	"embedding-bert-512-v1":       0.0715, // ¥0.001 / 1k tokens | ||||||
|  | 	"embedding_s1_v1":             0.0715, // ¥0.001 / 1k tokens | ||||||
|  | 	"semantic_similarity_s1_v1":   0.0715, // ¥0.001 / 1k tokens | ||||||
|  | 	"hunyuan":                     7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 | ||||||
|  | 	"ChatStd":                     0.01 * RMB, | ||||||
|  | 	"ChatPro":                     0.1 * RMB, | ||||||
| 	// https://platform.moonshot.cn/pricing | 	// https://platform.moonshot.cn/pricing | ||||||
| 	"moonshot-v1-8k":   0.012 * RMB, | 	"moonshot-v1-8k":   0.012 * RMB, | ||||||
| 	"moonshot-v1-32k":  0.024 * RMB, | 	"moonshot-v1-32k":  0.024 * RMB, | ||||||
| @@ -211,6 +287,50 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"deepl-ja": 25.0 / 1000 * USD, | 	"deepl-ja": 25.0 / 1000 * USD, | ||||||
| 	// https://console.x.ai/ | 	// https://console.x.ai/ | ||||||
| 	"grok-beta": 5.0 / 1000 * USD, | 	"grok-beta": 5.0 / 1000 * USD, | ||||||
|  | 	// replicate charges based on the number of generated images | ||||||
|  | 	// https://replicate.com/pricing | ||||||
|  | 	"black-forest-labs/flux-1.1-pro":                0.04 * USD, | ||||||
|  | 	"black-forest-labs/flux-1.1-pro-ultra":          0.06 * USD, | ||||||
|  | 	"black-forest-labs/flux-canny-dev":              0.025 * USD, | ||||||
|  | 	"black-forest-labs/flux-canny-pro":              0.05 * USD, | ||||||
|  | 	"black-forest-labs/flux-depth-dev":              0.025 * USD, | ||||||
|  | 	"black-forest-labs/flux-depth-pro":              0.05 * USD, | ||||||
|  | 	"black-forest-labs/flux-dev":                    0.025 * USD, | ||||||
|  | 	"black-forest-labs/flux-dev-lora":               0.032 * USD, | ||||||
|  | 	"black-forest-labs/flux-fill-dev":               0.04 * USD, | ||||||
|  | 	"black-forest-labs/flux-fill-pro":               0.05 * USD, | ||||||
|  | 	"black-forest-labs/flux-pro":                    0.055 * USD, | ||||||
|  | 	"black-forest-labs/flux-redux-dev":              0.025 * USD, | ||||||
|  | 	"black-forest-labs/flux-redux-schnell":          0.003 * USD, | ||||||
|  | 	"black-forest-labs/flux-schnell":                0.003 * USD, | ||||||
|  | 	"black-forest-labs/flux-schnell-lora":           0.02 * USD, | ||||||
|  | 	"ideogram-ai/ideogram-v2":                       0.08 * USD, | ||||||
|  | 	"ideogram-ai/ideogram-v2-turbo":                 0.05 * USD, | ||||||
|  | 	"recraft-ai/recraft-v3":                         0.04 * USD, | ||||||
|  | 	"recraft-ai/recraft-v3-svg":                     0.08 * USD, | ||||||
|  | 	"stability-ai/stable-diffusion-3":               0.035 * USD, | ||||||
|  | 	"stability-ai/stable-diffusion-3.5-large":       0.065 * USD, | ||||||
|  | 	"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD, | ||||||
|  | 	"stability-ai/stable-diffusion-3.5-medium":      0.035 * USD, | ||||||
|  | 	// replicate chat models | ||||||
|  | 	"ibm-granite/granite-20b-code-instruct-8k":  0.100 * USD, | ||||||
|  | 	"ibm-granite/granite-3.0-2b-instruct":       0.030 * USD, | ||||||
|  | 	"ibm-granite/granite-3.0-8b-instruct":       0.050 * USD, | ||||||
|  | 	"ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD, | ||||||
|  | 	"meta/llama-2-13b":                          0.100 * USD, | ||||||
|  | 	"meta/llama-2-13b-chat":                     0.100 * USD, | ||||||
|  | 	"meta/llama-2-70b":                          0.650 * USD, | ||||||
|  | 	"meta/llama-2-70b-chat":                     0.650 * USD, | ||||||
|  | 	"meta/llama-2-7b":                           0.050 * USD, | ||||||
|  | 	"meta/llama-2-7b-chat":                      0.050 * USD, | ||||||
|  | 	"meta/meta-llama-3.1-405b-instruct":         9.500 * USD, | ||||||
|  | 	"meta/meta-llama-3-70b":                     0.650 * USD, | ||||||
|  | 	"meta/meta-llama-3-70b-instruct":            0.650 * USD, | ||||||
|  | 	"meta/meta-llama-3-8b":                      0.050 * USD, | ||||||
|  | 	"meta/meta-llama-3-8b-instruct":             0.050 * USD, | ||||||
|  | 	"mistralai/mistral-7b-instruct-v0.2":        0.050 * USD, | ||||||
|  | 	"mistralai/mistral-7b-v0.1":                 0.050 * USD, | ||||||
|  | 	"mistralai/mixtral-8x7b-instruct-v0.1":      0.300 * USD, | ||||||
| } | } | ||||||
|  |  | ||||||
| var CompletionRatio = map[string]float64{ | var CompletionRatio = map[string]float64{ | ||||||
| @@ -334,16 +454,22 @@ func GetCompletionRatio(name string, channelType int) float64 { | |||||||
| 		return 4.0 / 3.0 | 		return 4.0 / 3.0 | ||||||
| 	} | 	} | ||||||
| 	if strings.HasPrefix(name, "gpt-4") { | 	if strings.HasPrefix(name, "gpt-4") { | ||||||
| 		if strings.HasPrefix(name, "gpt-4o-mini") || name == "gpt-4o-2024-08-06" { | 		if strings.HasPrefix(name, "gpt-4o") { | ||||||
|  | 			if name == "gpt-4o-2024-05-13" { | ||||||
|  | 				return 3 | ||||||
|  | 			} | ||||||
| 			return 4 | 			return 4 | ||||||
| 		} | 		} | ||||||
| 		if strings.HasPrefix(name, "gpt-4-turbo") || | 		if strings.HasPrefix(name, "gpt-4-turbo") || | ||||||
| 			strings.HasPrefix(name, "gpt-4o") || |  | ||||||
| 			strings.HasSuffix(name, "preview") { | 			strings.HasSuffix(name, "preview") { | ||||||
| 			return 3 | 			return 3 | ||||||
| 		} | 		} | ||||||
| 		return 2 | 		return 2 | ||||||
| 	} | 	} | ||||||
|  | 	// including o1, o1-preview, o1-mini | ||||||
|  | 	if strings.HasPrefix(name, "o1") { | ||||||
|  | 		return 4 | ||||||
|  | 	} | ||||||
| 	if name == "chatgpt-4o-latest" { | 	if name == "chatgpt-4o-latest" { | ||||||
| 		return 3 | 		return 3 | ||||||
| 	} | 	} | ||||||
| @@ -362,6 +488,7 @@ func GetCompletionRatio(name string, channelType int) float64 { | |||||||
| 	if strings.HasPrefix(name, "deepseek-") { | 	if strings.HasPrefix(name, "deepseek-") { | ||||||
| 		return 2 | 		return 2 | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	switch name { | 	switch name { | ||||||
| 	case "llama2-70b-4096": | 	case "llama2-70b-4096": | ||||||
| 		return 0.8 / 0.64 | 		return 0.8 / 0.64 | ||||||
| @@ -377,6 +504,35 @@ func GetCompletionRatio(name string, channelType int) float64 { | |||||||
| 		return 5 | 		return 5 | ||||||
| 	case "grok-beta": | 	case "grok-beta": | ||||||
| 		return 3 | 		return 3 | ||||||
|  | 	// Replicate Models | ||||||
|  | 	// https://replicate.com/pricing | ||||||
|  | 	case "ibm-granite/granite-20b-code-instruct-8k": | ||||||
|  | 		return 5 | ||||||
|  | 	case "ibm-granite/granite-3.0-2b-instruct": | ||||||
|  | 		return 8.333333333333334 | ||||||
|  | 	case "ibm-granite/granite-3.0-8b-instruct", | ||||||
|  | 		"ibm-granite/granite-8b-code-instruct-128k": | ||||||
|  | 		return 5 | ||||||
|  | 	case "meta/llama-2-13b", | ||||||
|  | 		"meta/llama-2-13b-chat", | ||||||
|  | 		"meta/llama-2-7b", | ||||||
|  | 		"meta/llama-2-7b-chat", | ||||||
|  | 		"meta/meta-llama-3-8b", | ||||||
|  | 		"meta/meta-llama-3-8b-instruct": | ||||||
|  | 		return 5 | ||||||
|  | 	case "meta/llama-2-70b", | ||||||
|  | 		"meta/llama-2-70b-chat", | ||||||
|  | 		"meta/meta-llama-3-70b", | ||||||
|  | 		"meta/meta-llama-3-70b-instruct": | ||||||
|  | 		return 2.750 / 0.650 // ≈4.230769 | ||||||
|  | 	case "meta/meta-llama-3.1-405b-instruct": | ||||||
|  | 		return 1 | ||||||
|  | 	case "mistralai/mistral-7b-instruct-v0.2", | ||||||
|  | 		"mistralai/mistral-7b-v0.1": | ||||||
|  | 		return 5 | ||||||
|  | 	case "mistralai/mixtral-8x7b-instruct-v0.1": | ||||||
|  | 		return 1.000 / 0.300 // ≈3.333333 | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return 1 | 	return 1 | ||||||
| } | } | ||||||
|   | |||||||
| @@ -47,5 +47,6 @@ const ( | |||||||
| 	Proxy | 	Proxy | ||||||
| 	SiliconFlow | 	SiliconFlow | ||||||
| 	XAI | 	XAI | ||||||
|  | 	Replicate | ||||||
| 	Dummy | 	Dummy | ||||||
| ) | ) | ||||||
|   | |||||||
| @@ -37,6 +37,8 @@ func ToAPIType(channelType int) int { | |||||||
| 		apiType = apitype.DeepL | 		apiType = apitype.DeepL | ||||||
| 	case VertextAI: | 	case VertextAI: | ||||||
| 		apiType = apitype.VertexAI | 		apiType = apitype.VertexAI | ||||||
|  | 	case Replicate: | ||||||
|  | 		apiType = apitype.Replicate | ||||||
| 	case Proxy: | 	case Proxy: | ||||||
| 		apiType = apitype.Proxy | 		apiType = apitype.Proxy | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -47,6 +47,7 @@ var ChannelBaseURLs = []string{ | |||||||
| 	"",                                          // 43 | 	"",                                          // 43 | ||||||
| 	"https://api.siliconflow.cn",                // 44 | 	"https://api.siliconflow.cn",                // 44 | ||||||
| 	"https://api.x.ai",                          // 45 | 	"https://api.x.ai",                          // 45 | ||||||
|  | 	"https://api.replicate.com/v1/models/",      // 46 | ||||||
| } | } | ||||||
|  |  | ||||||
| func init() { | func init() { | ||||||
|   | |||||||
| @@ -1,5 +1,6 @@ | |||||||
| package role | package role | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
|  | 	System    = "system" | ||||||
| 	Assistant = "assistant" | 	Assistant = "assistant" | ||||||
| ) | ) | ||||||
|   | |||||||
| @@ -110,16 +110,9 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | |||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	// map model name | 	// map model name | ||||||
| 	modelMapping := c.GetString(ctxkey.ModelMapping) | 	modelMapping := c.GetStringMapString(ctxkey.ModelMapping) | ||||||
| 	if modelMapping != "" { | 	if modelMapping != nil && modelMapping[audioModel] != "" { | ||||||
| 		modelMap := make(map[string]string) | 		audioModel = modelMapping[audioModel] | ||||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		if modelMap[audioModel] != "" { |  | ||||||
| 			audioModel = modelMap[audioModel] |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	baseURL := channeltype.ChannelBaseURLs[channelType] | 	baseURL := channeltype.ChannelBaseURLs[channelType] | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import ( | |||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/constant/role" | ||||||
| 	"math" | 	"math" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"strings" | 	"strings" | ||||||
| @@ -90,7 +91,7 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR | |||||||
| 	return preConsumedQuota, nil | 	return preConsumedQuota, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { | func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) { | ||||||
| 	if usage == nil { | 	if usage == nil { | ||||||
| 		logger.Error(ctx, "usage is nil, which is unexpected") | 		logger.Error(ctx, "usage is nil, which is unexpected") | ||||||
| 		return | 		return | ||||||
| @@ -118,7 +119,11 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Error(ctx, "error update user quota cache: "+err.Error()) | 		logger.Error(ctx, "error update user quota cache: "+err.Error()) | ||||||
| 	} | 	} | ||||||
| 	logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) | 	var extraLog string | ||||||
|  | 	if systemPromptReset { | ||||||
|  | 		extraLog = " (注意系统提示词已被重置)" | ||||||
|  | 	} | ||||||
|  | 	logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f%s", modelRatio, groupRatio, completionRatio, extraLog) | ||||||
| 	model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) | 	model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) | ||||||
| 	model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) | 	model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) | ||||||
| 	model.UpdateChannelUsedQuota(meta.ChannelId, quota) | 	model.UpdateChannelUsedQuota(meta.ChannelId, quota) | ||||||
| @@ -142,15 +147,41 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool { | |||||||
| 		} | 		} | ||||||
| 		return true | 		return true | ||||||
| 	} | 	} | ||||||
| 	if resp.StatusCode != http.StatusOK { | 	if resp.StatusCode != http.StatusOK && | ||||||
|  | 		// replicate return 201 to create a task | ||||||
|  | 		resp.StatusCode != http.StatusCreated { | ||||||
| 		return true | 		return true | ||||||
| 	} | 	} | ||||||
| 	if meta.ChannelType == channeltype.DeepL { | 	if meta.ChannelType == channeltype.DeepL { | ||||||
| 		// skip stream check for deepl | 		// skip stream check for deepl | ||||||
| 		return false | 		return false | ||||||
| 	} | 	} | ||||||
| 	if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { |  | ||||||
|  | 	if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") && | ||||||
|  | 		// Even if stream mode is enabled, replicate will first return a task info in JSON format, | ||||||
|  | 		// requiring the client to request the stream endpoint in the task info | ||||||
|  | 		meta.ChannelType != channeltype.Replicate { | ||||||
| 		return true | 		return true | ||||||
| 	} | 	} | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIRequest, prompt string) (reset bool) { | ||||||
|  | 	if prompt == "" { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if len(request.Messages) == 0 { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if request.Messages[0].Role == role.System { | ||||||
|  | 		request.Messages[0].Content = prompt | ||||||
|  | 		logger.Infof(ctx, "rewrite system prompt") | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	request.Messages = append([]relaymodel.Message{{ | ||||||
|  | 		Role:    role.System, | ||||||
|  | 		Content: prompt, | ||||||
|  | 	}}, request.Messages...) | ||||||
|  | 	logger.Infof(ctx, "add system prompt") | ||||||
|  | 	return true | ||||||
|  | } | ||||||
|   | |||||||
| @@ -22,7 +22,7 @@ import ( | |||||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { | func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) { | ||||||
| 	imageRequest := &relaymodel.ImageRequest{} | 	imageRequest := &relaymodel.ImageRequest{} | ||||||
| 	err := common.UnmarshalBodyReusable(c, imageRequest) | 	err := common.UnmarshalBodyReusable(c, imageRequest) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -65,7 +65,7 @@ func getImageSizeRatio(model string, size string) float64 { | |||||||
| 	return 1 | 	return 1 | ||||||
| } | } | ||||||
|  |  | ||||||
| func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { | func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode { | ||||||
| 	// check prompt length | 	// check prompt length | ||||||
| 	if imageRequest.Prompt == "" { | 	if imageRequest.Prompt == "" { | ||||||
| 		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | 		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | ||||||
| @@ -150,12 +150,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | |||||||
| 	} | 	} | ||||||
| 	adaptor.Init(meta) | 	adaptor.Init(meta) | ||||||
|  |  | ||||||
|  | 	// these adaptors need to convert the request | ||||||
| 	switch meta.ChannelType { | 	switch meta.ChannelType { | ||||||
| 	case channeltype.Ali: | 	case channeltype.Zhipu, | ||||||
| 		fallthrough | 		channeltype.Ali, | ||||||
| 	case channeltype.Baidu: | 		channeltype.Replicate, | ||||||
| 		fallthrough | 		channeltype.Baidu: | ||||||
| 	case channeltype.Zhipu: |  | ||||||
| 		finalRequest, err := adaptor.ConvertImageRequest(imageRequest) | 		finalRequest, err := adaptor.ConvertImageRequest(imageRequest) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) | ||||||
| @@ -172,7 +172,14 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | |||||||
| 	ratio := modelRatio * groupRatio | 	ratio := modelRatio * groupRatio | ||||||
| 	userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) | 	userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) | ||||||
|  |  | ||||||
| 	quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) | 	var quota int64 | ||||||
|  | 	switch meta.ChannelType { | ||||||
|  | 	case channeltype.Replicate: | ||||||
|  | 		// replicate always return 1 image | ||||||
|  | 		quota = int64(ratio * imageCostRatio * 1000) | ||||||
|  | 	default: | ||||||
|  | 		quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if userQuota-quota < 0 { | 	if userQuota-quota < 0 { | ||||||
| 		return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | 		return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||||
| @@ -186,7 +193,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	defer func(ctx context.Context) { | 	defer func(ctx context.Context) { | ||||||
| 		if resp != nil && resp.StatusCode != http.StatusOK { | 		if resp != nil && | ||||||
|  | 			resp.StatusCode != http.StatusCreated && // replicate returns 201 | ||||||
|  | 			resp.StatusCode != http.StatusOK { | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import ( | |||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  |  | ||||||
| @@ -35,6 +36,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { | |||||||
| 	meta.OriginModelName = textRequest.Model | 	meta.OriginModelName = textRequest.Model | ||||||
| 	textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping) | 	textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping) | ||||||
| 	meta.ActualModelName = textRequest.Model | 	meta.ActualModelName = textRequest.Model | ||||||
|  | 	// set system prompt if not empty | ||||||
|  | 	systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt) | ||||||
| 	// get model ratio & group ratio | 	// get model ratio & group ratio | ||||||
| 	modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) | 	modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) | ||||||
| 	groupRatio := billingratio.GetGroupRatio(meta.Group) | 	groupRatio := billingratio.GetGroupRatio(meta.Group) | ||||||
| @@ -79,12 +82,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { | |||||||
| 		return respErr | 		return respErr | ||||||
| 	} | 	} | ||||||
| 	// post-consume quota | 	// post-consume quota | ||||||
| 	go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) | 	go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) { | func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) { | ||||||
| 	if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan { | 	if !config.EnforceIncludeUsage && meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan { | ||||||
| 		// no need to convert request for openai | 		// no need to convert request for openai | ||||||
| 		return c.Request.Body, nil | 		return c.Request.Body, nil | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -30,6 +30,7 @@ type Meta struct { | |||||||
| 	ActualModelName string | 	ActualModelName string | ||||||
| 	RequestURLPath  string | 	RequestURLPath  string | ||||||
| 	PromptTokens    int // only for DoResponse | 	PromptTokens    int // only for DoResponse | ||||||
|  | 	SystemPrompt    string | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetByContext(c *gin.Context) *Meta { | func GetByContext(c *gin.Context) *Meta { | ||||||
| @@ -46,6 +47,7 @@ func GetByContext(c *gin.Context) *Meta { | |||||||
| 		BaseURL:         c.GetString(ctxkey.BaseURL), | 		BaseURL:         c.GetString(ctxkey.BaseURL), | ||||||
| 		APIKey:          strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), | 		APIKey:          strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), | ||||||
| 		RequestURLPath:  c.Request.URL.String(), | 		RequestURLPath:  c.Request.URL.String(), | ||||||
|  | 		SystemPrompt:    c.GetString(ctxkey.SystemPrompt), | ||||||
| 	} | 	} | ||||||
| 	cfg, ok := c.Get(ctxkey.Config) | 	cfg, ok := c.Get(ctxkey.Config) | ||||||
| 	if ok { | 	if ok { | ||||||
|   | |||||||
| @@ -9,6 +9,7 @@ import ( | |||||||
|  |  | ||||||
| func SetRelayRouter(router *gin.Engine) { | func SetRelayRouter(router *gin.Engine) { | ||||||
| 	router.Use(middleware.CORS()) | 	router.Use(middleware.CORS()) | ||||||
|  | 	router.Use(middleware.GzipDecodeMiddleware()) | ||||||
| 	// https://platform.openai.com/docs/api-reference/introduction | 	// https://platform.openai.com/docs/api-reference/introduction | ||||||
| 	modelsRouter := router.Group("/v1/models") | 	modelsRouter := router.Group("/v1/models") | ||||||
| 	modelsRouter.Use(middleware.TokenAuth()) | 	modelsRouter.Use(middleware.TokenAuth()) | ||||||
|   | |||||||
| @@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [ | |||||||
|   { key: 43, text: 'Proxy', value: 43, color: 'blue' }, |   { key: 43, text: 'Proxy', value: 43, color: 'blue' }, | ||||||
|   { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' }, |   { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' }, | ||||||
|   { key: 45, text: 'xAI', value: 45, color: 'blue' }, |   { key: 45, text: 'xAI', value: 45, color: 'blue' }, | ||||||
|  |   { key: 46, text: 'Replicate', value: 46, color: 'blue' }, | ||||||
|   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, |   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, | ||||||
|   { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, |   { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, | ||||||
|   { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, |   { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, | ||||||
|   | |||||||
| @@ -43,6 +43,7 @@ const EditChannel = (props) => { | |||||||
|         base_url: '', |         base_url: '', | ||||||
|         other: '', |         other: '', | ||||||
|         model_mapping: '', |         model_mapping: '', | ||||||
|  |         system_prompt: '', | ||||||
|         models: [], |         models: [], | ||||||
|         auto_ban: 1, |         auto_ban: 1, | ||||||
|         groups: ['default'] |         groups: ['default'] | ||||||
| @@ -304,163 +305,163 @@ const EditChannel = (props) => { | |||||||
|                 width={isMobile() ? '100%' : 600} |                 width={isMobile() ? '100%' : 600} | ||||||
|             > |             > | ||||||
|                 <Spin spinning={loading}> |                 <Spin spinning={loading}> | ||||||
|                     <div style={{marginTop: 10}}> |                     <div style={{ marginTop: 10 }}> | ||||||
|                         <Typography.Text strong>类型:</Typography.Text> |                         <Typography.Text strong>类型:</Typography.Text> | ||||||
|                     </div> |                     </div> | ||||||
|                     <Select |                     <Select | ||||||
|                         name='type' |                       name='type' | ||||||
|                         required |                       required | ||||||
|                         optionList={CHANNEL_OPTIONS} |                       optionList={CHANNEL_OPTIONS} | ||||||
|                         value={inputs.type} |                       value={inputs.type} | ||||||
|                         onChange={value => handleInputChange('type', value)} |                       onChange={value => handleInputChange('type', value)} | ||||||
|                         style={{width: '50%'}} |                       style={{ width: '50%' }} | ||||||
|                     /> |                     /> | ||||||
|                     { |                     { | ||||||
|                         inputs.type === 3 && ( |                       inputs.type === 3 && ( | ||||||
|                             <> |                         <> | ||||||
|                                 <div style={{marginTop: 10}}> |                             <div style={{ marginTop: 10 }}> | ||||||
|                                     <Banner type={"warning"} description={ |                                 <Banner type={"warning"} description={ | ||||||
|                                         <> |                                     <> | ||||||
|                                             注意,<strong>模型部署名称必须和模型名称保持一致</strong>,因为 One API 会把请求体中的 |                                         注意,<strong>模型部署名称必须和模型名称保持一致</strong>,因为 One API 会把请求体中的 | ||||||
|                                             model |                                         model | ||||||
|                                             参数替换为你的部署名称(模型名称中的点会被剔除),<a target='_blank' |                                         参数替换为你的部署名称(模型名称中的点会被剔除),<a target='_blank' | ||||||
|                                                                                               href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示</a>。 |                                                                                           href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示</a>。 | ||||||
|                                         </> |                                     </> | ||||||
|                                     }> |                                 }> | ||||||
|                                     </Banner> |                                 </Banner> | ||||||
|                                 </div> |                             </div> | ||||||
|                                 <div style={{marginTop: 10}}> |                             <div style={{ marginTop: 10 }}> | ||||||
|                                     <Typography.Text strong>AZURE_OPENAI_ENDPOINT:</Typography.Text> |                                 <Typography.Text strong>AZURE_OPENAI_ENDPOINT:</Typography.Text> | ||||||
|                                 </div> |                             </div> | ||||||
|                                 <Input |                             <Input | ||||||
|                                     label='AZURE_OPENAI_ENDPOINT' |                               label='AZURE_OPENAI_ENDPOINT' | ||||||
|                                     name='azure_base_url' |                               name='azure_base_url' | ||||||
|                                     placeholder={'请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com'} |                               placeholder={'请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com'} | ||||||
|                                     onChange={value => { |                               onChange={value => { | ||||||
|                                         handleInputChange('base_url', value) |                                   handleInputChange('base_url', value) | ||||||
|                                     }} |                               }} | ||||||
|                                     value={inputs.base_url} |                               value={inputs.base_url} | ||||||
|                                     autoComplete='new-password' |                               autoComplete='new-password' | ||||||
|                                 /> |                             /> | ||||||
|                                 <div style={{marginTop: 10}}> |                             <div style={{ marginTop: 10 }}> | ||||||
|                                     <Typography.Text strong>默认 API 版本:</Typography.Text> |                                 <Typography.Text strong>默认 API 版本:</Typography.Text> | ||||||
|                                 </div> |                             </div> | ||||||
|                                 <Input |                             <Input | ||||||
|                                     label='默认 API 版本' |                               label='默认 API 版本' | ||||||
|                                     name='azure_other' |                               name='azure_other' | ||||||
|                                     placeholder={'请输入默认 API 版本,例如:2024-03-01-preview,该配置可以被实际的请求查询参数所覆盖'} |                               placeholder={'请输入默认 API 版本,例如:2024-03-01-preview,该配置可以被实际的请求查询参数所覆盖'} | ||||||
|                                     onChange={value => { |                               onChange={value => { | ||||||
|                                         handleInputChange('other', value) |                                   handleInputChange('other', value) | ||||||
|                                     }} |                               }} | ||||||
|                                     value={inputs.other} |                               value={inputs.other} | ||||||
|                                     autoComplete='new-password' |                               autoComplete='new-password' | ||||||
|                                 /> |                             /> | ||||||
|                             </> |                         </> | ||||||
|                         ) |                       ) | ||||||
|                     } |                     } | ||||||
|                     { |                     { | ||||||
|                         inputs.type === 8 && ( |                       inputs.type === 8 && ( | ||||||
|                             <> |                         <> | ||||||
|                                 <div style={{marginTop: 10}}> |                             <div style={{ marginTop: 10 }}> | ||||||
|                                     <Typography.Text strong>Base URL:</Typography.Text> |                                 <Typography.Text strong>Base URL:</Typography.Text> | ||||||
|                                 </div> |                             </div> | ||||||
|                                 <Input |                             <Input | ||||||
|                                     name='base_url' |                               name='base_url' | ||||||
|                                     placeholder={'请输入自定义渠道的 Base URL'} |                               placeholder={'请输入自定义渠道的 Base URL'} | ||||||
|                                     onChange={value => { |                               onChange={value => { | ||||||
|                                         handleInputChange('base_url', value) |                                   handleInputChange('base_url', value) | ||||||
|                                     }} |                               }} | ||||||
|                                     value={inputs.base_url} |                               value={inputs.base_url} | ||||||
|                                     autoComplete='new-password' |                               autoComplete='new-password' | ||||||
|                                 /> |                             /> | ||||||
|                             </> |                         </> | ||||||
|                         ) |                       ) | ||||||
|                     } |                     } | ||||||
|                     <div style={{marginTop: 10}}> |                     <div style={{ marginTop: 10 }}> | ||||||
|                         <Typography.Text strong>名称:</Typography.Text> |                         <Typography.Text strong>名称:</Typography.Text> | ||||||
|                     </div> |                     </div> | ||||||
|                     <Input |                     <Input | ||||||
|                         required |                       required | ||||||
|                         name='name' |                       name='name' | ||||||
|                         placeholder={'请为渠道命名'} |                       placeholder={'请为渠道命名'} | ||||||
|                         onChange={value => { |                       onChange={value => { | ||||||
|                             handleInputChange('name', value) |                           handleInputChange('name', value) | ||||||
|                         }} |                       }} | ||||||
|                         value={inputs.name} |                       value={inputs.name} | ||||||
|                         autoComplete='new-password' |                       autoComplete='new-password' | ||||||
|                     /> |                     /> | ||||||
|                     <div style={{marginTop: 10}}> |                     <div style={{ marginTop: 10 }}> | ||||||
|                         <Typography.Text strong>分组:</Typography.Text> |                         <Typography.Text strong>分组:</Typography.Text> | ||||||
|                     </div> |                     </div> | ||||||
|                     <Select |                     <Select | ||||||
|                         placeholder={'请选择可以使用该渠道的分组'} |                       placeholder={'请选择可以使用该渠道的分组'} | ||||||
|                         name='groups' |                       name='groups' | ||||||
|                         required |                       required | ||||||
|                         multiple |                       multiple | ||||||
|                         selection |                       selection | ||||||
|                         allowAdditions |                       allowAdditions | ||||||
|                         additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'} |                       additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'} | ||||||
|                         onChange={value => { |                       onChange={value => { | ||||||
|                             handleInputChange('groups', value) |                           handleInputChange('groups', value) | ||||||
|                         }} |                       }} | ||||||
|                         value={inputs.groups} |                       value={inputs.groups} | ||||||
|                         autoComplete='new-password' |                       autoComplete='new-password' | ||||||
|                         optionList={groupOptions} |                       optionList={groupOptions} | ||||||
|                     /> |                     /> | ||||||
|                     { |                     { | ||||||
|                         inputs.type === 18 && ( |                       inputs.type === 18 && ( | ||||||
|                             <> |                         <> | ||||||
|                                 <div style={{marginTop: 10}}> |                             <div style={{ marginTop: 10 }}> | ||||||
|                                     <Typography.Text strong>模型版本:</Typography.Text> |                                 <Typography.Text strong>模型版本:</Typography.Text> | ||||||
|                                 </div> |                             </div> | ||||||
|                                 <Input |                             <Input | ||||||
|                                     name='other' |                               name='other' | ||||||
|                                     placeholder={'请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1'} |                               placeholder={'请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1'} | ||||||
|                                     onChange={value => { |                               onChange={value => { | ||||||
|                                         handleInputChange('other', value) |                                   handleInputChange('other', value) | ||||||
|                                     }} |                               }} | ||||||
|                                     value={inputs.other} |                               value={inputs.other} | ||||||
|                                     autoComplete='new-password' |                               autoComplete='new-password' | ||||||
|                                 /> |                             /> | ||||||
|                             </> |                         </> | ||||||
|                         ) |                       ) | ||||||
|                     } |                     } | ||||||
|                     { |                     { | ||||||
|                         inputs.type === 21 && ( |                       inputs.type === 21 && ( | ||||||
|                             <> |                         <> | ||||||
|                                 <div style={{marginTop: 10}}> |                             <div style={{ marginTop: 10 }}> | ||||||
|                                     <Typography.Text strong>知识库 ID:</Typography.Text> |                                 <Typography.Text strong>知识库 ID:</Typography.Text> | ||||||
|                                 </div> |                             </div> | ||||||
|                                 <Input |                             <Input | ||||||
|                                     label='知识库 ID' |                               label='知识库 ID' | ||||||
|                                     name='other' |                               name='other' | ||||||
|                                     placeholder={'请输入知识库 ID,例如:123456'} |                               placeholder={'请输入知识库 ID,例如:123456'} | ||||||
|                                     onChange={value => { |                               onChange={value => { | ||||||
|                                         handleInputChange('other', value) |                                   handleInputChange('other', value) | ||||||
|                                     }} |                               }} | ||||||
|                                     value={inputs.other} |                               value={inputs.other} | ||||||
|                                     autoComplete='new-password' |                               autoComplete='new-password' | ||||||
|                                 /> |                             /> | ||||||
|                             </> |                         </> | ||||||
|                         ) |                       ) | ||||||
|                     } |                     } | ||||||
|                     <div style={{marginTop: 10}}> |                     <div style={{ marginTop: 10 }}> | ||||||
|                         <Typography.Text strong>模型:</Typography.Text> |                         <Typography.Text strong>模型:</Typography.Text> | ||||||
|                     </div> |                     </div> | ||||||
|                     <Select |                     <Select | ||||||
|                         placeholder={'请选择该渠道所支持的模型'} |                       placeholder={'请选择该渠道所支持的模型'} | ||||||
|                         name='models' |                       name='models' | ||||||
|                         required |                       required | ||||||
|                         multiple |                       multiple | ||||||
|                         selection |                       selection | ||||||
|                         onChange={value => { |                       onChange={value => { | ||||||
|                             handleInputChange('models', value) |                           handleInputChange('models', value) | ||||||
|                         }} |                       }} | ||||||
|                         value={inputs.models} |                       value={inputs.models} | ||||||
|                         autoComplete='new-password' |                       autoComplete='new-password' | ||||||
|                         optionList={modelOptions} |                       optionList={modelOptions} | ||||||
|                     /> |                     /> | ||||||
|                     <div style={{lineHeight: '40px', marginBottom: '12px'}}> |                     <div style={{ lineHeight: '40px', marginBottom: '12px' }}> | ||||||
|                         <Space> |                         <Space> | ||||||
|                             <Button type='primary' onClick={() => { |                             <Button type='primary' onClick={() => { | ||||||
|                                 handleInputChange('models', basicModels); |                                 handleInputChange('models', basicModels); | ||||||
| @@ -473,28 +474,41 @@ const EditChannel = (props) => { | |||||||
|                             }}>清除所有模型</Button> |                             }}>清除所有模型</Button> | ||||||
|                         </Space> |                         </Space> | ||||||
|                         <Input |                         <Input | ||||||
|                             addonAfter={ |                           addonAfter={ | ||||||
|                                 <Button type='primary' onClick={addCustomModel}>填入</Button> |                               <Button type='primary' onClick={addCustomModel}>填入</Button> | ||||||
|                             } |                           } | ||||||
|                             placeholder='输入自定义模型名称' |                           placeholder='输入自定义模型名称' | ||||||
|                             value={customModel} |                           value={customModel} | ||||||
|                             onChange={(value) => { |                           onChange={(value) => { | ||||||
|                                 setCustomModel(value.trim()); |                               setCustomModel(value.trim()); | ||||||
|                             }} |                           }} | ||||||
|                         /> |                         /> | ||||||
|                     </div> |                     </div> | ||||||
|                     <div style={{marginTop: 10}}> |                     <div style={{ marginTop: 10 }}> | ||||||
|                         <Typography.Text strong>模型重定向:</Typography.Text> |                         <Typography.Text strong>模型重定向:</Typography.Text> | ||||||
|                     </div> |                     </div> | ||||||
|                     <TextArea |                     <TextArea | ||||||
|                         placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`} |                       placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`} | ||||||
|                         name='model_mapping' |                       name='model_mapping' | ||||||
|                         onChange={value => { |                       onChange={value => { | ||||||
|                             handleInputChange('model_mapping', value) |                           handleInputChange('model_mapping', value) | ||||||
|                         }} |                       }} | ||||||
|                         autosize |                       autosize | ||||||
|                         value={inputs.model_mapping} |                       value={inputs.model_mapping} | ||||||
|                         autoComplete='new-password' |                       autoComplete='new-password' | ||||||
|  |                     /> | ||||||
|  |                     <div style={{ marginTop: 10 }}> | ||||||
|  |                         <Typography.Text strong>系统提示词:</Typography.Text> | ||||||
|  |                     </div> | ||||||
|  |                     <TextArea | ||||||
|  |                       placeholder={`此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型`} | ||||||
|  |                       name='system_prompt' | ||||||
|  |                       onChange={value => { | ||||||
|  |                           handleInputChange('system_prompt', value) | ||||||
|  |                       }} | ||||||
|  |                       autosize | ||||||
|  |                       value={inputs.system_prompt} | ||||||
|  |                       autoComplete='new-password' | ||||||
|                     /> |                     /> | ||||||
|                     <Typography.Text style={{ |                     <Typography.Text style={{ | ||||||
|                         color: 'rgba(var(--semi-blue-5), 1)', |                         color: 'rgba(var(--semi-blue-5), 1)', | ||||||
| @@ -507,116 +521,116 @@ const EditChannel = (props) => { | |||||||
|                     }> |                     }> | ||||||
|                         填入模板 |                         填入模板 | ||||||
|                     </Typography.Text> |                     </Typography.Text> | ||||||
|                     <div style={{marginTop: 10}}> |                     <div style={{ marginTop: 10 }}> | ||||||
|                         <Typography.Text strong>密钥:</Typography.Text> |                         <Typography.Text strong>密钥:</Typography.Text> | ||||||
|                     </div> |                     </div> | ||||||
|                     { |                     { | ||||||
|                         batch ? |                         batch ? | ||||||
|                             <TextArea |                           <TextArea | ||||||
|                                 label='密钥' |                             label='密钥' | ||||||
|                                 name='key' |                             name='key' | ||||||
|                                 required |                             required | ||||||
|                                 placeholder={'请输入密钥,一行一个'} |                             placeholder={'请输入密钥,一行一个'} | ||||||
|                                 onChange={value => { |                             onChange={value => { | ||||||
|                                     handleInputChange('key', value) |                                 handleInputChange('key', value) | ||||||
|                                 }} |                             }} | ||||||
|                                 value={inputs.key} |                             value={inputs.key} | ||||||
|                                 style={{minHeight: 150, fontFamily: 'JetBrains Mono, Consolas'}} |                             style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }} | ||||||
|                                 autoComplete='new-password' |                             autoComplete='new-password' | ||||||
|                             /> |                           /> | ||||||
|                             : |                           : | ||||||
|                             <Input |                           <Input | ||||||
|                                 label='密钥' |                             label='密钥' | ||||||
|                                 name='key' |                             name='key' | ||||||
|                                 required |                             required | ||||||
|                                 placeholder={type2secretPrompt(inputs.type)} |                             placeholder={type2secretPrompt(inputs.type)} | ||||||
|                                 onChange={value => { |                             onChange={value => { | ||||||
|                                     handleInputChange('key', value) |                                 handleInputChange('key', value) | ||||||
|                                 }} |                             }} | ||||||
|                                 value={inputs.key} |                             value={inputs.key} | ||||||
|                                 autoComplete='new-password' |                             autoComplete='new-password' | ||||||
|                             /> |                           /> | ||||||
|                     } |                     } | ||||||
|                     <div style={{marginTop: 10}}> |                     <div style={{ marginTop: 10 }}> | ||||||
|                         <Typography.Text strong>组织:</Typography.Text> |                         <Typography.Text strong>组织:</Typography.Text> | ||||||
|                     </div> |                     </div> | ||||||
|                     <Input |                     <Input | ||||||
|                         label='组织,可选,不填则为默认组织' |                       label='组织,可选,不填则为默认组织' | ||||||
|                         name='openai_organization' |                       name='openai_organization' | ||||||
|                         placeholder='请输入组织org-xxx' |                       placeholder='请输入组织org-xxx' | ||||||
|                         onChange={value => { |                       onChange={value => { | ||||||
|                             handleInputChange('openai_organization', value) |                           handleInputChange('openai_organization', value) | ||||||
|                         }} |                       }} | ||||||
|                         value={inputs.openai_organization} |                       value={inputs.openai_organization} | ||||||
|                     /> |                     /> | ||||||
|                     <div style={{marginTop: 10, display: 'flex'}}> |                     <div style={{ marginTop: 10, display: 'flex' }}> | ||||||
|                         <Space> |                         <Space> | ||||||
|                             <Checkbox |                             <Checkbox | ||||||
|                                 name='auto_ban' |                               name='auto_ban' | ||||||
|                                 checked={autoBan} |                               checked={autoBan} | ||||||
|                                 onChange={ |                               onChange={ | ||||||
|                                     () => { |                                   () => { | ||||||
|                                         setAutoBan(!autoBan); |                                       setAutoBan(!autoBan); | ||||||
|                                     } |                                   } | ||||||
|                                 } |                               } | ||||||
|                                 // onChange={handleInputChange} |                               // onChange={handleInputChange} | ||||||
|                             /> |                             /> | ||||||
|                             <Typography.Text |                             <Typography.Text | ||||||
|                                 strong>是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道:</Typography.Text> |                               strong>是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道:</Typography.Text> | ||||||
|                         </Space> |                         </Space> | ||||||
|                     </div> |                     </div> | ||||||
|  |  | ||||||
|                     { |                     { | ||||||
|                         !isEdit && ( |                       !isEdit && ( | ||||||
|                             <div style={{marginTop: 10, display: 'flex'}}> |                         <div style={{ marginTop: 10, display: 'flex' }}> | ||||||
|                                 <Space> |                             <Space> | ||||||
|                                     <Checkbox |                                 <Checkbox | ||||||
|                                         checked={batch} |                                   checked={batch} | ||||||
|                                         label='批量创建' |                                   label='批量创建' | ||||||
|                                         name='batch' |                                   name='batch' | ||||||
|                                         onChange={() => setBatch(!batch)} |                                   onChange={() => setBatch(!batch)} | ||||||
|                                     /> |                                 /> | ||||||
|                                     <Typography.Text strong>批量创建</Typography.Text> |                                 <Typography.Text strong>批量创建</Typography.Text> | ||||||
|                                 </Space> |                             </Space> | ||||||
|  |                         </div> | ||||||
|  |                       ) | ||||||
|  |                     } | ||||||
|  |                     { | ||||||
|  |                       inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && ( | ||||||
|  |                         <> | ||||||
|  |                             <div style={{ marginTop: 10 }}> | ||||||
|  |                                 <Typography.Text strong>代理:</Typography.Text> | ||||||
|                             </div> |                             </div> | ||||||
|                         ) |                             <Input | ||||||
|  |                               label='代理' | ||||||
|  |                               name='base_url' | ||||||
|  |                               placeholder={'此项可选,用于通过代理站来进行 API 调用'} | ||||||
|  |                               onChange={value => { | ||||||
|  |                                   handleInputChange('base_url', value) | ||||||
|  |                               }} | ||||||
|  |                               value={inputs.base_url} | ||||||
|  |                               autoComplete='new-password' | ||||||
|  |                             /> | ||||||
|  |                         </> | ||||||
|  |                       ) | ||||||
|                     } |                     } | ||||||
|                     { |                     { | ||||||
|                         inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && ( |                       inputs.type === 22 && ( | ||||||
|                             <> |                         <> | ||||||
|                                 <div style={{marginTop: 10}}> |                             <div style={{ marginTop: 10 }}> | ||||||
|                                     <Typography.Text strong>代理:</Typography.Text> |                                 <Typography.Text strong>私有部署地址:</Typography.Text> | ||||||
|                                 </div> |                             </div> | ||||||
|                                 <Input |                             <Input | ||||||
|                                     label='代理' |                               name='base_url' | ||||||
|                                     name='base_url' |                               placeholder={'请输入私有部署地址,格式为:https://fastgpt.run/api/openapi'} | ||||||
|                                     placeholder={'此项可选,用于通过代理站来进行 API 调用'} |                               onChange={value => { | ||||||
|                                     onChange={value => { |                                   handleInputChange('base_url', value) | ||||||
|                                         handleInputChange('base_url', value) |                               }} | ||||||
|                                     }} |                               value={inputs.base_url} | ||||||
|                                     value={inputs.base_url} |                               autoComplete='new-password' | ||||||
|                                     autoComplete='new-password' |                             /> | ||||||
|                                 /> |                         </> | ||||||
|                             </> |                       ) | ||||||
|                         ) |  | ||||||
|                     } |  | ||||||
|                     { |  | ||||||
|                         inputs.type === 22 && ( |  | ||||||
|                             <> |  | ||||||
|                                 <div style={{marginTop: 10}}> |  | ||||||
|                                     <Typography.Text strong>私有部署地址:</Typography.Text> |  | ||||||
|                                 </div> |  | ||||||
|                                 <Input |  | ||||||
|                                     name='base_url' |  | ||||||
|                                     placeholder={'请输入私有部署地址,格式为:https://fastgpt.run/api/openapi'} |  | ||||||
|                                     onChange={value => { |  | ||||||
|                                         handleInputChange('base_url', value) |  | ||||||
|                                     }} |  | ||||||
|                                     value={inputs.base_url} |  | ||||||
|                                     autoComplete='new-password' |  | ||||||
|                                 /> |  | ||||||
|                             </> |  | ||||||
|                         ) |  | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
|                 </Spin> |                 </Spin> | ||||||
|   | |||||||
| @@ -185,6 +185,12 @@ export const CHANNEL_OPTIONS = { | |||||||
|     value: 45, |     value: 45, | ||||||
|     color: 'primary' |     color: 'primary' | ||||||
|   }, |   }, | ||||||
|  |   45: { | ||||||
|  |     key: 46, | ||||||
|  |     text: 'Replicate', | ||||||
|  |     value: 46, | ||||||
|  |     color: 'primary' | ||||||
|  |   }, | ||||||
|   41: { |   41: { | ||||||
|     key: 41, |     key: 41, | ||||||
|     text: 'Novita', |     text: 'Novita', | ||||||
|   | |||||||
| @@ -95,7 +95,7 @@ export async function onLarkOAuthClicked(lark_client_id) { | |||||||
|   const state = await getOAuthState(); |   const state = await getOAuthState(); | ||||||
|   if (!state) return; |   if (!state) return; | ||||||
|   let redirect_uri = `${window.location.origin}/oauth/lark`; |   let redirect_uri = `${window.location.origin}/oauth/lark`; | ||||||
|   window.open(`https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}`); |   window.open(`https://accounts.feishu.cn/open-apis/authen/v1/authorize?redirect_uri=${redirect_uri}&client_id=${lark_client_id}&state=${state}`); | ||||||
| } | } | ||||||
|  |  | ||||||
| export async function onOidcClicked(auth_url, client_id, openInNewTab = false) { | export async function onOidcClicked(auth_url, client_id, openInNewTab = false) { | ||||||
|   | |||||||
| @@ -595,6 +595,28 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { | |||||||
|                   <FormHelperText id="helper-tex-channel-model_mapping-label"> {inputPrompt.model_mapping} </FormHelperText> |                   <FormHelperText id="helper-tex-channel-model_mapping-label"> {inputPrompt.model_mapping} </FormHelperText> | ||||||
|                 )} |                 )} | ||||||
|               </FormControl> |               </FormControl> | ||||||
|  |               <FormControl fullWidth error={Boolean(touched.system_prompt && errors.system_prompt)} sx={{ ...theme.typography.otherInput }}> | ||||||
|  |                 {/* <InputLabel htmlFor="channel-model_mapping-label">{inputLabel.model_mapping}</InputLabel> */} | ||||||
|  |                 <TextField | ||||||
|  |                   multiline | ||||||
|  |                   id="channel-system_prompt-label" | ||||||
|  |                   label={inputLabel.system_prompt} | ||||||
|  |                   value={values.system_prompt} | ||||||
|  |                   name="system_prompt" | ||||||
|  |                   onBlur={handleBlur} | ||||||
|  |                   onChange={handleChange} | ||||||
|  |                   aria-describedby="helper-text-channel-system_prompt-label" | ||||||
|  |                   minRows={5} | ||||||
|  |                   placeholder={inputPrompt.system_prompt} | ||||||
|  |                 /> | ||||||
|  |                 {touched.system_prompt && errors.system_prompt ? ( | ||||||
|  |                   <FormHelperText error id="helper-tex-channel-system_prompt-label"> | ||||||
|  |                     {errors.system_prompt} | ||||||
|  |                   </FormHelperText> | ||||||
|  |                 ) : ( | ||||||
|  |                   <FormHelperText id="helper-tex-channel-system_prompt-label"> {inputPrompt.system_prompt} </FormHelperText> | ||||||
|  |                 )} | ||||||
|  |               </FormControl> | ||||||
|               <DialogActions> |               <DialogActions> | ||||||
|                 <Button onClick={onCancel}>取消</Button> |                 <Button onClick={onCancel}>取消</Button> | ||||||
|                 <Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary"> |                 <Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary"> | ||||||
|   | |||||||
| @@ -268,6 +268,8 @@ function renderBalance(type, balance) { | |||||||
|       return <span>¥{balance.toFixed(2)}</span>; |       return <span>¥{balance.toFixed(2)}</span>; | ||||||
|     case 13: // AIGC2D |     case 13: // AIGC2D | ||||||
|       return <span>{renderNumber(balance)}</span>; |       return <span>{renderNumber(balance)}</span>; | ||||||
|  |     case 36: // DeepSeek | ||||||
|  |       return <span>¥{balance.toFixed(2)}</span>; | ||||||
|     case 44: // SiliconFlow |     case 44: // SiliconFlow | ||||||
|       return <span>¥{balance.toFixed(2)}</span>; |       return <span>¥{balance.toFixed(2)}</span>; | ||||||
|     default: |     default: | ||||||
|   | |||||||
| @@ -18,6 +18,7 @@ const defaultConfig = { | |||||||
|     other: '其他参数', |     other: '其他参数', | ||||||
|     models: '模型', |     models: '模型', | ||||||
|     model_mapping: '模型映射关系', |     model_mapping: '模型映射关系', | ||||||
|  |     system_prompt: '系统提示词', | ||||||
|     groups: '用户组', |     groups: '用户组', | ||||||
|     config: null |     config: null | ||||||
|   }, |   }, | ||||||
| @@ -30,6 +31,7 @@ const defaultConfig = { | |||||||
|     models: '请选择该渠道所支持的模型', |     models: '请选择该渠道所支持的模型', | ||||||
|     model_mapping: |     model_mapping: | ||||||
|       '请输入要修改的模型映射关系,格式为:api请求模型ID:实际转发给渠道的模型ID,使用JSON数组表示,例如:{"gpt-3.5": "gpt-35"}', |       '请输入要修改的模型映射关系,格式为:api请求模型ID:实际转发给渠道的模型ID,使用JSON数组表示,例如:{"gpt-3.5": "gpt-35"}', | ||||||
|  |     system_prompt:"此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型", | ||||||
|     groups: '请选择该渠道所支持的用户组', |     groups: '请选择该渠道所支持的用户组', | ||||||
|     config: null |     config: null | ||||||
|   }, |   }, | ||||||
|   | |||||||
							
								
								
									
										0
									
								
								web/build.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							
							
						
						
									
										0
									
								
								web/build.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							| @@ -52,6 +52,8 @@ function renderBalance(type, balance) { | |||||||
|       return <span>¥{balance.toFixed(2)}</span>; |       return <span>¥{balance.toFixed(2)}</span>; | ||||||
|     case 13: // AIGC2D |     case 13: // AIGC2D | ||||||
|       return <span>{renderNumber(balance)}</span>; |       return <span>{renderNumber(balance)}</span>; | ||||||
|  |     case 36: // DeepSeek | ||||||
|  |       return <span>¥{balance.toFixed(2)}</span>; | ||||||
|     case 44: // SiliconFlow |     case 44: // SiliconFlow | ||||||
|       return <span>¥{balance.toFixed(2)}</span>; |       return <span>¥{balance.toFixed(2)}</span>; | ||||||
|     default: |     default: | ||||||
|   | |||||||
| @@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [ | |||||||
|     { key: 43, text: 'Proxy', value: 43, color: 'blue' }, |     { key: 43, text: 'Proxy', value: 43, color: 'blue' }, | ||||||
|     { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' }, |     { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' }, | ||||||
|     { key: 45, text: 'xAI', value: 45, color: 'blue' }, |     { key: 45, text: 'xAI', value: 45, color: 'blue' }, | ||||||
|  |     { key: 46, text: 'Replicate', value: 46, color: 'blue' }, | ||||||
|     { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, |     { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, | ||||||
|     { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, |     { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, | ||||||
|     { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, |     { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, | ||||||
|   | |||||||
| @@ -43,6 +43,7 @@ const EditChannel = () => { | |||||||
|     base_url: '', |     base_url: '', | ||||||
|     other: '', |     other: '', | ||||||
|     model_mapping: '', |     model_mapping: '', | ||||||
|  |     system_prompt: '', | ||||||
|     models: [], |     models: [], | ||||||
|     groups: ['default'] |     groups: ['default'] | ||||||
|   }; |   }; | ||||||
| @@ -425,7 +426,7 @@ const EditChannel = () => { | |||||||
|             ) |             ) | ||||||
|           } |           } | ||||||
|           { |           { | ||||||
|           inputs.type !== 43 && ( |           inputs.type !== 43 && (<> | ||||||
|               <Form.Field> |               <Form.Field> | ||||||
|                 <Form.TextArea |                 <Form.TextArea | ||||||
|                   label='模型重定向' |                   label='模型重定向' | ||||||
| @@ -437,6 +438,18 @@ const EditChannel = () => { | |||||||
|                   autoComplete='new-password' |                   autoComplete='new-password' | ||||||
|                 /> |                 /> | ||||||
|               </Form.Field> |               </Form.Field> | ||||||
|  |             <Form.Field> | ||||||
|  |                 <Form.TextArea | ||||||
|  |                   label='系统提示词' | ||||||
|  |                   placeholder={`此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型`} | ||||||
|  |                   name='system_prompt' | ||||||
|  |                   onChange={handleInputChange} | ||||||
|  |                   value={inputs.system_prompt} | ||||||
|  |                   style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }} | ||||||
|  |                   autoComplete='new-password' | ||||||
|  |                 /> | ||||||
|  |               </Form.Field> | ||||||
|  |               </> | ||||||
|             ) |             ) | ||||||
|           } |           } | ||||||
|           { |           { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user