mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-25 10:53:42 +08:00 
			
		
		
		
	Compare commits
	
		
			62 Commits
		
	
	
		
			v0.6.8-alp
			...
			v0.6.10
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 3915ce9814 | ||
|  | 999defc88b | ||
|  | b51c47bc77 | ||
|  | 4f25cde132 | ||
|  | d89e9d7e44 | ||
|  | a858292b54 | ||
|  | ff589b5e4a | ||
|  | 95e8c16338 | ||
|  | 381172cb36 | ||
|  | 59eae186a3 | ||
|  | ce52f355bb | ||
|  | cb9d0a74c9 | ||
|  | 49ffb1c60d | ||
|  | 2f16649896 | ||
|  | af3aa57bd6 | ||
|  | e9f117ff72 | ||
|  | 6bb5247bd6 | ||
|  | 305ce14fe3 | ||
|  | 36c8f4f15c | ||
|  | 45b51ea0ee | ||
|  | 7c8628bd95 | ||
|  | 6ab87f8a08 | ||
|  | 833fa7ad6f | ||
|  | 6eb0770a89 | ||
|  | 92cd46d64f | ||
|  | 2b2dc2c733 | ||
|  | a3d7df7f89 | ||
|  | c368232f50 | ||
|  | cbfc983dc3 | ||
|  | 8ec092ba44 | ||
|  | b0b88a79ff | ||
|  | 7e51b04221 | ||
|  | f75a17f8eb | ||
|  | 6f13a3bb3c | ||
|  | f092eed1db | ||
|  | 629378691b | ||
|  | 3716e1b0e6 | ||
|  | a4d6e7a886 | ||
|  | cb772e5d06 | ||
|  | e32cb0b844 | ||
|  | fdd7bf41c0 | ||
|  | 29389ed44f | ||
|  | 88acc5a614 | ||
|  | a21681096a | ||
|  | 32f90a79a8 | ||
|  | 99c8c77504 | ||
|  | 649ecbf29c | ||
|  | 3a27c90910 | ||
|  | cba82404ae | ||
|  | c9ac670ba1 | ||
|  | 15f815c23c | ||
|  | 89b63ca96f | ||
|  | 8cc54489b9 | ||
|  | 58bf60805e | ||
|  | 6714cf96d6 | ||
|  | f9774698e9 | ||
|  | 2af6f6a166 | ||
|  | 04bb3ef392 | ||
|  | b4bfa418a8 | ||
|  | e7e99e558a | ||
|  | 402fcf7f79 | ||
|  | 36039e329e | 
							
								
								
									
										2
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							| @@ -12,8 +12,6 @@ name: CI | ||||
| # would trigger our jobs twice on pull requests (once from "push" event and once | ||||
| # from "pull_request->synchronize") | ||||
| on: | ||||
|   pull_request: | ||||
|     types: [opened, reopened, synchronize] | ||||
|   push: | ||||
|     branches: | ||||
|       - 'main' | ||||
|   | ||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -10,3 +10,4 @@ data | ||||
| /web/node_modules | ||||
| cmd.md | ||||
| .env | ||||
| /one-api | ||||
|   | ||||
							
								
								
									
										22
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								README.md
									
									
									
									
									
								
							| @@ -89,6 +89,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|    + [x] [DeepL](https://www.deepl.com/) | ||||
|    + [x] [together.ai](https://www.together.ai/) | ||||
|    + [x] [novita.ai](https://www.novita.ai/) | ||||
|    + [x] [硅基流动 SiliconCloud](https://siliconflow.cn/siliconcloud) | ||||
|    + [x] [xAI](https://x.ai/) | ||||
| 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | ||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||
| @@ -113,8 +115,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
| 21. 支持 Cloudflare Turnstile 用户校验。 | ||||
| 22. 支持用户管理,支持**多种用户登录注册方式**: | ||||
|     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 | ||||
|     + 支持使用飞书进行授权登录。 | ||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||
|     + 支持[飞书授权登录](https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/authen-v1/authorize/get)([这里有 One API 的实现细节阐述供参考](https://iamazing.cn/page/feishu-oauth-login))。 | ||||
|     + 支持 [GitHub 授权登录](https://github.com/settings/applications/new)。 | ||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||
| 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 | ||||
| 24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 | ||||
| @@ -173,6 +175,10 @@ sudo service nginx restart | ||||
|  | ||||
| 初始账号用户名为 `root`,密码为 `123456`。 | ||||
|  | ||||
| ### 通过宝塔面板进行一键部署 | ||||
| 1. 安装宝塔面板9.2.0及以上版本,前往 [宝塔面板](https://www.bt.cn/new/download.html?r=dk_oneapi) 官网,选择正式版的脚本下载安装; | ||||
| 2. 安装后登录宝塔面板,在左侧菜单栏中点击 `Docker`,首次进入会提示安装 `Docker` 服务,点击立即安装,按提示完成安装; | ||||
| 3. 安装完成后在应用商店中搜索 `One-API`,点击安装,配置域名等基本信息即可完成安装; | ||||
|  | ||||
| ### 基于 Docker Compose 进行部署 | ||||
|  | ||||
| @@ -216,7 +222,7 @@ docker-compose ps | ||||
| 3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。 | ||||
| 4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置,在使用远程数据库的情况下,推荐设置该项并启用 Redis,无论主从。 | ||||
| 5. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。 | ||||
| 6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟。 | ||||
| 6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟(Redis 集群或者哨兵模式的支持请参考环境变量说明)。 | ||||
| 7. 如果主服务器访问数据库延迟也比较高,则也需要启用 Redis,并设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。 | ||||
|  | ||||
| 环境变量的具体使用方法详见[此处](#环境变量)。 | ||||
| @@ -251,9 +257,9 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | ||||
| #### QChatGPT - QQ机器人 | ||||
| 项目主页:https://github.com/RockChinQ/QChatGPT | ||||
|  | ||||
| 根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。 | ||||
| 根据[文档](https://qchatgpt.rockchin.top)完成部署后,在 `data/provider.json`设置`requester.openai-chat-completions.base-url`为 One API 实例地址,并填写 API Key 到 `keys.openai` 组中,设置 `model` 为要使用的模型名称。 | ||||
|  | ||||
| 可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。 | ||||
| 运行期间可以通过`!model`命令查看、切换可用模型。 | ||||
|  | ||||
| ### 部署到第三方平台 | ||||
| <details> | ||||
| @@ -345,6 +351,11 @@ graph LR | ||||
| 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 | ||||
|    + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||
|    + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 | ||||
|    + 如果需要使用哨兵或者集群模式: | ||||
|      + 则需要把该环境变量设置为节点列表,例如:`localhost:49153,localhost:49154,localhost:49155`。 | ||||
|      + 除此之外还需要设置以下环境变量: | ||||
|        + `REDIS_PASSWORD`:Redis 集群或者哨兵模式下的密码设置。 | ||||
|        + `REDIS_MASTER_NAME`:Redis 哨兵模式下主节点的名称。 | ||||
| 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 | ||||
|    + 例子:`SESSION_SECRET=random_string` | ||||
| 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 | ||||
| @@ -398,6 +409,7 @@ graph LR | ||||
| 26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 | ||||
| 27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 | ||||
| 28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。 | ||||
| 29. `ENFORCE_INCLUDE_USAGE`:是否强制在 stream 模型下返回 usage,默认不开启,可选值为 `true` 和 `false`。 | ||||
|  | ||||
| ### 命令行参数 | ||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||
|   | ||||
| @@ -35,6 +35,7 @@ var PasswordLoginEnabled = true | ||||
| var PasswordRegisterEnabled = true | ||||
| var EmailVerificationEnabled = false | ||||
| var GitHubOAuthEnabled = false | ||||
| var OidcEnabled = false | ||||
| var WeChatAuthEnabled = false | ||||
| var TurnstileCheckEnabled = false | ||||
| var RegisterEnabled = true | ||||
| @@ -70,6 +71,13 @@ var GitHubClientSecret = "" | ||||
| var LarkClientId = "" | ||||
| var LarkClientSecret = "" | ||||
|  | ||||
| var OidcClientId = "" | ||||
| var OidcClientSecret = "" | ||||
| var OidcWellKnown = "" | ||||
| var OidcAuthorizationEndpoint = "" | ||||
| var OidcTokenEndpoint = "" | ||||
| var OidcUserinfoEndpoint = "" | ||||
|  | ||||
| var WeChatServerAddress = "" | ||||
| var WeChatServerToken = "" | ||||
| var WeChatAccountQRCodeImageURL = "" | ||||
| @@ -152,3 +160,5 @@ var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false) | ||||
| var RelayProxy = env.String("RELAY_PROXY", "") | ||||
| var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") | ||||
| var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) | ||||
|  | ||||
| var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false) | ||||
|   | ||||
| @@ -20,4 +20,5 @@ const ( | ||||
| 	BaseURL           = "base_url" | ||||
| 	AvailableModels   = "available_models" | ||||
| 	KeyRequestBody    = "key_request_body" | ||||
| 	SystemPrompt      = "system_prompt" | ||||
| ) | ||||
|   | ||||
| @@ -31,15 +31,15 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { | ||||
| 	contentType := c.Request.Header.Get("Content-Type") | ||||
| 	if strings.HasPrefix(contentType, "application/json") { | ||||
| 		err = json.Unmarshal(requestBody, &v) | ||||
| 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||
| 	} else { | ||||
| 		// skip for now | ||||
| 		// TODO: someday non json request have variant model, we will need to implementation this | ||||
| 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||
| 		err = c.ShouldBind(&v) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	// Reset request body | ||||
| 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -137,3 +137,23 @@ func String2Int(str string) int { | ||||
| 	} | ||||
| 	return num | ||||
| } | ||||
|  | ||||
| func Float64PtrMax(p *float64, maxValue float64) *float64 { | ||||
| 	if p == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 	if *p > maxValue { | ||||
| 		return &maxValue | ||||
| 	} | ||||
| 	return p | ||||
| } | ||||
|  | ||||
| func Float64PtrMin(p *float64, minValue float64) *float64 { | ||||
| 	if p == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 	if *p < minValue { | ||||
| 		return &minValue | ||||
| 	} | ||||
| 	return p | ||||
| } | ||||
|   | ||||
| @@ -2,13 +2,15 @@ package common | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"os" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| var RDB *redis.Client | ||||
| var RDB redis.Cmdable | ||||
| var RedisEnabled = true | ||||
|  | ||||
| // InitRedisClient This function is called after init() | ||||
| @@ -23,13 +25,23 @@ func InitRedisClient() (err error) { | ||||
| 		logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") | ||||
| 		return nil | ||||
| 	} | ||||
| 	redisConnString := os.Getenv("REDIS_CONN_STRING") | ||||
| 	if os.Getenv("REDIS_MASTER_NAME") == "" { | ||||
| 		logger.SysLog("Redis is enabled") | ||||
| 	opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) | ||||
| 		opt, err := redis.ParseURL(redisConnString) | ||||
| 		if err != nil { | ||||
| 			logger.FatalLog("failed to parse Redis connection string: " + err.Error()) | ||||
| 		} | ||||
| 		RDB = redis.NewClient(opt) | ||||
|  | ||||
| 	} else { | ||||
| 		// cluster mode | ||||
| 		logger.SysLog("Redis cluster mode enabled") | ||||
| 		RDB = redis.NewUniversalClient(&redis.UniversalOptions{ | ||||
| 			Addrs:      strings.Split(redisConnString, ","), | ||||
| 			Password:   os.Getenv("REDIS_PASSWORD"), | ||||
| 			MasterName: os.Getenv("REDIS_MASTER_NAME"), | ||||
| 		}) | ||||
| 	} | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||
| 	defer cancel() | ||||
|  | ||||
|   | ||||
| @@ -3,9 +3,10 @@ package render | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func StringData(c *gin.Context, str string) { | ||||
|   | ||||
| @@ -40,7 +40,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) { | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData)) | ||||
| 	req, err := http.NewRequest("POST", "https://open.feishu.cn/open-apis/authen/v2/oauth/token", bytes.NewBuffer(jsonData)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|   | ||||
							
								
								
									
										225
									
								
								controller/auth/oidc.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										225
									
								
								controller/auth/oidc.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,225 @@ | ||||
| package auth | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/controller" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type OidcResponse struct { | ||||
| 	AccessToken  string `json:"access_token"` | ||||
| 	IDToken      string `json:"id_token"` | ||||
| 	RefreshToken string `json:"refresh_token"` | ||||
| 	TokenType    string `json:"token_type"` | ||||
| 	ExpiresIn    int    `json:"expires_in"` | ||||
| 	Scope        string `json:"scope"` | ||||
| } | ||||
|  | ||||
| type OidcUser struct { | ||||
| 	OpenID            string `json:"sub"` | ||||
| 	Email             string `json:"email"` | ||||
| 	Name              string `json:"name"` | ||||
| 	PreferredUsername string `json:"preferred_username"` | ||||
| 	Picture           string `json:"picture"` | ||||
| } | ||||
|  | ||||
| func getOidcUserInfoByCode(code string) (*OidcUser, error) { | ||||
| 	if code == "" { | ||||
| 		return nil, errors.New("无效的参数") | ||||
| 	} | ||||
| 	values := map[string]string{ | ||||
| 		"client_id":     config.OidcClientId, | ||||
| 		"client_secret": config.OidcClientSecret, | ||||
| 		"code":          code, | ||||
| 		"grant_type":    "authorization_code", | ||||
| 		"redirect_uri":  fmt.Sprintf("%s/oauth/oidc", config.ServerAddress), | ||||
| 	} | ||||
| 	jsonData, err := json.Marshal(values) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/json") | ||||
| 	req.Header.Set("Accept", "application/json") | ||||
| 	client := http.Client{ | ||||
| 		Timeout: 5 * time.Second, | ||||
| 	} | ||||
| 	res, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		logger.SysLog(err.Error()) | ||||
| 		return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
| 	var oidcResponse OidcResponse | ||||
| 	err = json.NewDecoder(res.Body).Decode(&oidcResponse) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken) | ||||
| 	res2, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		logger.SysLog(err.Error()) | ||||
| 		return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") | ||||
| 	} | ||||
| 	var oidcUser OidcUser | ||||
| 	err = json.NewDecoder(res2.Body).Decode(&oidcUser) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &oidcUser, nil | ||||
| } | ||||
|  | ||||
| func OidcAuth(c *gin.Context) { | ||||
| 	session := sessions.Default(c) | ||||
| 	state := c.Query("state") | ||||
| 	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { | ||||
| 		c.JSON(http.StatusForbidden, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "state is empty or not same", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	username := session.Get("username") | ||||
| 	if username != nil { | ||||
| 		OidcBind(c) | ||||
| 		return | ||||
| 	} | ||||
| 	if !config.OidcEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "管理员未开启通过 OIDC 登录以及注册", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	code := c.Query("code") | ||||
| 	oidcUser, err := getOidcUserInfoByCode(code) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user := model.User{ | ||||
| 		OidcId: oidcUser.OpenID, | ||||
| 	} | ||||
| 	if model.IsOidcIdAlreadyTaken(user.OidcId) { | ||||
| 		err := user.FillUserByOidcId() | ||||
| 		if err != nil { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": err.Error(), | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	} else { | ||||
| 		if config.RegisterEnabled { | ||||
| 			user.Email = oidcUser.Email | ||||
| 			if oidcUser.PreferredUsername != "" { | ||||
| 				user.Username = oidcUser.PreferredUsername | ||||
| 			} else { | ||||
| 				user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1) | ||||
| 			} | ||||
| 			if oidcUser.Name != "" { | ||||
| 				user.DisplayName = oidcUser.Name | ||||
| 			} else { | ||||
| 				user.DisplayName = "OIDC User" | ||||
| 			} | ||||
| 			err := user.Insert(0) | ||||
| 			if err != nil { | ||||
| 				c.JSON(http.StatusOK, gin.H{ | ||||
| 					"success": false, | ||||
| 					"message": err.Error(), | ||||
| 				}) | ||||
| 				return | ||||
| 			} | ||||
| 		} else { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "管理员关闭了新用户注册", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if user.Status != model.UserStatusEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"message": "用户已被封禁", | ||||
| 			"success": false, | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	controller.SetupLogin(&user, c) | ||||
| } | ||||
|  | ||||
| func OidcBind(c *gin.Context) { | ||||
| 	if !config.OidcEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "管理员未开启通过 OIDC 登录以及注册", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	code := c.Query("code") | ||||
| 	oidcUser, err := getOidcUserInfoByCode(code) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user := model.User{ | ||||
| 		OidcId: oidcUser.OpenID, | ||||
| 	} | ||||
| 	if model.IsOidcIdAlreadyTaken(user.OidcId) { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "该 OIDC 账户已被绑定", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	session := sessions.Default(c) | ||||
| 	id := session.Get("id") | ||||
| 	// id := c.GetInt("id")  // critical bug! | ||||
| 	user.Id = id.(int) | ||||
| 	err = user.FillUserById() | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user.OidcId = oidcUser.OpenID | ||||
| 	err = user.Update(false) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "bind", | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
| @@ -17,9 +17,11 @@ func GetSubscription(c *gin.Context) { | ||||
| 	if config.DisplayTokenStatEnabled { | ||||
| 		tokenId := c.GetInt(ctxkey.TokenId) | ||||
| 		token, err = model.GetTokenById(tokenId) | ||||
| 		if err == nil { | ||||
| 			expiredTime = token.ExpiredTime | ||||
| 			remainQuota = token.RemainQuota | ||||
| 			usedQuota = token.UsedQuota | ||||
| 		} | ||||
| 	} else { | ||||
| 		userId := c.GetInt(ctxkey.Id) | ||||
| 		remainQuota, err = model.GetUserQuota(userId) | ||||
|   | ||||
| @@ -4,16 +4,17 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/client" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/monitor" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
| @@ -81,6 +82,36 @@ type APGC2DGPTUsageResponse struct { | ||||
| 	TotalUsed      float64 `json:"total_used"` | ||||
| } | ||||
|  | ||||
| type SiliconFlowUsageResponse struct { | ||||
| 	Code    int    `json:"code"` | ||||
| 	Message string `json:"message"` | ||||
| 	Status  bool   `json:"status"` | ||||
| 	Data    struct { | ||||
| 		ID            string `json:"id"` | ||||
| 		Name          string `json:"name"` | ||||
| 		Image         string `json:"image"` | ||||
| 		Email         string `json:"email"` | ||||
| 		IsAdmin       bool   `json:"isAdmin"` | ||||
| 		Balance       string `json:"balance"` | ||||
| 		Status        string `json:"status"` | ||||
| 		Introduction  string `json:"introduction"` | ||||
| 		Role          string `json:"role"` | ||||
| 		ChargeBalance string `json:"chargeBalance"` | ||||
| 		TotalBalance  string `json:"totalBalance"` | ||||
| 		Category      string `json:"category"` | ||||
| 	} `json:"data"` | ||||
| } | ||||
|  | ||||
| type DeepSeekUsageResponse struct { | ||||
| 	IsAvailable  bool `json:"is_available"` | ||||
| 	BalanceInfos []struct { | ||||
| 		Currency        string `json:"currency"` | ||||
| 		TotalBalance    string `json:"total_balance"` | ||||
| 		GrantedBalance  string `json:"granted_balance"` | ||||
| 		ToppedUpBalance string `json:"topped_up_balance"` | ||||
| 	} `json:"balance_infos"` | ||||
| } | ||||
|  | ||||
| // GetAuthHeader get auth header | ||||
| func GetAuthHeader(token string) http.Header { | ||||
| 	h := http.Header{} | ||||
| @@ -203,6 +234,57 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { | ||||
| 	return response.TotalAvailable, nil | ||||
| } | ||||
|  | ||||
| func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) { | ||||
| 	url := "https://api.siliconflow.cn/v1/user/info" | ||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	response := SiliconFlowUsageResponse{} | ||||
| 	err = json.Unmarshal(body, &response) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	if response.Code != 20000 { | ||||
| 		return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message) | ||||
| 	} | ||||
| 	balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	channel.UpdateBalance(balance) | ||||
| 	return balance, nil | ||||
| } | ||||
|  | ||||
| func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) { | ||||
| 	url := "https://api.deepseek.com/user/balance" | ||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	response := DeepSeekUsageResponse{} | ||||
| 	err = json.Unmarshal(body, &response) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	index := -1 | ||||
| 	for i, balanceInfo := range response.BalanceInfos { | ||||
| 		if balanceInfo.Currency == "CNY" { | ||||
| 			index = i | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	if index == -1 { | ||||
| 		return 0, errors.New("currency CNY not found") | ||||
| 	} | ||||
| 	balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	channel.UpdateBalance(balance) | ||||
| 	return balance, nil | ||||
| } | ||||
|  | ||||
| func updateChannelBalance(channel *model.Channel) (float64, error) { | ||||
| 	baseURL := channeltype.ChannelBaseURLs[channel.Type] | ||||
| 	if channel.GetBaseURL() == "" { | ||||
| @@ -227,6 +309,10 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { | ||||
| 		return updateChannelAPI2GPTBalance(channel) | ||||
| 	case channeltype.AIGC2D: | ||||
| 		return updateChannelAIGC2DBalance(channel) | ||||
| 	case channeltype.SiliconFlow: | ||||
| 		return updateChannelSiliconFlowBalance(channel) | ||||
| 	case channeltype.DeepSeek: | ||||
| 		return updateChannelDeepSeekBalance(channel) | ||||
| 	default: | ||||
| 		return 0, errors.New("尚未实现") | ||||
| 	} | ||||
|   | ||||
| @@ -76,10 +76,10 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques | ||||
| 		if len(modelNames) > 0 { | ||||
| 			modelName = modelNames[0] | ||||
| 		} | ||||
| 	} | ||||
| 	if modelMap != nil && modelMap[modelName] != "" { | ||||
| 		modelName = modelMap[modelName] | ||||
| 	} | ||||
| 	} | ||||
| 	meta.OriginModelName, meta.ActualModelName = request.Model, modelName | ||||
| 	request.Model = modelName | ||||
| 	convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) | ||||
|   | ||||
| @@ -36,6 +36,12 @@ func GetStatus(c *gin.Context) { | ||||
| 			"chat_link":                   config.ChatLink, | ||||
| 			"quota_per_unit":              config.QuotaPerUnit, | ||||
| 			"display_in_currency":         config.DisplayInCurrencyEnabled, | ||||
| 			"oidc":                        config.OidcEnabled, | ||||
| 			"oidc_client_id":              config.OidcClientId, | ||||
| 			"oidc_well_known":             config.OidcWellKnown, | ||||
| 			"oidc_authorization_endpoint": config.OidcAuthorizationEndpoint, | ||||
| 			"oidc_token_endpoint":         config.OidcTokenEndpoint, | ||||
| 			"oidc_userinfo_endpoint":      config.OidcUserinfoEndpoint, | ||||
| 		}, | ||||
| 	}) | ||||
| 	return | ||||
|   | ||||
| @@ -60,7 +60,7 @@ func Relay(c *gin.Context) { | ||||
| 	channelName := c.GetString(ctxkey.ChannelName) | ||||
| 	group := c.GetString(ctxkey.Group) | ||||
| 	originalModel := c.GetString(ctxkey.OriginalModel) | ||||
| 	go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) | ||||
| 	go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) | ||||
| 	requestId := c.GetString(helper.RequestIdKey) | ||||
| 	retryTimes := config.RetryTimes | ||||
| 	if !shouldRetry(c, bizErr.StatusCode) { | ||||
| @@ -87,8 +87,7 @@ func Relay(c *gin.Context) { | ||||
| 		channelId := c.GetInt(ctxkey.ChannelId) | ||||
| 		lastFailedChannelId = channelId | ||||
| 		channelName := c.GetString(ctxkey.ChannelName) | ||||
| 		// BUG: bizErr is in race condition | ||||
| 		go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) | ||||
| 		go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) | ||||
| 	} | ||||
| 	if bizErr != nil { | ||||
| 		if bizErr.StatusCode == http.StatusTooManyRequests { | ||||
| @@ -122,7 +121,7 @@ func shouldRetry(c *gin.Context, statusCode int) bool { | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) { | ||||
| func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err model.ErrorWithStatusCode) { | ||||
| 	logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message) | ||||
| 	// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||
| 	if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { | ||||
|   | ||||
							
								
								
									
										8
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								go.mod
									
									
									
									
									
								
							| @@ -25,7 +25,7 @@ require ( | ||||
| 	github.com/pkoukk/tiktoken-go v0.1.7 | ||||
| 	github.com/smartystreets/goconvey v1.8.1 | ||||
| 	github.com/stretchr/testify v1.9.0 | ||||
| 	golang.org/x/crypto v0.24.0 | ||||
| 	golang.org/x/crypto v0.31.0 | ||||
| 	golang.org/x/image v0.18.0 | ||||
| 	google.golang.org/api v0.187.0 | ||||
| 	gorm.io/driver/mysql v1.5.6 | ||||
| @@ -99,9 +99,9 @@ require ( | ||||
| 	golang.org/x/arch v0.8.0 // indirect | ||||
| 	golang.org/x/net v0.26.0 // indirect | ||||
| 	golang.org/x/oauth2 v0.21.0 // indirect | ||||
| 	golang.org/x/sync v0.7.0 // indirect | ||||
| 	golang.org/x/sys v0.21.0 // indirect | ||||
| 	golang.org/x/text v0.16.0 // indirect | ||||
| 	golang.org/x/sync v0.10.0 // indirect | ||||
| 	golang.org/x/sys v0.28.0 // indirect | ||||
| 	golang.org/x/text v0.21.0 // indirect | ||||
| 	golang.org/x/time v0.5.0 // indirect | ||||
| 	google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect | ||||
| 	google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect | ||||
|   | ||||
							
								
								
									
										16
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								go.sum
									
									
									
									
									
								
							| @@ -222,8 +222,8 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= | ||||
| golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= | ||||
| golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= | ||||
| golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= | ||||
| golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= | ||||
| golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= | ||||
| golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= | ||||
| golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= | ||||
| golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= | ||||
| golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= | ||||
| golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= | ||||
| @@ -244,20 +244,20 @@ golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht | ||||
| golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= | ||||
| golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= | ||||
| golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= | ||||
| golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= | ||||
| golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= | ||||
| golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||||
| golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= | ||||
| golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||||
| golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= | ||||
| golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
| golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= | ||||
| golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= | ||||
| golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= | ||||
| golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= | ||||
| golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= | ||||
| golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= | ||||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||||
|   | ||||
| @@ -12,7 +12,7 @@ import ( | ||||
| ) | ||||
|  | ||||
| type ModelRequest struct { | ||||
| 	Model string `json:"model"` | ||||
| 	Model string `json:"model" form:"model"` | ||||
| } | ||||
|  | ||||
| func Distribute() func(c *gin.Context) { | ||||
| @@ -61,6 +61,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode | ||||
| 	c.Set(ctxkey.Channel, channel.Type) | ||||
| 	c.Set(ctxkey.ChannelId, channel.Id) | ||||
| 	c.Set(ctxkey.ChannelName, channel.Name) | ||||
| 	if channel.SystemPrompt != nil && *channel.SystemPrompt != "" { | ||||
| 		c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt) | ||||
| 	} | ||||
| 	c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) | ||||
| 	c.Set(ctxkey.OriginalModel, modelName) // for retry | ||||
| 	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | ||||
|   | ||||
							
								
								
									
										27
									
								
								middleware/gzip.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								middleware/gzip.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,27 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"compress/gzip" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| func GzipDecodeMiddleware() gin.HandlerFunc { | ||||
| 	return func(c *gin.Context) { | ||||
| 		if c.GetHeader("Content-Encoding") == "gzip" { | ||||
| 			gzipReader, err := gzip.NewReader(c.Request.Body) | ||||
| 			if err != nil { | ||||
| 				c.AbortWithStatus(http.StatusBadRequest) | ||||
| 				return | ||||
| 			} | ||||
| 			defer gzipReader.Close() | ||||
|  | ||||
| 			// Replace the request body with the decompressed data | ||||
| 			c.Request.Body = io.NopCloser(gzipReader) | ||||
| 		} | ||||
|  | ||||
| 		// Continue processing the request | ||||
| 		c.Next() | ||||
| 	} | ||||
| } | ||||
| @@ -37,6 +37,7 @@ type Channel struct { | ||||
| 	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | ||||
| 	Priority           *int64  `json:"priority" gorm:"bigint;default:0"` | ||||
| 	Config             string  `json:"config"` | ||||
| 	SystemPrompt       *string `json:"system_prompt" gorm:"type:text"` | ||||
| } | ||||
|  | ||||
| type ChannelConfig struct { | ||||
|   | ||||
							
								
								
									
										13
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								model/log.go
									
									
									
									
									
								
							| @@ -3,6 +3,7 @@ package model | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| @@ -152,7 +153,11 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { | ||||
| } | ||||
|  | ||||
| func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) { | ||||
| 	tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)") | ||||
| 	ifnull := "ifnull" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		ifnull = "COALESCE" | ||||
| 	} | ||||
| 	tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(quota),0)", ifnull)) | ||||
| 	if username != "" { | ||||
| 		tx = tx.Where("username = ?", username) | ||||
| 	} | ||||
| @@ -176,7 +181,11 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa | ||||
| } | ||||
|  | ||||
| func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { | ||||
| 	tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") | ||||
| 	ifnull := "ifnull" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		ifnull = "COALESCE" | ||||
| 	} | ||||
| 	tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(prompt_tokens),0) + %s(sum(completion_tokens),0)", ifnull, ifnull)) | ||||
| 	if username != "" { | ||||
| 		tx = tx.Where("username = ?", username) | ||||
| 	} | ||||
|   | ||||
| @@ -28,6 +28,7 @@ func InitOptionMap() { | ||||
| 	config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) | ||||
| 	config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) | ||||
| 	config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) | ||||
| 	config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled) | ||||
| 	config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) | ||||
| 	config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) | ||||
| 	config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) | ||||
| @@ -130,6 +131,8 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 			config.EmailVerificationEnabled = boolValue | ||||
| 		case "GitHubOAuthEnabled": | ||||
| 			config.GitHubOAuthEnabled = boolValue | ||||
| 		case "OidcEnabled": | ||||
| 			config.OidcEnabled = boolValue | ||||
| 		case "WeChatAuthEnabled": | ||||
| 			config.WeChatAuthEnabled = boolValue | ||||
| 		case "TurnstileCheckEnabled": | ||||
| @@ -176,6 +179,18 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 		config.LarkClientId = value | ||||
| 	case "LarkClientSecret": | ||||
| 		config.LarkClientSecret = value | ||||
| 	case "OidcClientId": | ||||
| 		config.OidcClientId = value | ||||
| 	case "OidcClientSecret": | ||||
| 		config.OidcClientSecret = value | ||||
| 	case "OidcWellKnown": | ||||
| 		config.OidcWellKnown = value | ||||
| 	case "OidcAuthorizationEndpoint": | ||||
| 		config.OidcAuthorizationEndpoint = value | ||||
| 	case "OidcTokenEndpoint": | ||||
| 		config.OidcTokenEndpoint = value | ||||
| 	case "OidcUserinfoEndpoint": | ||||
| 		config.OidcUserinfoEndpoint = value | ||||
| 	case "Footer": | ||||
| 		config.Footer = value | ||||
| 	case "SystemName": | ||||
|   | ||||
| @@ -30,7 +30,7 @@ type Token struct { | ||||
| 	RemainQuota    int64   `json:"remain_quota" gorm:"bigint;default:0"` | ||||
| 	UnlimitedQuota bool    `json:"unlimited_quota" gorm:"default:false"` | ||||
| 	UsedQuota      int64   `json:"used_quota" gorm:"bigint;default:0"` // used quota | ||||
| 	Models         *string `json:"models" gorm:"default:''"`           // allowed models | ||||
| 	Models         *string `json:"models" gorm:"type:text"`            // allowed models | ||||
| 	Subnet         *string `json:"subnet" gorm:"default:''"`           // allowed subnet | ||||
| } | ||||
|  | ||||
| @@ -121,30 +121,40 @@ func GetTokenById(id int) (*Token, error) { | ||||
| 	return &token, err | ||||
| } | ||||
|  | ||||
| func (token *Token) Insert() error { | ||||
| func (t *Token) Insert() error { | ||||
| 	var err error | ||||
| 	err = DB.Create(token).Error | ||||
| 	err = DB.Create(t).Error | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // Update Make sure your token's fields is completed, because this will update non-zero values | ||||
| func (token *Token) Update() error { | ||||
| func (t *Token) Update() error { | ||||
| 	var err error | ||||
| 	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error | ||||
| 	err = DB.Model(t).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(t).Error | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (token *Token) SelectUpdate() error { | ||||
| func (t *Token) SelectUpdate() error { | ||||
| 	// This can update zero values | ||||
| 	return DB.Model(token).Select("accessed_time", "status").Updates(token).Error | ||||
| 	return DB.Model(t).Select("accessed_time", "status").Updates(t).Error | ||||
| } | ||||
|  | ||||
| func (token *Token) Delete() error { | ||||
| func (t *Token) Delete() error { | ||||
| 	var err error | ||||
| 	err = DB.Delete(token).Error | ||||
| 	err = DB.Delete(t).Error | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (t *Token) GetModels() string { | ||||
| 	if t == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	if t.Models == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	return *t.Models | ||||
| } | ||||
|  | ||||
| func DeleteTokenById(id int, userId int) (err error) { | ||||
| 	// Why we need userId here? In case user want to delete other's token. | ||||
| 	if id == 0 || userId == 0 { | ||||
| @@ -254,14 +264,14 @@ func PreConsumeTokenQuota(tokenId int, quota int64) (err error) { | ||||
|  | ||||
| func PostConsumeTokenQuota(tokenId int, quota int64) (err error) { | ||||
| 	token, err := GetTokenById(tokenId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if quota > 0 { | ||||
| 		err = DecreaseUserQuota(token.UserId, quota) | ||||
| 	} else { | ||||
| 		err = IncreaseUserQuota(token.UserId, -quota) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if !token.UnlimitedQuota { | ||||
| 		if quota > 0 { | ||||
| 			err = DecreaseTokenQuota(tokenId, quota) | ||||
|   | ||||
| @@ -39,6 +39,7 @@ type User struct { | ||||
| 	GitHubId         string `json:"github_id" gorm:"column:github_id;index"` | ||||
| 	WeChatId         string `json:"wechat_id" gorm:"column:wechat_id;index"` | ||||
| 	LarkId           string `json:"lark_id" gorm:"column:lark_id;index"` | ||||
| 	OidcId           string `json:"oidc_id" gorm:"column:oidc_id;index"` | ||||
| 	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database! | ||||
| 	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management | ||||
| 	Quota            int64  `json:"quota" gorm:"bigint;default:0"` | ||||
| @@ -245,6 +246,14 @@ func (user *User) FillUserByLarkId() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (user *User) FillUserByOidcId() error { | ||||
| 	if user.OidcId == "" { | ||||
| 		return errors.New("oidc id 为空!") | ||||
| 	} | ||||
| 	DB.Where(User{OidcId: user.OidcId}).First(user) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (user *User) FillUserByWeChatId() error { | ||||
| 	if user.WeChatId == "" { | ||||
| 		return errors.New("WeChat id 为空!") | ||||
| @@ -277,6 +286,10 @@ func IsLarkIdAlreadyTaken(githubId string) bool { | ||||
| 	return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
|  | ||||
| func IsOidcIdAlreadyTaken(oidcId string) bool { | ||||
| 	return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
|  | ||||
| func IsUsernameAlreadyTaken(username string) bool { | ||||
| 	return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
|   | ||||
| @@ -1,10 +1,11 @@ | ||||
| package monitor | ||||
|  | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| func ShouldDisableChannel(err *model.Error, statusCode int) bool { | ||||
| @@ -18,31 +19,23 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool { | ||||
| 		return true | ||||
| 	} | ||||
| 	switch err.Type { | ||||
| 	case "insufficient_quota": | ||||
| 		return true | ||||
| 	// https://docs.anthropic.com/claude/reference/errors | ||||
| 	case "authentication_error": | ||||
| 		return true | ||||
| 	case "permission_error": | ||||
| 		return true | ||||
| 	case "forbidden": | ||||
| 	case "insufficient_quota", "authentication_error", "permission_error", "forbidden": | ||||
| 		return true | ||||
| 	} | ||||
| 	if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { | ||||
| 		return true | ||||
| 	} | ||||
| 	if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic | ||||
| 		return true | ||||
| 	} else if strings.HasPrefix(err.Message, "This organization has been disabled.") { | ||||
| 		return true | ||||
| 	} | ||||
| 	//if strings.Contains(err.Message, "quota") { | ||||
| 	//	return true | ||||
| 	//} | ||||
| 	if strings.Contains(err.Message, "credit") { | ||||
| 		return true | ||||
| 	} | ||||
| 	if strings.Contains(err.Message, "balance") { | ||||
|  | ||||
| 	lowerMessage := strings.ToLower(err.Message) | ||||
| 	if strings.Contains(lowerMessage, "your access was terminated") || | ||||
| 		strings.Contains(lowerMessage, "violation of our policies") || | ||||
| 		strings.Contains(lowerMessage, "your credit balance is too low") || | ||||
| 		strings.Contains(lowerMessage, "organization has been disabled") || | ||||
| 		strings.Contains(lowerMessage, "credit") || | ||||
| 		strings.Contains(lowerMessage, "balance") || | ||||
| 		strings.Contains(lowerMessage, "permission denied") || | ||||
| 		strings.Contains(lowerMessage, "organization has been restricted") || // groq | ||||
| 		strings.Contains(lowerMessage, "已欠费") { | ||||
| 		return true | ||||
| 	} | ||||
| 	return false | ||||
|   | ||||
| @@ -16,6 +16,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/palm" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/proxy" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/replicate" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/tencent" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/vertexai" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/xunfei" | ||||
| @@ -61,6 +62,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor { | ||||
| 		return &vertexai.Adaptor{} | ||||
| 	case apitype.Proxy: | ||||
| 		return &proxy.Adaptor{} | ||||
| 	case apitype.Replicate: | ||||
| 		return &replicate.Adaptor{} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -1,7 +1,23 @@ | ||||
| package ali | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", | ||||
| 	"text-embedding-v1", | ||||
| 	"qwen-turbo", "qwen-turbo-latest", | ||||
| 	"qwen-plus", "qwen-plus-latest", | ||||
| 	"qwen-max", "qwen-max-latest", | ||||
| 	"qwen-max-longcontext", | ||||
| 	"qwen-vl-max", "qwen-vl-max-latest", "qwen-vl-plus", "qwen-vl-plus-latest", | ||||
| 	"qwen-vl-ocr", "qwen-vl-ocr-latest", | ||||
| 	"qwen-audio-turbo", | ||||
| 	"qwen-math-plus", "qwen-math-plus-latest", "qwen-math-turbo", "qwen-math-turbo-latest", | ||||
| 	"qwen-coder-plus", "qwen-coder-plus-latest", "qwen-coder-turbo", "qwen-coder-turbo-latest", | ||||
| 	"qwq-32b-preview", "qwen2.5-72b-instruct", "qwen2.5-32b-instruct", "qwen2.5-14b-instruct", "qwen2.5-7b-instruct", "qwen2.5-3b-instruct", "qwen2.5-1.5b-instruct", "qwen2.5-0.5b-instruct", | ||||
| 	"qwen2-72b-instruct", "qwen2-57b-a14b-instruct", "qwen2-7b-instruct", "qwen2-1.5b-instruct", "qwen2-0.5b-instruct", | ||||
| 	"qwen1.5-110b-chat", "qwen1.5-72b-chat", "qwen1.5-32b-chat", "qwen1.5-14b-chat", "qwen1.5-7b-chat", "qwen1.5-1.8b-chat", "qwen1.5-0.5b-chat", | ||||
| 	"qwen-72b-chat", "qwen-14b-chat", "qwen-7b-chat", "qwen-1.8b-chat", "qwen-1.8b-longcontext-chat", | ||||
| 	"qwen2-vl-7b-instruct", "qwen2-vl-2b-instruct", "qwen-vl-v1", "qwen-vl-chat-v1", | ||||
| 	"qwen2-audio-instruct", "qwen-audio-chat", | ||||
| 	"qwen2.5-math-72b-instruct", "qwen2.5-math-7b-instruct", "qwen2.5-math-1.5b-instruct", "qwen2-math-72b-instruct", "qwen2-math-7b-instruct", "qwen2-math-1.5b-instruct", | ||||
| 	"qwen2.5-coder-32b-instruct", "qwen2.5-coder-14b-instruct", "qwen2.5-coder-7b-instruct", "qwen2.5-coder-3b-instruct", "qwen2.5-coder-1.5b-instruct", "qwen2.5-coder-0.5b-instruct", | ||||
| 	"text-embedding-v1", "text-embedding-v3", "text-embedding-v2", "text-embedding-async-v2", "text-embedding-async-v1", | ||||
| 	"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1", | ||||
| } | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package ali | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| @@ -35,9 +36,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 		enableSearch = true | ||||
| 		aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) | ||||
| 	} | ||||
| 	if request.TopP >= 1 { | ||||
| 		request.TopP = 0.9999 | ||||
| 	} | ||||
| 	request.TopP = helper.Float64PtrMax(request.TopP, 0.9999) | ||||
| 	return &ChatRequest{ | ||||
| 		Model: aliModel, | ||||
| 		Input: Input{ | ||||
| @@ -59,7 +58,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
|  | ||||
| func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { | ||||
| 	return &EmbeddingRequest{ | ||||
| 		Model: "text-embedding-v1", | ||||
| 		Model: request.Model, | ||||
| 		Input: struct { | ||||
| 			Texts []string `json:"texts"` | ||||
| 		}{ | ||||
| @@ -102,8 +101,9 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStat | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
|  | ||||
| 	requestModel := c.GetString(ctxkey.RequestModel) | ||||
| 	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) | ||||
| 	fullTextResponse.Model = requestModel | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
|   | ||||
| @@ -16,13 +16,13 @@ type Input struct { | ||||
| } | ||||
|  | ||||
| type Parameters struct { | ||||
| 	TopP              float64      `json:"top_p,omitempty"` | ||||
| 	TopP              *float64     `json:"top_p,omitempty"` | ||||
| 	TopK              int          `json:"top_k,omitempty"` | ||||
| 	Seed              uint64       `json:"seed,omitempty"` | ||||
| 	EnableSearch      bool         `json:"enable_search,omitempty"` | ||||
| 	IncrementalOutput bool         `json:"incremental_output,omitempty"` | ||||
| 	MaxTokens         int          `json:"max_tokens,omitempty"` | ||||
| 	Temperature       float64      `json:"temperature,omitempty"` | ||||
| 	Temperature       *float64     `json:"temperature,omitempty"` | ||||
| 	ResultFormat      string       `json:"result_format,omitempty"` | ||||
| 	Tools             []model.Tool `json:"tools,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -3,7 +3,10 @@ package anthropic | ||||
| var ModelList = []string{ | ||||
| 	"claude-instant-1.2", "claude-2.0", "claude-2.1", | ||||
| 	"claude-3-haiku-20240307", | ||||
| 	"claude-3-5-haiku-20241022", | ||||
| 	"claude-3-sonnet-20240229", | ||||
| 	"claude-3-opus-20240229", | ||||
| 	"claude-3-5-sonnet-20240620", | ||||
| 	"claude-3-5-sonnet-20241022", | ||||
| 	"claude-3-5-sonnet-latest", | ||||
| } | ||||
|   | ||||
| @@ -48,8 +48,8 @@ type Request struct { | ||||
| 	MaxTokens     int       `json:"max_tokens,omitempty"` | ||||
| 	StopSequences []string  `json:"stop_sequences,omitempty"` | ||||
| 	Stream        bool      `json:"stream,omitempty"` | ||||
| 	Temperature   float64   `json:"temperature,omitempty"` | ||||
| 	TopP          float64   `json:"top_p,omitempty"` | ||||
| 	Temperature   *float64  `json:"temperature,omitempty"` | ||||
| 	TopP          *float64  `json:"top_p,omitempty"` | ||||
| 	TopK          int       `json:"top_k,omitempty"` | ||||
| 	Tools         []Tool    `json:"tools,omitempty"` | ||||
| 	ToolChoice    any       `json:"tool_choice,omitempty"` | ||||
|   | ||||
| @@ -29,10 +29,13 @@ var AwsModelIDMap = map[string]string{ | ||||
| 	"claude-instant-1.2":         "anthropic.claude-instant-v1", | ||||
| 	"claude-2.0":                 "anthropic.claude-v2", | ||||
| 	"claude-2.1":                 "anthropic.claude-v2:1", | ||||
| 	"claude-3-sonnet-20240229":   "anthropic.claude-3-sonnet-20240229-v1:0", | ||||
| 	"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", | ||||
| 	"claude-3-opus-20240229":     "anthropic.claude-3-opus-20240229-v1:0", | ||||
| 	"claude-3-haiku-20240307":    "anthropic.claude-3-haiku-20240307-v1:0", | ||||
| 	"claude-3-sonnet-20240229":   "anthropic.claude-3-sonnet-20240229-v1:0", | ||||
| 	"claude-3-opus-20240229":     "anthropic.claude-3-opus-20240229-v1:0", | ||||
| 	"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", | ||||
| 	"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0", | ||||
| 	"claude-3-5-sonnet-latest":   "anthropic.claude-3-5-sonnet-20241022-v2:0", | ||||
| 	"claude-3-5-haiku-20241022":  "anthropic.claude-3-5-haiku-20241022-v1:0", | ||||
| } | ||||
|  | ||||
| func awsModelID(requestModel string) (string, error) { | ||||
|   | ||||
| @@ -11,8 +11,8 @@ type Request struct { | ||||
| 	Messages         []anthropic.Message `json:"messages"` | ||||
| 	System           string              `json:"system,omitempty"` | ||||
| 	MaxTokens        int                 `json:"max_tokens,omitempty"` | ||||
| 	Temperature      float64             `json:"temperature,omitempty"` | ||||
| 	TopP             float64             `json:"top_p,omitempty"` | ||||
| 	Temperature      *float64            `json:"temperature,omitempty"` | ||||
| 	TopP             *float64            `json:"top_p,omitempty"` | ||||
| 	TopK             int                 `json:"top_k,omitempty"` | ||||
| 	StopSequences    []string            `json:"stop_sequences,omitempty"` | ||||
| 	Tools            []anthropic.Tool    `json:"tools,omitempty"` | ||||
|   | ||||
| @@ -6,8 +6,8 @@ package aws | ||||
| type Request struct { | ||||
| 	Prompt      string   `json:"prompt"` | ||||
| 	MaxGenLen   int      `json:"max_gen_len,omitempty"` | ||||
| 	Temperature float64 `json:"temperature,omitempty"` | ||||
| 	TopP        float64 `json:"top_p,omitempty"` | ||||
| 	Temperature *float64 `json:"temperature,omitempty"` | ||||
| 	TopP        *float64 `json:"top_p,omitempty"` | ||||
| } | ||||
|  | ||||
| // Response is the response from AWS Llama3 | ||||
|   | ||||
| @@ -35,9 +35,9 @@ type Message struct { | ||||
|  | ||||
| type ChatRequest struct { | ||||
| 	Messages        []Message `json:"messages"` | ||||
| 	Temperature     float64   `json:"temperature,omitempty"` | ||||
| 	TopP            float64   `json:"top_p,omitempty"` | ||||
| 	PenaltyScore    float64   `json:"penalty_score,omitempty"` | ||||
| 	Temperature     *float64  `json:"temperature,omitempty"` | ||||
| 	TopP            *float64  `json:"top_p,omitempty"` | ||||
| 	PenaltyScore    *float64  `json:"penalty_score,omitempty"` | ||||
| 	Stream          bool      `json:"stream,omitempty"` | ||||
| 	System          string    `json:"system,omitempty"` | ||||
| 	DisableSearch   bool      `json:"disable_search,omitempty"` | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package cloudflare | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"@cf/meta/llama-3.1-8b-instruct", | ||||
| 	"@cf/meta/llama-2-7b-chat-fp16", | ||||
| 	"@cf/meta/llama-2-7b-chat-int8", | ||||
| 	"@cf/mistral/mistral-7b-instruct-v0.1", | ||||
|   | ||||
| @@ -9,5 +9,5 @@ type Request struct { | ||||
| 	Prompt      string          `json:"prompt,omitempty"` | ||||
| 	Raw         bool            `json:"raw,omitempty"` | ||||
| 	Stream      bool            `json:"stream,omitempty"` | ||||
| 	Temperature float64         `json:"temperature,omitempty"` | ||||
| 	Temperature *float64        `json:"temperature,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -43,7 +43,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 		K:                textRequest.TopK, | ||||
| 		Stream:           textRequest.Stream, | ||||
| 		FrequencyPenalty: textRequest.FrequencyPenalty, | ||||
| 		PresencePenalty:  textRequest.FrequencyPenalty, | ||||
| 		PresencePenalty:  textRequest.PresencePenalty, | ||||
| 		Seed:             int(textRequest.Seed), | ||||
| 	} | ||||
| 	if cohereRequest.Model == "" { | ||||
|   | ||||
| @@ -10,15 +10,15 @@ type Request struct { | ||||
| 	PromptTruncation string        `json:"prompt_truncation,omitempty"` // 默认值为"AUTO" | ||||
| 	Connectors       []Connector   `json:"connectors,omitempty"` | ||||
| 	Documents        []Document    `json:"documents,omitempty"` | ||||
| 	Temperature      float64       `json:"temperature,omitempty"` // 默认值为0.3 | ||||
| 	Temperature      *float64      `json:"temperature,omitempty"` // 默认值为0.3 | ||||
| 	MaxTokens        int           `json:"max_tokens,omitempty"` | ||||
| 	MaxInputTokens   int           `json:"max_input_tokens,omitempty"` | ||||
| 	K                int           `json:"k,omitempty"` // 默认值为0 | ||||
| 	P                float64       `json:"p,omitempty"` // 默认值为0.75 | ||||
| 	P                *float64      `json:"p,omitempty"` // 默认值为0.75 | ||||
| 	Seed             int           `json:"seed,omitempty"` | ||||
| 	StopSequences    []string      `json:"stop_sequences,omitempty"` | ||||
| 	FrequencyPenalty float64       `json:"frequency_penalty,omitempty"` // 默认值为0.0 | ||||
| 	PresencePenalty  float64       `json:"presence_penalty,omitempty"`  // 默认值为0.0 | ||||
| 	FrequencyPenalty *float64      `json:"frequency_penalty,omitempty"` // 默认值为0.0 | ||||
| 	PresencePenalty  *float64      `json:"presence_penalty,omitempty"`  // 默认值为0.0 | ||||
| 	Tools            []Tool        `json:"tools,omitempty"` | ||||
| 	ToolResults      []ToolResult  `json:"tool_results,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -24,7 +24,12 @@ func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion) | ||||
| 	defaultVersion := config.GeminiVersion | ||||
| 	if meta.ActualModelName == "gemini-2.0-flash-exp" { | ||||
| 		defaultVersion = "v1beta" | ||||
| 	} | ||||
|  | ||||
| 	version := helper.AssignOrDefault(meta.Config.APIVersion, defaultVersion) | ||||
| 	action := "" | ||||
| 	switch meta.Mode { | ||||
| 	case relaymode.Embeddings: | ||||
| @@ -36,6 +41,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	if meta.IsStream { | ||||
| 		action = "streamGenerateContent?alt=sse" | ||||
| 	} | ||||
|  | ||||
| 	return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -3,6 +3,9 @@ package gemini | ||||
| // https://ai.google.dev/models/gemini | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro", | ||||
| 	"gemini-pro-vision", "gemini-1.0-pro-vision-001", "embedding-001", "text-embedding-004", | ||||
| 	"gemini-pro", "gemini-1.0-pro", | ||||
| 	"gemini-1.5-flash", "gemini-1.5-pro", | ||||
| 	"text-embedding-004", "aqa", | ||||
| 	"gemini-2.0-flash-exp", | ||||
| 	"gemini-2.0-flash-thinking-exp", | ||||
| } | ||||
|   | ||||
| @@ -4,11 +4,12 @@ import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| @@ -28,6 +29,11 @@ const ( | ||||
| 	VisionMaxImageNum = 16 | ||||
| ) | ||||
|  | ||||
| var mimeTypeMap = map[string]string{ | ||||
| 	"json_object": "application/json", | ||||
| 	"text":        "text/plain", | ||||
| } | ||||
|  | ||||
| // Setting safety to the lowest possible values since Gemini is already powerless enough | ||||
| func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 	geminiRequest := ChatRequest{ | ||||
| @@ -49,6 +55,10 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 				Category:  "HARM_CATEGORY_DANGEROUS_CONTENT", | ||||
| 				Threshold: config.GeminiSafetySetting, | ||||
| 			}, | ||||
| 			{ | ||||
| 				Category:  "HARM_CATEGORY_CIVIC_INTEGRITY", | ||||
| 				Threshold: config.GeminiSafetySetting, | ||||
| 			}, | ||||
| 		}, | ||||
| 		GenerationConfig: ChatGenerationConfig{ | ||||
| 			Temperature:     textRequest.Temperature, | ||||
| @@ -56,6 +66,15 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 			MaxOutputTokens: textRequest.MaxTokens, | ||||
| 		}, | ||||
| 	} | ||||
| 	if textRequest.ResponseFormat != nil { | ||||
| 		if mimeType, ok := mimeTypeMap[textRequest.ResponseFormat.Type]; ok { | ||||
| 			geminiRequest.GenerationConfig.ResponseMimeType = mimeType | ||||
| 		} | ||||
| 		if textRequest.ResponseFormat.JsonSchema != nil { | ||||
| 			geminiRequest.GenerationConfig.ResponseSchema = textRequest.ResponseFormat.JsonSchema.Schema | ||||
| 			geminiRequest.GenerationConfig.ResponseMimeType = mimeTypeMap["json_object"] | ||||
| 		} | ||||
| 	} | ||||
| 	if textRequest.Tools != nil { | ||||
| 		functions := make([]model.Function, 0, len(textRequest.Tools)) | ||||
| 		for _, tool := range textRequest.Tools { | ||||
| @@ -232,7 +251,14 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 			if candidate.Content.Parts[0].FunctionCall != nil { | ||||
| 				choice.Message.ToolCalls = getToolCalls(&candidate) | ||||
| 			} else { | ||||
| 				choice.Message.Content = candidate.Content.Parts[0].Text | ||||
| 				var builder strings.Builder | ||||
| 				for _, part := range candidate.Content.Parts { | ||||
| 					if i > 0 { | ||||
| 						builder.WriteString("\n") | ||||
| 					} | ||||
| 					builder.WriteString(part.Text) | ||||
| 				} | ||||
| 				choice.Message.Content = builder.String() | ||||
| 			} | ||||
| 		} else { | ||||
| 			choice.Message.Content = "" | ||||
|   | ||||
| @@ -65,8 +65,10 @@ type ChatTools struct { | ||||
| } | ||||
|  | ||||
| type ChatGenerationConfig struct { | ||||
| 	Temperature     float64  `json:"temperature,omitempty"` | ||||
| 	TopP            float64  `json:"topP,omitempty"` | ||||
| 	ResponseMimeType string   `json:"responseMimeType,omitempty"` | ||||
| 	ResponseSchema   any      `json:"responseSchema,omitempty"` | ||||
| 	Temperature      *float64 `json:"temperature,omitempty"` | ||||
| 	TopP             *float64 `json:"topP,omitempty"` | ||||
| 	TopK             float64  `json:"topK,omitempty"` | ||||
| 	MaxOutputTokens  int      `json:"maxOutputTokens,omitempty"` | ||||
| 	CandidateCount   int      `json:"candidateCount,omitempty"` | ||||
|   | ||||
| @@ -4,9 +4,24 @@ package groq | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"gemma-7b-it", | ||||
| 	"llama2-7b-2048", | ||||
| 	"llama2-70b-4096", | ||||
| 	"mixtral-8x7b-32768", | ||||
| 	"llama3-8b-8192", | ||||
| 	"gemma2-9b-it", | ||||
| 	"llama-3.1-70b-versatile", | ||||
| 	"llama-3.1-8b-instant", | ||||
| 	"llama-3.2-11b-text-preview", | ||||
| 	"llama-3.2-11b-vision-preview", | ||||
| 	"llama-3.2-1b-preview", | ||||
| 	"llama-3.2-3b-preview", | ||||
| 	"llama-3.2-11b-vision-preview", | ||||
| 	"llama-3.2-90b-text-preview", | ||||
| 	"llama-3.2-90b-vision-preview", | ||||
| 	"llama-guard-3-8b", | ||||
| 	"llama3-70b-8192", | ||||
| 	"llama3-8b-8192", | ||||
| 	"llama3-groq-70b-8192-tool-use-preview", | ||||
| 	"llama3-groq-8b-8192-tool-use-preview", | ||||
| 	"llava-v1.5-7b-4096-preview", | ||||
| 	"mixtral-8x7b-32768", | ||||
| 	"distil-whisper-large-v3-en", | ||||
| 	"whisper-large-v3", | ||||
| 	"whisper-large-v3-turbo", | ||||
| } | ||||
|   | ||||
| @@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	// https://github.com/ollama/ollama/blob/main/docs/api.md | ||||
| 	fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) | ||||
| 	if meta.Mode == relaymode.Embeddings { | ||||
| 		fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL) | ||||
| 		fullRequestURL = fmt.Sprintf("%s/api/embed", meta.BaseURL) | ||||
| 	} | ||||
| 	return fullRequestURL, nil | ||||
| } | ||||
|   | ||||
| @@ -31,6 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 			TopP:             request.TopP, | ||||
| 			FrequencyPenalty: request.FrequencyPenalty, | ||||
| 			PresencePenalty:  request.PresencePenalty, | ||||
| 			NumPredict:       request.MaxTokens, | ||||
| 			NumCtx:           request.NumCtx, | ||||
| 		}, | ||||
| 		Stream: request.Stream, | ||||
| 	} | ||||
| @@ -118,8 +120,10 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 	common.SetEventStreamHeaders(c) | ||||
|  | ||||
| 	for scanner.Scan() { | ||||
| 		data := strings.TrimPrefix(scanner.Text(), "}") | ||||
| 		data = data + "}" | ||||
| 		data := scanner.Text() | ||||
| 		if strings.HasPrefix(data, "}") { | ||||
| 			data = strings.TrimPrefix(data, "}") + "}" | ||||
| 		} | ||||
|  | ||||
| 		var ollamaResponse ChatResponse | ||||
| 		err := json.Unmarshal([]byte(data), &ollamaResponse) | ||||
| @@ -158,7 +162,14 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { | ||||
| 	return &EmbeddingRequest{ | ||||
| 		Model: request.Model, | ||||
| 		Prompt: strings.Join(request.ParseInput(), " "), | ||||
| 		Input: request.ParseInput(), | ||||
| 		Options: &Options{ | ||||
| 			Seed:             int(request.Seed), | ||||
| 			Temperature:      request.Temperature, | ||||
| 			TopP:             request.TopP, | ||||
| 			FrequencyPenalty: request.FrequencyPenalty, | ||||
| 			PresencePenalty:  request.PresencePenalty, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -201,15 +212,17 @@ func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.Embeddi | ||||
| 	openAIEmbeddingResponse := openai.EmbeddingResponse{ | ||||
| 		Object: "list", | ||||
| 		Data:   make([]openai.EmbeddingResponseItem, 0, 1), | ||||
| 		Model:  "text-embedding-v1", | ||||
| 		Model:  response.Model, | ||||
| 		Usage:  model.Usage{TotalTokens: 0}, | ||||
| 	} | ||||
|  | ||||
| 	for i, embedding := range response.Embeddings { | ||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ | ||||
| 			Object:    `embedding`, | ||||
| 		Index:     0, | ||||
| 		Embedding: response.Embedding, | ||||
| 			Index:     i, | ||||
| 			Embedding: embedding, | ||||
| 		}) | ||||
| 	} | ||||
| 	return &openAIEmbeddingResponse | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -2,11 +2,13 @@ package ollama | ||||
|  | ||||
| type Options struct { | ||||
| 	Seed             int      `json:"seed,omitempty"` | ||||
| 	Temperature      float64 `json:"temperature,omitempty"` | ||||
| 	Temperature      *float64 `json:"temperature,omitempty"` | ||||
| 	TopK             int      `json:"top_k,omitempty"` | ||||
| 	TopP             float64 `json:"top_p,omitempty"` | ||||
| 	FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` | ||||
| 	PresencePenalty  float64 `json:"presence_penalty,omitempty"` | ||||
| 	TopP             *float64 `json:"top_p,omitempty"` | ||||
| 	FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` | ||||
| 	PresencePenalty  *float64 `json:"presence_penalty,omitempty"` | ||||
| 	NumPredict       int      `json:"num_predict,omitempty"` | ||||
| 	NumCtx           int      `json:"num_ctx,omitempty"` | ||||
| } | ||||
|  | ||||
| type Message struct { | ||||
| @@ -38,10 +40,14 @@ type ChatResponse struct { | ||||
|  | ||||
| type EmbeddingRequest struct { | ||||
| 	Model string   `json:"model"` | ||||
| 	Prompt string `json:"prompt"` | ||||
| 	Input []string `json:"input"` | ||||
| 	// Truncate  bool     `json:"truncate,omitempty"` | ||||
| 	Options *Options `json:"options,omitempty"` | ||||
| 	// KeepAlive string   `json:"keep_alive,omitempty"` | ||||
| } | ||||
|  | ||||
| type EmbeddingResponse struct { | ||||
| 	Error      string      `json:"error,omitempty"` | ||||
| 	Embedding []float64 `json:"embedding,omitempty"` | ||||
| 	Model      string      `json:"model"` | ||||
| 	Embeddings [][]float64 `json:"embeddings"` | ||||
| } | ||||
|   | ||||
| @@ -75,6 +75,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	if request.Stream { | ||||
| 		// always return usage in stream mode | ||||
| 		if request.StreamOptions == nil { | ||||
| 			request.StreamOptions = &model.StreamOptions{} | ||||
| 		} | ||||
| 		request.StreamOptions.IncludeUsage = true | ||||
| 	} | ||||
| 	return request, nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -11,8 +11,10 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/mistral" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/moonshot" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/novita" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/siliconflow" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/stepfun" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/togetherai" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/xai" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| ) | ||||
|  | ||||
| @@ -30,6 +32,8 @@ var CompatibleChannels = []int{ | ||||
| 	channeltype.DeepSeek, | ||||
| 	channeltype.TogetherAI, | ||||
| 	channeltype.Novita, | ||||
| 	channeltype.SiliconFlow, | ||||
| 	channeltype.XAI, | ||||
| } | ||||
|  | ||||
| func GetCompatibleChannelMeta(channelType int) (string, []string) { | ||||
| @@ -60,6 +64,10 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) { | ||||
| 		return "doubao", doubao.ModelList | ||||
| 	case channeltype.Novita: | ||||
| 		return "novita", novita.ModelList | ||||
| 	case channeltype.SiliconFlow: | ||||
| 		return "siliconflow", siliconflow.ModelList | ||||
| 	case channeltype.XAI: | ||||
| 		return "xai", xai.ModelList | ||||
| 	default: | ||||
| 		return "openai", ModelList | ||||
| 	} | ||||
|   | ||||
| @@ -8,6 +8,9 @@ var ModelList = []string{ | ||||
| 	"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", | ||||
| 	"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", | ||||
| 	"gpt-4o", "gpt-4o-2024-05-13", | ||||
| 	"gpt-4o-2024-08-06", | ||||
| 	"gpt-4o-2024-11-20", | ||||
| 	"chatgpt-4o-latest", | ||||
| 	"gpt-4o-mini", "gpt-4o-mini-2024-07-18", | ||||
| 	"gpt-4-vision-preview", | ||||
| 	"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", | ||||
| @@ -18,4 +21,7 @@ var ModelList = []string{ | ||||
| 	"dall-e-2", "dall-e-3", | ||||
| 	"whisper-1", | ||||
| 	"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106", | ||||
| 	"o1", "o1-2024-12-17", | ||||
| 	"o1-preview", "o1-preview-2024-09-12", | ||||
| 	"o1-mini", "o1-mini-2024-09-12", | ||||
| } | ||||
|   | ||||
| @@ -2,15 +2,16 @@ package openai | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage { | ||||
| func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage { | ||||
| 	usage := &model.Usage{} | ||||
| 	usage.PromptTokens = promptTokens | ||||
| 	usage.CompletionTokens = CountTokenText(responseText, modeName) | ||||
| 	usage.CompletionTokens = CountTokenText(responseText, modelName) | ||||
| 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens | ||||
| 	return usage | ||||
| } | ||||
|   | ||||
| @@ -55,8 +55,8 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | ||||
| 				render.StringData(c, data) // if error happened, pass the data to client | ||||
| 				continue                   // just ignore the error | ||||
| 			} | ||||
| 			if len(streamResponse.Choices) == 0 { | ||||
| 				// but for empty choice, we should not pass it to client, this is for azure | ||||
| 			if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil { | ||||
| 				// but for empty choice and no usage, we should not pass it to client, this is for azure | ||||
| 				continue // just ignore empty choice | ||||
| 			} | ||||
| 			render.StringData(c, data) | ||||
|   | ||||
| @@ -1,8 +1,16 @@ | ||||
| package openai | ||||
|  | ||||
| import "github.com/songquanpeng/one-api/relay/model" | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode { | ||||
| 	logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err)) | ||||
|  | ||||
| 	Error := model.Error{ | ||||
| 		Message: err.Error(), | ||||
| 		Type:    "one_api_error", | ||||
|   | ||||
| @@ -20,9 +20,9 @@ type Prompt struct { | ||||
|  | ||||
| type ChatRequest struct { | ||||
| 	Prompt         Prompt   `json:"prompt"` | ||||
| 	Temperature    float64 `json:"temperature,omitempty"` | ||||
| 	Temperature    *float64 `json:"temperature,omitempty"` | ||||
| 	CandidateCount int      `json:"candidateCount,omitempty"` | ||||
| 	TopP           float64 `json:"topP,omitempty"` | ||||
| 	TopP           *float64 `json:"topP,omitempty"` | ||||
| 	TopK           int      `json:"topK,omitempty"` | ||||
| } | ||||
|  | ||||
|   | ||||
							
								
								
									
										136
									
								
								relay/adaptor/replicate/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								relay/adaptor/replicate/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,136 @@ | ||||
| package replicate | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"slices" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| 	meta *meta.Meta | ||||
| } | ||||
|  | ||||
| // ConvertImageRequest implements adaptor.Adaptor. | ||||
| func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| 	return DrawImageRequest{ | ||||
| 		Input: ImageInput{ | ||||
| 			Steps:           25, | ||||
| 			Prompt:          request.Prompt, | ||||
| 			Guidance:        3, | ||||
| 			Seed:            int(time.Now().UnixNano()), | ||||
| 			SafetyTolerance: 5, | ||||
| 			NImages:         1, // replicate will always return 1 image | ||||
| 			Width:           1440, | ||||
| 			Height:          1440, | ||||
| 			AspectRatio:     "1:1", | ||||
| 		}, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	if !request.Stream { | ||||
| 		// TODO: support non-stream mode | ||||
| 		return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true") | ||||
| 	} | ||||
|  | ||||
| 	// Build the prompt from OpenAI messages | ||||
| 	var promptBuilder strings.Builder | ||||
| 	for _, message := range request.Messages { | ||||
| 		switch msgCnt := message.Content.(type) { | ||||
| 		case string: | ||||
| 			promptBuilder.WriteString(message.Role) | ||||
| 			promptBuilder.WriteString(": ") | ||||
| 			promptBuilder.WriteString(msgCnt) | ||||
| 			promptBuilder.WriteString("\n") | ||||
| 		default: | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	replicateRequest := ReplicateChatRequest{ | ||||
| 		Input: ChatInput{ | ||||
| 			Prompt:           promptBuilder.String(), | ||||
| 			MaxTokens:        request.MaxTokens, | ||||
| 			Temperature:      1.0, | ||||
| 			TopP:             1.0, | ||||
| 			PresencePenalty:  0.0, | ||||
| 			FrequencyPenalty: 0.0, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	// Map optional fields | ||||
| 	if request.Temperature != nil { | ||||
| 		replicateRequest.Input.Temperature = *request.Temperature | ||||
| 	} | ||||
| 	if request.TopP != nil { | ||||
| 		replicateRequest.Input.TopP = *request.TopP | ||||
| 	} | ||||
| 	if request.PresencePenalty != nil { | ||||
| 		replicateRequest.Input.PresencePenalty = *request.PresencePenalty | ||||
| 	} | ||||
| 	if request.FrequencyPenalty != nil { | ||||
| 		replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty | ||||
| 	} | ||||
| 	if request.MaxTokens > 0 { | ||||
| 		replicateRequest.Input.MaxTokens = request.MaxTokens | ||||
| 	} else if request.MaxTokens == 0 { | ||||
| 		replicateRequest.Input.MaxTokens = 500 | ||||
| 	} | ||||
|  | ||||
| 	return replicateRequest, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| 	a.meta = meta | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	if !slices.Contains(ModelList, meta.OriginModelName) { | ||||
| 		return "", errors.Errorf("model %s not supported", meta.OriginModelName) | ||||
| 	} | ||||
|  | ||||
| 	return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| 	adaptor.SetupCommonRequestHeader(c, req, meta) | ||||
| 	req.Header.Set("Authorization", "Bearer "+meta.APIKey) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	logger.Info(c, "send request to replicate") | ||||
| 	return adaptor.DoRequestHelper(a, c, meta, requestBody) | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	switch meta.Mode { | ||||
| 	case relaymode.ImagesGenerations: | ||||
| 		err, usage = ImageHandler(c, resp) | ||||
| 	case relaymode.ChatCompletions: | ||||
| 		err, usage = ChatHandler(c, resp) | ||||
| 	default: | ||||
| 		err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetModelList() []string { | ||||
| 	return ModelList | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetChannelName() string { | ||||
| 	return "replicate" | ||||
| } | ||||
							
								
								
									
										191
									
								
								relay/adaptor/replicate/chat.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								relay/adaptor/replicate/chat.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,191 @@ | ||||
| package replicate | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| func ChatHandler(c *gin.Context, resp *http.Response) ( | ||||
| 	srvErr *model.ErrorWithStatusCode, usage *model.Usage) { | ||||
| 	if resp.StatusCode != http.StatusCreated { | ||||
| 		payload, _ := io.ReadAll(resp.Body) | ||||
| 		return openai.ErrorWrapper( | ||||
| 				errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)), | ||||
| 				"bad_status_code", http.StatusInternalServerError), | ||||
| 			nil | ||||
| 	} | ||||
|  | ||||
| 	respBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	respData := new(ChatResponse) | ||||
| 	if err = json.Unmarshal(respBody, respData); err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	for { | ||||
| 		err = func() error { | ||||
| 			// get task | ||||
| 			taskReq, err := http.NewRequestWithContext(c.Request.Context(), | ||||
| 				http.MethodGet, respData.URLs.Get, nil) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "new request") | ||||
| 			} | ||||
|  | ||||
| 			taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) | ||||
| 			taskResp, err := http.DefaultClient.Do(taskReq) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "get task") | ||||
| 			} | ||||
| 			defer taskResp.Body.Close() | ||||
|  | ||||
| 			if taskResp.StatusCode != http.StatusOK { | ||||
| 				payload, _ := io.ReadAll(taskResp.Body) | ||||
| 				return errors.Errorf("bad status code [%d]%s", | ||||
| 					taskResp.StatusCode, string(payload)) | ||||
| 			} | ||||
|  | ||||
| 			taskBody, err := io.ReadAll(taskResp.Body) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "read task response") | ||||
| 			} | ||||
|  | ||||
| 			taskData := new(ChatResponse) | ||||
| 			if err = json.Unmarshal(taskBody, taskData); err != nil { | ||||
| 				return errors.Wrap(err, "decode task response") | ||||
| 			} | ||||
|  | ||||
| 			switch taskData.Status { | ||||
| 			case "succeeded": | ||||
| 			case "failed", "canceled": | ||||
| 				return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error) | ||||
| 			default: | ||||
| 				time.Sleep(time.Second * 3) | ||||
| 				return errNextLoop | ||||
| 			} | ||||
|  | ||||
| 			if taskData.URLs.Stream == "" { | ||||
| 				return errors.New("stream url is empty") | ||||
| 			} | ||||
|  | ||||
| 			// request stream url | ||||
| 			responseText, err := chatStreamHandler(c, taskData.URLs.Stream) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "chat stream handler") | ||||
| 			} | ||||
|  | ||||
| 			ctxMeta := meta.GetByContext(c) | ||||
| 			usage = openai.ResponseText2Usage(responseText, | ||||
| 				ctxMeta.ActualModelName, ctxMeta.PromptTokens) | ||||
| 			return nil | ||||
| 		}() | ||||
| 		if err != nil { | ||||
| 			if errors.Is(err, errNextLoop) { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
|  | ||||
| 		break | ||||
| 	} | ||||
|  | ||||
| 	return nil, usage | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	eventPrefix = "event: " | ||||
| 	dataPrefix  = "data: " | ||||
| 	done        = "[DONE]" | ||||
| ) | ||||
|  | ||||
| func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) { | ||||
| 	// request stream endpoint | ||||
| 	streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil) | ||||
| 	if err != nil { | ||||
| 		return "", errors.Wrap(err, "new request to stream") | ||||
| 	} | ||||
|  | ||||
| 	streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) | ||||
| 	streamReq.Header.Set("Accept", "text/event-stream") | ||||
| 	streamReq.Header.Set("Cache-Control", "no-store") | ||||
|  | ||||
| 	resp, err := http.DefaultClient.Do(streamReq) | ||||
| 	if err != nil { | ||||
| 		return "", errors.Wrap(err, "do request to stream") | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 		payload, _ := io.ReadAll(resp.Body) | ||||
| 		return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload)) | ||||
| 	} | ||||
|  | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(bufio.ScanLines) | ||||
|  | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	doneRendered := false | ||||
| 	for scanner.Scan() { | ||||
| 		line := strings.TrimSpace(scanner.Text()) | ||||
| 		if line == "" { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// Handle comments starting with ':' | ||||
| 		if strings.HasPrefix(line, ":") { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// Parse SSE fields | ||||
| 		if strings.HasPrefix(line, eventPrefix) { | ||||
| 			event := strings.TrimSpace(line[len(eventPrefix):]) | ||||
| 			var data string | ||||
| 			// Read the following lines to get data and id | ||||
| 			for scanner.Scan() { | ||||
| 				nextLine := scanner.Text() | ||||
| 				if nextLine == "" { | ||||
| 					break | ||||
| 				} | ||||
| 				if strings.HasPrefix(nextLine, dataPrefix) { | ||||
| 					data = nextLine[len(dataPrefix):] | ||||
| 				} else if strings.HasPrefix(nextLine, "id:") { | ||||
| 					// id = strings.TrimSpace(nextLine[len("id:"):]) | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			if event == "output" { | ||||
| 				render.StringData(c, data) | ||||
| 				responseText += data | ||||
| 			} else if event == "done" { | ||||
| 				render.Done(c) | ||||
| 				doneRendered = true | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := scanner.Err(); err != nil { | ||||
| 		return "", errors.Wrap(err, "scan stream") | ||||
| 	} | ||||
|  | ||||
| 	if !doneRendered { | ||||
| 		render.Done(c) | ||||
| 	} | ||||
|  | ||||
| 	return responseText, nil | ||||
| } | ||||
							
								
								
									
										58
									
								
								relay/adaptor/replicate/constant.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								relay/adaptor/replicate/constant.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | ||||
| package replicate | ||||
|  | ||||
| // ModelList is a list of models that can be used with Replicate. | ||||
| // | ||||
| // https://replicate.com/pricing | ||||
| var ModelList = []string{ | ||||
| 	// ------------------------------------- | ||||
| 	// image model | ||||
| 	// ------------------------------------- | ||||
| 	"black-forest-labs/flux-1.1-pro", | ||||
| 	"black-forest-labs/flux-1.1-pro-ultra", | ||||
| 	"black-forest-labs/flux-canny-dev", | ||||
| 	"black-forest-labs/flux-canny-pro", | ||||
| 	"black-forest-labs/flux-depth-dev", | ||||
| 	"black-forest-labs/flux-depth-pro", | ||||
| 	"black-forest-labs/flux-dev", | ||||
| 	"black-forest-labs/flux-dev-lora", | ||||
| 	"black-forest-labs/flux-fill-dev", | ||||
| 	"black-forest-labs/flux-fill-pro", | ||||
| 	"black-forest-labs/flux-pro", | ||||
| 	"black-forest-labs/flux-redux-dev", | ||||
| 	"black-forest-labs/flux-redux-schnell", | ||||
| 	"black-forest-labs/flux-schnell", | ||||
| 	"black-forest-labs/flux-schnell-lora", | ||||
| 	"ideogram-ai/ideogram-v2", | ||||
| 	"ideogram-ai/ideogram-v2-turbo", | ||||
| 	"recraft-ai/recraft-v3", | ||||
| 	"recraft-ai/recraft-v3-svg", | ||||
| 	"stability-ai/stable-diffusion-3", | ||||
| 	"stability-ai/stable-diffusion-3.5-large", | ||||
| 	"stability-ai/stable-diffusion-3.5-large-turbo", | ||||
| 	"stability-ai/stable-diffusion-3.5-medium", | ||||
| 	// ------------------------------------- | ||||
| 	// language model | ||||
| 	// ------------------------------------- | ||||
| 	"ibm-granite/granite-20b-code-instruct-8k", | ||||
| 	"ibm-granite/granite-3.0-2b-instruct", | ||||
| 	"ibm-granite/granite-3.0-8b-instruct", | ||||
| 	"ibm-granite/granite-8b-code-instruct-128k", | ||||
| 	"meta/llama-2-13b", | ||||
| 	"meta/llama-2-13b-chat", | ||||
| 	"meta/llama-2-70b", | ||||
| 	"meta/llama-2-70b-chat", | ||||
| 	"meta/llama-2-7b", | ||||
| 	"meta/llama-2-7b-chat", | ||||
| 	"meta/meta-llama-3.1-405b-instruct", | ||||
| 	"meta/meta-llama-3-70b", | ||||
| 	"meta/meta-llama-3-70b-instruct", | ||||
| 	"meta/meta-llama-3-8b", | ||||
| 	"meta/meta-llama-3-8b-instruct", | ||||
| 	"mistralai/mistral-7b-instruct-v0.2", | ||||
| 	"mistralai/mistral-7b-v0.1", | ||||
| 	"mistralai/mixtral-8x7b-instruct-v0.1", | ||||
| 	// ------------------------------------- | ||||
| 	// video model | ||||
| 	// ------------------------------------- | ||||
| 	// "minimax/video-01",  // TODO: implement the adaptor | ||||
| } | ||||
							
								
								
									
										222
									
								
								relay/adaptor/replicate/image.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										222
									
								
								relay/adaptor/replicate/image.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,222 @@ | ||||
| package replicate | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"image" | ||||
| 	"image/png" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"golang.org/x/image/webp" | ||||
| 	"golang.org/x/sync/errgroup" | ||||
| ) | ||||
|  | ||||
| // ImagesEditsHandler just copy response body to client | ||||
| // | ||||
| // https://replicate.com/black-forest-labs/flux-fill-pro | ||||
| // func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| // 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| // 	for k, v := range resp.Header { | ||||
| // 		c.Writer.Header().Set(k, v[0]) | ||||
| // 	} | ||||
|  | ||||
| // 	if _, err := io.Copy(c.Writer, resp.Body); err != nil { | ||||
| // 		return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil | ||||
| // 	} | ||||
| // 	defer resp.Body.Close() | ||||
|  | ||||
| // 	return nil, nil | ||||
| // } | ||||
|  | ||||
| var errNextLoop = errors.New("next_loop") | ||||
|  | ||||
| func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	if resp.StatusCode != http.StatusCreated { | ||||
| 		payload, _ := io.ReadAll(resp.Body) | ||||
| 		return openai.ErrorWrapper( | ||||
| 				errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)), | ||||
| 				"bad_status_code", http.StatusInternalServerError), | ||||
| 			nil | ||||
| 	} | ||||
|  | ||||
| 	respBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	respData := new(ImageResponse) | ||||
| 	if err = json.Unmarshal(respBody, respData); err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	for { | ||||
| 		err = func() error { | ||||
| 			// get task | ||||
| 			taskReq, err := http.NewRequestWithContext(c.Request.Context(), | ||||
| 				http.MethodGet, respData.URLs.Get, nil) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "new request") | ||||
| 			} | ||||
|  | ||||
| 			taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) | ||||
| 			taskResp, err := http.DefaultClient.Do(taskReq) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "get task") | ||||
| 			} | ||||
| 			defer taskResp.Body.Close() | ||||
|  | ||||
| 			if taskResp.StatusCode != http.StatusOK { | ||||
| 				payload, _ := io.ReadAll(taskResp.Body) | ||||
| 				return errors.Errorf("bad status code [%d]%s", | ||||
| 					taskResp.StatusCode, string(payload)) | ||||
| 			} | ||||
|  | ||||
| 			taskBody, err := io.ReadAll(taskResp.Body) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "read task response") | ||||
| 			} | ||||
|  | ||||
| 			taskData := new(ImageResponse) | ||||
| 			if err = json.Unmarshal(taskBody, taskData); err != nil { | ||||
| 				return errors.Wrap(err, "decode task response") | ||||
| 			} | ||||
|  | ||||
| 			switch taskData.Status { | ||||
| 			case "succeeded": | ||||
| 			case "failed", "canceled": | ||||
| 				return errors.Errorf("task failed: %s", taskData.Status) | ||||
| 			default: | ||||
| 				time.Sleep(time.Second * 3) | ||||
| 				return errNextLoop | ||||
| 			} | ||||
|  | ||||
| 			output, err := taskData.GetOutput() | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "get output") | ||||
| 			} | ||||
| 			if len(output) == 0 { | ||||
| 				return errors.New("response output is empty") | ||||
| 			} | ||||
|  | ||||
| 			var mu sync.Mutex | ||||
| 			var pool errgroup.Group | ||||
| 			respBody := &openai.ImageResponse{ | ||||
| 				Created: taskData.CompletedAt.Unix(), | ||||
| 				Data:    []openai.ImageData{}, | ||||
| 			} | ||||
|  | ||||
| 			for _, imgOut := range output { | ||||
| 				imgOut := imgOut | ||||
| 				pool.Go(func() error { | ||||
| 					// download image | ||||
| 					downloadReq, err := http.NewRequestWithContext(c.Request.Context(), | ||||
| 						http.MethodGet, imgOut, nil) | ||||
| 					if err != nil { | ||||
| 						return errors.Wrap(err, "new request") | ||||
| 					} | ||||
|  | ||||
| 					imgResp, err := http.DefaultClient.Do(downloadReq) | ||||
| 					if err != nil { | ||||
| 						return errors.Wrap(err, "download image") | ||||
| 					} | ||||
| 					defer imgResp.Body.Close() | ||||
|  | ||||
| 					if imgResp.StatusCode != http.StatusOK { | ||||
| 						payload, _ := io.ReadAll(imgResp.Body) | ||||
| 						return errors.Errorf("bad status code [%d]%s", | ||||
| 							imgResp.StatusCode, string(payload)) | ||||
| 					} | ||||
|  | ||||
| 					imgData, err := io.ReadAll(imgResp.Body) | ||||
| 					if err != nil { | ||||
| 						return errors.Wrap(err, "read image") | ||||
| 					} | ||||
|  | ||||
| 					imgData, err = ConvertImageToPNG(imgData) | ||||
| 					if err != nil { | ||||
| 						return errors.Wrap(err, "convert image") | ||||
| 					} | ||||
|  | ||||
| 					mu.Lock() | ||||
| 					respBody.Data = append(respBody.Data, openai.ImageData{ | ||||
| 						B64Json: fmt.Sprintf("data:image/png;base64,%s", | ||||
| 							base64.StdEncoding.EncodeToString(imgData)), | ||||
| 					}) | ||||
| 					mu.Unlock() | ||||
|  | ||||
| 					return nil | ||||
| 				}) | ||||
| 			} | ||||
|  | ||||
| 			if err := pool.Wait(); err != nil { | ||||
| 				if len(respBody.Data) == 0 { | ||||
| 					return errors.WithStack(err) | ||||
| 				} | ||||
|  | ||||
| 				logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err)) | ||||
| 			} | ||||
|  | ||||
| 			c.JSON(http.StatusOK, respBody) | ||||
| 			return nil | ||||
| 		}() | ||||
| 		if err != nil { | ||||
| 			if errors.Is(err, errNextLoop) { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
|  | ||||
| 		break | ||||
| 	} | ||||
|  | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| // ConvertImageToPNG converts a WebP image to PNG format | ||||
| func ConvertImageToPNG(webpData []byte) ([]byte, error) { | ||||
| 	// bypass if it's already a PNG image | ||||
| 	if bytes.HasPrefix(webpData, []byte("\x89PNG")) { | ||||
| 		return webpData, nil | ||||
| 	} | ||||
|  | ||||
| 	// check if is jpeg, convert to png | ||||
| 	if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) { | ||||
| 		img, _, err := image.Decode(bytes.NewReader(webpData)) | ||||
| 		if err != nil { | ||||
| 			return nil, errors.Wrap(err, "decode jpeg") | ||||
| 		} | ||||
|  | ||||
| 		var pngBuffer bytes.Buffer | ||||
| 		if err := png.Encode(&pngBuffer, img); err != nil { | ||||
| 			return nil, errors.Wrap(err, "encode png") | ||||
| 		} | ||||
|  | ||||
| 		return pngBuffer.Bytes(), nil | ||||
| 	} | ||||
|  | ||||
| 	// Decode the WebP image | ||||
| 	img, err := webp.Decode(bytes.NewReader(webpData)) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "decode webp") | ||||
| 	} | ||||
|  | ||||
| 	// Encode the image as PNG | ||||
| 	var pngBuffer bytes.Buffer | ||||
| 	if err := png.Encode(&pngBuffer, img); err != nil { | ||||
| 		return nil, errors.Wrap(err, "encode png") | ||||
| 	} | ||||
|  | ||||
| 	return pngBuffer.Bytes(), nil | ||||
| } | ||||
							
								
								
									
										159
									
								
								relay/adaptor/replicate/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								relay/adaptor/replicate/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,159 @@ | ||||
| package replicate | ||||
|  | ||||
| import ( | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/pkg/errors" | ||||
| ) | ||||
|  | ||||
| // DrawImageRequest draw image by fluxpro | ||||
| // | ||||
| // https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json | ||||
| type DrawImageRequest struct { | ||||
| 	Input ImageInput `json:"input"` | ||||
| } | ||||
|  | ||||
| // ImageInput is input of DrawImageByFluxProRequest | ||||
| // | ||||
| // https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema | ||||
| type ImageInput struct { | ||||
| 	Steps           int    `json:"steps" binding:"required,min=1"` | ||||
| 	Prompt          string `json:"prompt" binding:"required,min=5"` | ||||
| 	ImagePrompt     string `json:"image_prompt"` | ||||
| 	Guidance        int    `json:"guidance" binding:"required,min=2,max=5"` | ||||
| 	Interval        int    `json:"interval" binding:"required,min=1,max=4"` | ||||
| 	AspectRatio     string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"` | ||||
| 	SafetyTolerance int    `json:"safety_tolerance" binding:"required,min=1,max=5"` | ||||
| 	Seed            int    `json:"seed"` | ||||
| 	NImages         int    `json:"n_images" binding:"required,min=1,max=8"` | ||||
| 	Width           int    `json:"width" binding:"required,min=256,max=1440"` | ||||
| 	Height          int    `json:"height" binding:"required,min=256,max=1440"` | ||||
| } | ||||
|  | ||||
| // InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro | ||||
| // | ||||
| // https://replicate.com/black-forest-labs/flux-fill-pro/api/schema | ||||
| type InpaintingImageByFlusReplicateRequest struct { | ||||
| 	Input FluxInpaintingInput `json:"input"` | ||||
| } | ||||
|  | ||||
| // FluxInpaintingInput is input of DrawImageByFluxProRequest | ||||
| // | ||||
| // https://replicate.com/black-forest-labs/flux-fill-pro/api/schema | ||||
| type FluxInpaintingInput struct { | ||||
| 	Mask             string `json:"mask" binding:"required"` | ||||
| 	Image            string `json:"image" binding:"required"` | ||||
| 	Seed             int    `json:"seed"` | ||||
| 	Steps            int    `json:"steps" binding:"required,min=1"` | ||||
| 	Prompt           string `json:"prompt" binding:"required,min=5"` | ||||
| 	Guidance         int    `json:"guidance" binding:"required,min=2,max=5"` | ||||
| 	OutputFormat     string `json:"output_format"` | ||||
| 	SafetyTolerance  int    `json:"safety_tolerance" binding:"required,min=1,max=5"` | ||||
| 	PromptUnsampling bool   `json:"prompt_unsampling"` | ||||
| } | ||||
|  | ||||
| // ImageResponse is response of DrawImageByFluxProRequest | ||||
| // | ||||
| // https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json | ||||
| type ImageResponse struct { | ||||
| 	CompletedAt time.Time        `json:"completed_at"` | ||||
| 	CreatedAt   time.Time        `json:"created_at"` | ||||
| 	DataRemoved bool             `json:"data_removed"` | ||||
| 	Error       string           `json:"error"` | ||||
| 	ID          string           `json:"id"` | ||||
| 	Input       DrawImageRequest `json:"input"` | ||||
| 	Logs        string           `json:"logs"` | ||||
| 	Metrics     FluxMetrics      `json:"metrics"` | ||||
| 	// Output could be `string` or `[]string` | ||||
| 	Output    any       `json:"output"` | ||||
| 	StartedAt time.Time `json:"started_at"` | ||||
| 	Status    string    `json:"status"` | ||||
| 	URLs      FluxURLs  `json:"urls"` | ||||
| 	Version   string    `json:"version"` | ||||
| } | ||||
|  | ||||
| func (r *ImageResponse) GetOutput() ([]string, error) { | ||||
| 	switch v := r.Output.(type) { | ||||
| 	case string: | ||||
| 		return []string{v}, nil | ||||
| 	case []string: | ||||
| 		return v, nil | ||||
| 	case nil: | ||||
| 		return nil, nil | ||||
| 	case []interface{}: | ||||
| 		// convert []interface{} to []string | ||||
| 		ret := make([]string, len(v)) | ||||
| 		for idx, vv := range v { | ||||
| 			if vvv, ok := vv.(string); ok { | ||||
| 				ret[idx] = vvv | ||||
| 			} else { | ||||
| 				return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		return ret, nil | ||||
| 	default: | ||||
| 		return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // FluxMetrics is metrics of ImageResponse | ||||
| type FluxMetrics struct { | ||||
| 	ImageCount  int     `json:"image_count"` | ||||
| 	PredictTime float64 `json:"predict_time"` | ||||
| 	TotalTime   float64 `json:"total_time"` | ||||
| } | ||||
|  | ||||
| // FluxURLs is urls of ImageResponse | ||||
| type FluxURLs struct { | ||||
| 	Get    string `json:"get"` | ||||
| 	Cancel string `json:"cancel"` | ||||
| } | ||||
|  | ||||
| type ReplicateChatRequest struct { | ||||
| 	Input ChatInput `json:"input" form:"input" binding:"required"` | ||||
| } | ||||
|  | ||||
| // ChatInput is input of ChatByReplicateRequest | ||||
| // | ||||
| // https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema | ||||
| type ChatInput struct { | ||||
| 	TopK             int     `json:"top_k"` | ||||
| 	TopP             float64 `json:"top_p"` | ||||
| 	Prompt           string  `json:"prompt"` | ||||
| 	MaxTokens        int     `json:"max_tokens"` | ||||
| 	MinTokens        int     `json:"min_tokens"` | ||||
| 	Temperature      float64 `json:"temperature"` | ||||
| 	SystemPrompt     string  `json:"system_prompt"` | ||||
| 	StopSequences    string  `json:"stop_sequences"` | ||||
| 	PromptTemplate   string  `json:"prompt_template"` | ||||
| 	PresencePenalty  float64 `json:"presence_penalty"` | ||||
| 	FrequencyPenalty float64 `json:"frequency_penalty"` | ||||
| } | ||||
|  | ||||
| // ChatResponse is response of ChatByReplicateRequest | ||||
| // | ||||
| // https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json | ||||
| type ChatResponse struct { | ||||
| 	CompletedAt time.Time   `json:"completed_at"` | ||||
| 	CreatedAt   time.Time   `json:"created_at"` | ||||
| 	DataRemoved bool        `json:"data_removed"` | ||||
| 	Error       string      `json:"error"` | ||||
| 	ID          string      `json:"id"` | ||||
| 	Input       ChatInput   `json:"input"` | ||||
| 	Logs        string      `json:"logs"` | ||||
| 	Metrics     FluxMetrics `json:"metrics"` | ||||
| 	// Output could be `string` or `[]string` | ||||
| 	Output    []string        `json:"output"` | ||||
| 	StartedAt time.Time       `json:"started_at"` | ||||
| 	Status    string          `json:"status"` | ||||
| 	URLs      ChatResponseUrl `json:"urls"` | ||||
| 	Version   string          `json:"version"` | ||||
| } | ||||
|  | ||||
| // ChatResponseUrl is task urls of ChatResponse | ||||
| type ChatResponseUrl struct { | ||||
| 	Stream string `json:"stream"` | ||||
| 	Get    string `json:"get"` | ||||
| 	Cancel string `json:"cancel"` | ||||
| } | ||||
							
								
								
									
										36
									
								
								relay/adaptor/siliconflow/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								relay/adaptor/siliconflow/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,36 @@ | ||||
| package siliconflow | ||||
|  | ||||
| // https://docs.siliconflow.cn/docs/getting-started | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"deepseek-ai/deepseek-llm-67b-chat", | ||||
| 	"Qwen/Qwen1.5-14B-Chat", | ||||
| 	"Qwen/Qwen1.5-7B-Chat", | ||||
| 	"Qwen/Qwen1.5-110B-Chat", | ||||
| 	"Qwen/Qwen1.5-32B-Chat", | ||||
| 	"01-ai/Yi-1.5-6B-Chat", | ||||
| 	"01-ai/Yi-1.5-9B-Chat-16K", | ||||
| 	"01-ai/Yi-1.5-34B-Chat-16K", | ||||
| 	"THUDM/chatglm3-6b", | ||||
| 	"deepseek-ai/DeepSeek-V2-Chat", | ||||
| 	"THUDM/glm-4-9b-chat", | ||||
| 	"Qwen/Qwen2-72B-Instruct", | ||||
| 	"Qwen/Qwen2-7B-Instruct", | ||||
| 	"Qwen/Qwen2-57B-A14B-Instruct", | ||||
| 	"deepseek-ai/DeepSeek-Coder-V2-Instruct", | ||||
| 	"Qwen/Qwen2-1.5B-Instruct", | ||||
| 	"internlm/internlm2_5-7b-chat", | ||||
| 	"BAAI/bge-large-en-v1.5", | ||||
| 	"BAAI/bge-large-zh-v1.5", | ||||
| 	"Pro/Qwen/Qwen2-7B-Instruct", | ||||
| 	"Pro/Qwen/Qwen2-1.5B-Instruct", | ||||
| 	"Pro/Qwen/Qwen1.5-7B-Chat", | ||||
| 	"Pro/THUDM/glm-4-9b-chat", | ||||
| 	"Pro/THUDM/chatglm3-6b", | ||||
| 	"Pro/01-ai/Yi-1.5-9B-Chat-16K", | ||||
| 	"Pro/01-ai/Yi-1.5-6B-Chat", | ||||
| 	"Pro/google/gemma-2-9b-it", | ||||
| 	"Pro/internlm/internlm2_5-7b-chat", | ||||
| 	"Pro/meta-llama/Meta-Llama-3-8B-Instruct", | ||||
| 	"Pro/mistralai/Mistral-7B-Instruct-v0.2", | ||||
| } | ||||
| @@ -1,7 +1,13 @@ | ||||
| package stepfun | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"step-1-8k", | ||||
| 	"step-1-32k", | ||||
| 	"step-1-128k", | ||||
| 	"step-1-256k", | ||||
| 	"step-1-flash", | ||||
| 	"step-2-16k", | ||||
| 	"step-1v-8k", | ||||
| 	"step-1v-32k", | ||||
| 	"step-1-200k", | ||||
| 	"step-1x-medium", | ||||
| } | ||||
|   | ||||
| @@ -5,4 +5,5 @@ var ModelList = []string{ | ||||
| 	"hunyuan-standard", | ||||
| 	"hunyuan-standard-256K", | ||||
| 	"hunyuan-pro", | ||||
| 	"hunyuan-vision", | ||||
| } | ||||
|   | ||||
| @@ -39,8 +39,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 		Model:       &request.Model, | ||||
| 		Stream:      &request.Stream, | ||||
| 		Messages:    messages, | ||||
| 		TopP:        &request.TopP, | ||||
| 		Temperature: &request.Temperature, | ||||
| 		TopP:        request.TopP, | ||||
| 		Temperature: request.Temperature, | ||||
| 	} | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -13,7 +13,12 @@ import ( | ||||
| ) | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"claude-3-haiku@20240307", "claude-3-opus@20240229", "claude-3-5-sonnet@20240620", "claude-3-sonnet@20240229", | ||||
| 	"claude-3-haiku@20240307", | ||||
| 	"claude-3-sonnet@20240229", | ||||
| 	"claude-3-opus@20240229", | ||||
| 	"claude-3-5-sonnet@20240620", | ||||
| 	"claude-3-5-sonnet-v2@20241022", | ||||
| 	"claude-3-5-haiku@20241022", | ||||
| } | ||||
|  | ||||
| const anthropicVersion = "vertex-2023-10-16" | ||||
|   | ||||
| @@ -11,8 +11,8 @@ type Request struct { | ||||
| 	MaxTokens     int                 `json:"max_tokens,omitempty"` | ||||
| 	StopSequences []string            `json:"stop_sequences,omitempty"` | ||||
| 	Stream        bool                `json:"stream,omitempty"` | ||||
| 	Temperature   float64             `json:"temperature,omitempty"` | ||||
| 	TopP          float64             `json:"top_p,omitempty"` | ||||
| 	Temperature   *float64            `json:"temperature,omitempty"` | ||||
| 	TopP          *float64            `json:"top_p,omitempty"` | ||||
| 	TopK          int                 `json:"top_k,omitempty"` | ||||
| 	Tools         []anthropic.Tool    `json:"tools,omitempty"` | ||||
| 	ToolChoice    any                 `json:"tool_choice,omitempty"` | ||||
|   | ||||
| @@ -15,7 +15,10 @@ import ( | ||||
| ) | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", | ||||
| 	"gemini-pro", "gemini-pro-vision", | ||||
| 	"gemini-1.5-pro-001", "gemini-1.5-flash-001", | ||||
| 	"gemini-1.5-pro-002", "gemini-1.5-flash-002", | ||||
| 	"gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", | ||||
| } | ||||
|  | ||||
| type Adaptor struct { | ||||
|   | ||||
							
								
								
									
										5
									
								
								relay/adaptor/xai/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								relay/adaptor/xai/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| package xai | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"grok-beta", | ||||
| } | ||||
| @@ -5,6 +5,8 @@ var ModelList = []string{ | ||||
| 	"SparkDesk-v1.1", | ||||
| 	"SparkDesk-v2.1", | ||||
| 	"SparkDesk-v3.1", | ||||
| 	"SparkDesk-v3.1-128K", | ||||
| 	"SparkDesk-v3.5", | ||||
| 	"SparkDesk-v3.5-32K", | ||||
| 	"SparkDesk-v4.0", | ||||
| } | ||||
|   | ||||
| @@ -272,9 +272,9 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, | ||||
| } | ||||
|  | ||||
| func parseAPIVersionByModelName(modelName string) string { | ||||
| 	parts := strings.Split(modelName, "-") | ||||
| 	if len(parts) == 2 { | ||||
| 		return parts[1] | ||||
| 	index := strings.IndexAny(modelName, "-") | ||||
| 	if index != -1 { | ||||
| 		return modelName[index+1:] | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
| @@ -283,13 +283,17 @@ func parseAPIVersionByModelName(modelName string) string { | ||||
| func apiVersion2domain(apiVersion string) string { | ||||
| 	switch apiVersion { | ||||
| 	case "v1.1": | ||||
| 		return "general" | ||||
| 		return "lite" | ||||
| 	case "v2.1": | ||||
| 		return "generalv2" | ||||
| 	case "v3.1": | ||||
| 		return "generalv3" | ||||
| 	case "v3.1-128K": | ||||
| 		return "pro-128k" | ||||
| 	case "v3.5": | ||||
| 		return "generalv3.5" | ||||
| 	case "v3.5-32K": | ||||
| 		return "max-32k" | ||||
| 	case "v4.0": | ||||
| 		return "4.0Ultra" | ||||
| 	} | ||||
| @@ -297,7 +301,17 @@ func apiVersion2domain(apiVersion string) string { | ||||
| } | ||||
|  | ||||
| func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) { | ||||
| 	var authUrl string | ||||
| 	domain := apiVersion2domain(apiVersion) | ||||
| 	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) | ||||
| 	switch apiVersion { | ||||
| 	case "v3.1-128K": | ||||
| 		authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/pro-128k"), apiKey, apiSecret) | ||||
| 		break | ||||
| 	case "v3.5-32K": | ||||
| 		authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/max-32k"), apiKey, apiSecret) | ||||
| 		break | ||||
| 	default: | ||||
| 		authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) | ||||
| 	} | ||||
| 	return domain, authUrl | ||||
| } | ||||
|   | ||||
| @@ -20,7 +20,7 @@ type ChatRequest struct { | ||||
| 	Parameter struct { | ||||
| 		Chat struct { | ||||
| 			Domain      string   `json:"domain,omitempty"` | ||||
| 			Temperature float64 `json:"temperature,omitempty"` | ||||
| 			Temperature *float64 `json:"temperature,omitempty"` | ||||
| 			TopK        int      `json:"top_k,omitempty"` | ||||
| 			MaxTokens   int      `json:"max_tokens,omitempty"` | ||||
| 			Auditing    bool     `json:"auditing,omitempty"` | ||||
|   | ||||
| @@ -4,13 +4,13 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| 	"io" | ||||
| 	"math" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
| @@ -65,13 +65,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | ||||
| 		baiduEmbeddingRequest, err := ConvertEmbeddingRequest(*request) | ||||
| 		return baiduEmbeddingRequest, err | ||||
| 	default: | ||||
| 		// TopP (0.0, 1.0) | ||||
| 		request.TopP = math.Min(0.99, request.TopP) | ||||
| 		request.TopP = math.Max(0.01, request.TopP) | ||||
| 		// TopP [0.0, 1.0] | ||||
| 		request.TopP = helper.Float64PtrMax(request.TopP, 1) | ||||
| 		request.TopP = helper.Float64PtrMin(request.TopP, 0) | ||||
|  | ||||
| 		// Temperature (0.0, 1.0) | ||||
| 		request.Temperature = math.Min(0.99, request.Temperature) | ||||
| 		request.Temperature = math.Max(0.01, request.Temperature) | ||||
| 		// Temperature [0.0, 1.0] | ||||
| 		request.Temperature = helper.Float64PtrMax(request.Temperature, 1) | ||||
| 		request.Temperature = helper.Float64PtrMin(request.Temperature, 0) | ||||
| 		a.SetVersionByModeName(request.Model) | ||||
| 		if a.APIVersion == "v4" { | ||||
| 			return request, nil | ||||
|   | ||||
| @@ -12,8 +12,8 @@ type Message struct { | ||||
|  | ||||
| type Request struct { | ||||
| 	Prompt      []Message `json:"prompt"` | ||||
| 	Temperature float64   `json:"temperature,omitempty"` | ||||
| 	TopP        float64   `json:"top_p,omitempty"` | ||||
| 	Temperature *float64  `json:"temperature,omitempty"` | ||||
| 	TopP        *float64  `json:"top_p,omitempty"` | ||||
| 	RequestId   string    `json:"request_id,omitempty"` | ||||
| 	Incremental bool      `json:"incremental,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -19,6 +19,7 @@ const ( | ||||
| 	DeepL | ||||
| 	VertexAI | ||||
| 	Proxy | ||||
| 	Replicate | ||||
|  | ||||
| 	Dummy // this one is only for count, do not add any channel after this | ||||
| ) | ||||
|   | ||||
| @@ -30,6 +30,14 @@ var ImageSizeRatios = map[string]map[string]float64{ | ||||
| 		"720x1280":  1, | ||||
| 		"1280x720":  1, | ||||
| 	}, | ||||
| 	"step-1x-medium": { | ||||
| 		"256x256":   1, | ||||
| 		"512x512":   1, | ||||
| 		"768x768":   1, | ||||
| 		"1024x1024": 1, | ||||
| 		"1280x800":  1, | ||||
| 		"800x1280":  1, | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| var ImageGenerationAmounts = map[string][2]int{ | ||||
| @@ -39,6 +47,7 @@ var ImageGenerationAmounts = map[string][2]int{ | ||||
| 	"ali-stable-diffusion-v1.5": {1, 4}, // Ali | ||||
| 	"wanx-v1":                   {1, 4}, // Ali | ||||
| 	"cogview-3":                 {1, 1}, | ||||
| 	"step-1x-medium":            {1, 1}, | ||||
| } | ||||
|  | ||||
| var ImagePromptLengthLimitations = map[string]int{ | ||||
| @@ -48,6 +57,7 @@ var ImagePromptLengthLimitations = map[string]int{ | ||||
| 	"ali-stable-diffusion-v1.5": 4000, | ||||
| 	"wanx-v1":                   4000, | ||||
| 	"cogview-3":                 833, | ||||
| 	"step-1x-medium":            4000, | ||||
| } | ||||
|  | ||||
| var ImageOriginModelName = map[string]string{ | ||||
|   | ||||
| @@ -34,7 +34,10 @@ var ModelRatio = map[string]float64{ | ||||
| 	"gpt-4-turbo":             5,     // $0.01 / 1K tokens | ||||
| 	"gpt-4-turbo-2024-04-09":  5,     // $0.01 / 1K tokens | ||||
| 	"gpt-4o":                  2.5,   // $0.005 / 1K tokens | ||||
| 	"chatgpt-4o-latest":       2.5,   // $0.005 / 1K tokens | ||||
| 	"gpt-4o-2024-05-13":       2.5,   // $0.005 / 1K tokens | ||||
| 	"gpt-4o-2024-08-06":       1.25,  // $0.0025 / 1K tokens | ||||
| 	"gpt-4o-2024-11-20":       1.25,  // $0.0025 / 1K tokens | ||||
| 	"gpt-4o-mini":             0.075, // $0.00015 / 1K tokens | ||||
| 	"gpt-4o-mini-2024-07-18":  0.075, // $0.00015 / 1K tokens | ||||
| 	"gpt-4-vision-preview":    5,     // $0.01 / 1K tokens | ||||
| @@ -46,6 +49,12 @@ var ModelRatio = map[string]float64{ | ||||
| 	"gpt-3.5-turbo-instruct":  0.75, // $0.0015 / 1K tokens | ||||
| 	"gpt-3.5-turbo-1106":      0.5,  // $0.001 / 1K tokens | ||||
| 	"gpt-3.5-turbo-0125":      0.25, // $0.0005 / 1K tokens | ||||
| 	"o1":                      7.5,  // $15.00 / 1M input tokens | ||||
| 	"o1-2024-12-17":           7.5, | ||||
| 	"o1-preview":              7.5, // $15.00 / 1M input tokens | ||||
| 	"o1-preview-2024-09-12":   7.5, | ||||
| 	"o1-mini":                 1.5, // $3.00 / 1M input tokens | ||||
| 	"o1-mini-2024-09-12":      1.5, | ||||
| 	"davinci-002":             1,   // $0.002 / 1K tokens | ||||
| 	"babbage-002":             0.2, // $0.0004 / 1K tokens | ||||
| 	"text-ada-001":            0.2, | ||||
| @@ -77,8 +86,10 @@ var ModelRatio = map[string]float64{ | ||||
| 	"claude-2.0":                 8.0 / 1000 * USD, | ||||
| 	"claude-2.1":                 8.0 / 1000 * USD, | ||||
| 	"claude-3-haiku-20240307":    0.25 / 1000 * USD, | ||||
| 	"claude-3-5-haiku-20241022":  1.0 / 1000 * USD, | ||||
| 	"claude-3-sonnet-20240229":   3.0 / 1000 * USD, | ||||
| 	"claude-3-5-sonnet-20240620": 3.0 / 1000 * USD, | ||||
| 	"claude-3-5-sonnet-20241022": 3.0 / 1000 * USD, | ||||
| 	"claude-3-opus-20240229":     15.0 / 1000 * USD, | ||||
| 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 | ||||
| 	"ERNIE-4.0-8K":       0.120 * RMB, | ||||
| @@ -98,12 +109,15 @@ var ModelRatio = map[string]float64{ | ||||
| 	"bge-large-en":       0.002 * RMB, | ||||
| 	"tao-8k":             0.002 * RMB, | ||||
| 	// https://ai.google.dev/pricing | ||||
| 	"PaLM-2":                    1, | ||||
| 	"gemini-pro":                    1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||
| 	"gemini-pro-vision":         1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||
| 	"gemini-1.0-pro-vision-001": 1, | ||||
| 	"gemini-1.0-pro-001":        1, | ||||
| 	"gemini-1.0-pro":                1, | ||||
| 	"gemini-1.5-pro":                1, | ||||
| 	"gemini-1.5-pro-001":            1, | ||||
| 	"gemini-1.5-flash":              1, | ||||
| 	"gemini-1.5-flash-001":          1, | ||||
| 	"gemini-2.0-flash-exp":          1, | ||||
| 	"gemini-2.0-flash-thinking-exp": 1, | ||||
| 	"aqa":                           1, | ||||
| 	// https://open.bigmodel.cn/pricing | ||||
| 	"glm-4":         0.1 * RMB, | ||||
| 	"glm-4v":        0.1 * RMB, | ||||
| @@ -115,19 +129,86 @@ var ModelRatio = map[string]float64{ | ||||
| 	"chatglm_lite":  0.1429, // ¥0.002 / 1k tokens | ||||
| 	"cogview-3":     0.25 * RMB, | ||||
| 	// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing | ||||
| 	"qwen-turbo":                0.5715, // ¥0.008 / 1k tokens | ||||
| 	"qwen-plus":                 1.4286, // ¥0.02 / 1k tokens | ||||
| 	"qwen-max":                  1.4286, // ¥0.02 / 1k tokens | ||||
| 	"qwen-max-longcontext":      1.4286, // ¥0.02 / 1k tokens | ||||
| 	"qwen-turbo":                  1.4286, // ¥0.02 / 1k tokens | ||||
| 	"qwen-turbo-latest":           1.4286, | ||||
| 	"qwen-plus":                   1.4286, | ||||
| 	"qwen-plus-latest":            1.4286, | ||||
| 	"qwen-max":                    1.4286, | ||||
| 	"qwen-max-latest":             1.4286, | ||||
| 	"qwen-max-longcontext":        1.4286, | ||||
| 	"qwen-vl-max":                 1.4286, | ||||
| 	"qwen-vl-max-latest":          1.4286, | ||||
| 	"qwen-vl-plus":                1.4286, | ||||
| 	"qwen-vl-plus-latest":         1.4286, | ||||
| 	"qwen-vl-ocr":                 1.4286, | ||||
| 	"qwen-vl-ocr-latest":          1.4286, | ||||
| 	"qwen-audio-turbo":            1.4286, | ||||
| 	"qwen-math-plus":              1.4286, | ||||
| 	"qwen-math-plus-latest":       1.4286, | ||||
| 	"qwen-math-turbo":             1.4286, | ||||
| 	"qwen-math-turbo-latest":      1.4286, | ||||
| 	"qwen-coder-plus":             1.4286, | ||||
| 	"qwen-coder-plus-latest":      1.4286, | ||||
| 	"qwen-coder-turbo":            1.4286, | ||||
| 	"qwen-coder-turbo-latest":     1.4286, | ||||
| 	"qwq-32b-preview":             1.4286, | ||||
| 	"qwen2.5-72b-instruct":        1.4286, | ||||
| 	"qwen2.5-32b-instruct":        1.4286, | ||||
| 	"qwen2.5-14b-instruct":        1.4286, | ||||
| 	"qwen2.5-7b-instruct":         1.4286, | ||||
| 	"qwen2.5-3b-instruct":         1.4286, | ||||
| 	"qwen2.5-1.5b-instruct":       1.4286, | ||||
| 	"qwen2.5-0.5b-instruct":       1.4286, | ||||
| 	"qwen2-72b-instruct":          1.4286, | ||||
| 	"qwen2-57b-a14b-instruct":     1.4286, | ||||
| 	"qwen2-7b-instruct":           1.4286, | ||||
| 	"qwen2-1.5b-instruct":         1.4286, | ||||
| 	"qwen2-0.5b-instruct":         1.4286, | ||||
| 	"qwen1.5-110b-chat":           1.4286, | ||||
| 	"qwen1.5-72b-chat":            1.4286, | ||||
| 	"qwen1.5-32b-chat":            1.4286, | ||||
| 	"qwen1.5-14b-chat":            1.4286, | ||||
| 	"qwen1.5-7b-chat":             1.4286, | ||||
| 	"qwen1.5-1.8b-chat":           1.4286, | ||||
| 	"qwen1.5-0.5b-chat":           1.4286, | ||||
| 	"qwen-72b-chat":               1.4286, | ||||
| 	"qwen-14b-chat":               1.4286, | ||||
| 	"qwen-7b-chat":                1.4286, | ||||
| 	"qwen-1.8b-chat":              1.4286, | ||||
| 	"qwen-1.8b-longcontext-chat":  1.4286, | ||||
| 	"qwen2-vl-7b-instruct":        1.4286, | ||||
| 	"qwen2-vl-2b-instruct":        1.4286, | ||||
| 	"qwen-vl-v1":                  1.4286, | ||||
| 	"qwen-vl-chat-v1":             1.4286, | ||||
| 	"qwen2-audio-instruct":        1.4286, | ||||
| 	"qwen-audio-chat":             1.4286, | ||||
| 	"qwen2.5-math-72b-instruct":   1.4286, | ||||
| 	"qwen2.5-math-7b-instruct":    1.4286, | ||||
| 	"qwen2.5-math-1.5b-instruct":  1.4286, | ||||
| 	"qwen2-math-72b-instruct":     1.4286, | ||||
| 	"qwen2-math-7b-instruct":      1.4286, | ||||
| 	"qwen2-math-1.5b-instruct":    1.4286, | ||||
| 	"qwen2.5-coder-32b-instruct":  1.4286, | ||||
| 	"qwen2.5-coder-14b-instruct":  1.4286, | ||||
| 	"qwen2.5-coder-7b-instruct":   1.4286, | ||||
| 	"qwen2.5-coder-3b-instruct":   1.4286, | ||||
| 	"qwen2.5-coder-1.5b-instruct": 1.4286, | ||||
| 	"qwen2.5-coder-0.5b-instruct": 1.4286, | ||||
| 	"text-embedding-v1":           0.05, // ¥0.0007 / 1k tokens | ||||
| 	"ali-stable-diffusion-xl":   8, | ||||
| 	"ali-stable-diffusion-v1.5": 8, | ||||
| 	"wanx-v1":                   8, | ||||
| 	"text-embedding-v3":           0.05, | ||||
| 	"text-embedding-v2":           0.05, | ||||
| 	"text-embedding-async-v2":     0.05, | ||||
| 	"text-embedding-async-v1":     0.05, | ||||
| 	"ali-stable-diffusion-xl":     8.00, | ||||
| 	"ali-stable-diffusion-v1.5":   8.00, | ||||
| 	"wanx-v1":                     8.00, | ||||
| 	"SparkDesk":                   1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v1.1":              1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v2.1":              1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v3.1":              1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v3.1-128K":         1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v3.5":              1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v3.5-32K":          1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v4.0":              1.2858, // ¥0.018 / 1k tokens | ||||
| 	"360GPT_S2_V9":                0.8572, // ¥0.012 / 1k tokens | ||||
| 	"embedding-bert-512-v1":       0.0715, // ¥0.001 / 1k tokens | ||||
| @@ -158,20 +239,35 @@ var ModelRatio = map[string]float64{ | ||||
| 	"mistral-large-latest":  8.0 / 1000 * USD, | ||||
| 	"mistral-embed":         0.1 / 1000 * USD, | ||||
| 	// https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed | ||||
| 	"llama3-70b-8192":    0.59 / 1000 * USD, | ||||
| 	"mixtral-8x7b-32768": 0.27 / 1000 * USD, | ||||
| 	"llama3-8b-8192":     0.05 / 1000 * USD, | ||||
| 	"gemma-7b-it":        0.1 / 1000 * USD, | ||||
| 	"llama2-70b-4096":    0.64 / 1000 * USD, | ||||
| 	"llama2-7b-2048":     0.1 / 1000 * USD, | ||||
| 	"gemma-7b-it":                           0.07 / 1000000 * USD, | ||||
| 	"gemma2-9b-it":                          0.20 / 1000000 * USD, | ||||
| 	"llama-3.1-70b-versatile":               0.59 / 1000000 * USD, | ||||
| 	"llama-3.1-8b-instant":                  0.05 / 1000000 * USD, | ||||
| 	"llama-3.2-11b-text-preview":            0.05 / 1000000 * USD, | ||||
| 	"llama-3.2-11b-vision-preview":          0.05 / 1000000 * USD, | ||||
| 	"llama-3.2-1b-preview":                  0.05 / 1000000 * USD, | ||||
| 	"llama-3.2-3b-preview":                  0.05 / 1000000 * USD, | ||||
| 	"llama-3.2-90b-text-preview":            0.59 / 1000000 * USD, | ||||
| 	"llama-guard-3-8b":                      0.05 / 1000000 * USD, | ||||
| 	"llama3-70b-8192":                       0.59 / 1000000 * USD, | ||||
| 	"llama3-8b-8192":                        0.05 / 1000000 * USD, | ||||
| 	"llama3-groq-70b-8192-tool-use-preview": 0.89 / 1000000 * USD, | ||||
| 	"llama3-groq-8b-8192-tool-use-preview":  0.19 / 1000000 * USD, | ||||
| 	"mixtral-8x7b-32768":                    0.24 / 1000000 * USD, | ||||
|  | ||||
| 	// https://platform.lingyiwanwu.com/docs#-计费单元 | ||||
| 	"yi-34b-chat-0205": 2.5 / 1000 * RMB, | ||||
| 	"yi-34b-chat-200k": 12.0 / 1000 * RMB, | ||||
| 	"yi-vl-plus":       6.0 / 1000 * RMB, | ||||
| 	// stepfun todo | ||||
| 	"step-1v-32k": 0.024 * RMB, | ||||
| 	"step-1-32k":  0.024 * RMB, | ||||
| 	"step-1-200k": 0.15 * RMB, | ||||
| 	// https://platform.stepfun.com/docs/pricing/details | ||||
| 	"step-1-8k":    0.005 / 1000 * RMB, | ||||
| 	"step-1-32k":   0.015 / 1000 * RMB, | ||||
| 	"step-1-128k":  0.040 / 1000 * RMB, | ||||
| 	"step-1-256k":  0.095 / 1000 * RMB, | ||||
| 	"step-1-flash": 0.001 / 1000 * RMB, | ||||
| 	"step-2-16k":   0.038 / 1000 * RMB, | ||||
| 	"step-1v-8k":   0.005 / 1000 * RMB, | ||||
| 	"step-1v-32k":  0.015 / 1000 * RMB, | ||||
| 	// aws llama3 https://aws.amazon.com/cn/bedrock/pricing/ | ||||
| 	"llama3-8b-8192(33)":  0.0003 / 0.002,  // $0.0003 / 1K tokens | ||||
| 	"llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens | ||||
| @@ -189,6 +285,52 @@ var ModelRatio = map[string]float64{ | ||||
| 	"deepl-zh": 25.0 / 1000 * USD, | ||||
| 	"deepl-en": 25.0 / 1000 * USD, | ||||
| 	"deepl-ja": 25.0 / 1000 * USD, | ||||
| 	// https://console.x.ai/ | ||||
| 	"grok-beta": 5.0 / 1000 * USD, | ||||
| 	// replicate charges based on the number of generated images | ||||
| 	// https://replicate.com/pricing | ||||
| 	"black-forest-labs/flux-1.1-pro":                0.04 * USD, | ||||
| 	"black-forest-labs/flux-1.1-pro-ultra":          0.06 * USD, | ||||
| 	"black-forest-labs/flux-canny-dev":              0.025 * USD, | ||||
| 	"black-forest-labs/flux-canny-pro":              0.05 * USD, | ||||
| 	"black-forest-labs/flux-depth-dev":              0.025 * USD, | ||||
| 	"black-forest-labs/flux-depth-pro":              0.05 * USD, | ||||
| 	"black-forest-labs/flux-dev":                    0.025 * USD, | ||||
| 	"black-forest-labs/flux-dev-lora":               0.032 * USD, | ||||
| 	"black-forest-labs/flux-fill-dev":               0.04 * USD, | ||||
| 	"black-forest-labs/flux-fill-pro":               0.05 * USD, | ||||
| 	"black-forest-labs/flux-pro":                    0.055 * USD, | ||||
| 	"black-forest-labs/flux-redux-dev":              0.025 * USD, | ||||
| 	"black-forest-labs/flux-redux-schnell":          0.003 * USD, | ||||
| 	"black-forest-labs/flux-schnell":                0.003 * USD, | ||||
| 	"black-forest-labs/flux-schnell-lora":           0.02 * USD, | ||||
| 	"ideogram-ai/ideogram-v2":                       0.08 * USD, | ||||
| 	"ideogram-ai/ideogram-v2-turbo":                 0.05 * USD, | ||||
| 	"recraft-ai/recraft-v3":                         0.04 * USD, | ||||
| 	"recraft-ai/recraft-v3-svg":                     0.08 * USD, | ||||
| 	"stability-ai/stable-diffusion-3":               0.035 * USD, | ||||
| 	"stability-ai/stable-diffusion-3.5-large":       0.065 * USD, | ||||
| 	"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD, | ||||
| 	"stability-ai/stable-diffusion-3.5-medium":      0.035 * USD, | ||||
| 	// replicate chat models | ||||
| 	"ibm-granite/granite-20b-code-instruct-8k":  0.100 * USD, | ||||
| 	"ibm-granite/granite-3.0-2b-instruct":       0.030 * USD, | ||||
| 	"ibm-granite/granite-3.0-8b-instruct":       0.050 * USD, | ||||
| 	"ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD, | ||||
| 	"meta/llama-2-13b":                          0.100 * USD, | ||||
| 	"meta/llama-2-13b-chat":                     0.100 * USD, | ||||
| 	"meta/llama-2-70b":                          0.650 * USD, | ||||
| 	"meta/llama-2-70b-chat":                     0.650 * USD, | ||||
| 	"meta/llama-2-7b":                           0.050 * USD, | ||||
| 	"meta/llama-2-7b-chat":                      0.050 * USD, | ||||
| 	"meta/meta-llama-3.1-405b-instruct":         9.500 * USD, | ||||
| 	"meta/meta-llama-3-70b":                     0.650 * USD, | ||||
| 	"meta/meta-llama-3-70b-instruct":            0.650 * USD, | ||||
| 	"meta/meta-llama-3-8b":                      0.050 * USD, | ||||
| 	"meta/meta-llama-3-8b-instruct":             0.050 * USD, | ||||
| 	"mistralai/mistral-7b-instruct-v0.2":        0.050 * USD, | ||||
| 	"mistralai/mistral-7b-v0.1":                 0.050 * USD, | ||||
| 	"mistralai/mixtral-8x7b-instruct-v0.1":      0.300 * USD, | ||||
| } | ||||
|  | ||||
| var CompletionRatio = map[string]float64{ | ||||
| @@ -197,8 +339,10 @@ var CompletionRatio = map[string]float64{ | ||||
| 	"llama3-70b-8192(33)": 0.0035 / 0.00265, | ||||
| } | ||||
|  | ||||
| var DefaultModelRatio map[string]float64 | ||||
| var DefaultCompletionRatio map[string]float64 | ||||
| var ( | ||||
| 	DefaultModelRatio      map[string]float64 | ||||
| 	DefaultCompletionRatio map[string]float64 | ||||
| ) | ||||
|  | ||||
| func init() { | ||||
| 	DefaultModelRatio = make(map[string]float64) | ||||
| @@ -310,16 +454,25 @@ func GetCompletionRatio(name string, channelType int) float64 { | ||||
| 		return 4.0 / 3.0 | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "gpt-4") { | ||||
| 		if strings.HasPrefix(name, "gpt-4o-mini") { | ||||
| 		if strings.HasPrefix(name, "gpt-4o") { | ||||
| 			if name == "gpt-4o-2024-05-13" { | ||||
| 				return 3 | ||||
| 			} | ||||
| 			return 4 | ||||
| 		} | ||||
| 		if strings.HasPrefix(name, "gpt-4-turbo") || | ||||
| 			strings.HasPrefix(name, "gpt-4o") || | ||||
| 			strings.HasSuffix(name, "preview") { | ||||
| 			return 3 | ||||
| 		} | ||||
| 		return 2 | ||||
| 	} | ||||
| 	// including o1, o1-preview, o1-mini | ||||
| 	if strings.HasPrefix(name, "o1") { | ||||
| 		return 4 | ||||
| 	} | ||||
| 	if name == "chatgpt-4o-latest" { | ||||
| 		return 3 | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "claude-3") { | ||||
| 		return 5 | ||||
| 	} | ||||
| @@ -335,6 +488,7 @@ func GetCompletionRatio(name string, channelType int) float64 { | ||||
| 	if strings.HasPrefix(name, "deepseek-") { | ||||
| 		return 2 | ||||
| 	} | ||||
|  | ||||
| 	switch name { | ||||
| 	case "llama2-70b-4096": | ||||
| 		return 0.8 / 0.64 | ||||
| @@ -348,6 +502,37 @@ func GetCompletionRatio(name string, channelType int) float64 { | ||||
| 		return 3 | ||||
| 	case "command-r-plus": | ||||
| 		return 5 | ||||
| 	case "grok-beta": | ||||
| 		return 3 | ||||
| 	// Replicate Models | ||||
| 	// https://replicate.com/pricing | ||||
| 	case "ibm-granite/granite-20b-code-instruct-8k": | ||||
| 		return 5 | ||||
| 	case "ibm-granite/granite-3.0-2b-instruct": | ||||
| 		return 8.333333333333334 | ||||
| 	case "ibm-granite/granite-3.0-8b-instruct", | ||||
| 		"ibm-granite/granite-8b-code-instruct-128k": | ||||
| 		return 5 | ||||
| 	case "meta/llama-2-13b", | ||||
| 		"meta/llama-2-13b-chat", | ||||
| 		"meta/llama-2-7b", | ||||
| 		"meta/llama-2-7b-chat", | ||||
| 		"meta/meta-llama-3-8b", | ||||
| 		"meta/meta-llama-3-8b-instruct": | ||||
| 		return 5 | ||||
| 	case "meta/llama-2-70b", | ||||
| 		"meta/llama-2-70b-chat", | ||||
| 		"meta/meta-llama-3-70b", | ||||
| 		"meta/meta-llama-3-70b-instruct": | ||||
| 		return 2.750 / 0.650 // ≈4.230769 | ||||
| 	case "meta/meta-llama-3.1-405b-instruct": | ||||
| 		return 1 | ||||
| 	case "mistralai/mistral-7b-instruct-v0.2", | ||||
| 		"mistralai/mistral-7b-v0.1": | ||||
| 		return 5 | ||||
| 	case "mistralai/mixtral-8x7b-instruct-v0.1": | ||||
| 		return 1.000 / 0.300 // ≈3.333333 | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|   | ||||
| @@ -45,5 +45,8 @@ const ( | ||||
| 	Novita | ||||
| 	VertextAI | ||||
| 	Proxy | ||||
| 	SiliconFlow | ||||
| 	XAI | ||||
| 	Replicate | ||||
| 	Dummy | ||||
| ) | ||||
|   | ||||
| @@ -37,6 +37,8 @@ func ToAPIType(channelType int) int { | ||||
| 		apiType = apitype.DeepL | ||||
| 	case VertextAI: | ||||
| 		apiType = apitype.VertexAI | ||||
| 	case Replicate: | ||||
| 		apiType = apitype.Replicate | ||||
| 	case Proxy: | ||||
| 		apiType = apitype.Proxy | ||||
| 	} | ||||
|   | ||||
| @@ -45,6 +45,9 @@ var ChannelBaseURLs = []string{ | ||||
| 	"https://api.novita.ai/v3/openai",           // 41 | ||||
| 	"",                                          // 42 | ||||
| 	"",                                          // 43 | ||||
| 	"https://api.siliconflow.cn",                // 44 | ||||
| 	"https://api.x.ai",                          // 45 | ||||
| 	"https://api.replicate.com/v1/models/",      // 46 | ||||
| } | ||||
|  | ||||
| func init() { | ||||
|   | ||||
| @@ -1,5 +1,6 @@ | ||||
| package role | ||||
|  | ||||
| const ( | ||||
| 	System    = "system" | ||||
| 	Assistant = "assistant" | ||||
| ) | ||||
|   | ||||
| @@ -110,16 +110,9 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 	}() | ||||
|  | ||||
| 	// map model name | ||||
| 	modelMapping := c.GetString(ctxkey.ModelMapping) | ||||
| 	if modelMapping != "" { | ||||
| 		modelMap := make(map[string]string) | ||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||
| 		if err != nil { | ||||
| 			return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		if modelMap[audioModel] != "" { | ||||
| 			audioModel = modelMap[audioModel] | ||||
| 		} | ||||
| 	modelMapping := c.GetStringMapString(ctxkey.ModelMapping) | ||||
| 	if modelMapping != nil && modelMapping[audioModel] != "" { | ||||
| 		audioModel = modelMapping[audioModel] | ||||
| 	} | ||||
|  | ||||
| 	baseURL := channeltype.ChannelBaseURLs[channelType] | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant/role" | ||||
| 	"math" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| @@ -90,7 +91,7 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR | ||||
| 	return preConsumedQuota, nil | ||||
| } | ||||
|  | ||||
| func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { | ||||
| func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) { | ||||
| 	if usage == nil { | ||||
| 		logger.Error(ctx, "usage is nil, which is unexpected") | ||||
| 		return | ||||
| @@ -118,7 +119,11 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M | ||||
| 	if err != nil { | ||||
| 		logger.Error(ctx, "error update user quota cache: "+err.Error()) | ||||
| 	} | ||||
| 	logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) | ||||
| 	var extraLog string | ||||
| 	if systemPromptReset { | ||||
| 		extraLog = " (注意系统提示词已被重置)" | ||||
| 	} | ||||
| 	logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f%s", modelRatio, groupRatio, completionRatio, extraLog) | ||||
| 	model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) | ||||
| 	model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) | ||||
| 	model.UpdateChannelUsedQuota(meta.ChannelId, quota) | ||||
| @@ -142,15 +147,41 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool { | ||||
| 		} | ||||
| 		return true | ||||
| 	} | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 	if resp.StatusCode != http.StatusOK && | ||||
| 		// replicate return 201 to create a task | ||||
| 		resp.StatusCode != http.StatusCreated { | ||||
| 		return true | ||||
| 	} | ||||
| 	if meta.ChannelType == channeltype.DeepL { | ||||
| 		// skip stream check for deepl | ||||
| 		return false | ||||
| 	} | ||||
| 	if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { | ||||
|  | ||||
| 	if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") && | ||||
| 		// Even if stream mode is enabled, replicate will first return a task info in JSON format, | ||||
| 		// requiring the client to request the stream endpoint in the task info | ||||
| 		meta.ChannelType != channeltype.Replicate { | ||||
| 		return true | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIRequest, prompt string) (reset bool) { | ||||
| 	if prompt == "" { | ||||
| 		return false | ||||
| 	} | ||||
| 	if len(request.Messages) == 0 { | ||||
| 		return false | ||||
| 	} | ||||
| 	if request.Messages[0].Role == role.System { | ||||
| 		request.Messages[0].Content = prompt | ||||
| 		logger.Infof(ctx, "rewrite system prompt") | ||||
| 		return true | ||||
| 	} | ||||
| 	request.Messages = append([]relaymodel.Message{{ | ||||
| 		Role:    role.System, | ||||
| 		Content: prompt, | ||||
| 	}}, request.Messages...) | ||||
| 	logger.Infof(ctx, "add system prompt") | ||||
| 	return true | ||||
| } | ||||
|   | ||||
| @@ -22,7 +22,7 @@ import ( | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { | ||||
| func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) { | ||||
| 	imageRequest := &relaymodel.ImageRequest{} | ||||
| 	err := common.UnmarshalBodyReusable(c, imageRequest) | ||||
| 	if err != nil { | ||||
| @@ -65,7 +65,7 @@ func getImageSizeRatio(model string, size string) float64 { | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { | ||||
| func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode { | ||||
| 	// check prompt length | ||||
| 	if imageRequest.Prompt == "" { | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | ||||
| @@ -150,12 +150,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 	} | ||||
| 	adaptor.Init(meta) | ||||
|  | ||||
| 	// these adaptors need to convert the request | ||||
| 	switch meta.ChannelType { | ||||
| 	case channeltype.Ali: | ||||
| 		fallthrough | ||||
| 	case channeltype.Baidu: | ||||
| 		fallthrough | ||||
| 	case channeltype.Zhipu: | ||||
| 	case channeltype.Zhipu, | ||||
| 		channeltype.Ali, | ||||
| 		channeltype.Replicate, | ||||
| 		channeltype.Baidu: | ||||
| 		finalRequest, err := adaptor.ConvertImageRequest(imageRequest) | ||||
| 		if err != nil { | ||||
| 			return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) | ||||
| @@ -172,7 +172,14 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) | ||||
|  | ||||
| 	quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) | ||||
| 	var quota int64 | ||||
| 	switch meta.ChannelType { | ||||
| 	case channeltype.Replicate: | ||||
| 		// replicate always return 1 image | ||||
| 		quota = int64(ratio * imageCostRatio * 1000) | ||||
| 	default: | ||||
| 		quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) | ||||
| 	} | ||||
|  | ||||
| 	if userQuota-quota < 0 { | ||||
| 		return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||
| @@ -186,7 +193,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 	} | ||||
|  | ||||
| 	defer func(ctx context.Context) { | ||||
| 		if resp != nil && resp.StatusCode != http.StatusOK { | ||||
| 		if resp != nil && | ||||
| 			resp.StatusCode != http.StatusCreated && // replicate returns 201 | ||||
| 			resp.StatusCode != http.StatusOK { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
|  | ||||
| @@ -35,6 +36,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { | ||||
| 	meta.OriginModelName = textRequest.Model | ||||
| 	textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping) | ||||
| 	meta.ActualModelName = textRequest.Model | ||||
| 	// set system prompt if not empty | ||||
| 	systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt) | ||||
| 	// get model ratio & group ratio | ||||
| 	modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) | ||||
| 	groupRatio := billingratio.GetGroupRatio(meta.Group) | ||||
| @@ -79,12 +82,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { | ||||
| 		return respErr | ||||
| 	} | ||||
| 	// post-consume quota | ||||
| 	go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) | ||||
| 	go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) { | ||||
| 	if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan { | ||||
| 	if !config.EnforceIncludeUsage && meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan { | ||||
| 		// no need to convert request for openai | ||||
| 		return c.Request.Body, nil | ||||
| 	} | ||||
|   | ||||
| @@ -30,6 +30,7 @@ type Meta struct { | ||||
| 	ActualModelName string | ||||
| 	RequestURLPath  string | ||||
| 	PromptTokens    int // only for DoResponse | ||||
| 	SystemPrompt    string | ||||
| } | ||||
|  | ||||
| func GetByContext(c *gin.Context) *Meta { | ||||
| @@ -46,6 +47,7 @@ func GetByContext(c *gin.Context) *Meta { | ||||
| 		BaseURL:         c.GetString(ctxkey.BaseURL), | ||||
| 		APIKey:          strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), | ||||
| 		RequestURLPath:  c.Request.URL.String(), | ||||
| 		SystemPrompt:    c.GetString(ctxkey.SystemPrompt), | ||||
| 	} | ||||
| 	cfg, ok := c.Get(ctxkey.Config) | ||||
| 	if ok { | ||||
|   | ||||
| @@ -3,4 +3,5 @@ package model | ||||
| const ( | ||||
| 	ContentTypeText       = "text" | ||||
| 	ContentTypeImageURL   = "image_url" | ||||
| 	ContentTypeInputAudio = "input_audio" | ||||
| ) | ||||
|   | ||||
| @@ -2,33 +2,69 @@ package model | ||||
|  | ||||
| type ResponseFormat struct { | ||||
| 	Type       string      `json:"type,omitempty"` | ||||
| 	JsonSchema *JSONSchema `json:"json_schema,omitempty"` | ||||
| } | ||||
|  | ||||
| type JSONSchema struct { | ||||
| 	Description string                 `json:"description,omitempty"` | ||||
| 	Name        string                 `json:"name"` | ||||
| 	Schema      map[string]interface{} `json:"schema,omitempty"` | ||||
| 	Strict      *bool                  `json:"strict,omitempty"` | ||||
| } | ||||
|  | ||||
| type Audio struct { | ||||
| 	Voice  string `json:"voice,omitempty"` | ||||
| 	Format string `json:"format,omitempty"` | ||||
| } | ||||
|  | ||||
| type StreamOptions struct { | ||||
| 	IncludeUsage bool `json:"include_usage,omitempty"` | ||||
| } | ||||
|  | ||||
| type GeneralOpenAIRequest struct { | ||||
| 	// https://platform.openai.com/docs/api-reference/chat/create | ||||
| 	Messages            []Message       `json:"messages,omitempty"` | ||||
| 	Model               string          `json:"model,omitempty"` | ||||
| 	FrequencyPenalty float64         `json:"frequency_penalty,omitempty"` | ||||
| 	Store               *bool           `json:"store,omitempty"` | ||||
| 	Metadata            any             `json:"metadata,omitempty"` | ||||
| 	FrequencyPenalty    *float64        `json:"frequency_penalty,omitempty"` | ||||
| 	LogitBias           any             `json:"logit_bias,omitempty"` | ||||
| 	Logprobs            *bool           `json:"logprobs,omitempty"` | ||||
| 	TopLogprobs         *int            `json:"top_logprobs,omitempty"` | ||||
| 	MaxTokens           int             `json:"max_tokens,omitempty"` | ||||
| 	MaxCompletionTokens *int            `json:"max_completion_tokens,omitempty"` | ||||
| 	N                   int             `json:"n,omitempty"` | ||||
| 	PresencePenalty  float64         `json:"presence_penalty,omitempty"` | ||||
| 	Modalities          []string        `json:"modalities,omitempty"` | ||||
| 	Prediction          any             `json:"prediction,omitempty"` | ||||
| 	Audio               *Audio          `json:"audio,omitempty"` | ||||
| 	PresencePenalty     *float64        `json:"presence_penalty,omitempty"` | ||||
| 	ResponseFormat      *ResponseFormat `json:"response_format,omitempty"` | ||||
| 	Seed                float64         `json:"seed,omitempty"` | ||||
| 	ServiceTier         *string         `json:"service_tier,omitempty"` | ||||
| 	Stop                any             `json:"stop,omitempty"` | ||||
| 	Stream              bool            `json:"stream,omitempty"` | ||||
| 	Temperature      float64         `json:"temperature,omitempty"` | ||||
| 	TopP             float64         `json:"top_p,omitempty"` | ||||
| 	StreamOptions       *StreamOptions  `json:"stream_options,omitempty"` | ||||
| 	Temperature         *float64        `json:"temperature,omitempty"` | ||||
| 	TopP                *float64        `json:"top_p,omitempty"` | ||||
| 	TopK                int             `json:"top_k,omitempty"` | ||||
| 	Tools               []Tool          `json:"tools,omitempty"` | ||||
| 	ToolChoice          any             `json:"tool_choice,omitempty"` | ||||
| 	ParallelTooCalls    *bool           `json:"parallel_tool_calls,omitempty"` | ||||
| 	User                string          `json:"user,omitempty"` | ||||
| 	FunctionCall        any             `json:"function_call,omitempty"` | ||||
| 	Functions           any             `json:"functions,omitempty"` | ||||
| 	User             string          `json:"user,omitempty"` | ||||
| 	Prompt           any             `json:"prompt,omitempty"` | ||||
| 	// https://platform.openai.com/docs/api-reference/embeddings/create | ||||
| 	Input          any    `json:"input,omitempty"` | ||||
| 	EncodingFormat string `json:"encoding_format,omitempty"` | ||||
| 	Dimensions     int    `json:"dimensions,omitempty"` | ||||
| 	Instruction      string          `json:"instruction,omitempty"` | ||||
| 	// https://platform.openai.com/docs/api-reference/images/create | ||||
| 	Prompt  any     `json:"prompt,omitempty"` | ||||
| 	Quality *string `json:"quality,omitempty"` | ||||
| 	Size    string  `json:"size,omitempty"` | ||||
| 	Style   *string `json:"style,omitempty"` | ||||
| 	// Others | ||||
| 	Instruction string `json:"instruction,omitempty"` | ||||
| 	NumCtx      int    `json:"num_ctx,omitempty"` | ||||
| } | ||||
|  | ||||
| func (r GeneralOpenAIRequest) ParseInput() []string { | ||||
|   | ||||
| @@ -23,6 +23,7 @@ func SetApiRouter(router *gin.Engine) { | ||||
| 		apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) | ||||
| 		apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) | ||||
| 		apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth) | ||||
| 		apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), auth.OidcAuth) | ||||
| 		apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth) | ||||
| 		apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode) | ||||
| 		apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth) | ||||
|   | ||||
| @@ -9,6 +9,7 @@ import ( | ||||
|  | ||||
| func SetRelayRouter(router *gin.Engine) { | ||||
| 	router.Use(middleware.CORS()) | ||||
| 	router.Use(middleware.GzipDecodeMiddleware()) | ||||
| 	// https://platform.openai.com/docs/api-reference/introduction | ||||
| 	modelsRouter := router.Group("/v1/models") | ||||
| 	modelsRouter.Use(middleware.TokenAuth()) | ||||
|   | ||||
| @@ -11,12 +11,14 @@ import EditToken from '../pages/Token/EditToken'; | ||||
| const COPY_OPTIONS = [ | ||||
|   { key: 'next', text: 'ChatGPT Next Web', value: 'next' }, | ||||
|   { key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' }, | ||||
|   { key: 'opencat', text: 'OpenCat', value: 'opencat' } | ||||
|   { key: 'opencat', text: 'OpenCat', value: 'opencat' }, | ||||
|   { key: 'lobechat', text: 'LobeChat', value: 'lobechat' }, | ||||
| ]; | ||||
|  | ||||
| const OPEN_LINK_OPTIONS = [ | ||||
|   { key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' }, | ||||
|   { key: 'opencat', text: 'OpenCat', value: 'opencat' } | ||||
|   { key: 'opencat', text: 'OpenCat', value: 'opencat' }, | ||||
|   { key: 'lobechat', text: 'LobeChat', value: 'lobechat' } | ||||
| ]; | ||||
|  | ||||
| function renderTimestamp(timestamp) { | ||||
| @@ -60,7 +62,12 @@ const TokensTable = () => { | ||||
|         onOpenLink('next-mj'); | ||||
|       } | ||||
|     }, | ||||
|     { node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' } | ||||
|     { node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' }, | ||||
|     { | ||||
|       node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => { | ||||
|         onOpenLink('lobechat'); | ||||
|       } | ||||
|     } | ||||
|   ]; | ||||
|  | ||||
|   const columns = [ | ||||
| @@ -177,6 +184,11 @@ const TokensTable = () => { | ||||
|                   node: 'item', key: 'opencat', name: 'OpenCat', onClick: () => { | ||||
|                     onOpenLink('opencat', record.key); | ||||
|                   } | ||||
|                 }, | ||||
|                 { | ||||
|                   node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => { | ||||
|                     onOpenLink('lobechat'); | ||||
|                   } | ||||
|                 } | ||||
|               ] | ||||
|             } | ||||
| @@ -382,6 +394,9 @@ const TokensTable = () => { | ||||
|       case 'next-mj': | ||||
|         url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; | ||||
|         break; | ||||
|       case 'lobechat': | ||||
|         url = chatLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}/v1"}}}`; | ||||
|         break; | ||||
|       default: | ||||
|         if (!chatLink) { | ||||
|           showError('管理员未设置聊天链接'); | ||||
|   | ||||
| @@ -29,6 +29,9 @@ export const CHANNEL_OPTIONS = [ | ||||
|   { key: 39, text: 'together.ai', value: 39, color: 'blue' }, | ||||
|   { key: 42, text: 'VertexAI', value: 42, color: 'blue' }, | ||||
|   { key: 43, text: 'Proxy', value: 43, color: 'blue' }, | ||||
|   { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' }, | ||||
|   { key: 45, text: 'xAI', value: 45, color: 'blue' }, | ||||
|   { key: 46, text: 'Replicate', value: 46, color: 'blue' }, | ||||
|   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, | ||||
|   { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, | ||||
|   { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, | ||||
|   | ||||
| @@ -43,6 +43,7 @@ const EditChannel = (props) => { | ||||
|         base_url: '', | ||||
|         other: '', | ||||
|         model_mapping: '', | ||||
|         system_prompt: '', | ||||
|         models: [], | ||||
|         auto_ban: 1, | ||||
|         groups: ['default'] | ||||
| @@ -63,7 +64,7 @@ const EditChannel = (props) => { | ||||
|             let localModels = []; | ||||
|             switch (value) { | ||||
|                 case 14: | ||||
|                     localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20240620"]; | ||||
|                     localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-haiku-20241022", "claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20241022"]; | ||||
|                     break; | ||||
|                 case 11: | ||||
|                     localModels = ['PaLM-2']; | ||||
| @@ -78,7 +79,7 @@ const EditChannel = (props) => { | ||||
|                     localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; | ||||
|                     break; | ||||
|                 case 18: | ||||
|                     localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0']; | ||||
|                     localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.1-128K', 'SparkDesk-v3.5', 'SparkDesk-v3.5-32K', 'SparkDesk-v4.0']; | ||||
|                     break; | ||||
|                 case 19: | ||||
|                     localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; | ||||
| @@ -496,6 +497,19 @@ const EditChannel = (props) => { | ||||
|                       value={inputs.model_mapping} | ||||
|                       autoComplete='new-password' | ||||
|                     /> | ||||
|                     <div style={{ marginTop: 10 }}> | ||||
|                         <Typography.Text strong>系统提示词:</Typography.Text> | ||||
|                     </div> | ||||
|                     <TextArea | ||||
|                       placeholder={`此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型`} | ||||
|                       name='system_prompt' | ||||
|                       onChange={value => { | ||||
|                           handleInputChange('system_prompt', value) | ||||
|                       }} | ||||
|                       autosize | ||||
|                       value={inputs.system_prompt} | ||||
|                       autoComplete='new-password' | ||||
|                     /> | ||||
|                     <Typography.Text style={{ | ||||
|                         color: 'rgba(var(--semi-blue-5), 1)', | ||||
|                         userSelect: 'none', | ||||
|   | ||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| Before Width: | Height: | Size: 5.4 KiB After Width: | Height: | Size: 4.3 KiB | 
							
								
								
									
										7
									
								
								web/berry/src/assets/images/icons/oidc.svg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								web/berry/src/assets/images/icons/oidc.svg
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| <svg t="1723135116886" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" | ||||
|      p-id="10969" width="200" height="200"> | ||||
|     <path d="M512 960C265 960 64 759 64 512S265 64 512 64s448 201 448 448-201 448-448 448z m0-882.6c-239.7 0-434.6 195-434.6 434.6s195 434.6 434.6 434.6 434.6-195 434.6-434.6S751.7 77.4 512 77.4z" | ||||
|           p-id="10970" fill="#2c2c2c" stroke="#2c2c2c" stroke-width="60"></path> | ||||
|     <path d="M197.7 512c0-78.3 31.6-98.8 87.2-98.8 56.2 0 87.2 20.5 87.2 98.8s-31 98.8-87.2 98.8c-55.7 0-87.2-20.5-87.2-98.8z m130.4 0c0-46.8-7.8-64.5-43.2-64.5-35.2 0-42.9 17.7-42.9 64.5 0 47.1 7.8 63.7 42.9 63.7 35.4 0 43.2-16.6 43.2-63.7zM409.7 415.9h42.1V608h-42.1V415.9zM653.9 512c0 74.2-37.1 96.1-93.6 96.1h-65.9V415.9h65.9c56.5 0 93.6 16.1 93.6 96.1z m-43.5 0c0-49.3-17.7-60.6-52.3-60.6h-21.6v120.7h21.6c35.4 0 52.3-13.3 52.3-60.1zM686.5 512c0-74.2 36.3-98.8 92.7-98.8 18.3 0 33.2 2.2 44.8 6.4v36.3c-11.9-4.2-26-6.6-42.1-6.6-34.6 0-49.8 15.5-49.8 62.6 0 50.1 15.2 62.6 49.3 62.6 15.8 0 30.2-2.2 44.8-7.5v36c-11.3 4.7-28.5 8-46.8 8-56.1-0.2-92.9-18.7-92.9-99z" | ||||
|           p-id="10971" fill="#2c2c2c" stroke="#2c2c2c" stroke-width="20"></path> | ||||
| </svg> | ||||
| After Width: | Height: | Size: 1.2 KiB | 
| @@ -22,7 +22,12 @@ const config = { | ||||
|     turnstile_site_key: '', | ||||
|     version: '', | ||||
|     wechat_login: false, | ||||
|     wechat_qrcode: '' | ||||
|     wechat_qrcode: '', | ||||
|     oidc: false, | ||||
|     oidc_client_id: '', | ||||
|     oidc_authorization_endpoint: '', | ||||
|     oidc_token_endpoint: '', | ||||
|     oidc_userinfo_endpoint: '', | ||||
|   } | ||||
| }; | ||||
|  | ||||
|   | ||||
| @@ -173,6 +173,24 @@ export const CHANNEL_OPTIONS = { | ||||
|     value: 43, | ||||
|     color: 'primary' | ||||
|   }, | ||||
|   44: { | ||||
|     key: 44, | ||||
|     text: 'SiliconFlow', | ||||
|     value: 44, | ||||
|     color: 'primary' | ||||
|   }, | ||||
|   45: { | ||||
|     key: 45, | ||||
|     text: 'xAI', | ||||
|     value: 45, | ||||
|     color: 'primary' | ||||
|   }, | ||||
|   45: { | ||||
|     key: 46, | ||||
|     text: 'Replicate', | ||||
|     value: 46, | ||||
|     color: 'primary' | ||||
|   }, | ||||
|   41: { | ||||
|     key: 41, | ||||
|     text: 'Novita', | ||||
|   | ||||
| @@ -70,6 +70,28 @@ const useLogin = () => { | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const oidcLogin = async (code, state) => { | ||||
|     try { | ||||
|       const res = await API.get(`/api/oauth/oidc?code=${code}&state=${state}`); | ||||
|       const { success, message, data } = res.data; | ||||
|       if (success) { | ||||
|         if (message === 'bind') { | ||||
|           showSuccess('绑定成功!'); | ||||
|           navigate('/panel'); | ||||
|         } else { | ||||
|           dispatch({ type: LOGIN, payload: data }); | ||||
|           localStorage.setItem('user', JSON.stringify(data)); | ||||
|           showSuccess('登录成功!'); | ||||
|           navigate('/panel'); | ||||
|         } | ||||
|       } | ||||
|       return { success, message }; | ||||
|     } catch (err) { | ||||
|       // 请求失败,设置错误信息 | ||||
|       return { success: false, message: '' }; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   const wechatLogin = async (code) => { | ||||
|     try { | ||||
|       const res = await API.get(`/api/oauth/wechat?code=${code}`); | ||||
| @@ -94,7 +116,7 @@ const useLogin = () => { | ||||
|     navigate('/'); | ||||
|   }; | ||||
|  | ||||
|   return { login, logout, githubLogin, wechatLogin, larkLogin }; | ||||
|   return { login, logout, githubLogin, wechatLogin, larkLogin,oidcLogin }; | ||||
| }; | ||||
|  | ||||
| export default useLogin; | ||||
|   | ||||
| @@ -9,6 +9,7 @@ const AuthLogin = Loadable(lazy(() => import('views/Authentication/Auth/Login')) | ||||
| const AuthRegister = Loadable(lazy(() => import('views/Authentication/Auth/Register'))); | ||||
| const GitHubOAuth = Loadable(lazy(() => import('views/Authentication/Auth/GitHubOAuth'))); | ||||
| const LarkOAuth = Loadable(lazy(() => import('views/Authentication/Auth/LarkOAuth'))); | ||||
| const OidcOAuth = Loadable(lazy(() => import('views/Authentication/Auth/OidcOAuth'))); | ||||
| const ForgetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ForgetPassword'))); | ||||
| const ResetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ResetPassword'))); | ||||
| const Home = Loadable(lazy(() => import('views/Home'))); | ||||
| @@ -53,6 +54,10 @@ const OtherRoutes = { | ||||
|       path: '/oauth/lark', | ||||
|       element: <LarkOAuth /> | ||||
|     }, | ||||
|     { | ||||
|       path: 'oauth/oidc', | ||||
|       element: <OidcOAuth /> | ||||
|     }, | ||||
|     { | ||||
|       path: '/404', | ||||
|       element: <NotFoundView /> | ||||
|   | ||||
| @@ -95,7 +95,22 @@ export async function onLarkOAuthClicked(lark_client_id) { | ||||
|   const state = await getOAuthState(); | ||||
|   if (!state) return; | ||||
|   let redirect_uri = `${window.location.origin}/oauth/lark`; | ||||
|   window.open(`https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}`); | ||||
|   window.open(`https://accounts.feishu.cn/open-apis/authen/v1/authorize?redirect_uri=${redirect_uri}&client_id=${lark_client_id}&state=${state}`); | ||||
| } | ||||
|  | ||||
| export async function onOidcClicked(auth_url, client_id, openInNewTab = false) { | ||||
|   const state = await getOAuthState(); | ||||
|   if (!state) return; | ||||
|   const redirect_uri = `${window.location.origin}/oauth/oidc`; | ||||
|   const response_type = "code"; | ||||
|   const scope = "openid profile email"; | ||||
|   const url = `${auth_url}?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`; | ||||
|   if (openInNewTab) { | ||||
|     window.open(url); | ||||
|   } else | ||||
|   { | ||||
|     window.location.href = url; | ||||
|   } | ||||
| } | ||||
|  | ||||
| export function isAdmin() { | ||||
|   | ||||
							
								
								
									
										94
									
								
								web/berry/src/views/Authentication/Auth/OidcOAuth.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								web/berry/src/views/Authentication/Auth/OidcOAuth.js
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | ||||
| import { Link, useNavigate, useSearchParams } from 'react-router-dom'; | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { showError } from 'utils/common'; | ||||
| import useLogin from 'hooks/useLogin'; | ||||
|  | ||||
| // material-ui | ||||
| import { useTheme } from '@mui/material/styles'; | ||||
| import { Grid, Stack, Typography, useMediaQuery, CircularProgress } from '@mui/material'; | ||||
|  | ||||
| // project imports | ||||
| import AuthWrapper from '../AuthWrapper'; | ||||
| import AuthCardWrapper from '../AuthCardWrapper'; | ||||
| import Logo from 'ui-component/Logo'; | ||||
|  | ||||
| // assets | ||||
|  | ||||
| // ================================|| AUTH3 - LOGIN ||================================ // | ||||
|  | ||||
| const OidcOAuth = () => { | ||||
|   const theme = useTheme(); | ||||
|   const matchDownSM = useMediaQuery(theme.breakpoints.down('md')); | ||||
|  | ||||
|   const [searchParams] = useSearchParams(); | ||||
|   const [prompt, setPrompt] = useState('处理中...'); | ||||
|   const { oidcLogin } = useLogin(); | ||||
|  | ||||
|   let navigate = useNavigate(); | ||||
|  | ||||
|   const sendCode = async (code, state, count) => { | ||||
|     const { success, message } = await oidcLogin(code, state); | ||||
|     if (!success) { | ||||
|       if (message) { | ||||
|         showError(message); | ||||
|       } | ||||
|       if (count === 0) { | ||||
|         setPrompt(`操作失败,重定向至登录界面中...`); | ||||
|         await new Promise((resolve) => setTimeout(resolve, 2000)); | ||||
|         navigate('/login'); | ||||
|         return; | ||||
|       } | ||||
|       count++; | ||||
|       setPrompt(`出现错误,第 ${count} 次重试中...`); | ||||
|       await new Promise((resolve) => setTimeout(resolve, 2000)); | ||||
|       await sendCode(code, state, count); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   useEffect(() => { | ||||
|     let code = searchParams.get('code'); | ||||
|     let state = searchParams.get('state'); | ||||
|     sendCode(code, state, 0).then(); | ||||
|   }, []); | ||||
|  | ||||
|   return ( | ||||
|     <AuthWrapper> | ||||
|       <Grid container direction="column" justifyContent="flex-end"> | ||||
|         <Grid item xs={12}> | ||||
|           <Grid container justifyContent="center" alignItems="center" sx={{ minHeight: 'calc(100vh - 136px)' }}> | ||||
|             <Grid item sx={{ m: { xs: 1, sm: 3 }, mb: 0 }}> | ||||
|               <AuthCardWrapper> | ||||
|                 <Grid container spacing={2} alignItems="center" justifyContent="center"> | ||||
|                   <Grid item sx={{ mb: 3 }}> | ||||
|                     <Link to="#"> | ||||
|                       <Logo /> | ||||
|                     </Link> | ||||
|                   </Grid> | ||||
|                   <Grid item xs={12}> | ||||
|                     <Grid container direction={matchDownSM ? 'column-reverse' : 'row'} alignItems="center" justifyContent="center"> | ||||
|                       <Grid item> | ||||
|                         <Stack alignItems="center" justifyContent="center" spacing={1}> | ||||
|                           <Typography color={theme.palette.primary.main} gutterBottom variant={matchDownSM ? 'h3' : 'h2'}> | ||||
|                             OIDC 登录 | ||||
|                           </Typography> | ||||
|                         </Stack> | ||||
|                       </Grid> | ||||
|                     </Grid> | ||||
|                   </Grid> | ||||
|                   <Grid item xs={12} container direction="column" justifyContent="center" alignItems="center" style={{ height: '200px' }}> | ||||
|                     <CircularProgress /> | ||||
|                     <Typography variant="h3" paddingTop={'20px'}> | ||||
|                       {prompt} | ||||
|                     </Typography> | ||||
|                   </Grid> | ||||
|                 </Grid> | ||||
|               </AuthCardWrapper> | ||||
|             </Grid> | ||||
|           </Grid> | ||||
|         </Grid> | ||||
|       </Grid> | ||||
|     </AuthWrapper> | ||||
|   ); | ||||
| }; | ||||
|  | ||||
| export default OidcOAuth; | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user