mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-31 22:03:41 +08:00 
			
		
		
		
	Compare commits
	
		
			112 Commits
		
	
	
		
			v0.6.8
			...
			v0.6.11-al
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | a316ed7abc | ||
|  | 0895d8660e | ||
|  | be1ed114f4 | ||
|  | eb6da573a3 | ||
|  | 0a6273fc08 | ||
|  | 5997fce454 | ||
|  | 0df6d7a131 | ||
|  | 93fdb60de5 | ||
|  | 4db834da95 | ||
|  | 6818ed5ca8 | ||
|  | 7be3b5547d | ||
|  | 2d7ea61d67 | ||
|  | 83b34be067 | ||
|  | d5d879afdc | ||
|  | 0f205a3aa3 | ||
|  | 76c3f87351 | ||
|  | 6d9a92f8f7 | ||
|  | 835f0e0d67 | ||
|  | a6981f0d51 | ||
|  | 678d613179 | ||
|  | be089a072b | ||
|  | 45d10aa3df | ||
|  | 9cdd48ac22 | ||
|  | 310e7120e5 | ||
|  | 3d29713268 | ||
|  | f2c7c424e9 | ||
|  | 38a42bb265 | ||
|  | fa2e8f44b1 | ||
|  | 9f74101543 | ||
|  | 28a271a896 | ||
|  | e8ea87fff3 | ||
|  | abe2d2dba8 | ||
|  | 4bcaa064d6 | ||
|  | 52d81e0e24 | ||
|  | dc8c3bc69e | ||
|  | b4e69df802 | ||
|  | d9f74bdff3 | ||
|  | fa2a772731 | ||
|  | 4f68f3e1b3 | ||
|  | 0bab887b2d | ||
|  | 0230d36643 | ||
|  | bad57d049a | ||
|  | dc470ce82e | ||
|  | ea0721d525 | ||
|  | d0402f9086 | ||
|  | 1fead8e7f7 | ||
|  | 09911a301d | ||
|  | f95e6b78b8 | ||
|  | 605bb06667 | ||
|  | d88e07fd9a | ||
|  | 3915ce9814 | ||
|  | 999defc88b | ||
|  | b51c47bc77 | ||
|  | 4f25cde132 | ||
|  | d89e9d7e44 | ||
|  | a858292b54 | ||
|  | ff589b5e4a | ||
|  | 95e8c16338 | ||
|  | 381172cb36 | ||
|  | 59eae186a3 | ||
|  | ce52f355bb | ||
|  | cb9d0a74c9 | ||
|  | 49ffb1c60d | ||
|  | 2f16649896 | ||
|  | af3aa57bd6 | ||
|  | e9f117ff72 | ||
|  | 6bb5247bd6 | ||
|  | 305ce14fe3 | ||
|  | 36c8f4f15c | ||
|  | 45b51ea0ee | ||
|  | 7c8628bd95 | ||
|  | 6ab87f8a08 | ||
|  | 833fa7ad6f | ||
|  | 6eb0770a89 | ||
|  | 92cd46d64f | ||
|  | 2b2dc2c733 | ||
|  | a3d7df7f89 | ||
|  | c368232f50 | ||
|  | cbfc983dc3 | ||
|  | 8ec092ba44 | ||
|  | b0b88a79ff | ||
|  | 7e51b04221 | ||
|  | f75a17f8eb | ||
|  | 6f13a3bb3c | ||
|  | f092eed1db | ||
|  | 629378691b | ||
|  | 3716e1b0e6 | ||
|  | a4d6e7a886 | ||
|  | cb772e5d06 | ||
|  | e32cb0b844 | ||
|  | fdd7bf41c0 | ||
|  | 29389ed44f | ||
|  | 88acc5a614 | ||
|  | a21681096a | ||
|  | 32f90a79a8 | ||
|  | 99c8c77504 | ||
|  | 649ecbf29c | ||
|  | 3a27c90910 | ||
|  | cba82404ae | ||
|  | c9ac670ba1 | ||
|  | 15f815c23c | ||
|  | 89b63ca96f | ||
|  | 8cc54489b9 | ||
|  | 58bf60805e | ||
|  | 6714cf96d6 | ||
|  | f9774698e9 | ||
|  | 2af6f6a166 | ||
|  | 04bb3ef392 | ||
|  | b4bfa418a8 | ||
|  | e7e99e558a | ||
|  | 402fcf7f79 | ||
|  | 36039e329e | 
							
								
								
									
										2
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							| @@ -12,8 +12,6 @@ name: CI | ||||
| # would trigger our jobs twice on pull requests (once from "push" event and once | ||||
| # from "pull_request->synchronize") | ||||
| on: | ||||
|   pull_request: | ||||
|     types: [opened, reopened, synchronize] | ||||
|   push: | ||||
|     branches: | ||||
|       - 'main' | ||||
|   | ||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -10,3 +10,4 @@ data | ||||
| /web/node_modules | ||||
| cmd.md | ||||
| .env | ||||
| /one-api | ||||
|   | ||||
							
								
								
									
										19
									
								
								Dockerfile
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								Dockerfile
									
									
									
									
									
								
							| @@ -4,21 +4,20 @@ WORKDIR /web | ||||
| COPY ./VERSION . | ||||
| COPY ./web . | ||||
|  | ||||
| WORKDIR /web/default | ||||
| RUN npm install | ||||
| RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | ||||
| RUN npm install --prefix /web/default & \ | ||||
|     npm install --prefix /web/berry & \ | ||||
|     npm install --prefix /web/air & \ | ||||
|     wait | ||||
|  | ||||
| WORKDIR /web/berry | ||||
| RUN npm install | ||||
| RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | ||||
|  | ||||
| WORKDIR /web/air | ||||
| RUN npm install | ||||
| RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | ||||
| RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat /web/default/VERSION) npm run build --prefix /web/default & \ | ||||
|     DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat /web/berry/VERSION) npm run build --prefix /web/berry & \ | ||||
|     DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat /web/air/VERSION) npm run build --prefix /web/air & \ | ||||
|     wait | ||||
|  | ||||
| FROM golang:alpine AS builder2 | ||||
|  | ||||
| RUN apk add --no-cache g++ | ||||
| RUN apk add --no-cache gcc musl-dev libc-dev sqlite-dev | ||||
|  | ||||
| ENV GO111MODULE=on \ | ||||
|     CGO_ENABLED=1 \ | ||||
|   | ||||
							
								
								
									
										23
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								README.md
									
									
									
									
									
								
							| @@ -89,6 +89,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|    + [x] [DeepL](https://www.deepl.com/) | ||||
|    + [x] [together.ai](https://www.together.ai/) | ||||
|    + [x] [novita.ai](https://www.novita.ai/) | ||||
|    + [x] [硅基流动 SiliconCloud](https://siliconflow.cn/siliconcloud) | ||||
|    + [x] [xAI](https://x.ai/) | ||||
| 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | ||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||
| @@ -113,8 +115,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
| 21. 支持 Cloudflare Turnstile 用户校验。 | ||||
| 22. 支持用户管理,支持**多种用户登录注册方式**: | ||||
|     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 | ||||
|     + 支持使用飞书进行授权登录。 | ||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||
|     + 支持[飞书授权登录](https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/authen-v1/authorize/get)([这里有 One API 的实现细节阐述供参考](https://iamazing.cn/page/feishu-oauth-login))。 | ||||
|     + 支持 [GitHub 授权登录](https://github.com/settings/applications/new)。 | ||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||
| 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 | ||||
| 24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 | ||||
| @@ -173,6 +175,10 @@ sudo service nginx restart | ||||
|  | ||||
| 初始账号用户名为 `root`,密码为 `123456`。 | ||||
|  | ||||
| ### 通过宝塔面板进行一键部署 | ||||
| 1. 安装宝塔面板9.2.0及以上版本,前往 [宝塔面板](https://www.bt.cn/new/download.html?r=dk_oneapi) 官网,选择正式版的脚本下载安装; | ||||
| 2. 安装后登录宝塔面板,在左侧菜单栏中点击 `Docker`,首次进入会提示安装 `Docker` 服务,点击立即安装,按提示完成安装; | ||||
| 3. 安装完成后在应用商店中搜索 `One-API`,点击安装,配置域名等基本信息即可完成安装; | ||||
|  | ||||
| ### 基于 Docker Compose 进行部署 | ||||
|  | ||||
| @@ -216,7 +222,7 @@ docker-compose ps | ||||
| 3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。 | ||||
| 4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置,在使用远程数据库的情况下,推荐设置该项并启用 Redis,无论主从。 | ||||
| 5. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。 | ||||
| 6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟。 | ||||
| 6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟(Redis 集群或者哨兵模式的支持请参考环境变量说明)。 | ||||
| 7. 如果主服务器访问数据库延迟也比较高,则也需要启用 Redis,并设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。 | ||||
|  | ||||
| 环境变量的具体使用方法详见[此处](#环境变量)。 | ||||
| @@ -251,9 +257,9 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | ||||
| #### QChatGPT - QQ机器人 | ||||
| 项目主页:https://github.com/RockChinQ/QChatGPT | ||||
|  | ||||
| 根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。 | ||||
| 根据[文档](https://qchatgpt.rockchin.top)完成部署后,在 `data/provider.json`设置`requester.openai-chat-completions.base-url`为 One API 实例地址,并填写 API Key 到 `keys.openai` 组中,设置 `model` 为要使用的模型名称。 | ||||
|  | ||||
| 可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。 | ||||
| 运行期间可以通过`!model`命令查看、切换可用模型。 | ||||
|  | ||||
| ### 部署到第三方平台 | ||||
| <details> | ||||
| @@ -345,6 +351,11 @@ graph LR | ||||
| 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 | ||||
|    + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||
|    + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 | ||||
|    + 如果需要使用哨兵或者集群模式: | ||||
|      + 则需要把该环境变量设置为节点列表,例如:`localhost:49153,localhost:49154,localhost:49155`。 | ||||
|      + 除此之外还需要设置以下环境变量: | ||||
|        + `REDIS_PASSWORD`:Redis 集群或者哨兵模式下的密码设置。 | ||||
|        + `REDIS_MASTER_NAME`:Redis 哨兵模式下主节点的名称。 | ||||
| 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 | ||||
|    + 例子:`SESSION_SECRET=random_string` | ||||
| 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 | ||||
| @@ -398,6 +409,8 @@ graph LR | ||||
| 26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 | ||||
| 27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 | ||||
| 28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。 | ||||
| 29. `ENFORCE_INCLUDE_USAGE`:是否强制在 stream 模型下返回 usage,默认不开启,可选值为 `true` 和 `false`。 | ||||
| 30. `TEST_PROMPT`:测试模型时的用户 prompt,默认为 `Print your model name exactly and do not output without any other text.`。 | ||||
|  | ||||
| ### 命令行参数 | ||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||
|   | ||||
| @@ -1,13 +1,14 @@ | ||||
| package config | ||||
|  | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/common/env" | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/env" | ||||
|  | ||||
| 	"github.com/google/uuid" | ||||
| ) | ||||
|  | ||||
| @@ -35,6 +36,7 @@ var PasswordLoginEnabled = true | ||||
| var PasswordRegisterEnabled = true | ||||
| var EmailVerificationEnabled = false | ||||
| var GitHubOAuthEnabled = false | ||||
| var OidcEnabled = false | ||||
| var WeChatAuthEnabled = false | ||||
| var TurnstileCheckEnabled = false | ||||
| var RegisterEnabled = true | ||||
| @@ -70,6 +72,13 @@ var GitHubClientSecret = "" | ||||
| var LarkClientId = "" | ||||
| var LarkClientSecret = "" | ||||
|  | ||||
| var OidcClientId = "" | ||||
| var OidcClientSecret = "" | ||||
| var OidcWellKnown = "" | ||||
| var OidcAuthorizationEndpoint = "" | ||||
| var OidcTokenEndpoint = "" | ||||
| var OidcUserinfoEndpoint = "" | ||||
|  | ||||
| var WeChatServerAddress = "" | ||||
| var WeChatServerToken = "" | ||||
| var WeChatAccountQRCodeImageURL = "" | ||||
| @@ -152,3 +161,6 @@ var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false) | ||||
| var RelayProxy = env.String("RELAY_PROXY", "") | ||||
| var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") | ||||
| var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) | ||||
|  | ||||
| var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false) | ||||
| var TestPrompt = env.String("TEST_PROMPT", "Print your model name exactly and do not output without any other text.") | ||||
|   | ||||
| @@ -20,4 +20,5 @@ const ( | ||||
| 	BaseURL           = "base_url" | ||||
| 	AvailableModels   = "available_models" | ||||
| 	KeyRequestBody    = "key_request_body" | ||||
| 	SystemPrompt      = "system_prompt" | ||||
| ) | ||||
|   | ||||
| @@ -31,15 +31,15 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { | ||||
| 	contentType := c.Request.Header.Get("Content-Type") | ||||
| 	if strings.HasPrefix(contentType, "application/json") { | ||||
| 		err = json.Unmarshal(requestBody, &v) | ||||
| 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||
| 	} else { | ||||
| 		// skip for now | ||||
| 		// TODO: someday non json request have variant model, we will need to implementation this | ||||
| 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||
| 		err = c.ShouldBind(&v) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	// Reset request body | ||||
| 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -1,9 +1,8 @@ | ||||
| package helper | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"html/template" | ||||
| 	"log" | ||||
| 	"net" | ||||
| @@ -11,6 +10,10 @@ import ( | ||||
| 	"runtime" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| ) | ||||
|  | ||||
| func OpenBrowser(url string) { | ||||
| @@ -106,6 +109,18 @@ func GenRequestID() string { | ||||
| 	return GetTimeString() + random.GetRandomNumberString(8) | ||||
| } | ||||
|  | ||||
| func SetRequestID(ctx context.Context, id string) context.Context { | ||||
| 	return context.WithValue(ctx, RequestIdKey, id) | ||||
| } | ||||
|  | ||||
| func GetRequestID(ctx context.Context) string { | ||||
| 	rawRequestId := ctx.Value(RequestIdKey) | ||||
| 	if rawRequestId == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	return rawRequestId.(string) | ||||
| } | ||||
|  | ||||
| func GetResponseID(c *gin.Context) string { | ||||
| 	logID := c.GetString(RequestIdKey) | ||||
| 	return fmt.Sprintf("chatcmpl-%s", logID) | ||||
| @@ -137,3 +152,23 @@ func String2Int(str string) int { | ||||
| 	} | ||||
| 	return num | ||||
| } | ||||
|  | ||||
| func Float64PtrMax(p *float64, maxValue float64) *float64 { | ||||
| 	if p == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 	if *p > maxValue { | ||||
| 		return &maxValue | ||||
| 	} | ||||
| 	return p | ||||
| } | ||||
|  | ||||
| func Float64PtrMin(p *float64, minValue float64) *float64 { | ||||
| 	if p == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 	if *p < minValue { | ||||
| 		return &minValue | ||||
| 	} | ||||
| 	return p | ||||
| } | ||||
|   | ||||
| @@ -13,3 +13,8 @@ func GetTimeString() string { | ||||
| 	now := time.Now() | ||||
| 	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) | ||||
| } | ||||
|  | ||||
| // CalcElapsedTime return the elapsed time in milliseconds (ms) | ||||
| func CalcElapsedTime(start time.Time) int64 { | ||||
| 	return time.Now().Sub(start).Milliseconds() | ||||
| } | ||||
|   | ||||
| @@ -7,19 +7,25 @@ import ( | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"runtime" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| ) | ||||
|  | ||||
| type loggerLevel string | ||||
|  | ||||
| const ( | ||||
| 	loggerDEBUG = "DEBUG" | ||||
| 	loggerINFO  = "INFO" | ||||
| 	loggerWarn  = "WARN" | ||||
| 	loggerError = "ERR" | ||||
| 	loggerDEBUG loggerLevel = "DEBUG" | ||||
| 	loggerINFO  loggerLevel = "INFO" | ||||
| 	loggerWarn  loggerLevel = "WARN" | ||||
| 	loggerError loggerLevel = "ERROR" | ||||
| 	loggerFatal loggerLevel = "FATAL" | ||||
| ) | ||||
|  | ||||
| var setupLogOnce sync.Once | ||||
| @@ -44,27 +50,26 @@ func SetupLogger() { | ||||
| } | ||||
|  | ||||
| func SysLog(s string) { | ||||
| 	t := time.Now() | ||||
| 	_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | ||||
| 	logHelper(nil, loggerINFO, s) | ||||
| } | ||||
|  | ||||
| func SysLogf(format string, a ...any) { | ||||
| 	SysLog(fmt.Sprintf(format, a...)) | ||||
| 	logHelper(nil, loggerINFO, fmt.Sprintf(format, a...)) | ||||
| } | ||||
|  | ||||
| func SysError(s string) { | ||||
| 	t := time.Now() | ||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | ||||
| 	logHelper(nil, loggerError, s) | ||||
| } | ||||
|  | ||||
| func SysErrorf(format string, a ...any) { | ||||
| 	SysError(fmt.Sprintf(format, a...)) | ||||
| 	logHelper(nil, loggerError, fmt.Sprintf(format, a...)) | ||||
| } | ||||
|  | ||||
| func Debug(ctx context.Context, msg string) { | ||||
| 	if config.DebugEnabled { | ||||
| 		logHelper(ctx, loggerDEBUG, msg) | ||||
| 	if !config.DebugEnabled { | ||||
| 		return | ||||
| 	} | ||||
| 	logHelper(ctx, loggerDEBUG, msg) | ||||
| } | ||||
|  | ||||
| func Info(ctx context.Context, msg string) { | ||||
| @@ -80,37 +85,65 @@ func Error(ctx context.Context, msg string) { | ||||
| } | ||||
|  | ||||
| func Debugf(ctx context.Context, format string, a ...any) { | ||||
| 	Debug(ctx, fmt.Sprintf(format, a...)) | ||||
| 	logHelper(ctx, loggerDEBUG, fmt.Sprintf(format, a...)) | ||||
| } | ||||
|  | ||||
| func Infof(ctx context.Context, format string, a ...any) { | ||||
| 	Info(ctx, fmt.Sprintf(format, a...)) | ||||
| 	logHelper(ctx, loggerINFO, fmt.Sprintf(format, a...)) | ||||
| } | ||||
|  | ||||
| func Warnf(ctx context.Context, format string, a ...any) { | ||||
| 	Warn(ctx, fmt.Sprintf(format, a...)) | ||||
| 	logHelper(ctx, loggerWarn, fmt.Sprintf(format, a...)) | ||||
| } | ||||
|  | ||||
| func Errorf(ctx context.Context, format string, a ...any) { | ||||
| 	Error(ctx, fmt.Sprintf(format, a...)) | ||||
| 	logHelper(ctx, loggerError, fmt.Sprintf(format, a...)) | ||||
| } | ||||
|  | ||||
| func logHelper(ctx context.Context, level string, msg string) { | ||||
| func FatalLog(s string) { | ||||
| 	logHelper(nil, loggerFatal, s) | ||||
| } | ||||
|  | ||||
| func FatalLogf(format string, a ...any) { | ||||
| 	logHelper(nil, loggerFatal, fmt.Sprintf(format, a...)) | ||||
| } | ||||
|  | ||||
| func logHelper(ctx context.Context, level loggerLevel, msg string) { | ||||
| 	writer := gin.DefaultErrorWriter | ||||
| 	if level == loggerINFO { | ||||
| 		writer = gin.DefaultWriter | ||||
| 	} | ||||
| 	id := ctx.Value(helper.RequestIdKey) | ||||
| 	if id == nil { | ||||
| 		id = helper.GenRequestID() | ||||
| 	var requestId string | ||||
| 	if ctx != nil { | ||||
| 		rawRequestId := helper.GetRequestID(ctx) | ||||
| 		if rawRequestId != "" { | ||||
| 			requestId = fmt.Sprintf(" | %s", rawRequestId) | ||||
| 		} | ||||
| 	} | ||||
| 	lineInfo, funcName := getLineInfo() | ||||
| 	now := time.Now() | ||||
| 	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) | ||||
| 	_, _ = fmt.Fprintf(writer, "[%s] %v%s%s %s%s \n", level, now.Format("2006/01/02 - 15:04:05"), requestId, lineInfo, funcName, msg) | ||||
| 	SetupLogger() | ||||
| 	if level == loggerFatal { | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func FatalLog(v ...any) { | ||||
| 	t := time.Now() | ||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) | ||||
| 	os.Exit(1) | ||||
| func getLineInfo() (string, string) { | ||||
| 	funcName := "[unknown] " | ||||
| 	pc, file, line, ok := runtime.Caller(3) | ||||
| 	if ok { | ||||
| 		if fn := runtime.FuncForPC(pc); fn != nil { | ||||
| 			parts := strings.Split(fn.Name(), ".") | ||||
| 			funcName = "[" + parts[len(parts)-1] + "] " | ||||
| 		} | ||||
| 	} else { | ||||
| 		file = "unknown" | ||||
| 		line = 0 | ||||
| 	} | ||||
| 	parts := strings.Split(file, "one-api/") | ||||
| 	if len(parts) > 1 { | ||||
| 		file = parts[1] | ||||
| 	} | ||||
| 	return fmt.Sprintf(" | %s:%d", file, line), funcName | ||||
| } | ||||
|   | ||||
| @@ -2,13 +2,15 @@ package common | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"os" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| var RDB *redis.Client | ||||
| var RDB redis.Cmdable | ||||
| var RedisEnabled = true | ||||
|  | ||||
| // InitRedisClient This function is called after init() | ||||
| @@ -23,13 +25,23 @@ func InitRedisClient() (err error) { | ||||
| 		logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") | ||||
| 		return nil | ||||
| 	} | ||||
| 	logger.SysLog("Redis is enabled") | ||||
| 	opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) | ||||
| 	if err != nil { | ||||
| 		logger.FatalLog("failed to parse Redis connection string: " + err.Error()) | ||||
| 	redisConnString := os.Getenv("REDIS_CONN_STRING") | ||||
| 	if os.Getenv("REDIS_MASTER_NAME") == "" { | ||||
| 		logger.SysLog("Redis is enabled") | ||||
| 		opt, err := redis.ParseURL(redisConnString) | ||||
| 		if err != nil { | ||||
| 			logger.FatalLog("failed to parse Redis connection string: " + err.Error()) | ||||
| 		} | ||||
| 		RDB = redis.NewClient(opt) | ||||
| 	} else { | ||||
| 		// cluster mode | ||||
| 		logger.SysLog("Redis cluster mode enabled") | ||||
| 		RDB = redis.NewUniversalClient(&redis.UniversalOptions{ | ||||
| 			Addrs:      strings.Split(redisConnString, ","), | ||||
| 			Password:   os.Getenv("REDIS_PASSWORD"), | ||||
| 			MasterName: os.Getenv("REDIS_MASTER_NAME"), | ||||
| 		}) | ||||
| 	} | ||||
| 	RDB = redis.NewClient(opt) | ||||
|  | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||
| 	defer cancel() | ||||
|  | ||||
|   | ||||
| @@ -3,9 +3,10 @@ package render | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func StringData(c *gin.Context, str string) { | ||||
|   | ||||
| @@ -5,16 +5,18 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"github.com/songquanpeng/one-api/controller" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type GitHubOAuthResponse struct { | ||||
| @@ -81,6 +83,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { | ||||
| } | ||||
|  | ||||
| func GitHubOAuth(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	session := sessions.Default(c) | ||||
| 	state := c.Query("state") | ||||
| 	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { | ||||
| @@ -136,7 +139,7 @@ func GitHubOAuth(c *gin.Context) { | ||||
| 			user.Role = model.RoleCommonUser | ||||
| 			user.Status = model.UserStatusEnabled | ||||
|  | ||||
| 			if err := user.Insert(0); err != nil { | ||||
| 			if err := user.Insert(ctx, 0); err != nil { | ||||
| 				c.JSON(http.StatusOK, gin.H{ | ||||
| 					"success": false, | ||||
| 					"message": err.Error(), | ||||
|   | ||||
| @@ -5,15 +5,17 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/controller" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type LarkOAuthResponse struct { | ||||
| @@ -40,7 +42,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) { | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData)) | ||||
| 	req, err := http.NewRequest("POST", "https://open.feishu.cn/open-apis/authen/v2/oauth/token", bytes.NewBuffer(jsonData)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -79,6 +81,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) { | ||||
| } | ||||
|  | ||||
| func LarkOAuth(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	session := sessions.Default(c) | ||||
| 	state := c.Query("state") | ||||
| 	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { | ||||
| @@ -125,7 +128,7 @@ func LarkOAuth(c *gin.Context) { | ||||
| 			user.Role = model.RoleCommonUser | ||||
| 			user.Status = model.UserStatusEnabled | ||||
|  | ||||
| 			if err := user.Insert(0); err != nil { | ||||
| 			if err := user.Insert(ctx, 0); err != nil { | ||||
| 				c.JSON(http.StatusOK, gin.H{ | ||||
| 					"success": false, | ||||
| 					"message": err.Error(), | ||||
|   | ||||
							
								
								
									
										228
									
								
								controller/auth/oidc.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										228
									
								
								controller/auth/oidc.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,228 @@ | ||||
| package auth | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/controller" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| ) | ||||
|  | ||||
| type OidcResponse struct { | ||||
| 	AccessToken  string `json:"access_token"` | ||||
| 	IDToken      string `json:"id_token"` | ||||
| 	RefreshToken string `json:"refresh_token"` | ||||
| 	TokenType    string `json:"token_type"` | ||||
| 	ExpiresIn    int    `json:"expires_in"` | ||||
| 	Scope        string `json:"scope"` | ||||
| } | ||||
|  | ||||
| type OidcUser struct { | ||||
| 	OpenID            string `json:"sub"` | ||||
| 	Email             string `json:"email"` | ||||
| 	Name              string `json:"name"` | ||||
| 	PreferredUsername string `json:"preferred_username"` | ||||
| 	Picture           string `json:"picture"` | ||||
| } | ||||
|  | ||||
| func getOidcUserInfoByCode(code string) (*OidcUser, error) { | ||||
| 	if code == "" { | ||||
| 		return nil, errors.New("无效的参数") | ||||
| 	} | ||||
| 	values := map[string]string{ | ||||
| 		"client_id":     config.OidcClientId, | ||||
| 		"client_secret": config.OidcClientSecret, | ||||
| 		"code":          code, | ||||
| 		"grant_type":    "authorization_code", | ||||
| 		"redirect_uri":  fmt.Sprintf("%s/oauth/oidc", config.ServerAddress), | ||||
| 	} | ||||
| 	jsonData, err := json.Marshal(values) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/json") | ||||
| 	req.Header.Set("Accept", "application/json") | ||||
| 	client := http.Client{ | ||||
| 		Timeout: 5 * time.Second, | ||||
| 	} | ||||
| 	res, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		logger.SysLog(err.Error()) | ||||
| 		return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
| 	var oidcResponse OidcResponse | ||||
| 	err = json.NewDecoder(res.Body).Decode(&oidcResponse) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken) | ||||
| 	res2, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		logger.SysLog(err.Error()) | ||||
| 		return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") | ||||
| 	} | ||||
| 	var oidcUser OidcUser | ||||
| 	err = json.NewDecoder(res2.Body).Decode(&oidcUser) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &oidcUser, nil | ||||
| } | ||||
|  | ||||
| func OidcAuth(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	session := sessions.Default(c) | ||||
| 	state := c.Query("state") | ||||
| 	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { | ||||
| 		c.JSON(http.StatusForbidden, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "state is empty or not same", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	username := session.Get("username") | ||||
| 	if username != nil { | ||||
| 		OidcBind(c) | ||||
| 		return | ||||
| 	} | ||||
| 	if !config.OidcEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "管理员未开启通过 OIDC 登录以及注册", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	code := c.Query("code") | ||||
| 	oidcUser, err := getOidcUserInfoByCode(code) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user := model.User{ | ||||
| 		OidcId: oidcUser.OpenID, | ||||
| 	} | ||||
| 	if model.IsOidcIdAlreadyTaken(user.OidcId) { | ||||
| 		err := user.FillUserByOidcId() | ||||
| 		if err != nil { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": err.Error(), | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	} else { | ||||
| 		if config.RegisterEnabled { | ||||
| 			user.Email = oidcUser.Email | ||||
| 			if oidcUser.PreferredUsername != "" { | ||||
| 				user.Username = oidcUser.PreferredUsername | ||||
| 			} else { | ||||
| 				user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1) | ||||
| 			} | ||||
| 			if oidcUser.Name != "" { | ||||
| 				user.DisplayName = oidcUser.Name | ||||
| 			} else { | ||||
| 				user.DisplayName = "OIDC User" | ||||
| 			} | ||||
| 			err := user.Insert(ctx, 0) | ||||
| 			if err != nil { | ||||
| 				c.JSON(http.StatusOK, gin.H{ | ||||
| 					"success": false, | ||||
| 					"message": err.Error(), | ||||
| 				}) | ||||
| 				return | ||||
| 			} | ||||
| 		} else { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "管理员关闭了新用户注册", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if user.Status != model.UserStatusEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"message": "用户已被封禁", | ||||
| 			"success": false, | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	controller.SetupLogin(&user, c) | ||||
| } | ||||
|  | ||||
| func OidcBind(c *gin.Context) { | ||||
| 	if !config.OidcEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "管理员未开启通过 OIDC 登录以及注册", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	code := c.Query("code") | ||||
| 	oidcUser, err := getOidcUserInfoByCode(code) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user := model.User{ | ||||
| 		OidcId: oidcUser.OpenID, | ||||
| 	} | ||||
| 	if model.IsOidcIdAlreadyTaken(user.OidcId) { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "该 OIDC 账户已被绑定", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	session := sessions.Default(c) | ||||
| 	id := session.Get("id") | ||||
| 	// id := c.GetInt("id")  // critical bug! | ||||
| 	user.Id = id.(int) | ||||
| 	err = user.FillUserById() | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user.OidcId = oidcUser.OpenID | ||||
| 	err = user.Update(false) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "bind", | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
| @@ -4,14 +4,16 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/controller" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type wechatLoginResponse struct { | ||||
| @@ -52,6 +54,7 @@ func getWeChatIdByCode(code string) (string, error) { | ||||
| } | ||||
|  | ||||
| func WeChatAuth(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	if !config.WeChatAuthEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"message": "管理员未开启通过微信登录以及注册", | ||||
| @@ -87,7 +90,7 @@ func WeChatAuth(c *gin.Context) { | ||||
| 			user.Role = model.RoleCommonUser | ||||
| 			user.Status = model.UserStatusEnabled | ||||
|  | ||||
| 			if err := user.Insert(0); err != nil { | ||||
| 			if err := user.Insert(ctx, 0); err != nil { | ||||
| 				c.JSON(http.StatusOK, gin.H{ | ||||
| 					"success": false, | ||||
| 					"message": err.Error(), | ||||
|   | ||||
| @@ -17,9 +17,11 @@ func GetSubscription(c *gin.Context) { | ||||
| 	if config.DisplayTokenStatEnabled { | ||||
| 		tokenId := c.GetInt(ctxkey.TokenId) | ||||
| 		token, err = model.GetTokenById(tokenId) | ||||
| 		expiredTime = token.ExpiredTime | ||||
| 		remainQuota = token.RemainQuota | ||||
| 		usedQuota = token.UsedQuota | ||||
| 		if err == nil { | ||||
| 			expiredTime = token.ExpiredTime | ||||
| 			remainQuota = token.RemainQuota | ||||
| 			usedQuota = token.UsedQuota | ||||
| 		} | ||||
| 	} else { | ||||
| 		userId := c.GetInt(ctxkey.Id) | ||||
| 		remainQuota, err = model.GetUserQuota(userId) | ||||
|   | ||||
| @@ -4,16 +4,17 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/client" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/monitor" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
| @@ -81,6 +82,36 @@ type APGC2DGPTUsageResponse struct { | ||||
| 	TotalUsed      float64 `json:"total_used"` | ||||
| } | ||||
|  | ||||
| type SiliconFlowUsageResponse struct { | ||||
| 	Code    int    `json:"code"` | ||||
| 	Message string `json:"message"` | ||||
| 	Status  bool   `json:"status"` | ||||
| 	Data    struct { | ||||
| 		ID            string `json:"id"` | ||||
| 		Name          string `json:"name"` | ||||
| 		Image         string `json:"image"` | ||||
| 		Email         string `json:"email"` | ||||
| 		IsAdmin       bool   `json:"isAdmin"` | ||||
| 		Balance       string `json:"balance"` | ||||
| 		Status        string `json:"status"` | ||||
| 		Introduction  string `json:"introduction"` | ||||
| 		Role          string `json:"role"` | ||||
| 		ChargeBalance string `json:"chargeBalance"` | ||||
| 		TotalBalance  string `json:"totalBalance"` | ||||
| 		Category      string `json:"category"` | ||||
| 	} `json:"data"` | ||||
| } | ||||
|  | ||||
| type DeepSeekUsageResponse struct { | ||||
| 	IsAvailable  bool `json:"is_available"` | ||||
| 	BalanceInfos []struct { | ||||
| 		Currency        string `json:"currency"` | ||||
| 		TotalBalance    string `json:"total_balance"` | ||||
| 		GrantedBalance  string `json:"granted_balance"` | ||||
| 		ToppedUpBalance string `json:"topped_up_balance"` | ||||
| 	} `json:"balance_infos"` | ||||
| } | ||||
|  | ||||
| // GetAuthHeader get auth header | ||||
| func GetAuthHeader(token string) http.Header { | ||||
| 	h := http.Header{} | ||||
| @@ -203,6 +234,57 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { | ||||
| 	return response.TotalAvailable, nil | ||||
| } | ||||
|  | ||||
| func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) { | ||||
| 	url := "https://api.siliconflow.cn/v1/user/info" | ||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	response := SiliconFlowUsageResponse{} | ||||
| 	err = json.Unmarshal(body, &response) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	if response.Code != 20000 { | ||||
| 		return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message) | ||||
| 	} | ||||
| 	balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	channel.UpdateBalance(balance) | ||||
| 	return balance, nil | ||||
| } | ||||
|  | ||||
| func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) { | ||||
| 	url := "https://api.deepseek.com/user/balance" | ||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	response := DeepSeekUsageResponse{} | ||||
| 	err = json.Unmarshal(body, &response) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	index := -1 | ||||
| 	for i, balanceInfo := range response.BalanceInfos { | ||||
| 		if balanceInfo.Currency == "CNY" { | ||||
| 			index = i | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	if index == -1 { | ||||
| 		return 0, errors.New("currency CNY not found") | ||||
| 	} | ||||
| 	balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	channel.UpdateBalance(balance) | ||||
| 	return balance, nil | ||||
| } | ||||
|  | ||||
| func updateChannelBalance(channel *model.Channel) (float64, error) { | ||||
| 	baseURL := channeltype.ChannelBaseURLs[channel.Type] | ||||
| 	if channel.GetBaseURL() == "" { | ||||
| @@ -227,6 +309,10 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { | ||||
| 		return updateChannelAPI2GPTBalance(channel) | ||||
| 	case channeltype.AIGC2D: | ||||
| 		return updateChannelAIGC2DBalance(channel) | ||||
| 	case channeltype.SiliconFlow: | ||||
| 		return updateChannelSiliconFlowBalance(channel) | ||||
| 	case channeltype.DeepSeek: | ||||
| 		return updateChannelDeepSeekBalance(channel) | ||||
| 	default: | ||||
| 		return 0, errors.New("尚未实现") | ||||
| 	} | ||||
|   | ||||
| @@ -2,6 +2,7 @@ package controller | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| @@ -15,14 +16,17 @@ import ( | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/message" | ||||
| 	"github.com/songquanpeng/one-api/middleware" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/monitor" | ||||
| 	relay "github.com/songquanpeng/one-api/relay" | ||||
| 	"github.com/songquanpeng/one-api/relay" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"github.com/songquanpeng/one-api/relay/controller" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| @@ -35,18 +39,34 @@ func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest { | ||||
| 		model = "gpt-3.5-turbo" | ||||
| 	} | ||||
| 	testRequest := &relaymodel.GeneralOpenAIRequest{ | ||||
| 		MaxTokens: 2, | ||||
| 		Model:     model, | ||||
| 		Model: model, | ||||
| 	} | ||||
| 	testMessage := relaymodel.Message{ | ||||
| 		Role:    "user", | ||||
| 		Content: "hi", | ||||
| 		Content: config.TestPrompt, | ||||
| 	} | ||||
| 	testRequest.Messages = append(testRequest.Messages, testMessage) | ||||
| 	return testRequest | ||||
| } | ||||
|  | ||||
| func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) { | ||||
| func parseTestResponse(resp string) (*openai.TextResponse, string, error) { | ||||
| 	var response openai.TextResponse | ||||
| 	err := json.Unmarshal([]byte(resp), &response) | ||||
| 	if err != nil { | ||||
| 		return nil, "", err | ||||
| 	} | ||||
| 	if len(response.Choices) == 0 { | ||||
| 		return nil, "", errors.New("response has no choices") | ||||
| 	} | ||||
| 	stringContent, ok := response.Choices[0].Content.(string) | ||||
| 	if !ok { | ||||
| 		return nil, "", errors.New("response content is not string") | ||||
| 	} | ||||
| 	return &response, stringContent, nil | ||||
| } | ||||
|  | ||||
| func testChannel(ctx context.Context, channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (responseMessage string, err error, openaiErr *relaymodel.Error) { | ||||
| 	startTime := time.Now() | ||||
| 	w := httptest.NewRecorder() | ||||
| 	c, _ := gin.CreateTestContext(w) | ||||
| 	c.Request = &http.Request{ | ||||
| @@ -66,7 +86,7 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques | ||||
| 	apiType := channeltype.ToAPIType(channel.Type) | ||||
| 	adaptor := relay.GetAdaptor(apiType) | ||||
| 	if adaptor == nil { | ||||
| 		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil | ||||
| 		return "", fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil | ||||
| 	} | ||||
| 	adaptor.Init(meta) | ||||
| 	modelName := request.Model | ||||
| @@ -76,49 +96,77 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques | ||||
| 		if len(modelNames) > 0 { | ||||
| 			modelName = modelNames[0] | ||||
| 		} | ||||
| 		if modelMap != nil && modelMap[modelName] != "" { | ||||
| 			modelName = modelMap[modelName] | ||||
| 		} | ||||
| 	} | ||||
| 	if modelMap != nil && modelMap[modelName] != "" { | ||||
| 		modelName = modelMap[modelName] | ||||
| 	} | ||||
| 	meta.OriginModelName, meta.ActualModelName = request.Model, modelName | ||||
| 	request.Model = modelName | ||||
| 	convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 		return "", err, nil | ||||
| 	} | ||||
| 	jsonData, err := json.Marshal(convertedRequest) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 		return "", err, nil | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		logContent := fmt.Sprintf("渠道 %s 测试成功,响应:%s", channel.Name, responseMessage) | ||||
| 		if err != nil || openaiErr != nil { | ||||
| 			errorMessage := "" | ||||
| 			if err != nil { | ||||
| 				errorMessage = err.Error() | ||||
| 			} else { | ||||
| 				errorMessage = openaiErr.Message | ||||
| 			} | ||||
| 			logContent = fmt.Sprintf("渠道 %s 测试失败,错误:%s", channel.Name, errorMessage) | ||||
| 		} | ||||
| 		go model.RecordTestLog(ctx, &model.Log{ | ||||
| 			ChannelId:   channel.Id, | ||||
| 			ModelName:   modelName, | ||||
| 			Content:     logContent, | ||||
| 			ElapsedTime: helper.CalcElapsedTime(startTime), | ||||
| 		}) | ||||
| 	}() | ||||
| 	logger.SysLog(string(jsonData)) | ||||
| 	requestBody := bytes.NewBuffer(jsonData) | ||||
| 	c.Request.Body = io.NopCloser(requestBody) | ||||
| 	resp, err := adaptor.DoRequest(c, meta, requestBody) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 		return "", err, nil | ||||
| 	} | ||||
| 	if resp != nil && resp.StatusCode != http.StatusOK { | ||||
| 		err := controller.RelayErrorHandler(resp) | ||||
| 		return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error | ||||
| 		errorMessage := err.Error.Message | ||||
| 		if errorMessage != "" { | ||||
| 			errorMessage = ", error message: " + errorMessage | ||||
| 		} | ||||
| 		return "", fmt.Errorf("http status code: %d%s", resp.StatusCode, errorMessage), &err.Error | ||||
| 	} | ||||
| 	usage, respErr := adaptor.DoResponse(c, resp, meta) | ||||
| 	if respErr != nil { | ||||
| 		return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error | ||||
| 		return "", fmt.Errorf("%s", respErr.Error.Message), &respErr.Error | ||||
| 	} | ||||
| 	if usage == nil { | ||||
| 		return errors.New("usage is nil"), nil | ||||
| 		return "", errors.New("usage is nil"), nil | ||||
| 	} | ||||
| 	rawResponse := w.Body.String() | ||||
| 	_, responseMessage, err = parseTestResponse(rawResponse) | ||||
| 	if err != nil { | ||||
| 		return "", err, nil | ||||
| 	} | ||||
| 	result := w.Result() | ||||
| 	// print result.Body | ||||
| 	respBody, err := io.ReadAll(result.Body) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 		return "", err, nil | ||||
| 	} | ||||
| 	logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) | ||||
| 	return nil, nil | ||||
| 	return responseMessage, nil, nil | ||||
| } | ||||
|  | ||||
| func TestChannel(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	id, err := strconv.Atoi(c.Param("id")) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| @@ -135,10 +183,10 @@ func TestChannel(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	model := c.Query("model") | ||||
| 	testRequest := buildTestRequest(model) | ||||
| 	modelName := c.Query("model") | ||||
| 	testRequest := buildTestRequest(modelName) | ||||
| 	tik := time.Now() | ||||
| 	err, _ = testChannel(channel, testRequest) | ||||
| 	responseMessage, err, _ := testChannel(ctx, channel, testRequest) | ||||
| 	tok := time.Now() | ||||
| 	milliseconds := tok.Sub(tik).Milliseconds() | ||||
| 	if err != nil { | ||||
| @@ -148,18 +196,18 @@ func TestChannel(c *gin.Context) { | ||||
| 	consumedTime := float64(milliseconds) / 1000.0 | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 			"time":    consumedTime, | ||||
| 			"model":   model, | ||||
| 			"success":   false, | ||||
| 			"message":   err.Error(), | ||||
| 			"time":      consumedTime, | ||||
| 			"modelName": modelName, | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"time":    consumedTime, | ||||
| 		"model":   model, | ||||
| 		"success":   true, | ||||
| 		"message":   responseMessage, | ||||
| 		"time":      consumedTime, | ||||
| 		"modelName": modelName, | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
| @@ -167,7 +215,7 @@ func TestChannel(c *gin.Context) { | ||||
| var testAllChannelsLock sync.Mutex | ||||
| var testAllChannelsRunning bool = false | ||||
|  | ||||
| func testChannels(notify bool, scope string) error { | ||||
| func testChannels(ctx context.Context, notify bool, scope string) error { | ||||
| 	if config.RootUserEmail == "" { | ||||
| 		config.RootUserEmail = model.GetRootUserEmail() | ||||
| 	} | ||||
| @@ -191,7 +239,7 @@ func testChannels(notify bool, scope string) error { | ||||
| 			isChannelEnabled := channel.Status == model.ChannelStatusEnabled | ||||
| 			tik := time.Now() | ||||
| 			testRequest := buildTestRequest("") | ||||
| 			err, openaiErr := testChannel(channel, testRequest) | ||||
| 			_, err, openaiErr := testChannel(ctx, channel, testRequest) | ||||
| 			tok := time.Now() | ||||
| 			milliseconds := tok.Sub(tik).Milliseconds() | ||||
| 			if isChannelEnabled && milliseconds > disableThreshold { | ||||
| @@ -225,11 +273,12 @@ func testChannels(notify bool, scope string) error { | ||||
| } | ||||
|  | ||||
| func TestChannels(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	scope := c.Query("scope") | ||||
| 	if scope == "" { | ||||
| 		scope = "all" | ||||
| 	} | ||||
| 	err := testChannels(true, scope) | ||||
| 	err := testChannels(ctx, true, scope) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -245,10 +294,11 @@ func TestChannels(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func AutomaticallyTestChannels(frequency int) { | ||||
| 	ctx := context.Background() | ||||
| 	for { | ||||
| 		time.Sleep(time.Duration(frequency) * time.Minute) | ||||
| 		logger.SysLog("testing all channels") | ||||
| 		_ = testChannels(false, "all") | ||||
| 		_ = testChannels(ctx, false, "all") | ||||
| 		logger.SysLog("channel test finished") | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -18,24 +18,30 @@ func GetStatus(c *gin.Context) { | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data": gin.H{ | ||||
| 			"version":             common.Version, | ||||
| 			"start_time":          common.StartTime, | ||||
| 			"email_verification":  config.EmailVerificationEnabled, | ||||
| 			"github_oauth":        config.GitHubOAuthEnabled, | ||||
| 			"github_client_id":    config.GitHubClientId, | ||||
| 			"lark_client_id":      config.LarkClientId, | ||||
| 			"system_name":         config.SystemName, | ||||
| 			"logo":                config.Logo, | ||||
| 			"footer_html":         config.Footer, | ||||
| 			"wechat_qrcode":       config.WeChatAccountQRCodeImageURL, | ||||
| 			"wechat_login":        config.WeChatAuthEnabled, | ||||
| 			"server_address":      config.ServerAddress, | ||||
| 			"turnstile_check":     config.TurnstileCheckEnabled, | ||||
| 			"turnstile_site_key":  config.TurnstileSiteKey, | ||||
| 			"top_up_link":         config.TopUpLink, | ||||
| 			"chat_link":           config.ChatLink, | ||||
| 			"quota_per_unit":      config.QuotaPerUnit, | ||||
| 			"display_in_currency": config.DisplayInCurrencyEnabled, | ||||
| 			"version":                     common.Version, | ||||
| 			"start_time":                  common.StartTime, | ||||
| 			"email_verification":          config.EmailVerificationEnabled, | ||||
| 			"github_oauth":                config.GitHubOAuthEnabled, | ||||
| 			"github_client_id":            config.GitHubClientId, | ||||
| 			"lark_client_id":              config.LarkClientId, | ||||
| 			"system_name":                 config.SystemName, | ||||
| 			"logo":                        config.Logo, | ||||
| 			"footer_html":                 config.Footer, | ||||
| 			"wechat_qrcode":               config.WeChatAccountQRCodeImageURL, | ||||
| 			"wechat_login":                config.WeChatAuthEnabled, | ||||
| 			"server_address":              config.ServerAddress, | ||||
| 			"turnstile_check":             config.TurnstileCheckEnabled, | ||||
| 			"turnstile_site_key":          config.TurnstileSiteKey, | ||||
| 			"top_up_link":                 config.TopUpLink, | ||||
| 			"chat_link":                   config.ChatLink, | ||||
| 			"quota_per_unit":              config.QuotaPerUnit, | ||||
| 			"display_in_currency":         config.DisplayInCurrencyEnabled, | ||||
| 			"oidc":                        config.OidcEnabled, | ||||
| 			"oidc_client_id":              config.OidcClientId, | ||||
| 			"oidc_well_known":             config.OidcWellKnown, | ||||
| 			"oidc_authorization_endpoint": config.OidcAuthorizationEndpoint, | ||||
| 			"oidc_token_endpoint":         config.OidcTokenEndpoint, | ||||
| 			"oidc_userinfo_endpoint":      config.OidcUserinfoEndpoint, | ||||
| 		}, | ||||
| 	}) | ||||
| 	return | ||||
|   | ||||
| @@ -60,7 +60,7 @@ func Relay(c *gin.Context) { | ||||
| 	channelName := c.GetString(ctxkey.ChannelName) | ||||
| 	group := c.GetString(ctxkey.Group) | ||||
| 	originalModel := c.GetString(ctxkey.OriginalModel) | ||||
| 	go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) | ||||
| 	go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) | ||||
| 	requestId := c.GetString(helper.RequestIdKey) | ||||
| 	retryTimes := config.RetryTimes | ||||
| 	if !shouldRetry(c, bizErr.StatusCode) { | ||||
| @@ -87,8 +87,7 @@ func Relay(c *gin.Context) { | ||||
| 		channelId := c.GetInt(ctxkey.ChannelId) | ||||
| 		lastFailedChannelId = channelId | ||||
| 		channelName := c.GetString(ctxkey.ChannelName) | ||||
| 		// BUG: bizErr is in race condition | ||||
| 		go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) | ||||
| 		go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) | ||||
| 	} | ||||
| 	if bizErr != nil { | ||||
| 		if bizErr.StatusCode == http.StatusTooManyRequests { | ||||
| @@ -122,7 +121,7 @@ func shouldRetry(c *gin.Context, statusCode int) bool { | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) { | ||||
| func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err model.ErrorWithStatusCode) { | ||||
| 	logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message) | ||||
| 	// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||
| 	if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { | ||||
|   | ||||
| @@ -109,6 +109,7 @@ func Logout(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func Register(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	if !config.RegisterEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"message": "管理员关闭了新用户注册", | ||||
| @@ -166,7 +167,7 @@ func Register(c *gin.Context) { | ||||
| 	if config.EmailVerificationEnabled { | ||||
| 		cleanUser.Email = user.Email | ||||
| 	} | ||||
| 	if err := cleanUser.Insert(inviterId); err != nil { | ||||
| 	if err := cleanUser.Insert(ctx, inviterId); err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| @@ -362,6 +363,7 @@ func GetSelf(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func UpdateUser(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	var updatedUser model.User | ||||
| 	err := json.NewDecoder(c.Request.Body).Decode(&updatedUser) | ||||
| 	if err != nil || updatedUser.Id == 0 { | ||||
| @@ -416,7 +418,7 @@ func UpdateUser(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
| 	if originUser.Quota != updatedUser.Quota { | ||||
| 		model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) | ||||
| 		model.RecordLog(ctx, originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| @@ -535,6 +537,7 @@ func DeleteSelf(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func CreateUser(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	var user model.User | ||||
| 	err := json.NewDecoder(c.Request.Body).Decode(&user) | ||||
| 	if err != nil || user.Username == "" || user.Password == "" { | ||||
| @@ -568,7 +571,7 @@ func CreateUser(c *gin.Context) { | ||||
| 		Password:    user.Password, | ||||
| 		DisplayName: user.DisplayName, | ||||
| 	} | ||||
| 	if err := cleanUser.Insert(0); err != nil { | ||||
| 	if err := cleanUser.Insert(ctx, 0); err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| @@ -747,6 +750,7 @@ type topUpRequest struct { | ||||
| } | ||||
|  | ||||
| func TopUp(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	req := topUpRequest{} | ||||
| 	err := c.ShouldBindJSON(&req) | ||||
| 	if err != nil { | ||||
| @@ -757,7 +761,7 @@ func TopUp(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
| 	id := c.GetInt("id") | ||||
| 	quota, err := model.Redeem(req.Key, id) | ||||
| 	quota, err := model.Redeem(ctx, req.Key, id) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -780,6 +784,7 @@ type adminTopUpRequest struct { | ||||
| } | ||||
|  | ||||
| func AdminTopUp(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	req := adminTopUpRequest{} | ||||
| 	err := c.ShouldBindJSON(&req) | ||||
| 	if err != nil { | ||||
| @@ -800,7 +805,7 @@ func AdminTopUp(c *gin.Context) { | ||||
| 	if req.Remark == "" { | ||||
| 		req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota))) | ||||
| 	} | ||||
| 	model.RecordTopupLog(req.UserId, req.Remark, req.Quota) | ||||
| 	model.RecordTopupLog(ctx, req.UserId, req.Remark, req.Quota) | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
|   | ||||
							
								
								
									
										8
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								go.mod
									
									
									
									
									
								
							| @@ -25,7 +25,7 @@ require ( | ||||
| 	github.com/pkoukk/tiktoken-go v0.1.7 | ||||
| 	github.com/smartystreets/goconvey v1.8.1 | ||||
| 	github.com/stretchr/testify v1.9.0 | ||||
| 	golang.org/x/crypto v0.24.0 | ||||
| 	golang.org/x/crypto v0.31.0 | ||||
| 	golang.org/x/image v0.18.0 | ||||
| 	google.golang.org/api v0.187.0 | ||||
| 	gorm.io/driver/mysql v1.5.6 | ||||
| @@ -99,9 +99,9 @@ require ( | ||||
| 	golang.org/x/arch v0.8.0 // indirect | ||||
| 	golang.org/x/net v0.26.0 // indirect | ||||
| 	golang.org/x/oauth2 v0.21.0 // indirect | ||||
| 	golang.org/x/sync v0.7.0 // indirect | ||||
| 	golang.org/x/sys v0.21.0 // indirect | ||||
| 	golang.org/x/text v0.16.0 // indirect | ||||
| 	golang.org/x/sync v0.10.0 // indirect | ||||
| 	golang.org/x/sys v0.28.0 // indirect | ||||
| 	golang.org/x/text v0.21.0 // indirect | ||||
| 	golang.org/x/time v0.5.0 // indirect | ||||
| 	google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect | ||||
| 	google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect | ||||
|   | ||||
							
								
								
									
										16
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								go.sum
									
									
									
									
									
								
							| @@ -222,8 +222,8 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= | ||||
| golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= | ||||
| golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= | ||||
| golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= | ||||
| golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= | ||||
| golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= | ||||
| golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= | ||||
| golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= | ||||
| golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= | ||||
| golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= | ||||
| golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= | ||||
| @@ -244,20 +244,20 @@ golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht | ||||
| golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= | ||||
| golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= | ||||
| golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= | ||||
| golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= | ||||
| golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= | ||||
| golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||||
| golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= | ||||
| golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||||
| golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= | ||||
| golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
| golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= | ||||
| golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= | ||||
| golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= | ||||
| golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= | ||||
| golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= | ||||
| golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= | ||||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||||
|   | ||||
| @@ -2,21 +2,24 @@ package middleware | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| ) | ||||
|  | ||||
| type ModelRequest struct { | ||||
| 	Model string `json:"model"` | ||||
| 	Model string `json:"model" form:"model"` | ||||
| } | ||||
|  | ||||
| func Distribute() func(c *gin.Context) { | ||||
| 	return func(c *gin.Context) { | ||||
| 		ctx := c.Request.Context() | ||||
| 		userId := c.GetInt(ctxkey.Id) | ||||
| 		userGroup, _ := model.CacheGetUserGroup(userId) | ||||
| 		c.Set(ctxkey.Group, userGroup) | ||||
| @@ -52,6 +55,7 @@ func Distribute() func(c *gin.Context) { | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 		logger.Debugf(ctx, "user id %d, user group: %s, request model: %s, using channel #%d", userId, userGroup, requestModel, channel.Id) | ||||
| 		SetupContextForSelectedChannel(c, channel, requestModel) | ||||
| 		c.Next() | ||||
| 	} | ||||
| @@ -61,6 +65,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode | ||||
| 	c.Set(ctxkey.Channel, channel.Type) | ||||
| 	c.Set(ctxkey.ChannelId, channel.Id) | ||||
| 	c.Set(ctxkey.ChannelName, channel.Name) | ||||
| 	if channel.SystemPrompt != nil && *channel.SystemPrompt != "" { | ||||
| 		c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt) | ||||
| 	} | ||||
| 	c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) | ||||
| 	c.Set(ctxkey.OriginalModel, modelName) // for retry | ||||
| 	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | ||||
|   | ||||
							
								
								
									
										27
									
								
								middleware/gzip.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								middleware/gzip.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,27 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"compress/gzip" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| func GzipDecodeMiddleware() gin.HandlerFunc { | ||||
| 	return func(c *gin.Context) { | ||||
| 		if c.GetHeader("Content-Encoding") == "gzip" { | ||||
| 			gzipReader, err := gzip.NewReader(c.Request.Body) | ||||
| 			if err != nil { | ||||
| 				c.AbortWithStatus(http.StatusBadRequest) | ||||
| 				return | ||||
| 			} | ||||
| 			defer gzipReader.Close() | ||||
|  | ||||
| 			// Replace the request body with the decompressed data | ||||
| 			c.Request.Body = io.NopCloser(gzipReader) | ||||
| 		} | ||||
|  | ||||
| 		// Continue processing the request | ||||
| 		c.Next() | ||||
| 	} | ||||
| } | ||||
| @@ -1,8 +1,8 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| ) | ||||
|  | ||||
| @@ -10,7 +10,7 @@ func RequestId() func(c *gin.Context) { | ||||
| 	return func(c *gin.Context) { | ||||
| 		id := helper.GenRequestID() | ||||
| 		c.Set(helper.RequestIdKey, id) | ||||
| 		ctx := context.WithValue(c.Request.Context(), helper.RequestIdKey, id) | ||||
| 		ctx := helper.SetRequestID(c.Request.Context(), id) | ||||
| 		c.Request = c.Request.WithContext(ctx) | ||||
| 		c.Header(helper.RequestIdKey, id) | ||||
| 		c.Next() | ||||
|   | ||||
| @@ -37,6 +37,7 @@ type Channel struct { | ||||
| 	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | ||||
| 	Priority           *int64  `json:"priority" gorm:"bigint;default:0"` | ||||
| 	Config             string  `json:"config"` | ||||
| 	SystemPrompt       *string `json:"system_prompt" gorm:"type:text"` | ||||
| } | ||||
|  | ||||
| type ChannelConfig struct { | ||||
|   | ||||
							
								
								
									
										100
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										100
									
								
								model/log.go
									
									
									
									
									
								
							| @@ -3,26 +3,32 @@ package model | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type Log struct { | ||||
| 	Id               int    `json:"id"` | ||||
| 	UserId           int    `json:"user_id" gorm:"index"` | ||||
| 	CreatedAt        int64  `json:"created_at" gorm:"bigint;index:idx_created_at_type"` | ||||
| 	Type             int    `json:"type" gorm:"index:idx_created_at_type"` | ||||
| 	Content          string `json:"content"` | ||||
| 	Username         string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` | ||||
| 	TokenName        string `json:"token_name" gorm:"index;default:''"` | ||||
| 	ModelName        string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"` | ||||
| 	Quota            int    `json:"quota" gorm:"default:0"` | ||||
| 	PromptTokens     int    `json:"prompt_tokens" gorm:"default:0"` | ||||
| 	CompletionTokens int    `json:"completion_tokens" gorm:"default:0"` | ||||
| 	ChannelId        int    `json:"channel" gorm:"index"` | ||||
| 	Id                int    `json:"id"` | ||||
| 	UserId            int    `json:"user_id" gorm:"index"` | ||||
| 	CreatedAt         int64  `json:"created_at" gorm:"bigint;index:idx_created_at_type"` | ||||
| 	Type              int    `json:"type" gorm:"index:idx_created_at_type"` | ||||
| 	Content           string `json:"content"` | ||||
| 	Username          string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` | ||||
| 	TokenName         string `json:"token_name" gorm:"index;default:''"` | ||||
| 	ModelName         string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"` | ||||
| 	Quota             int    `json:"quota" gorm:"default:0"` | ||||
| 	PromptTokens      int    `json:"prompt_tokens" gorm:"default:0"` | ||||
| 	CompletionTokens  int    `json:"completion_tokens" gorm:"default:0"` | ||||
| 	ChannelId         int    `json:"channel" gorm:"index"` | ||||
| 	RequestId         string `json:"request_id" gorm:"default:''"` | ||||
| 	ElapsedTime       int64  `json:"elapsed_time" gorm:"default:0"` // unit is ms | ||||
| 	IsStream          bool   `json:"is_stream" gorm:"default:false"` | ||||
| 	SystemPromptReset bool   `json:"system_prompt_reset" gorm:"default:false"` | ||||
| } | ||||
|  | ||||
| const ( | ||||
| @@ -31,9 +37,21 @@ const ( | ||||
| 	LogTypeConsume | ||||
| 	LogTypeManage | ||||
| 	LogTypeSystem | ||||
| 	LogTypeTest | ||||
| ) | ||||
|  | ||||
| func RecordLog(userId int, logType int, content string) { | ||||
| func recordLogHelper(ctx context.Context, log *Log) { | ||||
| 	requestId := helper.GetRequestID(ctx) | ||||
| 	log.RequestId = requestId | ||||
| 	err := LOG_DB.Create(log).Error | ||||
| 	if err != nil { | ||||
| 		logger.Error(ctx, "failed to record log: "+err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	logger.Infof(ctx, "record log: %+v", log) | ||||
| } | ||||
|  | ||||
| func RecordLog(ctx context.Context, userId int, logType int, content string) { | ||||
| 	if logType == LogTypeConsume && !config.LogConsumeEnabled { | ||||
| 		return | ||||
| 	} | ||||
| @@ -44,13 +62,10 @@ func RecordLog(userId int, logType int, content string) { | ||||
| 		Type:      logType, | ||||
| 		Content:   content, | ||||
| 	} | ||||
| 	err := LOG_DB.Create(log).Error | ||||
| 	if err != nil { | ||||
| 		logger.SysError("failed to record log: " + err.Error()) | ||||
| 	} | ||||
| 	recordLogHelper(ctx, log) | ||||
| } | ||||
|  | ||||
| func RecordTopupLog(userId int, content string, quota int) { | ||||
| func RecordTopupLog(ctx context.Context, userId int, content string, quota int) { | ||||
| 	log := &Log{ | ||||
| 		UserId:    userId, | ||||
| 		Username:  GetUsernameById(userId), | ||||
| @@ -59,34 +74,23 @@ func RecordTopupLog(userId int, content string, quota int) { | ||||
| 		Content:   content, | ||||
| 		Quota:     quota, | ||||
| 	} | ||||
| 	err := LOG_DB.Create(log).Error | ||||
| 	if err != nil { | ||||
| 		logger.SysError("failed to record log: " + err.Error()) | ||||
| 	} | ||||
| 	recordLogHelper(ctx, log) | ||||
| } | ||||
|  | ||||
| func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { | ||||
| 	logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | ||||
| func RecordConsumeLog(ctx context.Context, log *Log) { | ||||
| 	if !config.LogConsumeEnabled { | ||||
| 		return | ||||
| 	} | ||||
| 	log := &Log{ | ||||
| 		UserId:           userId, | ||||
| 		Username:         GetUsernameById(userId), | ||||
| 		CreatedAt:        helper.GetTimestamp(), | ||||
| 		Type:             LogTypeConsume, | ||||
| 		Content:          content, | ||||
| 		PromptTokens:     promptTokens, | ||||
| 		CompletionTokens: completionTokens, | ||||
| 		TokenName:        tokenName, | ||||
| 		ModelName:        modelName, | ||||
| 		Quota:            int(quota), | ||||
| 		ChannelId:        channelId, | ||||
| 	} | ||||
| 	err := LOG_DB.Create(log).Error | ||||
| 	if err != nil { | ||||
| 		logger.Error(ctx, "failed to record log: "+err.Error()) | ||||
| 	} | ||||
| 	log.Username = GetUsernameById(log.UserId) | ||||
| 	log.CreatedAt = helper.GetTimestamp() | ||||
| 	log.Type = LogTypeConsume | ||||
| 	recordLogHelper(ctx, log) | ||||
| } | ||||
|  | ||||
| func RecordTestLog(ctx context.Context, log *Log) { | ||||
| 	log.CreatedAt = helper.GetTimestamp() | ||||
| 	log.Type = LogTypeTest | ||||
| 	recordLogHelper(ctx, log) | ||||
| } | ||||
|  | ||||
| func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { | ||||
| @@ -152,7 +156,11 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { | ||||
| } | ||||
|  | ||||
| func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) { | ||||
| 	tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)") | ||||
| 	ifnull := "ifnull" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		ifnull = "COALESCE" | ||||
| 	} | ||||
| 	tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(quota),0)", ifnull)) | ||||
| 	if username != "" { | ||||
| 		tx = tx.Where("username = ?", username) | ||||
| 	} | ||||
| @@ -176,7 +184,11 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa | ||||
| } | ||||
|  | ||||
| func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { | ||||
| 	tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") | ||||
| 	ifnull := "ifnull" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		ifnull = "COALESCE" | ||||
| 	} | ||||
| 	tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(prompt_tokens),0) + %s(sum(completion_tokens),0)", ifnull, ifnull)) | ||||
| 	if username != "" { | ||||
| 		tx = tx.Where("username = ?", username) | ||||
| 	} | ||||
|   | ||||
| @@ -28,6 +28,7 @@ func InitOptionMap() { | ||||
| 	config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) | ||||
| 	config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) | ||||
| 	config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) | ||||
| 	config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled) | ||||
| 	config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) | ||||
| 	config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) | ||||
| 	config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) | ||||
| @@ -130,6 +131,8 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 			config.EmailVerificationEnabled = boolValue | ||||
| 		case "GitHubOAuthEnabled": | ||||
| 			config.GitHubOAuthEnabled = boolValue | ||||
| 		case "OidcEnabled": | ||||
| 			config.OidcEnabled = boolValue | ||||
| 		case "WeChatAuthEnabled": | ||||
| 			config.WeChatAuthEnabled = boolValue | ||||
| 		case "TurnstileCheckEnabled": | ||||
| @@ -176,6 +179,18 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 		config.LarkClientId = value | ||||
| 	case "LarkClientSecret": | ||||
| 		config.LarkClientSecret = value | ||||
| 	case "OidcClientId": | ||||
| 		config.OidcClientId = value | ||||
| 	case "OidcClientSecret": | ||||
| 		config.OidcClientSecret = value | ||||
| 	case "OidcWellKnown": | ||||
| 		config.OidcWellKnown = value | ||||
| 	case "OidcAuthorizationEndpoint": | ||||
| 		config.OidcAuthorizationEndpoint = value | ||||
| 	case "OidcTokenEndpoint": | ||||
| 		config.OidcTokenEndpoint = value | ||||
| 	case "OidcUserinfoEndpoint": | ||||
| 		config.OidcUserinfoEndpoint = value | ||||
| 	case "Footer": | ||||
| 		config.Footer = value | ||||
| 	case "SystemName": | ||||
|   | ||||
| @@ -1,11 +1,14 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -48,7 +51,7 @@ func GetRedemptionById(id int) (*Redemption, error) { | ||||
| 	return &redemption, err | ||||
| } | ||||
|  | ||||
| func Redeem(key string, userId int) (quota int64, err error) { | ||||
| func Redeem(ctx context.Context, key string, userId int) (quota int64, err error) { | ||||
| 	if key == "" { | ||||
| 		return 0, errors.New("未提供兑换码") | ||||
| 	} | ||||
| @@ -82,7 +85,7 @@ func Redeem(key string, userId int) (quota int64, err error) { | ||||
| 	if err != nil { | ||||
| 		return 0, errors.New("兑换失败," + err.Error()) | ||||
| 	} | ||||
| 	RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota))) | ||||
| 	RecordLog(ctx, userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota))) | ||||
| 	return redemption.Quota, nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -30,7 +30,7 @@ type Token struct { | ||||
| 	RemainQuota    int64   `json:"remain_quota" gorm:"bigint;default:0"` | ||||
| 	UnlimitedQuota bool    `json:"unlimited_quota" gorm:"default:false"` | ||||
| 	UsedQuota      int64   `json:"used_quota" gorm:"bigint;default:0"` // used quota | ||||
| 	Models         *string `json:"models" gorm:"default:''"`           // allowed models | ||||
| 	Models         *string `json:"models" gorm:"type:text"`            // allowed models | ||||
| 	Subnet         *string `json:"subnet" gorm:"default:''"`           // allowed subnet | ||||
| } | ||||
|  | ||||
| @@ -121,30 +121,40 @@ func GetTokenById(id int) (*Token, error) { | ||||
| 	return &token, err | ||||
| } | ||||
|  | ||||
| func (token *Token) Insert() error { | ||||
| func (t *Token) Insert() error { | ||||
| 	var err error | ||||
| 	err = DB.Create(token).Error | ||||
| 	err = DB.Create(t).Error | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // Update Make sure your token's fields is completed, because this will update non-zero values | ||||
| func (token *Token) Update() error { | ||||
| func (t *Token) Update() error { | ||||
| 	var err error | ||||
| 	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error | ||||
| 	err = DB.Model(t).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(t).Error | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (token *Token) SelectUpdate() error { | ||||
| func (t *Token) SelectUpdate() error { | ||||
| 	// This can update zero values | ||||
| 	return DB.Model(token).Select("accessed_time", "status").Updates(token).Error | ||||
| 	return DB.Model(t).Select("accessed_time", "status").Updates(t).Error | ||||
| } | ||||
|  | ||||
| func (token *Token) Delete() error { | ||||
| func (t *Token) Delete() error { | ||||
| 	var err error | ||||
| 	err = DB.Delete(token).Error | ||||
| 	err = DB.Delete(t).Error | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (t *Token) GetModels() string { | ||||
| 	if t == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	if t.Models == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	return *t.Models | ||||
| } | ||||
|  | ||||
| func DeleteTokenById(id int, userId int) (err error) { | ||||
| 	// Why we need userId here? In case user want to delete other's token. | ||||
| 	if id == 0 || userId == 0 { | ||||
| @@ -254,14 +264,14 @@ func PreConsumeTokenQuota(tokenId int, quota int64) (err error) { | ||||
|  | ||||
| func PostConsumeTokenQuota(tokenId int, quota int64) (err error) { | ||||
| 	token, err := GetTokenById(tokenId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if quota > 0 { | ||||
| 		err = DecreaseUserQuota(token.UserId, quota) | ||||
| 	} else { | ||||
| 		err = IncreaseUserQuota(token.UserId, -quota) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if !token.UnlimitedQuota { | ||||
| 		if quota > 0 { | ||||
| 			err = DecreaseTokenQuota(tokenId, quota) | ||||
|   | ||||
| @@ -1,16 +1,19 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/blacklist" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"gorm.io/gorm" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -39,6 +42,7 @@ type User struct { | ||||
| 	GitHubId         string `json:"github_id" gorm:"column:github_id;index"` | ||||
| 	WeChatId         string `json:"wechat_id" gorm:"column:wechat_id;index"` | ||||
| 	LarkId           string `json:"lark_id" gorm:"column:lark_id;index"` | ||||
| 	OidcId           string `json:"oidc_id" gorm:"column:oidc_id;index"` | ||||
| 	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database! | ||||
| 	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management | ||||
| 	Quota            int64  `json:"quota" gorm:"bigint;default:0"` | ||||
| @@ -91,7 +95,7 @@ func GetUserById(id int, selectAll bool) (*User, error) { | ||||
| 	if selectAll { | ||||
| 		err = DB.First(&user, "id = ?", id).Error | ||||
| 	} else { | ||||
| 		err = DB.Omit("password").First(&user, "id = ?", id).Error | ||||
| 		err = DB.Omit("password", "access_token").First(&user, "id = ?", id).Error | ||||
| 	} | ||||
| 	return &user, err | ||||
| } | ||||
| @@ -113,7 +117,7 @@ func DeleteUserById(id int) (err error) { | ||||
| 	return user.Delete() | ||||
| } | ||||
|  | ||||
| func (user *User) Insert(inviterId int) error { | ||||
| func (user *User) Insert(ctx context.Context, inviterId int) error { | ||||
| 	var err error | ||||
| 	if user.Password != "" { | ||||
| 		user.Password, err = common.Password2Hash(user.Password) | ||||
| @@ -129,16 +133,16 @@ func (user *User) Insert(inviterId int) error { | ||||
| 		return result.Error | ||||
| 	} | ||||
| 	if config.QuotaForNewUser > 0 { | ||||
| 		RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser))) | ||||
| 		RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser))) | ||||
| 	} | ||||
| 	if inviterId != 0 { | ||||
| 		if config.QuotaForInvitee > 0 { | ||||
| 			_ = IncreaseUserQuota(user.Id, config.QuotaForInvitee) | ||||
| 			RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee))) | ||||
| 			RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee))) | ||||
| 		} | ||||
| 		if config.QuotaForInviter > 0 { | ||||
| 			_ = IncreaseUserQuota(inviterId, config.QuotaForInviter) | ||||
| 			RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) | ||||
| 			RecordLog(ctx, inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) | ||||
| 		} | ||||
| 	} | ||||
| 	// create default token | ||||
| @@ -245,6 +249,14 @@ func (user *User) FillUserByLarkId() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (user *User) FillUserByOidcId() error { | ||||
| 	if user.OidcId == "" { | ||||
| 		return errors.New("oidc id 为空!") | ||||
| 	} | ||||
| 	DB.Where(User{OidcId: user.OidcId}).First(user) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (user *User) FillUserByWeChatId() error { | ||||
| 	if user.WeChatId == "" { | ||||
| 		return errors.New("WeChat id 为空!") | ||||
| @@ -277,6 +289,10 @@ func IsLarkIdAlreadyTaken(githubId string) bool { | ||||
| 	return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
|  | ||||
| func IsOidcIdAlreadyTaken(oidcId string) bool { | ||||
| 	return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
|  | ||||
| func IsUsernameAlreadyTaken(username string) bool { | ||||
| 	return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
|   | ||||
| @@ -1,10 +1,11 @@ | ||||
| package monitor | ||||
|  | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| func ShouldDisableChannel(err *model.Error, statusCode int) bool { | ||||
| @@ -18,31 +19,23 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool { | ||||
| 		return true | ||||
| 	} | ||||
| 	switch err.Type { | ||||
| 	case "insufficient_quota": | ||||
| 		return true | ||||
| 	// https://docs.anthropic.com/claude/reference/errors | ||||
| 	case "authentication_error": | ||||
| 		return true | ||||
| 	case "permission_error": | ||||
| 		return true | ||||
| 	case "forbidden": | ||||
| 	case "insufficient_quota", "authentication_error", "permission_error", "forbidden": | ||||
| 		return true | ||||
| 	} | ||||
| 	if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { | ||||
| 		return true | ||||
| 	} | ||||
| 	if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic | ||||
| 		return true | ||||
| 	} else if strings.HasPrefix(err.Message, "This organization has been disabled.") { | ||||
| 		return true | ||||
| 	} | ||||
| 	//if strings.Contains(err.Message, "quota") { | ||||
| 	//	return true | ||||
| 	//} | ||||
| 	if strings.Contains(err.Message, "credit") { | ||||
| 		return true | ||||
| 	} | ||||
| 	if strings.Contains(err.Message, "balance") { | ||||
|  | ||||
| 	lowerMessage := strings.ToLower(err.Message) | ||||
| 	if strings.Contains(lowerMessage, "your access was terminated") || | ||||
| 		strings.Contains(lowerMessage, "violation of our policies") || | ||||
| 		strings.Contains(lowerMessage, "your credit balance is too low") || | ||||
| 		strings.Contains(lowerMessage, "organization has been disabled") || | ||||
| 		strings.Contains(lowerMessage, "credit") || | ||||
| 		strings.Contains(lowerMessage, "balance") || | ||||
| 		strings.Contains(lowerMessage, "permission denied") || | ||||
| 		strings.Contains(lowerMessage, "organization has been restricted") || // groq | ||||
| 		strings.Contains(lowerMessage, "已欠费") { | ||||
| 		return true | ||||
| 	} | ||||
| 	return false | ||||
|   | ||||
| @@ -16,6 +16,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/palm" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/proxy" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/replicate" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/tencent" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/vertexai" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/xunfei" | ||||
| @@ -61,6 +62,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor { | ||||
| 		return &vertexai.Adaptor{} | ||||
| 	case apitype.Proxy: | ||||
| 		return &proxy.Adaptor{} | ||||
| 	case apitype.Replicate: | ||||
| 		return &replicate.Adaptor{} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -1,7 +1,23 @@ | ||||
| package ali | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", | ||||
| 	"text-embedding-v1", | ||||
| 	"qwen-turbo", "qwen-turbo-latest", | ||||
| 	"qwen-plus", "qwen-plus-latest", | ||||
| 	"qwen-max", "qwen-max-latest", | ||||
| 	"qwen-max-longcontext", | ||||
| 	"qwen-vl-max", "qwen-vl-max-latest", "qwen-vl-plus", "qwen-vl-plus-latest", | ||||
| 	"qwen-vl-ocr", "qwen-vl-ocr-latest", | ||||
| 	"qwen-audio-turbo", | ||||
| 	"qwen-math-plus", "qwen-math-plus-latest", "qwen-math-turbo", "qwen-math-turbo-latest", | ||||
| 	"qwen-coder-plus", "qwen-coder-plus-latest", "qwen-coder-turbo", "qwen-coder-turbo-latest", | ||||
| 	"qwq-32b-preview", "qwen2.5-72b-instruct", "qwen2.5-32b-instruct", "qwen2.5-14b-instruct", "qwen2.5-7b-instruct", "qwen2.5-3b-instruct", "qwen2.5-1.5b-instruct", "qwen2.5-0.5b-instruct", | ||||
| 	"qwen2-72b-instruct", "qwen2-57b-a14b-instruct", "qwen2-7b-instruct", "qwen2-1.5b-instruct", "qwen2-0.5b-instruct", | ||||
| 	"qwen1.5-110b-chat", "qwen1.5-72b-chat", "qwen1.5-32b-chat", "qwen1.5-14b-chat", "qwen1.5-7b-chat", "qwen1.5-1.8b-chat", "qwen1.5-0.5b-chat", | ||||
| 	"qwen-72b-chat", "qwen-14b-chat", "qwen-7b-chat", "qwen-1.8b-chat", "qwen-1.8b-longcontext-chat", | ||||
| 	"qwen2-vl-7b-instruct", "qwen2-vl-2b-instruct", "qwen-vl-v1", "qwen-vl-chat-v1", | ||||
| 	"qwen2-audio-instruct", "qwen-audio-chat", | ||||
| 	"qwen2.5-math-72b-instruct", "qwen2.5-math-7b-instruct", "qwen2.5-math-1.5b-instruct", "qwen2-math-72b-instruct", "qwen2-math-7b-instruct", "qwen2-math-1.5b-instruct", | ||||
| 	"qwen2.5-coder-32b-instruct", "qwen2.5-coder-14b-instruct", "qwen2.5-coder-7b-instruct", "qwen2.5-coder-3b-instruct", "qwen2.5-coder-1.5b-instruct", "qwen2.5-coder-0.5b-instruct", | ||||
| 	"text-embedding-v1", "text-embedding-v3", "text-embedding-v2", "text-embedding-async-v2", "text-embedding-async-v1", | ||||
| 	"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1", | ||||
| } | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package ali | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| @@ -35,9 +36,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 		enableSearch = true | ||||
| 		aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) | ||||
| 	} | ||||
| 	if request.TopP >= 1 { | ||||
| 		request.TopP = 0.9999 | ||||
| 	} | ||||
| 	request.TopP = helper.Float64PtrMax(request.TopP, 0.9999) | ||||
| 	return &ChatRequest{ | ||||
| 		Model: aliModel, | ||||
| 		Input: Input{ | ||||
| @@ -59,7 +58,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
|  | ||||
| func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { | ||||
| 	return &EmbeddingRequest{ | ||||
| 		Model: "text-embedding-v1", | ||||
| 		Model: request.Model, | ||||
| 		Input: struct { | ||||
| 			Texts []string `json:"texts"` | ||||
| 		}{ | ||||
| @@ -102,8 +101,9 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStat | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
|  | ||||
| 	requestModel := c.GetString(ctxkey.RequestModel) | ||||
| 	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) | ||||
| 	fullTextResponse.Model = requestModel | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
|   | ||||
| @@ -16,13 +16,13 @@ type Input struct { | ||||
| } | ||||
|  | ||||
| type Parameters struct { | ||||
| 	TopP              float64      `json:"top_p,omitempty"` | ||||
| 	TopP              *float64     `json:"top_p,omitempty"` | ||||
| 	TopK              int          `json:"top_k,omitempty"` | ||||
| 	Seed              uint64       `json:"seed,omitempty"` | ||||
| 	EnableSearch      bool         `json:"enable_search,omitempty"` | ||||
| 	IncrementalOutput bool         `json:"incremental_output,omitempty"` | ||||
| 	MaxTokens         int          `json:"max_tokens,omitempty"` | ||||
| 	Temperature       float64      `json:"temperature,omitempty"` | ||||
| 	Temperature       *float64     `json:"temperature,omitempty"` | ||||
| 	ResultFormat      string       `json:"result_format,omitempty"` | ||||
| 	Tools             []model.Tool `json:"tools,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -3,7 +3,10 @@ package anthropic | ||||
| var ModelList = []string{ | ||||
| 	"claude-instant-1.2", "claude-2.0", "claude-2.1", | ||||
| 	"claude-3-haiku-20240307", | ||||
| 	"claude-3-5-haiku-20241022", | ||||
| 	"claude-3-sonnet-20240229", | ||||
| 	"claude-3-opus-20240229", | ||||
| 	"claude-3-5-sonnet-20240620", | ||||
| 	"claude-3-5-sonnet-20241022", | ||||
| 	"claude-3-5-sonnet-latest", | ||||
| } | ||||
|   | ||||
| @@ -48,8 +48,8 @@ type Request struct { | ||||
| 	MaxTokens     int       `json:"max_tokens,omitempty"` | ||||
| 	StopSequences []string  `json:"stop_sequences,omitempty"` | ||||
| 	Stream        bool      `json:"stream,omitempty"` | ||||
| 	Temperature   float64   `json:"temperature,omitempty"` | ||||
| 	TopP          float64   `json:"top_p,omitempty"` | ||||
| 	Temperature   *float64  `json:"temperature,omitempty"` | ||||
| 	TopP          *float64  `json:"top_p,omitempty"` | ||||
| 	TopK          int       `json:"top_k,omitempty"` | ||||
| 	Tools         []Tool    `json:"tools,omitempty"` | ||||
| 	ToolChoice    any       `json:"tool_choice,omitempty"` | ||||
|   | ||||
| @@ -29,10 +29,13 @@ var AwsModelIDMap = map[string]string{ | ||||
| 	"claude-instant-1.2":         "anthropic.claude-instant-v1", | ||||
| 	"claude-2.0":                 "anthropic.claude-v2", | ||||
| 	"claude-2.1":                 "anthropic.claude-v2:1", | ||||
| 	"claude-3-sonnet-20240229":   "anthropic.claude-3-sonnet-20240229-v1:0", | ||||
| 	"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", | ||||
| 	"claude-3-opus-20240229":     "anthropic.claude-3-opus-20240229-v1:0", | ||||
| 	"claude-3-haiku-20240307":    "anthropic.claude-3-haiku-20240307-v1:0", | ||||
| 	"claude-3-sonnet-20240229":   "anthropic.claude-3-sonnet-20240229-v1:0", | ||||
| 	"claude-3-opus-20240229":     "anthropic.claude-3-opus-20240229-v1:0", | ||||
| 	"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", | ||||
| 	"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0", | ||||
| 	"claude-3-5-sonnet-latest":   "anthropic.claude-3-5-sonnet-20241022-v2:0", | ||||
| 	"claude-3-5-haiku-20241022":  "anthropic.claude-3-5-haiku-20241022-v1:0", | ||||
| } | ||||
|  | ||||
| func awsModelID(requestModel string) (string, error) { | ||||
|   | ||||
| @@ -11,8 +11,8 @@ type Request struct { | ||||
| 	Messages         []anthropic.Message `json:"messages"` | ||||
| 	System           string              `json:"system,omitempty"` | ||||
| 	MaxTokens        int                 `json:"max_tokens,omitempty"` | ||||
| 	Temperature      float64             `json:"temperature,omitempty"` | ||||
| 	TopP             float64             `json:"top_p,omitempty"` | ||||
| 	Temperature      *float64            `json:"temperature,omitempty"` | ||||
| 	TopP             *float64            `json:"top_p,omitempty"` | ||||
| 	TopK             int                 `json:"top_k,omitempty"` | ||||
| 	StopSequences    []string            `json:"stop_sequences,omitempty"` | ||||
| 	Tools            []anthropic.Tool    `json:"tools,omitempty"` | ||||
|   | ||||
| @@ -4,10 +4,10 @@ package aws | ||||
| // | ||||
| // https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html | ||||
| type Request struct { | ||||
| 	Prompt      string  `json:"prompt"` | ||||
| 	MaxGenLen   int     `json:"max_gen_len,omitempty"` | ||||
| 	Temperature float64 `json:"temperature,omitempty"` | ||||
| 	TopP        float64 `json:"top_p,omitempty"` | ||||
| 	Prompt      string   `json:"prompt"` | ||||
| 	MaxGenLen   int      `json:"max_gen_len,omitempty"` | ||||
| 	Temperature *float64 `json:"temperature,omitempty"` | ||||
| 	TopP        *float64 `json:"top_p,omitempty"` | ||||
| } | ||||
|  | ||||
| // Response is the response from AWS Llama3 | ||||
|   | ||||
| @@ -35,9 +35,9 @@ type Message struct { | ||||
|  | ||||
| type ChatRequest struct { | ||||
| 	Messages        []Message `json:"messages"` | ||||
| 	Temperature     float64   `json:"temperature,omitempty"` | ||||
| 	TopP            float64   `json:"top_p,omitempty"` | ||||
| 	PenaltyScore    float64   `json:"penalty_score,omitempty"` | ||||
| 	Temperature     *float64  `json:"temperature,omitempty"` | ||||
| 	TopP            *float64  `json:"top_p,omitempty"` | ||||
| 	PenaltyScore    *float64  `json:"penalty_score,omitempty"` | ||||
| 	Stream          bool      `json:"stream,omitempty"` | ||||
| 	System          string    `json:"system,omitempty"` | ||||
| 	DisableSearch   bool      `json:"disable_search,omitempty"` | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package cloudflare | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"@cf/meta/llama-3.1-8b-instruct", | ||||
| 	"@cf/meta/llama-2-7b-chat-fp16", | ||||
| 	"@cf/meta/llama-2-7b-chat-int8", | ||||
| 	"@cf/mistral/mistral-7b-instruct-v0.1", | ||||
|   | ||||
| @@ -9,5 +9,5 @@ type Request struct { | ||||
| 	Prompt      string          `json:"prompt,omitempty"` | ||||
| 	Raw         bool            `json:"raw,omitempty"` | ||||
| 	Stream      bool            `json:"stream,omitempty"` | ||||
| 	Temperature float64         `json:"temperature,omitempty"` | ||||
| 	Temperature *float64        `json:"temperature,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -43,7 +43,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 		K:                textRequest.TopK, | ||||
| 		Stream:           textRequest.Stream, | ||||
| 		FrequencyPenalty: textRequest.FrequencyPenalty, | ||||
| 		PresencePenalty:  textRequest.FrequencyPenalty, | ||||
| 		PresencePenalty:  textRequest.PresencePenalty, | ||||
| 		Seed:             int(textRequest.Seed), | ||||
| 	} | ||||
| 	if cohereRequest.Model == "" { | ||||
|   | ||||
| @@ -10,15 +10,15 @@ type Request struct { | ||||
| 	PromptTruncation string        `json:"prompt_truncation,omitempty"` // 默认值为"AUTO" | ||||
| 	Connectors       []Connector   `json:"connectors,omitempty"` | ||||
| 	Documents        []Document    `json:"documents,omitempty"` | ||||
| 	Temperature      float64       `json:"temperature,omitempty"` // 默认值为0.3 | ||||
| 	Temperature      *float64      `json:"temperature,omitempty"` // 默认值为0.3 | ||||
| 	MaxTokens        int           `json:"max_tokens,omitempty"` | ||||
| 	MaxInputTokens   int           `json:"max_input_tokens,omitempty"` | ||||
| 	K                int           `json:"k,omitempty"` // 默认值为0 | ||||
| 	P                float64       `json:"p,omitempty"` // 默认值为0.75 | ||||
| 	P                *float64      `json:"p,omitempty"` // 默认值为0.75 | ||||
| 	Seed             int           `json:"seed,omitempty"` | ||||
| 	StopSequences    []string      `json:"stop_sequences,omitempty"` | ||||
| 	FrequencyPenalty float64       `json:"frequency_penalty,omitempty"` // 默认值为0.0 | ||||
| 	PresencePenalty  float64       `json:"presence_penalty,omitempty"`  // 默认值为0.0 | ||||
| 	FrequencyPenalty *float64      `json:"frequency_penalty,omitempty"` // 默认值为0.0 | ||||
| 	PresencePenalty  *float64      `json:"presence_penalty,omitempty"`  // 默认值为0.0 | ||||
| 	Tools            []Tool        `json:"tools,omitempty"` | ||||
| 	ToolResults      []ToolResult  `json:"tool_results,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -7,7 +7,6 @@ import ( | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	channelhelper "github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| @@ -24,7 +23,15 @@ func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion) | ||||
| 	var defaultVersion string | ||||
| 	switch meta.ActualModelName { | ||||
| 	case "gemini-2.0-flash-exp", | ||||
| 		"gemini-2.0-flash-thinking-exp", | ||||
| 		"gemini-2.0-flash-thinking-exp-01-21": | ||||
| 		defaultVersion = "v1beta" | ||||
| 	} | ||||
|  | ||||
| 	version := helper.AssignOrDefault(meta.Config.APIVersion, defaultVersion) | ||||
| 	action := "" | ||||
| 	switch meta.Mode { | ||||
| 	case relaymode.Embeddings: | ||||
| @@ -36,6 +43,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	if meta.IsStream { | ||||
| 		action = "streamGenerateContent?alt=sse" | ||||
| 	} | ||||
|  | ||||
| 	return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -3,6 +3,9 @@ package gemini | ||||
| // https://ai.google.dev/models/gemini | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro", | ||||
| 	"gemini-pro-vision", "gemini-1.0-pro-vision-001", "embedding-001", "text-embedding-004", | ||||
| 	"gemini-pro", "gemini-1.0-pro", | ||||
| 	"gemini-1.5-flash", "gemini-1.5-pro", | ||||
| 	"text-embedding-004", "aqa", | ||||
| 	"gemini-2.0-flash-exp", | ||||
| 	"gemini-2.0-flash-thinking-exp", "gemini-2.0-flash-thinking-exp-01-21", | ||||
| } | ||||
|   | ||||
| @@ -4,11 +4,12 @@ import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| @@ -28,6 +29,11 @@ const ( | ||||
| 	VisionMaxImageNum = 16 | ||||
| ) | ||||
|  | ||||
| var mimeTypeMap = map[string]string{ | ||||
| 	"json_object": "application/json", | ||||
| 	"text":        "text/plain", | ||||
| } | ||||
|  | ||||
| // Setting safety to the lowest possible values since Gemini is already powerless enough | ||||
| func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 	geminiRequest := ChatRequest{ | ||||
| @@ -49,6 +55,10 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 				Category:  "HARM_CATEGORY_DANGEROUS_CONTENT", | ||||
| 				Threshold: config.GeminiSafetySetting, | ||||
| 			}, | ||||
| 			{ | ||||
| 				Category:  "HARM_CATEGORY_CIVIC_INTEGRITY", | ||||
| 				Threshold: config.GeminiSafetySetting, | ||||
| 			}, | ||||
| 		}, | ||||
| 		GenerationConfig: ChatGenerationConfig{ | ||||
| 			Temperature:     textRequest.Temperature, | ||||
| @@ -56,6 +66,15 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 			MaxOutputTokens: textRequest.MaxTokens, | ||||
| 		}, | ||||
| 	} | ||||
| 	if textRequest.ResponseFormat != nil { | ||||
| 		if mimeType, ok := mimeTypeMap[textRequest.ResponseFormat.Type]; ok { | ||||
| 			geminiRequest.GenerationConfig.ResponseMimeType = mimeType | ||||
| 		} | ||||
| 		if textRequest.ResponseFormat.JsonSchema != nil { | ||||
| 			geminiRequest.GenerationConfig.ResponseSchema = textRequest.ResponseFormat.JsonSchema.Schema | ||||
| 			geminiRequest.GenerationConfig.ResponseMimeType = mimeTypeMap["json_object"] | ||||
| 		} | ||||
| 	} | ||||
| 	if textRequest.Tools != nil { | ||||
| 		functions := make([]model.Function, 0, len(textRequest.Tools)) | ||||
| 		for _, tool := range textRequest.Tools { | ||||
| @@ -232,7 +251,14 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 			if candidate.Content.Parts[0].FunctionCall != nil { | ||||
| 				choice.Message.ToolCalls = getToolCalls(&candidate) | ||||
| 			} else { | ||||
| 				choice.Message.Content = candidate.Content.Parts[0].Text | ||||
| 				var builder strings.Builder | ||||
| 				for _, part := range candidate.Content.Parts { | ||||
| 					if i > 0 { | ||||
| 						builder.WriteString("\n") | ||||
| 					} | ||||
| 					builder.WriteString(part.Text) | ||||
| 				} | ||||
| 				choice.Message.Content = builder.String() | ||||
| 			} | ||||
| 		} else { | ||||
| 			choice.Message.Content = "" | ||||
|   | ||||
| @@ -65,10 +65,12 @@ type ChatTools struct { | ||||
| } | ||||
|  | ||||
| type ChatGenerationConfig struct { | ||||
| 	Temperature     float64  `json:"temperature,omitempty"` | ||||
| 	TopP            float64  `json:"topP,omitempty"` | ||||
| 	TopK            float64  `json:"topK,omitempty"` | ||||
| 	MaxOutputTokens int      `json:"maxOutputTokens,omitempty"` | ||||
| 	CandidateCount  int      `json:"candidateCount,omitempty"` | ||||
| 	StopSequences   []string `json:"stopSequences,omitempty"` | ||||
| 	ResponseMimeType string   `json:"responseMimeType,omitempty"` | ||||
| 	ResponseSchema   any      `json:"responseSchema,omitempty"` | ||||
| 	Temperature      *float64 `json:"temperature,omitempty"` | ||||
| 	TopP             *float64 `json:"topP,omitempty"` | ||||
| 	TopK             float64  `json:"topK,omitempty"` | ||||
| 	MaxOutputTokens  int      `json:"maxOutputTokens,omitempty"` | ||||
| 	CandidateCount   int      `json:"candidateCount,omitempty"` | ||||
| 	StopSequences    []string `json:"stopSequences,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -4,9 +4,24 @@ package groq | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"gemma-7b-it", | ||||
| 	"llama2-7b-2048", | ||||
| 	"llama2-70b-4096", | ||||
| 	"mixtral-8x7b-32768", | ||||
| 	"llama3-8b-8192", | ||||
| 	"gemma2-9b-it", | ||||
| 	"llama-3.1-70b-versatile", | ||||
| 	"llama-3.1-8b-instant", | ||||
| 	"llama-3.2-11b-text-preview", | ||||
| 	"llama-3.2-11b-vision-preview", | ||||
| 	"llama-3.2-1b-preview", | ||||
| 	"llama-3.2-3b-preview", | ||||
| 	"llama-3.2-11b-vision-preview", | ||||
| 	"llama-3.2-90b-text-preview", | ||||
| 	"llama-3.2-90b-vision-preview", | ||||
| 	"llama-guard-3-8b", | ||||
| 	"llama3-70b-8192", | ||||
| 	"llama3-8b-8192", | ||||
| 	"llama3-groq-70b-8192-tool-use-preview", | ||||
| 	"llama3-groq-8b-8192-tool-use-preview", | ||||
| 	"llava-v1.5-7b-4096-preview", | ||||
| 	"mixtral-8x7b-32768", | ||||
| 	"distil-whisper-large-v3-en", | ||||
| 	"whisper-large-v3", | ||||
| 	"whisper-large-v3-turbo", | ||||
| } | ||||
|   | ||||
| @@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	// https://github.com/ollama/ollama/blob/main/docs/api.md | ||||
| 	fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) | ||||
| 	if meta.Mode == relaymode.Embeddings { | ||||
| 		fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL) | ||||
| 		fullRequestURL = fmt.Sprintf("%s/api/embed", meta.BaseURL) | ||||
| 	} | ||||
| 	return fullRequestURL, nil | ||||
| } | ||||
|   | ||||
| @@ -31,6 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 			TopP:             request.TopP, | ||||
| 			FrequencyPenalty: request.FrequencyPenalty, | ||||
| 			PresencePenalty:  request.PresencePenalty, | ||||
| 			NumPredict:       request.MaxTokens, | ||||
| 			NumCtx:           request.NumCtx, | ||||
| 		}, | ||||
| 		Stream: request.Stream, | ||||
| 	} | ||||
| @@ -118,8 +120,10 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 	common.SetEventStreamHeaders(c) | ||||
|  | ||||
| 	for scanner.Scan() { | ||||
| 		data := strings.TrimPrefix(scanner.Text(), "}") | ||||
| 		data = data + "}" | ||||
| 		data := scanner.Text() | ||||
| 		if strings.HasPrefix(data, "}") { | ||||
| 			data = strings.TrimPrefix(data, "}") + "}" | ||||
| 		} | ||||
|  | ||||
| 		var ollamaResponse ChatResponse | ||||
| 		err := json.Unmarshal([]byte(data), &ollamaResponse) | ||||
| @@ -157,8 +161,15 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
|  | ||||
| func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { | ||||
| 	return &EmbeddingRequest{ | ||||
| 		Model:  request.Model, | ||||
| 		Prompt: strings.Join(request.ParseInput(), " "), | ||||
| 		Model: request.Model, | ||||
| 		Input: request.ParseInput(), | ||||
| 		Options: &Options{ | ||||
| 			Seed:             int(request.Seed), | ||||
| 			Temperature:      request.Temperature, | ||||
| 			TopP:             request.TopP, | ||||
| 			FrequencyPenalty: request.FrequencyPenalty, | ||||
| 			PresencePenalty:  request.PresencePenalty, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -201,15 +212,17 @@ func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.Embeddi | ||||
| 	openAIEmbeddingResponse := openai.EmbeddingResponse{ | ||||
| 		Object: "list", | ||||
| 		Data:   make([]openai.EmbeddingResponseItem, 0, 1), | ||||
| 		Model:  "text-embedding-v1", | ||||
| 		Model:  response.Model, | ||||
| 		Usage:  model.Usage{TotalTokens: 0}, | ||||
| 	} | ||||
|  | ||||
| 	openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ | ||||
| 		Object:    `embedding`, | ||||
| 		Index:     0, | ||||
| 		Embedding: response.Embedding, | ||||
| 	}) | ||||
| 	for i, embedding := range response.Embeddings { | ||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ | ||||
| 			Object:    `embedding`, | ||||
| 			Index:     i, | ||||
| 			Embedding: embedding, | ||||
| 		}) | ||||
| 	} | ||||
| 	return &openAIEmbeddingResponse | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -1,12 +1,14 @@ | ||||
| package ollama | ||||
|  | ||||
| type Options struct { | ||||
| 	Seed             int     `json:"seed,omitempty"` | ||||
| 	Temperature      float64 `json:"temperature,omitempty"` | ||||
| 	TopK             int     `json:"top_k,omitempty"` | ||||
| 	TopP             float64 `json:"top_p,omitempty"` | ||||
| 	FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` | ||||
| 	PresencePenalty  float64 `json:"presence_penalty,omitempty"` | ||||
| 	Seed             int      `json:"seed,omitempty"` | ||||
| 	Temperature      *float64 `json:"temperature,omitempty"` | ||||
| 	TopK             int      `json:"top_k,omitempty"` | ||||
| 	TopP             *float64 `json:"top_p,omitempty"` | ||||
| 	FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` | ||||
| 	PresencePenalty  *float64 `json:"presence_penalty,omitempty"` | ||||
| 	NumPredict       int      `json:"num_predict,omitempty"` | ||||
| 	NumCtx           int      `json:"num_ctx,omitempty"` | ||||
| } | ||||
|  | ||||
| type Message struct { | ||||
| @@ -37,11 +39,15 @@ type ChatResponse struct { | ||||
| } | ||||
|  | ||||
| type EmbeddingRequest struct { | ||||
| 	Model  string `json:"model"` | ||||
| 	Prompt string `json:"prompt"` | ||||
| 	Model string   `json:"model"` | ||||
| 	Input []string `json:"input"` | ||||
| 	// Truncate  bool     `json:"truncate,omitempty"` | ||||
| 	Options *Options `json:"options,omitempty"` | ||||
| 	// KeepAlive string   `json:"keep_alive,omitempty"` | ||||
| } | ||||
|  | ||||
| type EmbeddingResponse struct { | ||||
| 	Error     string    `json:"error,omitempty"` | ||||
| 	Embedding []float64 `json:"embedding,omitempty"` | ||||
| 	Error      string      `json:"error,omitempty"` | ||||
| 	Model      string      `json:"model"` | ||||
| 	Embeddings [][]float64 `json:"embeddings"` | ||||
| } | ||||
|   | ||||
| @@ -75,6 +75,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	if request.Stream { | ||||
| 		// always return usage in stream mode | ||||
| 		if request.StreamOptions == nil { | ||||
| 			request.StreamOptions = &model.StreamOptions{} | ||||
| 		} | ||||
| 		request.StreamOptions.IncludeUsage = true | ||||
| 	} | ||||
| 	return request, nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -11,8 +11,10 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/mistral" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/moonshot" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/novita" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/siliconflow" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/stepfun" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/togetherai" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/xai" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| ) | ||||
|  | ||||
| @@ -30,6 +32,8 @@ var CompatibleChannels = []int{ | ||||
| 	channeltype.DeepSeek, | ||||
| 	channeltype.TogetherAI, | ||||
| 	channeltype.Novita, | ||||
| 	channeltype.SiliconFlow, | ||||
| 	channeltype.XAI, | ||||
| } | ||||
|  | ||||
| func GetCompatibleChannelMeta(channelType int) (string, []string) { | ||||
| @@ -60,6 +64,10 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) { | ||||
| 		return "doubao", doubao.ModelList | ||||
| 	case channeltype.Novita: | ||||
| 		return "novita", novita.ModelList | ||||
| 	case channeltype.SiliconFlow: | ||||
| 		return "siliconflow", siliconflow.ModelList | ||||
| 	case channeltype.XAI: | ||||
| 		return "xai", xai.ModelList | ||||
| 	default: | ||||
| 		return "openai", ModelList | ||||
| 	} | ||||
|   | ||||
| @@ -8,6 +8,9 @@ var ModelList = []string{ | ||||
| 	"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", | ||||
| 	"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", | ||||
| 	"gpt-4o", "gpt-4o-2024-05-13", | ||||
| 	"gpt-4o-2024-08-06", | ||||
| 	"gpt-4o-2024-11-20", | ||||
| 	"chatgpt-4o-latest", | ||||
| 	"gpt-4o-mini", "gpt-4o-mini-2024-07-18", | ||||
| 	"gpt-4-vision-preview", | ||||
| 	"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", | ||||
| @@ -18,4 +21,7 @@ var ModelList = []string{ | ||||
| 	"dall-e-2", "dall-e-3", | ||||
| 	"whisper-1", | ||||
| 	"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106", | ||||
| 	"o1", "o1-2024-12-17", | ||||
| 	"o1-preview", "o1-preview-2024-09-12", | ||||
| 	"o1-mini", "o1-mini-2024-09-12", | ||||
| } | ||||
|   | ||||
| @@ -2,15 +2,16 @@ package openai | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage { | ||||
| func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage { | ||||
| 	usage := &model.Usage{} | ||||
| 	usage.PromptTokens = promptTokens | ||||
| 	usage.CompletionTokens = CountTokenText(responseText, modeName) | ||||
| 	usage.CompletionTokens = CountTokenText(responseText, modelName) | ||||
| 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens | ||||
| 	return usage | ||||
| } | ||||
|   | ||||
| @@ -55,8 +55,8 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | ||||
| 				render.StringData(c, data) // if error happened, pass the data to client | ||||
| 				continue                   // just ignore the error | ||||
| 			} | ||||
| 			if len(streamResponse.Choices) == 0 { | ||||
| 				// but for empty choice, we should not pass it to client, this is for azure | ||||
| 			if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil { | ||||
| 				// but for empty choice and no usage, we should not pass it to client, this is for azure | ||||
| 				continue // just ignore empty choice | ||||
| 			} | ||||
| 			render.StringData(c, data) | ||||
|   | ||||
| @@ -1,8 +1,16 @@ | ||||
| package openai | ||||
|  | ||||
| import "github.com/songquanpeng/one-api/relay/model" | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode { | ||||
| 	logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err)) | ||||
|  | ||||
| 	Error := model.Error{ | ||||
| 		Message: err.Error(), | ||||
| 		Type:    "one_api_error", | ||||
|   | ||||
| @@ -19,11 +19,11 @@ type Prompt struct { | ||||
| } | ||||
|  | ||||
| type ChatRequest struct { | ||||
| 	Prompt         Prompt  `json:"prompt"` | ||||
| 	Temperature    float64 `json:"temperature,omitempty"` | ||||
| 	CandidateCount int     `json:"candidateCount,omitempty"` | ||||
| 	TopP           float64 `json:"topP,omitempty"` | ||||
| 	TopK           int     `json:"topK,omitempty"` | ||||
| 	Prompt         Prompt   `json:"prompt"` | ||||
| 	Temperature    *float64 `json:"temperature,omitempty"` | ||||
| 	CandidateCount int      `json:"candidateCount,omitempty"` | ||||
| 	TopP           *float64 `json:"topP,omitempty"` | ||||
| 	TopK           int      `json:"topK,omitempty"` | ||||
| } | ||||
|  | ||||
| type Error struct { | ||||
|   | ||||
							
								
								
									
										136
									
								
								relay/adaptor/replicate/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								relay/adaptor/replicate/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,136 @@ | ||||
| package replicate | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"slices" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| 	meta *meta.Meta | ||||
| } | ||||
|  | ||||
| // ConvertImageRequest implements adaptor.Adaptor. | ||||
| func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| 	return DrawImageRequest{ | ||||
| 		Input: ImageInput{ | ||||
| 			Steps:           25, | ||||
| 			Prompt:          request.Prompt, | ||||
| 			Guidance:        3, | ||||
| 			Seed:            int(time.Now().UnixNano()), | ||||
| 			SafetyTolerance: 5, | ||||
| 			NImages:         1, // replicate will always return 1 image | ||||
| 			Width:           1440, | ||||
| 			Height:          1440, | ||||
| 			AspectRatio:     "1:1", | ||||
| 		}, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	if !request.Stream { | ||||
| 		// TODO: support non-stream mode | ||||
| 		return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true") | ||||
| 	} | ||||
|  | ||||
| 	// Build the prompt from OpenAI messages | ||||
| 	var promptBuilder strings.Builder | ||||
| 	for _, message := range request.Messages { | ||||
| 		switch msgCnt := message.Content.(type) { | ||||
| 		case string: | ||||
| 			promptBuilder.WriteString(message.Role) | ||||
| 			promptBuilder.WriteString(": ") | ||||
| 			promptBuilder.WriteString(msgCnt) | ||||
| 			promptBuilder.WriteString("\n") | ||||
| 		default: | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	replicateRequest := ReplicateChatRequest{ | ||||
| 		Input: ChatInput{ | ||||
| 			Prompt:           promptBuilder.String(), | ||||
| 			MaxTokens:        request.MaxTokens, | ||||
| 			Temperature:      1.0, | ||||
| 			TopP:             1.0, | ||||
| 			PresencePenalty:  0.0, | ||||
| 			FrequencyPenalty: 0.0, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	// Map optional fields | ||||
| 	if request.Temperature != nil { | ||||
| 		replicateRequest.Input.Temperature = *request.Temperature | ||||
| 	} | ||||
| 	if request.TopP != nil { | ||||
| 		replicateRequest.Input.TopP = *request.TopP | ||||
| 	} | ||||
| 	if request.PresencePenalty != nil { | ||||
| 		replicateRequest.Input.PresencePenalty = *request.PresencePenalty | ||||
| 	} | ||||
| 	if request.FrequencyPenalty != nil { | ||||
| 		replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty | ||||
| 	} | ||||
| 	if request.MaxTokens > 0 { | ||||
| 		replicateRequest.Input.MaxTokens = request.MaxTokens | ||||
| 	} else if request.MaxTokens == 0 { | ||||
| 		replicateRequest.Input.MaxTokens = 500 | ||||
| 	} | ||||
|  | ||||
| 	return replicateRequest, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| 	a.meta = meta | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	if !slices.Contains(ModelList, meta.OriginModelName) { | ||||
| 		return "", errors.Errorf("model %s not supported", meta.OriginModelName) | ||||
| 	} | ||||
|  | ||||
| 	return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| 	adaptor.SetupCommonRequestHeader(c, req, meta) | ||||
| 	req.Header.Set("Authorization", "Bearer "+meta.APIKey) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	logger.Info(c, "send request to replicate") | ||||
| 	return adaptor.DoRequestHelper(a, c, meta, requestBody) | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	switch meta.Mode { | ||||
| 	case relaymode.ImagesGenerations: | ||||
| 		err, usage = ImageHandler(c, resp) | ||||
| 	case relaymode.ChatCompletions: | ||||
| 		err, usage = ChatHandler(c, resp) | ||||
| 	default: | ||||
| 		err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetModelList() []string { | ||||
| 	return ModelList | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetChannelName() string { | ||||
| 	return "replicate" | ||||
| } | ||||
							
								
								
									
										191
									
								
								relay/adaptor/replicate/chat.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								relay/adaptor/replicate/chat.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,191 @@ | ||||
| package replicate | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| func ChatHandler(c *gin.Context, resp *http.Response) ( | ||||
| 	srvErr *model.ErrorWithStatusCode, usage *model.Usage) { | ||||
| 	if resp.StatusCode != http.StatusCreated { | ||||
| 		payload, _ := io.ReadAll(resp.Body) | ||||
| 		return openai.ErrorWrapper( | ||||
| 				errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)), | ||||
| 				"bad_status_code", http.StatusInternalServerError), | ||||
| 			nil | ||||
| 	} | ||||
|  | ||||
| 	respBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	respData := new(ChatResponse) | ||||
| 	if err = json.Unmarshal(respBody, respData); err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	for { | ||||
| 		err = func() error { | ||||
| 			// get task | ||||
| 			taskReq, err := http.NewRequestWithContext(c.Request.Context(), | ||||
| 				http.MethodGet, respData.URLs.Get, nil) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "new request") | ||||
| 			} | ||||
|  | ||||
| 			taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) | ||||
| 			taskResp, err := http.DefaultClient.Do(taskReq) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "get task") | ||||
| 			} | ||||
| 			defer taskResp.Body.Close() | ||||
|  | ||||
| 			if taskResp.StatusCode != http.StatusOK { | ||||
| 				payload, _ := io.ReadAll(taskResp.Body) | ||||
| 				return errors.Errorf("bad status code [%d]%s", | ||||
| 					taskResp.StatusCode, string(payload)) | ||||
| 			} | ||||
|  | ||||
| 			taskBody, err := io.ReadAll(taskResp.Body) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "read task response") | ||||
| 			} | ||||
|  | ||||
| 			taskData := new(ChatResponse) | ||||
| 			if err = json.Unmarshal(taskBody, taskData); err != nil { | ||||
| 				return errors.Wrap(err, "decode task response") | ||||
| 			} | ||||
|  | ||||
| 			switch taskData.Status { | ||||
| 			case "succeeded": | ||||
| 			case "failed", "canceled": | ||||
| 				return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error) | ||||
| 			default: | ||||
| 				time.Sleep(time.Second * 3) | ||||
| 				return errNextLoop | ||||
| 			} | ||||
|  | ||||
| 			if taskData.URLs.Stream == "" { | ||||
| 				return errors.New("stream url is empty") | ||||
| 			} | ||||
|  | ||||
| 			// request stream url | ||||
| 			responseText, err := chatStreamHandler(c, taskData.URLs.Stream) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "chat stream handler") | ||||
| 			} | ||||
|  | ||||
| 			ctxMeta := meta.GetByContext(c) | ||||
| 			usage = openai.ResponseText2Usage(responseText, | ||||
| 				ctxMeta.ActualModelName, ctxMeta.PromptTokens) | ||||
| 			return nil | ||||
| 		}() | ||||
| 		if err != nil { | ||||
| 			if errors.Is(err, errNextLoop) { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
|  | ||||
| 		break | ||||
| 	} | ||||
|  | ||||
| 	return nil, usage | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	eventPrefix = "event: " | ||||
| 	dataPrefix  = "data: " | ||||
| 	done        = "[DONE]" | ||||
| ) | ||||
|  | ||||
| func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) { | ||||
| 	// request stream endpoint | ||||
| 	streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil) | ||||
| 	if err != nil { | ||||
| 		return "", errors.Wrap(err, "new request to stream") | ||||
| 	} | ||||
|  | ||||
| 	streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) | ||||
| 	streamReq.Header.Set("Accept", "text/event-stream") | ||||
| 	streamReq.Header.Set("Cache-Control", "no-store") | ||||
|  | ||||
| 	resp, err := http.DefaultClient.Do(streamReq) | ||||
| 	if err != nil { | ||||
| 		return "", errors.Wrap(err, "do request to stream") | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 		payload, _ := io.ReadAll(resp.Body) | ||||
| 		return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload)) | ||||
| 	} | ||||
|  | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(bufio.ScanLines) | ||||
|  | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	doneRendered := false | ||||
| 	for scanner.Scan() { | ||||
| 		line := strings.TrimSpace(scanner.Text()) | ||||
| 		if line == "" { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// Handle comments starting with ':' | ||||
| 		if strings.HasPrefix(line, ":") { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// Parse SSE fields | ||||
| 		if strings.HasPrefix(line, eventPrefix) { | ||||
| 			event := strings.TrimSpace(line[len(eventPrefix):]) | ||||
| 			var data string | ||||
| 			// Read the following lines to get data and id | ||||
| 			for scanner.Scan() { | ||||
| 				nextLine := scanner.Text() | ||||
| 				if nextLine == "" { | ||||
| 					break | ||||
| 				} | ||||
| 				if strings.HasPrefix(nextLine, dataPrefix) { | ||||
| 					data = nextLine[len(dataPrefix):] | ||||
| 				} else if strings.HasPrefix(nextLine, "id:") { | ||||
| 					// id = strings.TrimSpace(nextLine[len("id:"):]) | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			if event == "output" { | ||||
| 				render.StringData(c, data) | ||||
| 				responseText += data | ||||
| 			} else if event == "done" { | ||||
| 				render.Done(c) | ||||
| 				doneRendered = true | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := scanner.Err(); err != nil { | ||||
| 		return "", errors.Wrap(err, "scan stream") | ||||
| 	} | ||||
|  | ||||
| 	if !doneRendered { | ||||
| 		render.Done(c) | ||||
| 	} | ||||
|  | ||||
| 	return responseText, nil | ||||
| } | ||||
							
								
								
									
										58
									
								
								relay/adaptor/replicate/constant.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								relay/adaptor/replicate/constant.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | ||||
| package replicate | ||||
|  | ||||
| // ModelList is a list of models that can be used with Replicate. | ||||
| // | ||||
| // https://replicate.com/pricing | ||||
| var ModelList = []string{ | ||||
| 	// ------------------------------------- | ||||
| 	// image model | ||||
| 	// ------------------------------------- | ||||
| 	"black-forest-labs/flux-1.1-pro", | ||||
| 	"black-forest-labs/flux-1.1-pro-ultra", | ||||
| 	"black-forest-labs/flux-canny-dev", | ||||
| 	"black-forest-labs/flux-canny-pro", | ||||
| 	"black-forest-labs/flux-depth-dev", | ||||
| 	"black-forest-labs/flux-depth-pro", | ||||
| 	"black-forest-labs/flux-dev", | ||||
| 	"black-forest-labs/flux-dev-lora", | ||||
| 	"black-forest-labs/flux-fill-dev", | ||||
| 	"black-forest-labs/flux-fill-pro", | ||||
| 	"black-forest-labs/flux-pro", | ||||
| 	"black-forest-labs/flux-redux-dev", | ||||
| 	"black-forest-labs/flux-redux-schnell", | ||||
| 	"black-forest-labs/flux-schnell", | ||||
| 	"black-forest-labs/flux-schnell-lora", | ||||
| 	"ideogram-ai/ideogram-v2", | ||||
| 	"ideogram-ai/ideogram-v2-turbo", | ||||
| 	"recraft-ai/recraft-v3", | ||||
| 	"recraft-ai/recraft-v3-svg", | ||||
| 	"stability-ai/stable-diffusion-3", | ||||
| 	"stability-ai/stable-diffusion-3.5-large", | ||||
| 	"stability-ai/stable-diffusion-3.5-large-turbo", | ||||
| 	"stability-ai/stable-diffusion-3.5-medium", | ||||
| 	// ------------------------------------- | ||||
| 	// language model | ||||
| 	// ------------------------------------- | ||||
| 	"ibm-granite/granite-20b-code-instruct-8k", | ||||
| 	"ibm-granite/granite-3.0-2b-instruct", | ||||
| 	"ibm-granite/granite-3.0-8b-instruct", | ||||
| 	"ibm-granite/granite-8b-code-instruct-128k", | ||||
| 	"meta/llama-2-13b", | ||||
| 	"meta/llama-2-13b-chat", | ||||
| 	"meta/llama-2-70b", | ||||
| 	"meta/llama-2-70b-chat", | ||||
| 	"meta/llama-2-7b", | ||||
| 	"meta/llama-2-7b-chat", | ||||
| 	"meta/meta-llama-3.1-405b-instruct", | ||||
| 	"meta/meta-llama-3-70b", | ||||
| 	"meta/meta-llama-3-70b-instruct", | ||||
| 	"meta/meta-llama-3-8b", | ||||
| 	"meta/meta-llama-3-8b-instruct", | ||||
| 	"mistralai/mistral-7b-instruct-v0.2", | ||||
| 	"mistralai/mistral-7b-v0.1", | ||||
| 	"mistralai/mixtral-8x7b-instruct-v0.1", | ||||
| 	// ------------------------------------- | ||||
| 	// video model | ||||
| 	// ------------------------------------- | ||||
| 	// "minimax/video-01",  // TODO: implement the adaptor | ||||
| } | ||||
							
								
								
									
										222
									
								
								relay/adaptor/replicate/image.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										222
									
								
								relay/adaptor/replicate/image.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,222 @@ | ||||
| package replicate | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"image" | ||||
| 	"image/png" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"golang.org/x/image/webp" | ||||
| 	"golang.org/x/sync/errgroup" | ||||
| ) | ||||
|  | ||||
| // ImagesEditsHandler just copy response body to client | ||||
| // | ||||
| // https://replicate.com/black-forest-labs/flux-fill-pro | ||||
| // func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| // 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| // 	for k, v := range resp.Header { | ||||
| // 		c.Writer.Header().Set(k, v[0]) | ||||
| // 	} | ||||
|  | ||||
| // 	if _, err := io.Copy(c.Writer, resp.Body); err != nil { | ||||
| // 		return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil | ||||
| // 	} | ||||
| // 	defer resp.Body.Close() | ||||
|  | ||||
| // 	return nil, nil | ||||
| // } | ||||
|  | ||||
| var errNextLoop = errors.New("next_loop") | ||||
|  | ||||
| func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	if resp.StatusCode != http.StatusCreated { | ||||
| 		payload, _ := io.ReadAll(resp.Body) | ||||
| 		return openai.ErrorWrapper( | ||||
| 				errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)), | ||||
| 				"bad_status_code", http.StatusInternalServerError), | ||||
| 			nil | ||||
| 	} | ||||
|  | ||||
| 	respBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	respData := new(ImageResponse) | ||||
| 	if err = json.Unmarshal(respBody, respData); err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	for { | ||||
| 		err = func() error { | ||||
| 			// get task | ||||
| 			taskReq, err := http.NewRequestWithContext(c.Request.Context(), | ||||
| 				http.MethodGet, respData.URLs.Get, nil) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "new request") | ||||
| 			} | ||||
|  | ||||
| 			taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) | ||||
| 			taskResp, err := http.DefaultClient.Do(taskReq) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "get task") | ||||
| 			} | ||||
| 			defer taskResp.Body.Close() | ||||
|  | ||||
| 			if taskResp.StatusCode != http.StatusOK { | ||||
| 				payload, _ := io.ReadAll(taskResp.Body) | ||||
| 				return errors.Errorf("bad status code [%d]%s", | ||||
| 					taskResp.StatusCode, string(payload)) | ||||
| 			} | ||||
|  | ||||
| 			taskBody, err := io.ReadAll(taskResp.Body) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "read task response") | ||||
| 			} | ||||
|  | ||||
| 			taskData := new(ImageResponse) | ||||
| 			if err = json.Unmarshal(taskBody, taskData); err != nil { | ||||
| 				return errors.Wrap(err, "decode task response") | ||||
| 			} | ||||
|  | ||||
| 			switch taskData.Status { | ||||
| 			case "succeeded": | ||||
| 			case "failed", "canceled": | ||||
| 				return errors.Errorf("task failed: %s", taskData.Status) | ||||
| 			default: | ||||
| 				time.Sleep(time.Second * 3) | ||||
| 				return errNextLoop | ||||
| 			} | ||||
|  | ||||
| 			output, err := taskData.GetOutput() | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "get output") | ||||
| 			} | ||||
| 			if len(output) == 0 { | ||||
| 				return errors.New("response output is empty") | ||||
| 			} | ||||
|  | ||||
| 			var mu sync.Mutex | ||||
| 			var pool errgroup.Group | ||||
| 			respBody := &openai.ImageResponse{ | ||||
| 				Created: taskData.CompletedAt.Unix(), | ||||
| 				Data:    []openai.ImageData{}, | ||||
| 			} | ||||
|  | ||||
| 			for _, imgOut := range output { | ||||
| 				imgOut := imgOut | ||||
| 				pool.Go(func() error { | ||||
| 					// download image | ||||
| 					downloadReq, err := http.NewRequestWithContext(c.Request.Context(), | ||||
| 						http.MethodGet, imgOut, nil) | ||||
| 					if err != nil { | ||||
| 						return errors.Wrap(err, "new request") | ||||
| 					} | ||||
|  | ||||
| 					imgResp, err := http.DefaultClient.Do(downloadReq) | ||||
| 					if err != nil { | ||||
| 						return errors.Wrap(err, "download image") | ||||
| 					} | ||||
| 					defer imgResp.Body.Close() | ||||
|  | ||||
| 					if imgResp.StatusCode != http.StatusOK { | ||||
| 						payload, _ := io.ReadAll(imgResp.Body) | ||||
| 						return errors.Errorf("bad status code [%d]%s", | ||||
| 							imgResp.StatusCode, string(payload)) | ||||
| 					} | ||||
|  | ||||
| 					imgData, err := io.ReadAll(imgResp.Body) | ||||
| 					if err != nil { | ||||
| 						return errors.Wrap(err, "read image") | ||||
| 					} | ||||
|  | ||||
| 					imgData, err = ConvertImageToPNG(imgData) | ||||
| 					if err != nil { | ||||
| 						return errors.Wrap(err, "convert image") | ||||
| 					} | ||||
|  | ||||
| 					mu.Lock() | ||||
| 					respBody.Data = append(respBody.Data, openai.ImageData{ | ||||
| 						B64Json: fmt.Sprintf("data:image/png;base64,%s", | ||||
| 							base64.StdEncoding.EncodeToString(imgData)), | ||||
| 					}) | ||||
| 					mu.Unlock() | ||||
|  | ||||
| 					return nil | ||||
| 				}) | ||||
| 			} | ||||
|  | ||||
| 			if err := pool.Wait(); err != nil { | ||||
| 				if len(respBody.Data) == 0 { | ||||
| 					return errors.WithStack(err) | ||||
| 				} | ||||
|  | ||||
| 				logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err)) | ||||
| 			} | ||||
|  | ||||
| 			c.JSON(http.StatusOK, respBody) | ||||
| 			return nil | ||||
| 		}() | ||||
| 		if err != nil { | ||||
| 			if errors.Is(err, errNextLoop) { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
|  | ||||
| 		break | ||||
| 	} | ||||
|  | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| // ConvertImageToPNG converts a WebP image to PNG format | ||||
| func ConvertImageToPNG(webpData []byte) ([]byte, error) { | ||||
| 	// bypass if it's already a PNG image | ||||
| 	if bytes.HasPrefix(webpData, []byte("\x89PNG")) { | ||||
| 		return webpData, nil | ||||
| 	} | ||||
|  | ||||
| 	// check if is jpeg, convert to png | ||||
| 	if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) { | ||||
| 		img, _, err := image.Decode(bytes.NewReader(webpData)) | ||||
| 		if err != nil { | ||||
| 			return nil, errors.Wrap(err, "decode jpeg") | ||||
| 		} | ||||
|  | ||||
| 		var pngBuffer bytes.Buffer | ||||
| 		if err := png.Encode(&pngBuffer, img); err != nil { | ||||
| 			return nil, errors.Wrap(err, "encode png") | ||||
| 		} | ||||
|  | ||||
| 		return pngBuffer.Bytes(), nil | ||||
| 	} | ||||
|  | ||||
| 	// Decode the WebP image | ||||
| 	img, err := webp.Decode(bytes.NewReader(webpData)) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "decode webp") | ||||
| 	} | ||||
|  | ||||
| 	// Encode the image as PNG | ||||
| 	var pngBuffer bytes.Buffer | ||||
| 	if err := png.Encode(&pngBuffer, img); err != nil { | ||||
| 		return nil, errors.Wrap(err, "encode png") | ||||
| 	} | ||||
|  | ||||
| 	return pngBuffer.Bytes(), nil | ||||
| } | ||||
							
								
								
									
										159
									
								
								relay/adaptor/replicate/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								relay/adaptor/replicate/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,159 @@ | ||||
| package replicate | ||||
|  | ||||
| import ( | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/pkg/errors" | ||||
| ) | ||||
|  | ||||
| // DrawImageRequest draw image by fluxpro | ||||
| // | ||||
| // https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json | ||||
| type DrawImageRequest struct { | ||||
| 	Input ImageInput `json:"input"` | ||||
| } | ||||
|  | ||||
| // ImageInput is input of DrawImageByFluxProRequest | ||||
| // | ||||
| // https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema | ||||
| type ImageInput struct { | ||||
| 	Steps           int    `json:"steps" binding:"required,min=1"` | ||||
| 	Prompt          string `json:"prompt" binding:"required,min=5"` | ||||
| 	ImagePrompt     string `json:"image_prompt"` | ||||
| 	Guidance        int    `json:"guidance" binding:"required,min=2,max=5"` | ||||
| 	Interval        int    `json:"interval" binding:"required,min=1,max=4"` | ||||
| 	AspectRatio     string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"` | ||||
| 	SafetyTolerance int    `json:"safety_tolerance" binding:"required,min=1,max=5"` | ||||
| 	Seed            int    `json:"seed"` | ||||
| 	NImages         int    `json:"n_images" binding:"required,min=1,max=8"` | ||||
| 	Width           int    `json:"width" binding:"required,min=256,max=1440"` | ||||
| 	Height          int    `json:"height" binding:"required,min=256,max=1440"` | ||||
| } | ||||
|  | ||||
| // InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro | ||||
| // | ||||
| // https://replicate.com/black-forest-labs/flux-fill-pro/api/schema | ||||
| type InpaintingImageByFlusReplicateRequest struct { | ||||
| 	Input FluxInpaintingInput `json:"input"` | ||||
| } | ||||
|  | ||||
| // FluxInpaintingInput is input of DrawImageByFluxProRequest | ||||
| // | ||||
| // https://replicate.com/black-forest-labs/flux-fill-pro/api/schema | ||||
| type FluxInpaintingInput struct { | ||||
| 	Mask             string `json:"mask" binding:"required"` | ||||
| 	Image            string `json:"image" binding:"required"` | ||||
| 	Seed             int    `json:"seed"` | ||||
| 	Steps            int    `json:"steps" binding:"required,min=1"` | ||||
| 	Prompt           string `json:"prompt" binding:"required,min=5"` | ||||
| 	Guidance         int    `json:"guidance" binding:"required,min=2,max=5"` | ||||
| 	OutputFormat     string `json:"output_format"` | ||||
| 	SafetyTolerance  int    `json:"safety_tolerance" binding:"required,min=1,max=5"` | ||||
| 	PromptUnsampling bool   `json:"prompt_unsampling"` | ||||
| } | ||||
|  | ||||
| // ImageResponse is response of DrawImageByFluxProRequest | ||||
| // | ||||
| // https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json | ||||
| type ImageResponse struct { | ||||
| 	CompletedAt time.Time        `json:"completed_at"` | ||||
| 	CreatedAt   time.Time        `json:"created_at"` | ||||
| 	DataRemoved bool             `json:"data_removed"` | ||||
| 	Error       string           `json:"error"` | ||||
| 	ID          string           `json:"id"` | ||||
| 	Input       DrawImageRequest `json:"input"` | ||||
| 	Logs        string           `json:"logs"` | ||||
| 	Metrics     FluxMetrics      `json:"metrics"` | ||||
| 	// Output could be `string` or `[]string` | ||||
| 	Output    any       `json:"output"` | ||||
| 	StartedAt time.Time `json:"started_at"` | ||||
| 	Status    string    `json:"status"` | ||||
| 	URLs      FluxURLs  `json:"urls"` | ||||
| 	Version   string    `json:"version"` | ||||
| } | ||||
|  | ||||
| func (r *ImageResponse) GetOutput() ([]string, error) { | ||||
| 	switch v := r.Output.(type) { | ||||
| 	case string: | ||||
| 		return []string{v}, nil | ||||
| 	case []string: | ||||
| 		return v, nil | ||||
| 	case nil: | ||||
| 		return nil, nil | ||||
| 	case []interface{}: | ||||
| 		// convert []interface{} to []string | ||||
| 		ret := make([]string, len(v)) | ||||
| 		for idx, vv := range v { | ||||
| 			if vvv, ok := vv.(string); ok { | ||||
| 				ret[idx] = vvv | ||||
| 			} else { | ||||
| 				return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		return ret, nil | ||||
| 	default: | ||||
| 		return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // FluxMetrics is metrics of ImageResponse | ||||
| type FluxMetrics struct { | ||||
| 	ImageCount  int     `json:"image_count"` | ||||
| 	PredictTime float64 `json:"predict_time"` | ||||
| 	TotalTime   float64 `json:"total_time"` | ||||
| } | ||||
|  | ||||
| // FluxURLs is urls of ImageResponse | ||||
| type FluxURLs struct { | ||||
| 	Get    string `json:"get"` | ||||
| 	Cancel string `json:"cancel"` | ||||
| } | ||||
|  | ||||
| type ReplicateChatRequest struct { | ||||
| 	Input ChatInput `json:"input" form:"input" binding:"required"` | ||||
| } | ||||
|  | ||||
| // ChatInput is input of ChatByReplicateRequest | ||||
| // | ||||
| // https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema | ||||
| type ChatInput struct { | ||||
| 	TopK             int     `json:"top_k"` | ||||
| 	TopP             float64 `json:"top_p"` | ||||
| 	Prompt           string  `json:"prompt"` | ||||
| 	MaxTokens        int     `json:"max_tokens"` | ||||
| 	MinTokens        int     `json:"min_tokens"` | ||||
| 	Temperature      float64 `json:"temperature"` | ||||
| 	SystemPrompt     string  `json:"system_prompt"` | ||||
| 	StopSequences    string  `json:"stop_sequences"` | ||||
| 	PromptTemplate   string  `json:"prompt_template"` | ||||
| 	PresencePenalty  float64 `json:"presence_penalty"` | ||||
| 	FrequencyPenalty float64 `json:"frequency_penalty"` | ||||
| } | ||||
|  | ||||
| // ChatResponse is response of ChatByReplicateRequest | ||||
| // | ||||
| // https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json | ||||
| type ChatResponse struct { | ||||
| 	CompletedAt time.Time   `json:"completed_at"` | ||||
| 	CreatedAt   time.Time   `json:"created_at"` | ||||
| 	DataRemoved bool        `json:"data_removed"` | ||||
| 	Error       string      `json:"error"` | ||||
| 	ID          string      `json:"id"` | ||||
| 	Input       ChatInput   `json:"input"` | ||||
| 	Logs        string      `json:"logs"` | ||||
| 	Metrics     FluxMetrics `json:"metrics"` | ||||
| 	// Output could be `string` or `[]string` | ||||
| 	Output    []string        `json:"output"` | ||||
| 	StartedAt time.Time       `json:"started_at"` | ||||
| 	Status    string          `json:"status"` | ||||
| 	URLs      ChatResponseUrl `json:"urls"` | ||||
| 	Version   string          `json:"version"` | ||||
| } | ||||
|  | ||||
| // ChatResponseUrl is task urls of ChatResponse | ||||
| type ChatResponseUrl struct { | ||||
| 	Stream string `json:"stream"` | ||||
| 	Get    string `json:"get"` | ||||
| 	Cancel string `json:"cancel"` | ||||
| } | ||||
							
								
								
									
										36
									
								
								relay/adaptor/siliconflow/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								relay/adaptor/siliconflow/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,36 @@ | ||||
| package siliconflow | ||||
|  | ||||
| // https://docs.siliconflow.cn/docs/getting-started | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"deepseek-ai/deepseek-llm-67b-chat", | ||||
| 	"Qwen/Qwen1.5-14B-Chat", | ||||
| 	"Qwen/Qwen1.5-7B-Chat", | ||||
| 	"Qwen/Qwen1.5-110B-Chat", | ||||
| 	"Qwen/Qwen1.5-32B-Chat", | ||||
| 	"01-ai/Yi-1.5-6B-Chat", | ||||
| 	"01-ai/Yi-1.5-9B-Chat-16K", | ||||
| 	"01-ai/Yi-1.5-34B-Chat-16K", | ||||
| 	"THUDM/chatglm3-6b", | ||||
| 	"deepseek-ai/DeepSeek-V2-Chat", | ||||
| 	"THUDM/glm-4-9b-chat", | ||||
| 	"Qwen/Qwen2-72B-Instruct", | ||||
| 	"Qwen/Qwen2-7B-Instruct", | ||||
| 	"Qwen/Qwen2-57B-A14B-Instruct", | ||||
| 	"deepseek-ai/DeepSeek-Coder-V2-Instruct", | ||||
| 	"Qwen/Qwen2-1.5B-Instruct", | ||||
| 	"internlm/internlm2_5-7b-chat", | ||||
| 	"BAAI/bge-large-en-v1.5", | ||||
| 	"BAAI/bge-large-zh-v1.5", | ||||
| 	"Pro/Qwen/Qwen2-7B-Instruct", | ||||
| 	"Pro/Qwen/Qwen2-1.5B-Instruct", | ||||
| 	"Pro/Qwen/Qwen1.5-7B-Chat", | ||||
| 	"Pro/THUDM/glm-4-9b-chat", | ||||
| 	"Pro/THUDM/chatglm3-6b", | ||||
| 	"Pro/01-ai/Yi-1.5-9B-Chat-16K", | ||||
| 	"Pro/01-ai/Yi-1.5-6B-Chat", | ||||
| 	"Pro/google/gemma-2-9b-it", | ||||
| 	"Pro/internlm/internlm2_5-7b-chat", | ||||
| 	"Pro/meta-llama/Meta-Llama-3-8B-Instruct", | ||||
| 	"Pro/mistralai/Mistral-7B-Instruct-v0.2", | ||||
| } | ||||
| @@ -1,7 +1,13 @@ | ||||
| package stepfun | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"step-1-8k", | ||||
| 	"step-1-32k", | ||||
| 	"step-1-128k", | ||||
| 	"step-1-256k", | ||||
| 	"step-1-flash", | ||||
| 	"step-2-16k", | ||||
| 	"step-1v-8k", | ||||
| 	"step-1v-32k", | ||||
| 	"step-1-200k", | ||||
| 	"step-1x-medium", | ||||
| } | ||||
|   | ||||
| @@ -2,16 +2,19 @@ package tencent | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| ) | ||||
|  | ||||
| // https://cloud.tencent.com/document/api/1729/101837 | ||||
| @@ -52,10 +55,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	tencentRequest := ConvertRequest(*request) | ||||
| 	var convertedRequest any | ||||
| 	switch relayMode { | ||||
| 	case relaymode.Embeddings: | ||||
| 		a.Action = "GetEmbedding" | ||||
| 		convertedRequest = ConvertEmbeddingRequest(*request) | ||||
| 	default: | ||||
| 		a.Action = "ChatCompletions" | ||||
| 		convertedRequest = ConvertRequest(*request) | ||||
| 	} | ||||
| 	// we have to calculate the sign here | ||||
| 	a.Sign = GetSign(*tencentRequest, a, secretId, secretKey) | ||||
| 	return tencentRequest, nil | ||||
| 	a.Sign = GetSign(convertedRequest, a, secretId, secretKey) | ||||
| 	return convertedRequest, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| @@ -75,7 +86,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met | ||||
| 		err, responseText = StreamHandler(c, resp) | ||||
| 		usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, resp) | ||||
| 		switch meta.Mode { | ||||
| 		case relaymode.Embeddings: | ||||
| 			err, usage = EmbeddingHandler(c, resp) | ||||
| 		default: | ||||
| 			err, usage = Handler(c, resp) | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|   | ||||
| @@ -5,4 +5,6 @@ var ModelList = []string{ | ||||
| 	"hunyuan-standard", | ||||
| 	"hunyuan-standard-256K", | ||||
| 	"hunyuan-pro", | ||||
| 	"hunyuan-vision", | ||||
| 	"hunyuan-embedding", | ||||
| } | ||||
|   | ||||
| @@ -8,7 +8,6 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| @@ -16,11 +15,14 @@ import ( | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/conv" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| @@ -39,13 +41,73 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 		Model:       &request.Model, | ||||
| 		Stream:      &request.Stream, | ||||
| 		Messages:    messages, | ||||
| 		TopP:        &request.TopP, | ||||
| 		Temperature: &request.Temperature, | ||||
| 		TopP:        request.TopP, | ||||
| 		Temperature: request.Temperature, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { | ||||
| 	return &EmbeddingRequest{ | ||||
| 		InputList: request.ParseInput(), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	var tencentResponseP EmbeddingResponseP | ||||
| 	err := json.NewDecoder(resp.Body).Decode(&tencentResponseP) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	tencentResponse := tencentResponseP.Response | ||||
| 	if tencentResponse.Error.Code != "" { | ||||
| 		return &model.ErrorWithStatusCode{ | ||||
| 			Error: model.Error{ | ||||
| 				Message: tencentResponse.Error.Message, | ||||
| 				Code:    tencentResponse.Error.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	requestModel := c.GetString(ctxkey.RequestModel) | ||||
| 	fullTextResponse := embeddingResponseTencent2OpenAI(&tencentResponse) | ||||
| 	fullTextResponse.Model = requestModel | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
|  | ||||
| func embeddingResponseTencent2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { | ||||
| 	openAIEmbeddingResponse := openai.EmbeddingResponse{ | ||||
| 		Object: "list", | ||||
| 		Data:   make([]openai.EmbeddingResponseItem, 0, len(response.Data)), | ||||
| 		Model:  "hunyuan-embedding", | ||||
| 		Usage:  model.Usage{TotalTokens: response.EmbeddingUsage.TotalTokens}, | ||||
| 	} | ||||
|  | ||||
| 	for _, item := range response.Data { | ||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ | ||||
| 			Object:    item.Object, | ||||
| 			Index:     item.Index, | ||||
| 			Embedding: item.Embedding, | ||||
| 		}) | ||||
| 	} | ||||
| 	return &openAIEmbeddingResponse | ||||
| } | ||||
|  | ||||
| func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      response.ReqID, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Usage: model.Usage{ | ||||
| @@ -148,7 +210,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	TencentResponse = responseP.Response | ||||
| 	if TencentResponse.Error.Code != 0 { | ||||
| 	if TencentResponse.Error.Code != "" { | ||||
| 		return &model.ErrorWithStatusCode{ | ||||
| 			Error: model.Error{ | ||||
| 				Message: TencentResponse.Error.Message, | ||||
| @@ -195,7 +257,7 @@ func hmacSha256(s, key string) string { | ||||
| 	return string(hashed.Sum(nil)) | ||||
| } | ||||
|  | ||||
| func GetSign(req ChatRequest, adaptor *Adaptor, secId, secKey string) string { | ||||
| func GetSign(req any, adaptor *Adaptor, secId, secKey string) string { | ||||
| 	// build canonical request string | ||||
| 	host := "hunyuan.tencentcloudapi.com" | ||||
| 	httpRequestMethod := "POST" | ||||
|   | ||||
| @@ -35,16 +35,16 @@ type ChatRequest struct { | ||||
| 	// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。 | ||||
| 	// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。 | ||||
| 	// 3. 非必要不建议使用,不合理的取值会影响效果。 | ||||
| 	TopP *float64 `json:"TopP"` | ||||
| 	TopP *float64 `json:"TopP,omitempty"` | ||||
| 	// 说明: | ||||
| 	// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。 | ||||
| 	// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。 | ||||
| 	// 3. 非必要不建议使用,不合理的取值会影响效果。 | ||||
| 	Temperature *float64 `json:"Temperature"` | ||||
| 	Temperature *float64 `json:"Temperature,omitempty"` | ||||
| } | ||||
|  | ||||
| type Error struct { | ||||
| 	Code    int    `json:"Code"` | ||||
| 	Code    string `json:"Code"` | ||||
| 	Message string `json:"Message"` | ||||
| } | ||||
|  | ||||
| @@ -61,15 +61,41 @@ type ResponseChoices struct { | ||||
| } | ||||
|  | ||||
| type ChatResponse struct { | ||||
| 	Choices []ResponseChoices `json:"Choices,omitempty"` // 结果 | ||||
| 	Created int64             `json:"Created,omitempty"` // unix 时间戳的字符串 | ||||
| 	Id      string            `json:"Id,omitempty"`      // 会话 id | ||||
| 	Usage   Usage             `json:"Usage,omitempty"`   // token 数量 | ||||
| 	Error   Error             `json:"Error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值 | ||||
| 	Note    string            `json:"Note,omitempty"`    // 注释 | ||||
| 	ReqID   string            `json:"Req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 | ||||
| 	Choices []ResponseChoices `json:"Choices,omitempty"`   // 结果 | ||||
| 	Created int64             `json:"Created,omitempty"`   // unix 时间戳的字符串 | ||||
| 	Id      string            `json:"Id,omitempty"`        // 会话 id | ||||
| 	Usage   Usage             `json:"Usage,omitempty"`     // token 数量 | ||||
| 	Error   Error             `json:"Error,omitempty"`     // 错误信息 注意:此字段可能返回 null,表示取不到有效值 | ||||
| 	Note    string            `json:"Note,omitempty"`      // 注释 | ||||
| 	ReqID   string            `json:"RequestId,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 | ||||
| } | ||||
|  | ||||
| type ChatResponseP struct { | ||||
| 	Response ChatResponse `json:"Response,omitempty"` | ||||
| } | ||||
|  | ||||
| type EmbeddingRequest struct { | ||||
| 	InputList []string `json:"InputList"` | ||||
| } | ||||
|  | ||||
| type EmbeddingData struct { | ||||
| 	Embedding []float64 `json:"Embedding"` | ||||
| 	Index     int       `json:"Index"` | ||||
| 	Object    string    `json:"Object"` | ||||
| } | ||||
|  | ||||
| type EmbeddingUsage struct { | ||||
| 	PromptTokens int `json:"PromptTokens"` | ||||
| 	TotalTokens  int `json:"TotalTokens"` | ||||
| } | ||||
|  | ||||
| type EmbeddingResponse struct { | ||||
| 	Data           []EmbeddingData `json:"Data"` | ||||
| 	EmbeddingUsage EmbeddingUsage  `json:"Usage,omitempty"` | ||||
| 	RequestId      string          `json:"RequestId,omitempty"` | ||||
| 	Error          Error           `json:"Error,omitempty"` | ||||
| } | ||||
|  | ||||
| type EmbeddingResponseP struct { | ||||
| 	Response EmbeddingResponse `json:"Response,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -13,7 +13,12 @@ import ( | ||||
| ) | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"claude-3-haiku@20240307", "claude-3-opus@20240229", "claude-3-5-sonnet@20240620", "claude-3-sonnet@20240229", | ||||
| 	"claude-3-haiku@20240307", | ||||
| 	"claude-3-sonnet@20240229", | ||||
| 	"claude-3-opus@20240229", | ||||
| 	"claude-3-5-sonnet@20240620", | ||||
| 	"claude-3-5-sonnet-v2@20241022", | ||||
| 	"claude-3-5-haiku@20241022", | ||||
| } | ||||
|  | ||||
| const anthropicVersion = "vertex-2023-10-16" | ||||
|   | ||||
| @@ -11,8 +11,8 @@ type Request struct { | ||||
| 	MaxTokens     int                 `json:"max_tokens,omitempty"` | ||||
| 	StopSequences []string            `json:"stop_sequences,omitempty"` | ||||
| 	Stream        bool                `json:"stream,omitempty"` | ||||
| 	Temperature   float64             `json:"temperature,omitempty"` | ||||
| 	TopP          float64             `json:"top_p,omitempty"` | ||||
| 	Temperature   *float64            `json:"temperature,omitempty"` | ||||
| 	TopP          *float64            `json:"top_p,omitempty"` | ||||
| 	TopK          int                 `json:"top_k,omitempty"` | ||||
| 	Tools         []anthropic.Tool    `json:"tools,omitempty"` | ||||
| 	ToolChoice    any                 `json:"tool_choice,omitempty"` | ||||
|   | ||||
| @@ -15,7 +15,11 @@ import ( | ||||
| ) | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", | ||||
| 	"gemini-pro", "gemini-pro-vision", | ||||
| 	"gemini-1.5-pro-001", "gemini-1.5-flash-001", | ||||
| 	"gemini-1.5-pro-002", "gemini-1.5-flash-002", | ||||
| 	"gemini-2.0-flash-exp", | ||||
| 	"gemini-2.0-flash-thinking-exp", "gemini-2.0-flash-thinking-exp-01-21", | ||||
| } | ||||
|  | ||||
| type Adaptor struct { | ||||
|   | ||||
							
								
								
									
										5
									
								
								relay/adaptor/xai/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								relay/adaptor/xai/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| package xai | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"grok-beta", | ||||
| } | ||||
| @@ -5,6 +5,8 @@ var ModelList = []string{ | ||||
| 	"SparkDesk-v1.1", | ||||
| 	"SparkDesk-v2.1", | ||||
| 	"SparkDesk-v3.1", | ||||
| 	"SparkDesk-v3.1-128K", | ||||
| 	"SparkDesk-v3.5", | ||||
| 	"SparkDesk-v3.5-32K", | ||||
| 	"SparkDesk-v4.0", | ||||
| } | ||||
|   | ||||
| @@ -272,9 +272,9 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, | ||||
| } | ||||
|  | ||||
| func parseAPIVersionByModelName(modelName string) string { | ||||
| 	parts := strings.Split(modelName, "-") | ||||
| 	if len(parts) == 2 { | ||||
| 		return parts[1] | ||||
| 	index := strings.IndexAny(modelName, "-") | ||||
| 	if index != -1 { | ||||
| 		return modelName[index+1:] | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
| @@ -283,13 +283,17 @@ func parseAPIVersionByModelName(modelName string) string { | ||||
| func apiVersion2domain(apiVersion string) string { | ||||
| 	switch apiVersion { | ||||
| 	case "v1.1": | ||||
| 		return "general" | ||||
| 		return "lite" | ||||
| 	case "v2.1": | ||||
| 		return "generalv2" | ||||
| 	case "v3.1": | ||||
| 		return "generalv3" | ||||
| 	case "v3.1-128K": | ||||
| 		return "pro-128k" | ||||
| 	case "v3.5": | ||||
| 		return "generalv3.5" | ||||
| 	case "v3.5-32K": | ||||
| 		return "max-32k" | ||||
| 	case "v4.0": | ||||
| 		return "4.0Ultra" | ||||
| 	} | ||||
| @@ -297,7 +301,17 @@ func apiVersion2domain(apiVersion string) string { | ||||
| } | ||||
|  | ||||
| func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) { | ||||
| 	var authUrl string | ||||
| 	domain := apiVersion2domain(apiVersion) | ||||
| 	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) | ||||
| 	switch apiVersion { | ||||
| 	case "v3.1-128K": | ||||
| 		authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/pro-128k"), apiKey, apiSecret) | ||||
| 		break | ||||
| 	case "v3.5-32K": | ||||
| 		authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/max-32k"), apiKey, apiSecret) | ||||
| 		break | ||||
| 	default: | ||||
| 		authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) | ||||
| 	} | ||||
| 	return domain, authUrl | ||||
| } | ||||
|   | ||||
| @@ -19,11 +19,11 @@ type ChatRequest struct { | ||||
| 	} `json:"header"` | ||||
| 	Parameter struct { | ||||
| 		Chat struct { | ||||
| 			Domain      string  `json:"domain,omitempty"` | ||||
| 			Temperature float64 `json:"temperature,omitempty"` | ||||
| 			TopK        int     `json:"top_k,omitempty"` | ||||
| 			MaxTokens   int     `json:"max_tokens,omitempty"` | ||||
| 			Auditing    bool    `json:"auditing,omitempty"` | ||||
| 			Domain      string   `json:"domain,omitempty"` | ||||
| 			Temperature *float64 `json:"temperature,omitempty"` | ||||
| 			TopK        int      `json:"top_k,omitempty"` | ||||
| 			MaxTokens   int      `json:"max_tokens,omitempty"` | ||||
| 			Auditing    bool     `json:"auditing,omitempty"` | ||||
| 		} `json:"chat"` | ||||
| 	} `json:"parameter"` | ||||
| 	Payload struct { | ||||
|   | ||||
| @@ -4,13 +4,13 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| 	"io" | ||||
| 	"math" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
| @@ -65,13 +65,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | ||||
| 		baiduEmbeddingRequest, err := ConvertEmbeddingRequest(*request) | ||||
| 		return baiduEmbeddingRequest, err | ||||
| 	default: | ||||
| 		// TopP (0.0, 1.0) | ||||
| 		request.TopP = math.Min(0.99, request.TopP) | ||||
| 		request.TopP = math.Max(0.01, request.TopP) | ||||
| 		// TopP [0.0, 1.0] | ||||
| 		request.TopP = helper.Float64PtrMax(request.TopP, 1) | ||||
| 		request.TopP = helper.Float64PtrMin(request.TopP, 0) | ||||
|  | ||||
| 		// Temperature (0.0, 1.0) | ||||
| 		request.Temperature = math.Min(0.99, request.Temperature) | ||||
| 		request.Temperature = math.Max(0.01, request.Temperature) | ||||
| 		// Temperature [0.0, 1.0] | ||||
| 		request.Temperature = helper.Float64PtrMax(request.Temperature, 1) | ||||
| 		request.Temperature = helper.Float64PtrMin(request.Temperature, 0) | ||||
| 		a.SetVersionByModeName(request.Model) | ||||
| 		if a.APIVersion == "v4" { | ||||
| 			return request, nil | ||||
|   | ||||
| @@ -12,8 +12,8 @@ type Message struct { | ||||
|  | ||||
| type Request struct { | ||||
| 	Prompt      []Message `json:"prompt"` | ||||
| 	Temperature float64   `json:"temperature,omitempty"` | ||||
| 	TopP        float64   `json:"top_p,omitempty"` | ||||
| 	Temperature *float64  `json:"temperature,omitempty"` | ||||
| 	TopP        *float64  `json:"top_p,omitempty"` | ||||
| 	RequestId   string    `json:"request_id,omitempty"` | ||||
| 	Incremental bool      `json:"incremental,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -19,6 +19,7 @@ const ( | ||||
| 	DeepL | ||||
| 	VertexAI | ||||
| 	Proxy | ||||
| 	Replicate | ||||
|  | ||||
| 	Dummy // this one is only for count, do not add any channel after this | ||||
| ) | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package billing | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| ) | ||||
| @@ -31,8 +32,17 @@ func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQ | ||||
| 	} | ||||
| 	// totalQuota is total quota consumed | ||||
| 	if totalQuota != 0 { | ||||
| 		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 		model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent) | ||||
| 		logContent := fmt.Sprintf("倍率:%.2f × %.2f", modelRatio, groupRatio) | ||||
| 		model.RecordConsumeLog(ctx, &model.Log{ | ||||
| 			UserId:           userId, | ||||
| 			ChannelId:        channelId, | ||||
| 			PromptTokens:     int(totalQuota), | ||||
| 			CompletionTokens: 0, | ||||
| 			ModelName:        modelName, | ||||
| 			TokenName:        tokenName, | ||||
| 			Quota:            int(totalQuota), | ||||
| 			Content:          logContent, | ||||
| 		}) | ||||
| 		model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) | ||||
| 		model.UpdateChannelUsedQuota(channelId, totalQuota) | ||||
| 	} | ||||
|   | ||||
| @@ -30,6 +30,14 @@ var ImageSizeRatios = map[string]map[string]float64{ | ||||
| 		"720x1280":  1, | ||||
| 		"1280x720":  1, | ||||
| 	}, | ||||
| 	"step-1x-medium": { | ||||
| 		"256x256":   1, | ||||
| 		"512x512":   1, | ||||
| 		"768x768":   1, | ||||
| 		"1024x1024": 1, | ||||
| 		"1280x800":  1, | ||||
| 		"800x1280":  1, | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| var ImageGenerationAmounts = map[string][2]int{ | ||||
| @@ -39,6 +47,7 @@ var ImageGenerationAmounts = map[string][2]int{ | ||||
| 	"ali-stable-diffusion-v1.5": {1, 4}, // Ali | ||||
| 	"wanx-v1":                   {1, 4}, // Ali | ||||
| 	"cogview-3":                 {1, 1}, | ||||
| 	"step-1x-medium":            {1, 1}, | ||||
| } | ||||
|  | ||||
| var ImagePromptLengthLimitations = map[string]int{ | ||||
| @@ -48,6 +57,7 @@ var ImagePromptLengthLimitations = map[string]int{ | ||||
| 	"ali-stable-diffusion-v1.5": 4000, | ||||
| 	"wanx-v1":                   4000, | ||||
| 	"cogview-3":                 833, | ||||
| 	"step-1x-medium":            4000, | ||||
| } | ||||
|  | ||||
| var ImageOriginModelName = map[string]string{ | ||||
|   | ||||
| @@ -9,9 +9,10 @@ import ( | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	USD2RMB = 7 | ||||
| 	USD     = 500 // $0.002 = 1 -> $1 = 500 | ||||
| 	RMB     = USD / USD2RMB | ||||
| 	USD2RMB   = 7 | ||||
| 	USD       = 500 // $0.002 = 1 -> $1 = 500 | ||||
| 	MILLI_USD = 1.0 / 1000 * USD | ||||
| 	RMB       = USD / USD2RMB | ||||
| ) | ||||
|  | ||||
| // ModelRatio | ||||
| @@ -34,7 +35,10 @@ var ModelRatio = map[string]float64{ | ||||
| 	"gpt-4-turbo":             5,     // $0.01 / 1K tokens | ||||
| 	"gpt-4-turbo-2024-04-09":  5,     // $0.01 / 1K tokens | ||||
| 	"gpt-4o":                  2.5,   // $0.005 / 1K tokens | ||||
| 	"chatgpt-4o-latest":       2.5,   // $0.005 / 1K tokens | ||||
| 	"gpt-4o-2024-05-13":       2.5,   // $0.005 / 1K tokens | ||||
| 	"gpt-4o-2024-08-06":       1.25,  // $0.0025 / 1K tokens | ||||
| 	"gpt-4o-2024-11-20":       1.25,  // $0.0025 / 1K tokens | ||||
| 	"gpt-4o-mini":             0.075, // $0.00015 / 1K tokens | ||||
| 	"gpt-4o-mini-2024-07-18":  0.075, // $0.00015 / 1K tokens | ||||
| 	"gpt-4-vision-preview":    5,     // $0.01 / 1K tokens | ||||
| @@ -46,8 +50,14 @@ var ModelRatio = map[string]float64{ | ||||
| 	"gpt-3.5-turbo-instruct":  0.75, // $0.0015 / 1K tokens | ||||
| 	"gpt-3.5-turbo-1106":      0.5,  // $0.001 / 1K tokens | ||||
| 	"gpt-3.5-turbo-0125":      0.25, // $0.0005 / 1K tokens | ||||
| 	"davinci-002":             1,    // $0.002 / 1K tokens | ||||
| 	"babbage-002":             0.2,  // $0.0004 / 1K tokens | ||||
| 	"o1":                      7.5,  // $15.00 / 1M input tokens | ||||
| 	"o1-2024-12-17":           7.5, | ||||
| 	"o1-preview":              7.5, // $15.00 / 1M input tokens | ||||
| 	"o1-preview-2024-09-12":   7.5, | ||||
| 	"o1-mini":                 1.5, // $3.00 / 1M input tokens | ||||
| 	"o1-mini-2024-09-12":      1.5, | ||||
| 	"davinci-002":             1,   // $0.002 / 1K tokens | ||||
| 	"babbage-002":             0.2, // $0.0004 / 1K tokens | ||||
| 	"text-ada-001":            0.2, | ||||
| 	"text-babbage-001":        0.25, | ||||
| 	"text-curie-001":          1, | ||||
| @@ -77,8 +87,10 @@ var ModelRatio = map[string]float64{ | ||||
| 	"claude-2.0":                 8.0 / 1000 * USD, | ||||
| 	"claude-2.1":                 8.0 / 1000 * USD, | ||||
| 	"claude-3-haiku-20240307":    0.25 / 1000 * USD, | ||||
| 	"claude-3-5-haiku-20241022":  1.0 / 1000 * USD, | ||||
| 	"claude-3-sonnet-20240229":   3.0 / 1000 * USD, | ||||
| 	"claude-3-5-sonnet-20240620": 3.0 / 1000 * USD, | ||||
| 	"claude-3-5-sonnet-20241022": 3.0 / 1000 * USD, | ||||
| 	"claude-3-opus-20240229":     15.0 / 1000 * USD, | ||||
| 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 | ||||
| 	"ERNIE-4.0-8K":       0.120 * RMB, | ||||
| @@ -98,12 +110,16 @@ var ModelRatio = map[string]float64{ | ||||
| 	"bge-large-en":       0.002 * RMB, | ||||
| 	"tao-8k":             0.002 * RMB, | ||||
| 	// https://ai.google.dev/pricing | ||||
| 	"PaLM-2":                    1, | ||||
| 	"gemini-pro":                1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||
| 	"gemini-pro-vision":         1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||
| 	"gemini-1.0-pro-vision-001": 1, | ||||
| 	"gemini-1.0-pro-001":        1, | ||||
| 	"gemini-1.5-pro":            1, | ||||
| 	"gemini-pro":                          1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||
| 	"gemini-1.0-pro":                      1, | ||||
| 	"gemini-1.5-pro":                      1, | ||||
| 	"gemini-1.5-pro-001":                  1, | ||||
| 	"gemini-1.5-flash":                    1, | ||||
| 	"gemini-1.5-flash-001":                1, | ||||
| 	"gemini-2.0-flash-exp":                1, | ||||
| 	"gemini-2.0-flash-thinking-exp":       1, | ||||
| 	"gemini-2.0-flash-thinking-exp-01-21": 1, | ||||
| 	"aqa":                                 1, | ||||
| 	// https://open.bigmodel.cn/pricing | ||||
| 	"glm-4":         0.1 * RMB, | ||||
| 	"glm-4v":        0.1 * RMB, | ||||
| @@ -115,27 +131,94 @@ var ModelRatio = map[string]float64{ | ||||
| 	"chatglm_lite":  0.1429, // ¥0.002 / 1k tokens | ||||
| 	"cogview-3":     0.25 * RMB, | ||||
| 	// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing | ||||
| 	"qwen-turbo":                0.5715, // ¥0.008 / 1k tokens | ||||
| 	"qwen-plus":                 1.4286, // ¥0.02 / 1k tokens | ||||
| 	"qwen-max":                  1.4286, // ¥0.02 / 1k tokens | ||||
| 	"qwen-max-longcontext":      1.4286, // ¥0.02 / 1k tokens | ||||
| 	"text-embedding-v1":         0.05,   // ¥0.0007 / 1k tokens | ||||
| 	"ali-stable-diffusion-xl":   8, | ||||
| 	"ali-stable-diffusion-v1.5": 8, | ||||
| 	"wanx-v1":                   8, | ||||
| 	"SparkDesk":                 1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v1.1":            1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v2.1":            1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v3.1":            1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v3.5":            1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v4.0":            1.2858, // ¥0.018 / 1k tokens | ||||
| 	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens | ||||
| 	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens | ||||
| 	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens | ||||
| 	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens | ||||
| 	"hunyuan":                   7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 | ||||
| 	"ChatStd":                   0.01 * RMB, | ||||
| 	"ChatPro":                   0.1 * RMB, | ||||
| 	"qwen-turbo":                  1.4286, // ¥0.02 / 1k tokens | ||||
| 	"qwen-turbo-latest":           1.4286, | ||||
| 	"qwen-plus":                   1.4286, | ||||
| 	"qwen-plus-latest":            1.4286, | ||||
| 	"qwen-max":                    1.4286, | ||||
| 	"qwen-max-latest":             1.4286, | ||||
| 	"qwen-max-longcontext":        1.4286, | ||||
| 	"qwen-vl-max":                 1.4286, | ||||
| 	"qwen-vl-max-latest":          1.4286, | ||||
| 	"qwen-vl-plus":                1.4286, | ||||
| 	"qwen-vl-plus-latest":         1.4286, | ||||
| 	"qwen-vl-ocr":                 1.4286, | ||||
| 	"qwen-vl-ocr-latest":          1.4286, | ||||
| 	"qwen-audio-turbo":            1.4286, | ||||
| 	"qwen-math-plus":              1.4286, | ||||
| 	"qwen-math-plus-latest":       1.4286, | ||||
| 	"qwen-math-turbo":             1.4286, | ||||
| 	"qwen-math-turbo-latest":      1.4286, | ||||
| 	"qwen-coder-plus":             1.4286, | ||||
| 	"qwen-coder-plus-latest":      1.4286, | ||||
| 	"qwen-coder-turbo":            1.4286, | ||||
| 	"qwen-coder-turbo-latest":     1.4286, | ||||
| 	"qwq-32b-preview":             1.4286, | ||||
| 	"qwen2.5-72b-instruct":        1.4286, | ||||
| 	"qwen2.5-32b-instruct":        1.4286, | ||||
| 	"qwen2.5-14b-instruct":        1.4286, | ||||
| 	"qwen2.5-7b-instruct":         1.4286, | ||||
| 	"qwen2.5-3b-instruct":         1.4286, | ||||
| 	"qwen2.5-1.5b-instruct":       1.4286, | ||||
| 	"qwen2.5-0.5b-instruct":       1.4286, | ||||
| 	"qwen2-72b-instruct":          1.4286, | ||||
| 	"qwen2-57b-a14b-instruct":     1.4286, | ||||
| 	"qwen2-7b-instruct":           1.4286, | ||||
| 	"qwen2-1.5b-instruct":         1.4286, | ||||
| 	"qwen2-0.5b-instruct":         1.4286, | ||||
| 	"qwen1.5-110b-chat":           1.4286, | ||||
| 	"qwen1.5-72b-chat":            1.4286, | ||||
| 	"qwen1.5-32b-chat":            1.4286, | ||||
| 	"qwen1.5-14b-chat":            1.4286, | ||||
| 	"qwen1.5-7b-chat":             1.4286, | ||||
| 	"qwen1.5-1.8b-chat":           1.4286, | ||||
| 	"qwen1.5-0.5b-chat":           1.4286, | ||||
| 	"qwen-72b-chat":               1.4286, | ||||
| 	"qwen-14b-chat":               1.4286, | ||||
| 	"qwen-7b-chat":                1.4286, | ||||
| 	"qwen-1.8b-chat":              1.4286, | ||||
| 	"qwen-1.8b-longcontext-chat":  1.4286, | ||||
| 	"qwen2-vl-7b-instruct":        1.4286, | ||||
| 	"qwen2-vl-2b-instruct":        1.4286, | ||||
| 	"qwen-vl-v1":                  1.4286, | ||||
| 	"qwen-vl-chat-v1":             1.4286, | ||||
| 	"qwen2-audio-instruct":        1.4286, | ||||
| 	"qwen-audio-chat":             1.4286, | ||||
| 	"qwen2.5-math-72b-instruct":   1.4286, | ||||
| 	"qwen2.5-math-7b-instruct":    1.4286, | ||||
| 	"qwen2.5-math-1.5b-instruct":  1.4286, | ||||
| 	"qwen2-math-72b-instruct":     1.4286, | ||||
| 	"qwen2-math-7b-instruct":      1.4286, | ||||
| 	"qwen2-math-1.5b-instruct":    1.4286, | ||||
| 	"qwen2.5-coder-32b-instruct":  1.4286, | ||||
| 	"qwen2.5-coder-14b-instruct":  1.4286, | ||||
| 	"qwen2.5-coder-7b-instruct":   1.4286, | ||||
| 	"qwen2.5-coder-3b-instruct":   1.4286, | ||||
| 	"qwen2.5-coder-1.5b-instruct": 1.4286, | ||||
| 	"qwen2.5-coder-0.5b-instruct": 1.4286, | ||||
| 	"text-embedding-v1":           0.05, // ¥0.0007 / 1k tokens | ||||
| 	"text-embedding-v3":           0.05, | ||||
| 	"text-embedding-v2":           0.05, | ||||
| 	"text-embedding-async-v2":     0.05, | ||||
| 	"text-embedding-async-v1":     0.05, | ||||
| 	"ali-stable-diffusion-xl":     8.00, | ||||
| 	"ali-stable-diffusion-v1.5":   8.00, | ||||
| 	"wanx-v1":                     8.00, | ||||
| 	"SparkDesk":                   1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v1.1":              1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v2.1":              1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v3.1":              1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v3.1-128K":         1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v3.5":              1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v3.5-32K":          1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v4.0":              1.2858, // ¥0.018 / 1k tokens | ||||
| 	"360GPT_S2_V9":                0.8572, // ¥0.012 / 1k tokens | ||||
| 	"embedding-bert-512-v1":       0.0715, // ¥0.001 / 1k tokens | ||||
| 	"embedding_s1_v1":             0.0715, // ¥0.001 / 1k tokens | ||||
| 	"semantic_similarity_s1_v1":   0.0715, // ¥0.001 / 1k tokens | ||||
| 	"hunyuan":                     7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 | ||||
| 	"ChatStd":                     0.01 * RMB, | ||||
| 	"ChatPro":                     0.1 * RMB, | ||||
| 	// https://platform.moonshot.cn/pricing | ||||
| 	"moonshot-v1-8k":   0.012 * RMB, | ||||
| 	"moonshot-v1-32k":  0.024 * RMB, | ||||
| @@ -158,20 +241,35 @@ var ModelRatio = map[string]float64{ | ||||
| 	"mistral-large-latest":  8.0 / 1000 * USD, | ||||
| 	"mistral-embed":         0.1 / 1000 * USD, | ||||
| 	// https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed | ||||
| 	"llama3-70b-8192":    0.59 / 1000 * USD, | ||||
| 	"mixtral-8x7b-32768": 0.27 / 1000 * USD, | ||||
| 	"llama3-8b-8192":     0.05 / 1000 * USD, | ||||
| 	"gemma-7b-it":        0.1 / 1000 * USD, | ||||
| 	"llama2-70b-4096":    0.64 / 1000 * USD, | ||||
| 	"llama2-7b-2048":     0.1 / 1000 * USD, | ||||
| 	"gemma-7b-it":                           0.07 / 1000000 * USD, | ||||
| 	"gemma2-9b-it":                          0.20 / 1000000 * USD, | ||||
| 	"llama-3.1-70b-versatile":               0.59 / 1000000 * USD, | ||||
| 	"llama-3.1-8b-instant":                  0.05 / 1000000 * USD, | ||||
| 	"llama-3.2-11b-text-preview":            0.05 / 1000000 * USD, | ||||
| 	"llama-3.2-11b-vision-preview":          0.05 / 1000000 * USD, | ||||
| 	"llama-3.2-1b-preview":                  0.05 / 1000000 * USD, | ||||
| 	"llama-3.2-3b-preview":                  0.05 / 1000000 * USD, | ||||
| 	"llama-3.2-90b-text-preview":            0.59 / 1000000 * USD, | ||||
| 	"llama-guard-3-8b":                      0.05 / 1000000 * USD, | ||||
| 	"llama3-70b-8192":                       0.59 / 1000000 * USD, | ||||
| 	"llama3-8b-8192":                        0.05 / 1000000 * USD, | ||||
| 	"llama3-groq-70b-8192-tool-use-preview": 0.89 / 1000000 * USD, | ||||
| 	"llama3-groq-8b-8192-tool-use-preview":  0.19 / 1000000 * USD, | ||||
| 	"mixtral-8x7b-32768":                    0.24 / 1000000 * USD, | ||||
|  | ||||
| 	// https://platform.lingyiwanwu.com/docs#-计费单元 | ||||
| 	"yi-34b-chat-0205": 2.5 / 1000 * RMB, | ||||
| 	"yi-34b-chat-200k": 12.0 / 1000 * RMB, | ||||
| 	"yi-vl-plus":       6.0 / 1000 * RMB, | ||||
| 	// stepfun todo | ||||
| 	"step-1v-32k": 0.024 * RMB, | ||||
| 	"step-1-32k":  0.024 * RMB, | ||||
| 	"step-1-200k": 0.15 * RMB, | ||||
| 	// https://platform.stepfun.com/docs/pricing/details | ||||
| 	"step-1-8k":    0.005 / 1000 * RMB, | ||||
| 	"step-1-32k":   0.015 / 1000 * RMB, | ||||
| 	"step-1-128k":  0.040 / 1000 * RMB, | ||||
| 	"step-1-256k":  0.095 / 1000 * RMB, | ||||
| 	"step-1-flash": 0.001 / 1000 * RMB, | ||||
| 	"step-2-16k":   0.038 / 1000 * RMB, | ||||
| 	"step-1v-8k":   0.005 / 1000 * RMB, | ||||
| 	"step-1v-32k":  0.015 / 1000 * RMB, | ||||
| 	// aws llama3 https://aws.amazon.com/cn/bedrock/pricing/ | ||||
| 	"llama3-8b-8192(33)":  0.0003 / 0.002,  // $0.0003 / 1K tokens | ||||
| 	"llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens | ||||
| @@ -183,22 +281,75 @@ var ModelRatio = map[string]float64{ | ||||
| 	"command-r":             0.5 / 1000 * USD, | ||||
| 	"command-r-plus":        3.0 / 1000 * USD, | ||||
| 	// https://platform.deepseek.com/api-docs/pricing/ | ||||
| 	"deepseek-chat":  1.0 / 1000 * RMB, | ||||
| 	"deepseek-coder": 1.0 / 1000 * RMB, | ||||
| 	"deepseek-chat":     0.14 * MILLI_USD, | ||||
| 	"deepseek-reasoner": 0.55 * MILLI_USD, | ||||
| 	// https://www.deepl.com/pro?cta=header-prices | ||||
| 	"deepl-zh": 25.0 / 1000 * USD, | ||||
| 	"deepl-en": 25.0 / 1000 * USD, | ||||
| 	"deepl-ja": 25.0 / 1000 * USD, | ||||
| 	// https://console.x.ai/ | ||||
| 	"grok-beta": 5.0 / 1000 * USD, | ||||
| 	// replicate charges based on the number of generated images | ||||
| 	// https://replicate.com/pricing | ||||
| 	"black-forest-labs/flux-1.1-pro":                0.04 * USD, | ||||
| 	"black-forest-labs/flux-1.1-pro-ultra":          0.06 * USD, | ||||
| 	"black-forest-labs/flux-canny-dev":              0.025 * USD, | ||||
| 	"black-forest-labs/flux-canny-pro":              0.05 * USD, | ||||
| 	"black-forest-labs/flux-depth-dev":              0.025 * USD, | ||||
| 	"black-forest-labs/flux-depth-pro":              0.05 * USD, | ||||
| 	"black-forest-labs/flux-dev":                    0.025 * USD, | ||||
| 	"black-forest-labs/flux-dev-lora":               0.032 * USD, | ||||
| 	"black-forest-labs/flux-fill-dev":               0.04 * USD, | ||||
| 	"black-forest-labs/flux-fill-pro":               0.05 * USD, | ||||
| 	"black-forest-labs/flux-pro":                    0.055 * USD, | ||||
| 	"black-forest-labs/flux-redux-dev":              0.025 * USD, | ||||
| 	"black-forest-labs/flux-redux-schnell":          0.003 * USD, | ||||
| 	"black-forest-labs/flux-schnell":                0.003 * USD, | ||||
| 	"black-forest-labs/flux-schnell-lora":           0.02 * USD, | ||||
| 	"ideogram-ai/ideogram-v2":                       0.08 * USD, | ||||
| 	"ideogram-ai/ideogram-v2-turbo":                 0.05 * USD, | ||||
| 	"recraft-ai/recraft-v3":                         0.04 * USD, | ||||
| 	"recraft-ai/recraft-v3-svg":                     0.08 * USD, | ||||
| 	"stability-ai/stable-diffusion-3":               0.035 * USD, | ||||
| 	"stability-ai/stable-diffusion-3.5-large":       0.065 * USD, | ||||
| 	"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD, | ||||
| 	"stability-ai/stable-diffusion-3.5-medium":      0.035 * USD, | ||||
| 	// replicate chat models | ||||
| 	"ibm-granite/granite-20b-code-instruct-8k":  0.100 * USD, | ||||
| 	"ibm-granite/granite-3.0-2b-instruct":       0.030 * USD, | ||||
| 	"ibm-granite/granite-3.0-8b-instruct":       0.050 * USD, | ||||
| 	"ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD, | ||||
| 	"meta/llama-2-13b":                          0.100 * USD, | ||||
| 	"meta/llama-2-13b-chat":                     0.100 * USD, | ||||
| 	"meta/llama-2-70b":                          0.650 * USD, | ||||
| 	"meta/llama-2-70b-chat":                     0.650 * USD, | ||||
| 	"meta/llama-2-7b":                           0.050 * USD, | ||||
| 	"meta/llama-2-7b-chat":                      0.050 * USD, | ||||
| 	"meta/meta-llama-3.1-405b-instruct":         9.500 * USD, | ||||
| 	"meta/meta-llama-3-70b":                     0.650 * USD, | ||||
| 	"meta/meta-llama-3-70b-instruct":            0.650 * USD, | ||||
| 	"meta/meta-llama-3-8b":                      0.050 * USD, | ||||
| 	"meta/meta-llama-3-8b-instruct":             0.050 * USD, | ||||
| 	"mistralai/mistral-7b-instruct-v0.2":        0.050 * USD, | ||||
| 	"mistralai/mistral-7b-v0.1":                 0.050 * USD, | ||||
| 	"mistralai/mixtral-8x7b-instruct-v0.1":      0.300 * USD, | ||||
| } | ||||
|  | ||||
| var CompletionRatio = map[string]float64{ | ||||
| 	// aws llama3 | ||||
| 	"llama3-8b-8192(33)":  0.0006 / 0.0003, | ||||
| 	"llama3-70b-8192(33)": 0.0035 / 0.00265, | ||||
| 	// whisper | ||||
| 	"whisper-1": 0, // only count input tokens | ||||
| 	// deepseek | ||||
| 	"deepseek-chat":     0.28 / 0.14, | ||||
| 	"deepseek-reasoner": 2.19 / 0.55, | ||||
| } | ||||
|  | ||||
| var DefaultModelRatio map[string]float64 | ||||
| var DefaultCompletionRatio map[string]float64 | ||||
| var ( | ||||
| 	DefaultModelRatio      map[string]float64 | ||||
| 	DefaultCompletionRatio map[string]float64 | ||||
| ) | ||||
|  | ||||
| func init() { | ||||
| 	DefaultModelRatio = make(map[string]float64) | ||||
| @@ -310,16 +461,25 @@ func GetCompletionRatio(name string, channelType int) float64 { | ||||
| 		return 4.0 / 3.0 | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "gpt-4") { | ||||
| 		if strings.HasPrefix(name, "gpt-4o-mini") { | ||||
| 		if strings.HasPrefix(name, "gpt-4o") { | ||||
| 			if name == "gpt-4o-2024-05-13" { | ||||
| 				return 3 | ||||
| 			} | ||||
| 			return 4 | ||||
| 		} | ||||
| 		if strings.HasPrefix(name, "gpt-4-turbo") || | ||||
| 			strings.HasPrefix(name, "gpt-4o") || | ||||
| 			strings.HasSuffix(name, "preview") { | ||||
| 			return 3 | ||||
| 		} | ||||
| 		return 2 | ||||
| 	} | ||||
| 	// including o1, o1-preview, o1-mini | ||||
| 	if strings.HasPrefix(name, "o1") { | ||||
| 		return 4 | ||||
| 	} | ||||
| 	if name == "chatgpt-4o-latest" { | ||||
| 		return 3 | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "claude-3") { | ||||
| 		return 5 | ||||
| 	} | ||||
| @@ -335,6 +495,7 @@ func GetCompletionRatio(name string, channelType int) float64 { | ||||
| 	if strings.HasPrefix(name, "deepseek-") { | ||||
| 		return 2 | ||||
| 	} | ||||
|  | ||||
| 	switch name { | ||||
| 	case "llama2-70b-4096": | ||||
| 		return 0.8 / 0.64 | ||||
| @@ -348,6 +509,37 @@ func GetCompletionRatio(name string, channelType int) float64 { | ||||
| 		return 3 | ||||
| 	case "command-r-plus": | ||||
| 		return 5 | ||||
| 	case "grok-beta": | ||||
| 		return 3 | ||||
| 	// Replicate Models | ||||
| 	// https://replicate.com/pricing | ||||
| 	case "ibm-granite/granite-20b-code-instruct-8k": | ||||
| 		return 5 | ||||
| 	case "ibm-granite/granite-3.0-2b-instruct": | ||||
| 		return 8.333333333333334 | ||||
| 	case "ibm-granite/granite-3.0-8b-instruct", | ||||
| 		"ibm-granite/granite-8b-code-instruct-128k": | ||||
| 		return 5 | ||||
| 	case "meta/llama-2-13b", | ||||
| 		"meta/llama-2-13b-chat", | ||||
| 		"meta/llama-2-7b", | ||||
| 		"meta/llama-2-7b-chat", | ||||
| 		"meta/meta-llama-3-8b", | ||||
| 		"meta/meta-llama-3-8b-instruct": | ||||
| 		return 5 | ||||
| 	case "meta/llama-2-70b", | ||||
| 		"meta/llama-2-70b-chat", | ||||
| 		"meta/meta-llama-3-70b", | ||||
| 		"meta/meta-llama-3-70b-instruct": | ||||
| 		return 2.750 / 0.650 // ≈4.230769 | ||||
| 	case "meta/meta-llama-3.1-405b-instruct": | ||||
| 		return 1 | ||||
| 	case "mistralai/mistral-7b-instruct-v0.2", | ||||
| 		"mistralai/mistral-7b-v0.1": | ||||
| 		return 5 | ||||
| 	case "mistralai/mixtral-8x7b-instruct-v0.1": | ||||
| 		return 1.000 / 0.300 // ≈3.333333 | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|   | ||||
| @@ -45,5 +45,8 @@ const ( | ||||
| 	Novita | ||||
| 	VertextAI | ||||
| 	Proxy | ||||
| 	SiliconFlow | ||||
| 	XAI | ||||
| 	Replicate | ||||
| 	Dummy | ||||
| ) | ||||
|   | ||||
| @@ -37,6 +37,8 @@ func ToAPIType(channelType int) int { | ||||
| 		apiType = apitype.DeepL | ||||
| 	case VertextAI: | ||||
| 		apiType = apitype.VertexAI | ||||
| 	case Replicate: | ||||
| 		apiType = apitype.Replicate | ||||
| 	case Proxy: | ||||
| 		apiType = apitype.Proxy | ||||
| 	} | ||||
|   | ||||
| @@ -45,6 +45,9 @@ var ChannelBaseURLs = []string{ | ||||
| 	"https://api.novita.ai/v3/openai",           // 41 | ||||
| 	"",                                          // 42 | ||||
| 	"",                                          // 43 | ||||
| 	"https://api.siliconflow.cn",                // 44 | ||||
| 	"https://api.x.ai",                          // 45 | ||||
| 	"https://api.replicate.com/v1/models/",      // 46 | ||||
| } | ||||
|  | ||||
| func init() { | ||||
|   | ||||
| @@ -1,5 +1,6 @@ | ||||
| package role | ||||
|  | ||||
| const ( | ||||
| 	System    = "system" | ||||
| 	Assistant = "assistant" | ||||
| ) | ||||
|   | ||||
| @@ -110,16 +110,9 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 	}() | ||||
|  | ||||
| 	// map model name | ||||
| 	modelMapping := c.GetString(ctxkey.ModelMapping) | ||||
| 	if modelMapping != "" { | ||||
| 		modelMap := make(map[string]string) | ||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||
| 		if err != nil { | ||||
| 			return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		if modelMap[audioModel] != "" { | ||||
| 			audioModel = modelMap[audioModel] | ||||
| 		} | ||||
| 	modelMapping := c.GetStringMapString(ctxkey.ModelMapping) | ||||
| 	if modelMapping != nil && modelMapping[audioModel] != "" { | ||||
| 		audioModel = modelMapping[audioModel] | ||||
| 	} | ||||
|  | ||||
| 	baseURL := channeltype.ChannelBaseURLs[channelType] | ||||
|   | ||||
| @@ -8,7 +8,11 @@ import ( | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant/role" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| @@ -90,7 +94,7 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR | ||||
| 	return preConsumedQuota, nil | ||||
| } | ||||
|  | ||||
| func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { | ||||
| func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) { | ||||
| 	if usage == nil { | ||||
| 		logger.Error(ctx, "usage is nil, which is unexpected") | ||||
| 		return | ||||
| @@ -118,8 +122,20 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M | ||||
| 	if err != nil { | ||||
| 		logger.Error(ctx, "error update user quota cache: "+err.Error()) | ||||
| 	} | ||||
| 	logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) | ||||
| 	model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) | ||||
| 	logContent := fmt.Sprintf("倍率:%.2f × %.2f × %.2f", modelRatio, groupRatio, completionRatio) | ||||
| 	model.RecordConsumeLog(ctx, &model.Log{ | ||||
| 		UserId:            meta.UserId, | ||||
| 		ChannelId:         meta.ChannelId, | ||||
| 		PromptTokens:      promptTokens, | ||||
| 		CompletionTokens:  completionTokens, | ||||
| 		ModelName:         textRequest.Model, | ||||
| 		TokenName:         meta.TokenName, | ||||
| 		Quota:             int(quota), | ||||
| 		Content:           logContent, | ||||
| 		IsStream:          meta.IsStream, | ||||
| 		ElapsedTime:       helper.CalcElapsedTime(meta.StartTime), | ||||
| 		SystemPromptReset: systemPromptReset, | ||||
| 	}) | ||||
| 	model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) | ||||
| 	model.UpdateChannelUsedQuota(meta.ChannelId, quota) | ||||
| } | ||||
| @@ -142,15 +158,41 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool { | ||||
| 		} | ||||
| 		return true | ||||
| 	} | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 	if resp.StatusCode != http.StatusOK && | ||||
| 		// replicate return 201 to create a task | ||||
| 		resp.StatusCode != http.StatusCreated { | ||||
| 		return true | ||||
| 	} | ||||
| 	if meta.ChannelType == channeltype.DeepL { | ||||
| 		// skip stream check for deepl | ||||
| 		return false | ||||
| 	} | ||||
| 	if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { | ||||
|  | ||||
| 	if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") && | ||||
| 		// Even if stream mode is enabled, replicate will first return a task info in JSON format, | ||||
| 		// requiring the client to request the stream endpoint in the task info | ||||
| 		meta.ChannelType != channeltype.Replicate { | ||||
| 		return true | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIRequest, prompt string) (reset bool) { | ||||
| 	if prompt == "" { | ||||
| 		return false | ||||
| 	} | ||||
| 	if len(request.Messages) == 0 { | ||||
| 		return false | ||||
| 	} | ||||
| 	if request.Messages[0].Role == role.System { | ||||
| 		request.Messages[0].Content = prompt | ||||
| 		logger.Infof(ctx, "rewrite system prompt") | ||||
| 		return true | ||||
| 	} | ||||
| 	request.Messages = append([]relaymodel.Message{{ | ||||
| 		Role:    role.System, | ||||
| 		Content: prompt, | ||||
| 	}}, request.Messages...) | ||||
| 	logger.Infof(ctx, "add system prompt") | ||||
| 	return true | ||||
| } | ||||
|   | ||||
| @@ -10,6 +10,7 @@ import ( | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| @@ -22,7 +23,7 @@ import ( | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { | ||||
| func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) { | ||||
| 	imageRequest := &relaymodel.ImageRequest{} | ||||
| 	err := common.UnmarshalBodyReusable(c, imageRequest) | ||||
| 	if err != nil { | ||||
| @@ -65,7 +66,7 @@ func getImageSizeRatio(model string, size string) float64 { | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { | ||||
| func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode { | ||||
| 	// check prompt length | ||||
| 	if imageRequest.Prompt == "" { | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | ||||
| @@ -150,12 +151,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 	} | ||||
| 	adaptor.Init(meta) | ||||
|  | ||||
| 	// these adaptors need to convert the request | ||||
| 	switch meta.ChannelType { | ||||
| 	case channeltype.Ali: | ||||
| 		fallthrough | ||||
| 	case channeltype.Baidu: | ||||
| 		fallthrough | ||||
| 	case channeltype.Zhipu: | ||||
| 	case channeltype.Zhipu, | ||||
| 		channeltype.Ali, | ||||
| 		channeltype.Replicate, | ||||
| 		channeltype.Baidu: | ||||
| 		finalRequest, err := adaptor.ConvertImageRequest(imageRequest) | ||||
| 		if err != nil { | ||||
| 			return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) | ||||
| @@ -172,7 +173,14 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) | ||||
|  | ||||
| 	quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) | ||||
| 	var quota int64 | ||||
| 	switch meta.ChannelType { | ||||
| 	case channeltype.Replicate: | ||||
| 		// replicate always return 1 image | ||||
| 		quota = int64(ratio * imageCostRatio * 1000) | ||||
| 	default: | ||||
| 		quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) | ||||
| 	} | ||||
|  | ||||
| 	if userQuota-quota < 0 { | ||||
| 		return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||
| @@ -186,7 +194,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 	} | ||||
|  | ||||
| 	defer func(ctx context.Context) { | ||||
| 		if resp != nil && resp.StatusCode != http.StatusOK { | ||||
| 		if resp != nil && | ||||
| 			resp.StatusCode != http.StatusCreated && // replicate returns 201 | ||||
| 			resp.StatusCode != http.StatusOK { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| @@ -200,8 +210,17 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 		} | ||||
| 		if quota != 0 { | ||||
| 			tokenName := c.GetString(ctxkey.TokenName) | ||||
| 			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 			model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent) | ||||
| 			logContent := fmt.Sprintf("倍率:%.2f × %.2f", modelRatio, groupRatio) | ||||
| 			model.RecordConsumeLog(ctx, &model.Log{ | ||||
| 				UserId:           meta.UserId, | ||||
| 				ChannelId:        meta.ChannelId, | ||||
| 				PromptTokens:     0, | ||||
| 				CompletionTokens: 0, | ||||
| 				ModelName:        imageRequest.Model, | ||||
| 				TokenName:        tokenName, | ||||
| 				Quota:            int(quota), | ||||
| 				Content:          logContent, | ||||
| 			}) | ||||
| 			model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) | ||||
| 			channelId := c.GetInt(ctxkey.ChannelId) | ||||
| 			model.UpdateChannelUsedQuota(channelId, quota) | ||||
|   | ||||
| @@ -8,6 +8,8 @@ import ( | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| @@ -35,6 +37,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { | ||||
| 	meta.OriginModelName = textRequest.Model | ||||
| 	textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping) | ||||
| 	meta.ActualModelName = textRequest.Model | ||||
| 	// set system prompt if not empty | ||||
| 	systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt) | ||||
| 	// get model ratio & group ratio | ||||
| 	modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) | ||||
| 	groupRatio := billingratio.GetGroupRatio(meta.Group) | ||||
| @@ -79,12 +83,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { | ||||
| 		return respErr | ||||
| 	} | ||||
| 	// post-consume quota | ||||
| 	go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) | ||||
| 	go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) { | ||||
| 	if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan { | ||||
| 	if !config.EnforceIncludeUsage && meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan { | ||||
| 		// no need to convert request for openai | ||||
| 		return c.Request.Body, nil | ||||
| 	} | ||||
|   | ||||
| @@ -1,12 +1,15 @@ | ||||
| package meta | ||||
|  | ||||
| import ( | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type Meta struct { | ||||
| @@ -30,6 +33,8 @@ type Meta struct { | ||||
| 	ActualModelName string | ||||
| 	RequestURLPath  string | ||||
| 	PromptTokens    int // only for DoResponse | ||||
| 	SystemPrompt    string | ||||
| 	StartTime       time.Time | ||||
| } | ||||
|  | ||||
| func GetByContext(c *gin.Context) *Meta { | ||||
| @@ -46,6 +51,8 @@ func GetByContext(c *gin.Context) *Meta { | ||||
| 		BaseURL:         c.GetString(ctxkey.BaseURL), | ||||
| 		APIKey:          strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), | ||||
| 		RequestURLPath:  c.Request.URL.String(), | ||||
| 		SystemPrompt:    c.GetString(ctxkey.SystemPrompt), | ||||
| 		StartTime:       time.Now(), | ||||
| 	} | ||||
| 	cfg, ok := c.Get(ctxkey.Config) | ||||
| 	if ok { | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package model | ||||
|  | ||||
| const ( | ||||
| 	ContentTypeText     = "text" | ||||
| 	ContentTypeImageURL = "image_url" | ||||
| 	ContentTypeText       = "text" | ||||
| 	ContentTypeImageURL   = "image_url" | ||||
| 	ContentTypeInputAudio = "input_audio" | ||||
| ) | ||||
|   | ||||
| @@ -1,34 +1,70 @@ | ||||
| package model | ||||
|  | ||||
| type ResponseFormat struct { | ||||
| 	Type string `json:"type,omitempty"` | ||||
| 	Type       string      `json:"type,omitempty"` | ||||
| 	JsonSchema *JSONSchema `json:"json_schema,omitempty"` | ||||
| } | ||||
|  | ||||
| type JSONSchema struct { | ||||
| 	Description string                 `json:"description,omitempty"` | ||||
| 	Name        string                 `json:"name"` | ||||
| 	Schema      map[string]interface{} `json:"schema,omitempty"` | ||||
| 	Strict      *bool                  `json:"strict,omitempty"` | ||||
| } | ||||
|  | ||||
| type Audio struct { | ||||
| 	Voice  string `json:"voice,omitempty"` | ||||
| 	Format string `json:"format,omitempty"` | ||||
| } | ||||
|  | ||||
| type StreamOptions struct { | ||||
| 	IncludeUsage bool `json:"include_usage,omitempty"` | ||||
| } | ||||
|  | ||||
| type GeneralOpenAIRequest struct { | ||||
| 	Messages         []Message       `json:"messages,omitempty"` | ||||
| 	Model            string          `json:"model,omitempty"` | ||||
| 	FrequencyPenalty float64         `json:"frequency_penalty,omitempty"` | ||||
| 	MaxTokens        int             `json:"max_tokens,omitempty"` | ||||
| 	N                int             `json:"n,omitempty"` | ||||
| 	PresencePenalty  float64         `json:"presence_penalty,omitempty"` | ||||
| 	ResponseFormat   *ResponseFormat `json:"response_format,omitempty"` | ||||
| 	Seed             float64         `json:"seed,omitempty"` | ||||
| 	Stop             any             `json:"stop,omitempty"` | ||||
| 	Stream           bool            `json:"stream,omitempty"` | ||||
| 	Temperature      float64         `json:"temperature,omitempty"` | ||||
| 	TopP             float64         `json:"top_p,omitempty"` | ||||
| 	TopK             int             `json:"top_k,omitempty"` | ||||
| 	Tools            []Tool          `json:"tools,omitempty"` | ||||
| 	ToolChoice       any             `json:"tool_choice,omitempty"` | ||||
| 	FunctionCall     any             `json:"function_call,omitempty"` | ||||
| 	Functions        any             `json:"functions,omitempty"` | ||||
| 	User             string          `json:"user,omitempty"` | ||||
| 	Prompt           any             `json:"prompt,omitempty"` | ||||
| 	Input            any             `json:"input,omitempty"` | ||||
| 	EncodingFormat   string          `json:"encoding_format,omitempty"` | ||||
| 	Dimensions       int             `json:"dimensions,omitempty"` | ||||
| 	Instruction      string          `json:"instruction,omitempty"` | ||||
| 	Size             string          `json:"size,omitempty"` | ||||
| 	// https://platform.openai.com/docs/api-reference/chat/create | ||||
| 	Messages            []Message       `json:"messages,omitempty"` | ||||
| 	Model               string          `json:"model,omitempty"` | ||||
| 	Store               *bool           `json:"store,omitempty"` | ||||
| 	Metadata            any             `json:"metadata,omitempty"` | ||||
| 	FrequencyPenalty    *float64        `json:"frequency_penalty,omitempty"` | ||||
| 	LogitBias           any             `json:"logit_bias,omitempty"` | ||||
| 	Logprobs            *bool           `json:"logprobs,omitempty"` | ||||
| 	TopLogprobs         *int            `json:"top_logprobs,omitempty"` | ||||
| 	MaxTokens           int             `json:"max_tokens,omitempty"` | ||||
| 	MaxCompletionTokens *int            `json:"max_completion_tokens,omitempty"` | ||||
| 	N                   int             `json:"n,omitempty"` | ||||
| 	Modalities          []string        `json:"modalities,omitempty"` | ||||
| 	Prediction          any             `json:"prediction,omitempty"` | ||||
| 	Audio               *Audio          `json:"audio,omitempty"` | ||||
| 	PresencePenalty     *float64        `json:"presence_penalty,omitempty"` | ||||
| 	ResponseFormat      *ResponseFormat `json:"response_format,omitempty"` | ||||
| 	Seed                float64         `json:"seed,omitempty"` | ||||
| 	ServiceTier         *string         `json:"service_tier,omitempty"` | ||||
| 	Stop                any             `json:"stop,omitempty"` | ||||
| 	Stream              bool            `json:"stream,omitempty"` | ||||
| 	StreamOptions       *StreamOptions  `json:"stream_options,omitempty"` | ||||
| 	Temperature         *float64        `json:"temperature,omitempty"` | ||||
| 	TopP                *float64        `json:"top_p,omitempty"` | ||||
| 	TopK                int             `json:"top_k,omitempty"` | ||||
| 	Tools               []Tool          `json:"tools,omitempty"` | ||||
| 	ToolChoice          any             `json:"tool_choice,omitempty"` | ||||
| 	ParallelTooCalls    *bool           `json:"parallel_tool_calls,omitempty"` | ||||
| 	User                string          `json:"user,omitempty"` | ||||
| 	FunctionCall        any             `json:"function_call,omitempty"` | ||||
| 	Functions           any             `json:"functions,omitempty"` | ||||
| 	// https://platform.openai.com/docs/api-reference/embeddings/create | ||||
| 	Input          any    `json:"input,omitempty"` | ||||
| 	EncodingFormat string `json:"encoding_format,omitempty"` | ||||
| 	Dimensions     int    `json:"dimensions,omitempty"` | ||||
| 	// https://platform.openai.com/docs/api-reference/images/create | ||||
| 	Prompt  any     `json:"prompt,omitempty"` | ||||
| 	Quality *string `json:"quality,omitempty"` | ||||
| 	Size    string  `json:"size,omitempty"` | ||||
| 	Style   *string `json:"style,omitempty"` | ||||
| 	// Others | ||||
| 	Instruction string `json:"instruction,omitempty"` | ||||
| 	NumCtx      int    `json:"num_ctx,omitempty"` | ||||
| } | ||||
|  | ||||
| func (r GeneralOpenAIRequest) ParseInput() []string { | ||||
|   | ||||
| @@ -23,6 +23,7 @@ func SetApiRouter(router *gin.Engine) { | ||||
| 		apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) | ||||
| 		apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) | ||||
| 		apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth) | ||||
| 		apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), auth.OidcAuth) | ||||
| 		apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth) | ||||
| 		apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode) | ||||
| 		apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth) | ||||
|   | ||||
| @@ -9,6 +9,7 @@ import ( | ||||
|  | ||||
| func SetRelayRouter(router *gin.Engine) { | ||||
| 	router.Use(middleware.CORS()) | ||||
| 	router.Use(middleware.GzipDecodeMiddleware()) | ||||
| 	// https://platform.openai.com/docs/api-reference/introduction | ||||
| 	modelsRouter := router.Group("/v1/models") | ||||
| 	modelsRouter.Use(middleware.TokenAuth()) | ||||
|   | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user