mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-25 02:43:41 +08:00 
			
		
		
		
	Compare commits
	
		
			18 Commits
		
	
	
		
			v0.6.8-alp
			...
			v0.6.8-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | f636c50c84 | ||
|  | 720fe2dfeb | ||
|  | e090e76c86 | ||
|  | 6a941748f8 | ||
|  | 46a0773580 | ||
|  | ffdb0b0c81 | ||
|  | efd30a40b3 | ||
|  | d7a78f3397 | ||
|  | 273be55797 | ||
|  | ec6ad24810 | ||
|  | c4fe57c165 | ||
|  | 274fcf3d76 | ||
|  | 0fc07ea558 | ||
|  | 1ce1e529ee | ||
|  | d936817de9 | ||
|  | fecaece71b | ||
|  | c135d74f13 | ||
|  | d0369b114f | 
							
								
								
									
										16
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										16
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							| @@ -36,21 +36,9 @@ jobs: | ||||
|       # in the next step as well as the next job. | ||||
|       - name: Test | ||||
|         run: go test -cover -coverprofile=coverage.txt ./... | ||||
|  | ||||
|       - name: Archive code coverage results | ||||
|         uses: actions/upload-artifact@v4 | ||||
|       - uses: codecov/codecov-action@v4 | ||||
|         with: | ||||
|           name: code-coverage | ||||
|           path: coverage.txt # Make sure to use the same file name you chose for the "-coverprofile" in the "Test" step | ||||
|  | ||||
|   code_coverage: | ||||
|     name: "Code coverage report" | ||||
|     runs-on: ubuntu-latest | ||||
|     needs: unit_tests # Depends on the artifact uploaded by the "unit_tests" job | ||||
|     steps: | ||||
|     - uses: codecov/codecov-action@v4 | ||||
|       with: | ||||
|         token: ${{ secrets.CODECOV_TOKEN }} | ||||
|           token: ${{ secrets.CODECOV_TOKEN }} | ||||
|  | ||||
|   commit_lint: | ||||
|     runs-on: ubuntu-latest | ||||
|   | ||||
| @@ -16,7 +16,9 @@ WORKDIR /web/air | ||||
| RUN npm install | ||||
| RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | ||||
|  | ||||
| FROM golang AS builder2 | ||||
| FROM golang:alpine AS builder2 | ||||
|  | ||||
| RUN apk add --no-cache g++ | ||||
|  | ||||
| ENV GO111MODULE=on \ | ||||
|     CGO_ENABLED=1 \ | ||||
| @@ -27,7 +29,7 @@ ADD go.mod go.sum ./ | ||||
| RUN go mod download | ||||
| COPY . . | ||||
| COPY --from=builder /web/build ./web/build | ||||
| RUN go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api | ||||
| RUN go build -trimpath -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api | ||||
|  | ||||
| FROM alpine | ||||
|  | ||||
|   | ||||
							
								
								
									
										39
									
								
								README.en.md
									
									
									
									
									
								
							
							
						
						
									
										39
									
								
								README.en.md
									
									
									
									
									
								
							| @@ -245,16 +245,41 @@ If the channel ID is not provided, load balancing will be used to distribute the | ||||
|     + Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs` | ||||
| 5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. | ||||
|     + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||
| 6. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. | ||||
| 6. 'MEMORY_CACHE_ENABLED': Enabling memory caching can cause a certain delay in updating user quotas, with optional values of 'true' and 'false'. If not set, it defaults to 'false'. | ||||
| 7. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. | ||||
|     + Example: `SYNC_FREQUENCY=60` | ||||
| 7. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. | ||||
| 8. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. | ||||
|     + Example: `NODE_TYPE=slave` | ||||
| 8. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. | ||||
| 9. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. | ||||
|     + Example: `CHANNEL_UPDATE_FREQUENCY=1440` | ||||
| 9. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. | ||||
| 10. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. | ||||
|     + Example: `CHANNEL_TEST_FREQUENCY=1440` | ||||
| 10. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. | ||||
| 11. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. | ||||
|     + Example: `POLLING_INTERVAL=5` | ||||
| 12. `BATCH_UPDATE_ENABLED`: Enabling batch database update aggregation can cause a certain delay in updating user quotas. The optional values are 'true' and 'false', but if not set, it defaults to 'false'. | ||||
|     +Example: ` BATCH_UPDATE_ENABLED=true` | ||||
|     +If you encounter an issue with too many database connections, you can try enabling this option. | ||||
| 13. `BATCH_UPDATE_INTERVAL=5`: The time interval for batch updating aggregates, measured in seconds, defaults to '5'. | ||||
|     +Example: ` BATCH_UPDATE_INTERVAL=5` | ||||
| 14. Request frequency limit: | ||||
|     + `GLOBAL_API_RATE_LIMIT`: Global API rate limit (excluding relay requests), the maximum number of requests within three minutes per IP, default to 180. | ||||
|     + `GLOBAL_WEL_RATE_LIMIT`: Global web speed limit, the maximum number of requests within three minutes per IP, default to 60. | ||||
| 15. Encoder cache settings: | ||||
|     +`TIKTOKEN_CACHE_DIR`: By default, when the program starts, it will download the encoding of some common word elements online, such as' gpt-3.5 turbo '. In some unstable network environments or offline situations, it may cause startup problems. This directory can be configured to cache data and can be migrated to an offline environment. | ||||
|     +`DATA_GYM_CACHE_DIR`: Currently, this configuration has the same function as' TIKTOKEN-CACHE-DIR ', but its priority is not as high as it. | ||||
| 16. `RELAY_TIMEOUT`: Relay timeout setting, measured in seconds, with no default timeout time set. | ||||
| 17. `RELAY_PROXY`: After setting up, use this proxy to request APIs. | ||||
| 18. `USER_CONTENT_REQUEST_TIMEOUT`: The timeout period for users to upload and download content, measured in seconds. | ||||
| 19. `USER_CONTENT_REQUEST_PROXY`: After setting up, use this agent to request content uploaded by users, such as images. | ||||
| 20. `SQLITE_BUSY_TIMEOUT`: SQLite lock wait timeout setting, measured in milliseconds, default to '3000'. | ||||
| 21. `GEMINI_SAFETY_SETTING`: Gemini's security settings are set to 'BLOCK-NONE' by default. | ||||
| 22. `GEMINI_VERSION`: The Gemini version used by the One API, which defaults to 'v1'. | ||||
| 23. `THE`: The system's theme setting, default to 'default', specific optional values refer to [here] (./web/README. md). | ||||
| 24. `ENABLE_METRIC`: Whether to disable channels based on request success rate, default not enabled, optional values are 'true' and 'false'. | ||||
| 25. `METRIC_QUEUE_SIZE`: Request success rate statistics queue size, default to '10'. | ||||
| 26. `METRIC_SUCCESS_RATE_THRESHOLD`: Request success rate threshold, default to '0.8'. | ||||
| 27. `INITIAL_ROOT_TOKEN`: If this value is set, a root user token with the value of the environment variable will be automatically created when the system starts for the first time. | ||||
| 28. `INITIAL_ROOT_ACCESS_TOKEN`: If this value is set, a system management token will be automatically created for the root user with a value of the environment variable when the system starts for the first time. | ||||
|  | ||||
| ### Command Line Parameters | ||||
| 1. `--port <port_number>`: Specifies the port number on which the server listens. Defaults to `3000`. | ||||
| @@ -287,7 +312,9 @@ If the channel ID is not provided, load balancing will be used to distribute the | ||||
|     + Double-check that your interface address and API Key are correct. | ||||
|  | ||||
| ## Related Projects | ||||
| [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM | ||||
| * [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM | ||||
| * [VChart](https://github.com/VisActor/VChart):  More than just a cross-platform charting library, but also an expressive data storyteller. | ||||
| * [VMind](https://github.com/VisActor/VMind):  Not just automatic, but also fantastic. Open-source solution for intelligent visualization. | ||||
|  | ||||
| ## Note | ||||
| This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes. | ||||
|   | ||||
							
								
								
									
										40
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										40
									
								
								README.md
									
									
									
									
									
								
							| @@ -88,6 +88,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|    + [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) | ||||
|    + [x] [DeepL](https://www.deepl.com/) | ||||
|    + [x] [together.ai](https://www.together.ai/) | ||||
|    + [x] [novita.ai](https://www.novita.ai/) | ||||
| 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | ||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||
| @@ -370,32 +371,33 @@ graph LR | ||||
| 9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | ||||
|    + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` | ||||
| 10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。  | ||||
| 11. 例子:`CHANNEL_TEST_FREQUENCY=1440` | ||||
| 12. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | ||||
|    +例子:`CHANNEL_TEST_FREQUENCY=1440` | ||||
| 11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | ||||
|     + 例子:`POLLING_INTERVAL=5` | ||||
| 13. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||
| 12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||
|     + 例子:`BATCH_UPDATE_ENABLED=true` | ||||
|     + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 | ||||
| 14. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | ||||
| 13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | ||||
|     + 例子:`BATCH_UPDATE_INTERVAL=5` | ||||
| 15. 请求频率限制: | ||||
| 14. 请求频率限制: | ||||
|     + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 | ||||
|     + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 | ||||
| 16. 编码器缓存设置: | ||||
| 15. 编码器缓存设置: | ||||
|     + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 | ||||
|     + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 | ||||
| 17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 | ||||
| 18. `RELAY_PROXY`:设置后使用该代理来请求 API。 | ||||
| 19. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。 | ||||
| 20. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。 | ||||
| 21. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 | ||||
| 22. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 | ||||
| 23. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。 | ||||
| 24. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 | ||||
| 25. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 | ||||
| 26. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 | ||||
| 27. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 | ||||
| 28. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 | ||||
| 16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 | ||||
| 17. `RELAY_PROXY`:设置后使用该代理来请求 API。 | ||||
| 18. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。 | ||||
| 19. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。 | ||||
| 20. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 | ||||
| 21. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 | ||||
| 22. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。 | ||||
| 23. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 | ||||
| 24. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 | ||||
| 25. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 | ||||
| 26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 | ||||
| 27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 | ||||
| 28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。 | ||||
|  | ||||
| ### 命令行参数 | ||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||
| @@ -448,6 +450,8 @@ https://openai.justsong.cn | ||||
| ## 相关项目 | ||||
| * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 | ||||
| * [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web):  一键拥有你自己的跨平台 ChatGPT 应用 | ||||
| * [VChart](https://github.com/VisActor/VChart):  不只是开箱即用的多端图表库,更是生动灵活的数据故事讲述者。 | ||||
| * [VMind](https://github.com/VisActor/VMind):  不仅自动,还很智能。开源智能可视化解决方案。 | ||||
|  | ||||
| ## 注意 | ||||
|  | ||||
|   | ||||
| @@ -143,8 +143,13 @@ var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128) | ||||
|  | ||||
| var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN") | ||||
|  | ||||
| var InitialRootAccessToken = os.Getenv("INITIAL_ROOT_ACCESS_TOKEN") | ||||
|  | ||||
| var GeminiVersion = env.String("GEMINI_VERSION", "v1") | ||||
|  | ||||
|  | ||||
| var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false) | ||||
|  | ||||
| var RelayProxy = env.String("RELAY_PROXY", "") | ||||
| var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") | ||||
| var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) | ||||
|   | ||||
| @@ -27,7 +27,12 @@ var setupLogOnce sync.Once | ||||
| func SetupLogger() { | ||||
| 	setupLogOnce.Do(func() { | ||||
| 		if LogDir != "" { | ||||
| 			logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) | ||||
| 			var logPath string | ||||
| 			if config.OnlyOneLogFile { | ||||
| 				logPath = filepath.Join(LogDir, "oneapi.log") | ||||
| 			} else { | ||||
| 				logPath = filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) | ||||
| 			} | ||||
| 			fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | ||||
| 			if err != nil { | ||||
| 				log.Fatal("failed to open log file") | ||||
|   | ||||
| @@ -6,11 +6,16 @@ import ( | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"net" | ||||
| 	"net/smtp" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func shouldAuth() bool { | ||||
| 	return config.SMTPAccount != "" || config.SMTPToken != "" | ||||
| } | ||||
|  | ||||
| func SendEmail(subject string, receiver string, content string) error { | ||||
| 	if receiver == "" { | ||||
| 		return fmt.Errorf("receiver is empty") | ||||
| @@ -41,16 +46,24 @@ func SendEmail(subject string, receiver string, content string) error { | ||||
| 		"Date: %s\r\n"+ | ||||
| 		"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", | ||||
| 		receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) | ||||
|  | ||||
| 	auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer) | ||||
| 	addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort) | ||||
| 	to := strings.Split(receiver, ";") | ||||
|  | ||||
| 	if config.SMTPPort == 465 { | ||||
| 		tlsConfig := &tls.Config{ | ||||
| 			InsecureSkipVerify: true, | ||||
| 			ServerName:         config.SMTPServer, | ||||
| 	if config.SMTPPort == 465 || !shouldAuth() { | ||||
| 		// need advanced client | ||||
| 		var conn net.Conn | ||||
| 		var err error | ||||
| 		if config.SMTPPort == 465 { | ||||
| 			tlsConfig := &tls.Config{ | ||||
| 				InsecureSkipVerify: true, | ||||
| 				ServerName:         config.SMTPServer, | ||||
| 			} | ||||
| 			conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) | ||||
| 		} else { | ||||
| 			conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)) | ||||
| 		} | ||||
| 		conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @@ -59,8 +72,10 @@ func SendEmail(subject string, receiver string, content string) error { | ||||
| 			return err | ||||
| 		} | ||||
| 		defer client.Close() | ||||
| 		if err = client.Auth(auth); err != nil { | ||||
| 			return err | ||||
| 		if shouldAuth() { | ||||
| 			if err = client.Auth(auth); err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 		if err = client.Mail(config.SMTPFrom); err != nil { | ||||
| 			return err | ||||
|   | ||||
| @@ -14,6 +14,7 @@ import ( | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| @@ -27,15 +28,15 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func buildTestRequest() *relaymodel.GeneralOpenAIRequest { | ||||
| func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest { | ||||
| 	if model == "" { | ||||
| 		model = "gpt-3.5-turbo" | ||||
| 	} | ||||
| 	testRequest := &relaymodel.GeneralOpenAIRequest{ | ||||
| 		MaxTokens: 2, | ||||
| 		Stream:    false, | ||||
| 		Model:     "gpt-3.5-turbo", | ||||
| 		Model:     model, | ||||
| 	} | ||||
| 	testMessage := relaymodel.Message{ | ||||
| 		Role:    "user", | ||||
| @@ -45,7 +46,7 @@ func buildTestRequest() *relaymodel.GeneralOpenAIRequest { | ||||
| 	return testRequest | ||||
| } | ||||
|  | ||||
| func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) { | ||||
| func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) { | ||||
| 	w := httptest.NewRecorder() | ||||
| 	c, _ := gin.CreateTestContext(w) | ||||
| 	c.Request = &http.Request{ | ||||
| @@ -68,12 +69,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error | ||||
| 		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil | ||||
| 	} | ||||
| 	adaptor.Init(meta) | ||||
| 	var modelName string | ||||
| 	modelList := adaptor.GetModelList() | ||||
| 	modelName := request.Model | ||||
| 	modelMap := channel.GetModelMapping() | ||||
| 	if len(modelList) != 0 { | ||||
| 		modelName = modelList[0] | ||||
| 	} | ||||
| 	if modelName == "" || !strings.Contains(channel.Models, modelName) { | ||||
| 		modelNames := strings.Split(channel.Models, ",") | ||||
| 		if len(modelNames) > 0 { | ||||
| @@ -83,9 +80,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error | ||||
| 			modelName = modelMap[modelName] | ||||
| 		} | ||||
| 	} | ||||
| 	request := buildTestRequest() | ||||
| 	meta.OriginModelName, meta.ActualModelName = request.Model, modelName | ||||
| 	request.Model = modelName | ||||
| 	meta.OriginModelName, meta.ActualModelName = modelName, modelName | ||||
| 	convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| @@ -139,10 +135,15 @@ func TestChannel(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	model := c.Query("model") | ||||
| 	testRequest := buildTestRequest(model) | ||||
| 	tik := time.Now() | ||||
| 	err, _ = testChannel(channel) | ||||
| 	err, _ = testChannel(channel, testRequest) | ||||
| 	tok := time.Now() | ||||
| 	milliseconds := tok.Sub(tik).Milliseconds() | ||||
| 	if err != nil { | ||||
| 		milliseconds = 0 | ||||
| 	} | ||||
| 	go channel.UpdateResponseTime(milliseconds) | ||||
| 	consumedTime := float64(milliseconds) / 1000.0 | ||||
| 	if err != nil { | ||||
| @@ -150,6 +151,7 @@ func TestChannel(c *gin.Context) { | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 			"time":    consumedTime, | ||||
| 			"model":   model, | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| @@ -157,6 +159,7 @@ func TestChannel(c *gin.Context) { | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"time":    consumedTime, | ||||
| 		"model":   model, | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
| @@ -187,11 +190,12 @@ func testChannels(notify bool, scope string) error { | ||||
| 		for _, channel := range channels { | ||||
| 			isChannelEnabled := channel.Status == model.ChannelStatusEnabled | ||||
| 			tik := time.Now() | ||||
| 			err, openaiErr := testChannel(channel) | ||||
| 			testRequest := buildTestRequest("") | ||||
| 			err, openaiErr := testChannel(channel, testRequest) | ||||
| 			tok := time.Now() | ||||
| 			milliseconds := tok.Sub(tik).Milliseconds() | ||||
| 			if isChannelEnabled && milliseconds > disableThreshold { | ||||
| 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||
| 				err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0) | ||||
| 				if config.AutomaticDisableChannelEnabled { | ||||
| 					monitor.DisableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 				} else { | ||||
|   | ||||
							
								
								
									
										2
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
									
									
									
									
								
							| @@ -68,7 +68,7 @@ require ( | ||||
| 	github.com/kr/text v0.2.0 // indirect | ||||
| 	github.com/leodido/go-urn v1.4.0 // indirect | ||||
| 	github.com/mattn/go-isatty v0.0.20 // indirect | ||||
| 	github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect | ||||
| 	github.com/mattn/go-sqlite3 v1.14.22 // indirect | ||||
| 	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect | ||||
| 	github.com/modern-go/reflect2 v1.0.2 // indirect | ||||
| 	github.com/pelletier/go-toml/v2 v2.2.2 // indirect | ||||
|   | ||||
							
								
								
									
										4
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.sum
									
									
									
									
									
								
							| @@ -110,8 +110,8 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= | ||||
| github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= | ||||
| github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= | ||||
| github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= | ||||
| github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= | ||||
| github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= | ||||
| github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= | ||||
| github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= | ||||
| github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= | ||||
| github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= | ||||
| github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= | ||||
|   | ||||
							
								
								
									
										22
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								main.go
									
									
									
									
									
								
							| @@ -27,27 +27,19 @@ func main() { | ||||
| 	common.Init() | ||||
| 	logger.SetupLogger() | ||||
| 	logger.SysLogf("One API %s started", common.Version) | ||||
| 	if os.Getenv("GIN_MODE") != "debug" { | ||||
|  | ||||
| 	if os.Getenv("GIN_MODE") != gin.DebugMode { | ||||
| 		gin.SetMode(gin.ReleaseMode) | ||||
| 	} | ||||
| 	if config.DebugEnabled { | ||||
| 		logger.SysLog("running in debug mode") | ||||
| 	} | ||||
| 	var err error | ||||
|  | ||||
| 	// Initialize SQL Database | ||||
| 	model.DB, err = model.InitDB("SQL_DSN") | ||||
| 	if err != nil { | ||||
| 		logger.FatalLog("failed to initialize database: " + err.Error()) | ||||
| 	} | ||||
| 	if os.Getenv("LOG_SQL_DSN") != "" { | ||||
| 		logger.SysLog("using secondary database for table logs") | ||||
| 		model.LOG_DB, err = model.InitDB("LOG_SQL_DSN") | ||||
| 		if err != nil { | ||||
| 			logger.FatalLog("failed to initialize secondary database: " + err.Error()) | ||||
| 		} | ||||
| 	} else { | ||||
| 		model.LOG_DB = model.DB | ||||
| 	} | ||||
| 	model.InitDB() | ||||
| 	model.InitLogDB() | ||||
|  | ||||
| 	var err error | ||||
| 	err = model.CreateRootAccountIfNeed() | ||||
| 	if err != nil { | ||||
| 		logger.FatalLog("database init error: " + err.Error()) | ||||
|   | ||||
							
								
								
									
										225
									
								
								model/main.go
									
									
									
									
									
								
							
							
						
						
									
										225
									
								
								model/main.go
									
									
									
									
									
								
							| @@ -1,6 +1,7 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| @@ -29,13 +30,17 @@ func CreateRootAccountIfNeed() error { | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		accessToken := random.GetUUID() | ||||
| 		if config.InitialRootAccessToken != "" { | ||||
| 			accessToken = config.InitialRootAccessToken | ||||
| 		} | ||||
| 		rootUser := User{ | ||||
| 			Username:    "root", | ||||
| 			Password:    hashedPassword, | ||||
| 			Role:        RoleRootUser, | ||||
| 			Status:      UserStatusEnabled, | ||||
| 			DisplayName: "Root User", | ||||
| 			AccessToken: random.GetUUID(), | ||||
| 			AccessToken: accessToken, | ||||
| 			Quota:       500000000000000, | ||||
| 		} | ||||
| 		DB.Create(&rootUser) | ||||
| @@ -60,90 +65,156 @@ func CreateRootAccountIfNeed() error { | ||||
| } | ||||
|  | ||||
| func chooseDB(envName string) (*gorm.DB, error) { | ||||
| 	if os.Getenv(envName) != "" { | ||||
| 		dsn := os.Getenv(envName) | ||||
| 		if strings.HasPrefix(dsn, "postgres://") { | ||||
| 			// Use PostgreSQL | ||||
| 			logger.SysLog("using PostgreSQL as database") | ||||
| 			common.UsingPostgreSQL = true | ||||
| 			return gorm.Open(postgres.New(postgres.Config{ | ||||
| 				DSN:                  dsn, | ||||
| 				PreferSimpleProtocol: true, // disables implicit prepared statement usage | ||||
| 			}), &gorm.Config{ | ||||
| 				PrepareStmt: true, // precompile SQL | ||||
| 			}) | ||||
| 		} | ||||
| 	dsn := os.Getenv(envName) | ||||
|  | ||||
| 	switch { | ||||
| 	case strings.HasPrefix(dsn, "postgres://"): | ||||
| 		// Use PostgreSQL | ||||
| 		return openPostgreSQL(dsn) | ||||
| 	case dsn != "": | ||||
| 		// Use MySQL | ||||
| 		logger.SysLog("using MySQL as database") | ||||
| 		common.UsingMySQL = true | ||||
| 		return gorm.Open(mysql.Open(dsn), &gorm.Config{ | ||||
| 			PrepareStmt: true, // precompile SQL | ||||
| 		}) | ||||
| 		return openMySQL(dsn) | ||||
| 	default: | ||||
| 		// Use SQLite | ||||
| 		return openSQLite() | ||||
| 	} | ||||
| 	// Use SQLite | ||||
| 	logger.SysLog("SQL_DSN not set, using SQLite as database") | ||||
| 	common.UsingSQLite = true | ||||
| 	config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) | ||||
| 	return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ | ||||
| } | ||||
|  | ||||
| func openPostgreSQL(dsn string) (*gorm.DB, error) { | ||||
| 	logger.SysLog("using PostgreSQL as database") | ||||
| 	common.UsingPostgreSQL = true | ||||
| 	return gorm.Open(postgres.New(postgres.Config{ | ||||
| 		DSN:                  dsn, | ||||
| 		PreferSimpleProtocol: true, // disables implicit prepared statement usage | ||||
| 	}), &gorm.Config{ | ||||
| 		PrepareStmt: true, // precompile SQL | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func InitDB(envName string) (db *gorm.DB, err error) { | ||||
| 	db, err = chooseDB(envName) | ||||
| 	if err == nil { | ||||
| 		if config.DebugSQLEnabled { | ||||
| 			db = db.Debug() | ||||
| 		} | ||||
| 		sqlDB, err := db.DB() | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) | ||||
| 		sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) | ||||
| 		sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) | ||||
| func openMySQL(dsn string) (*gorm.DB, error) { | ||||
| 	logger.SysLog("using MySQL as database") | ||||
| 	common.UsingMySQL = true | ||||
| 	return gorm.Open(mysql.Open(dsn), &gorm.Config{ | ||||
| 		PrepareStmt: true, // precompile SQL | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| 		if !config.IsMasterNode { | ||||
| 			return db, err | ||||
| 		} | ||||
| 		if common.UsingMySQL { | ||||
| 			_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded | ||||
| 		} | ||||
| 		logger.SysLog("database migration started") | ||||
| 		err = db.AutoMigrate(&Channel{}) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = db.AutoMigrate(&Token{}) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = db.AutoMigrate(&User{}) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = db.AutoMigrate(&Option{}) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = db.AutoMigrate(&Redemption{}) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = db.AutoMigrate(&Ability{}) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = db.AutoMigrate(&Log{}) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		logger.SysLog("database migrated") | ||||
| 		return db, err | ||||
| 	} else { | ||||
| 		logger.FatalLog(err) | ||||
| func openSQLite() (*gorm.DB, error) { | ||||
| 	logger.SysLog("SQL_DSN not set, using SQLite as database") | ||||
| 	common.UsingSQLite = true | ||||
| 	dsn := fmt.Sprintf("%s?_busy_timeout=%d", common.SQLitePath, common.SQLiteBusyTimeout) | ||||
| 	return gorm.Open(sqlite.Open(dsn), &gorm.Config{ | ||||
| 		PrepareStmt: true, // precompile SQL | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func InitDB() { | ||||
| 	var err error | ||||
| 	DB, err = chooseDB("SQL_DSN") | ||||
| 	if err != nil { | ||||
| 		logger.FatalLog("failed to initialize database: " + err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	return db, err | ||||
|  | ||||
| 	sqlDB := setDBConns(DB) | ||||
|  | ||||
| 	if !config.IsMasterNode { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if common.UsingMySQL { | ||||
| 		_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded | ||||
| 	} | ||||
|  | ||||
| 	logger.SysLog("database migration started") | ||||
| 	if err = migrateDB(); err != nil { | ||||
| 		logger.FatalLog("failed to migrate database: " + err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	logger.SysLog("database migrated") | ||||
| } | ||||
|  | ||||
| func migrateDB() error { | ||||
| 	var err error | ||||
| 	if err = DB.AutoMigrate(&Channel{}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err = DB.AutoMigrate(&Token{}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err = DB.AutoMigrate(&User{}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err = DB.AutoMigrate(&Option{}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err = DB.AutoMigrate(&Redemption{}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err = DB.AutoMigrate(&Ability{}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err = DB.AutoMigrate(&Log{}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err = DB.AutoMigrate(&Channel{}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func InitLogDB() { | ||||
| 	if os.Getenv("LOG_SQL_DSN") == "" { | ||||
| 		LOG_DB = DB | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	logger.SysLog("using secondary database for table logs") | ||||
| 	var err error | ||||
| 	LOG_DB, err = chooseDB("LOG_SQL_DSN") | ||||
| 	if err != nil { | ||||
| 		logger.FatalLog("failed to initialize secondary database: " + err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	setDBConns(LOG_DB) | ||||
|  | ||||
| 	if !config.IsMasterNode { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	logger.SysLog("secondary database migration started") | ||||
| 	err = migrateLOGDB() | ||||
| 	if err != nil { | ||||
| 		logger.FatalLog("failed to migrate secondary database: " + err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	logger.SysLog("secondary database migrated") | ||||
| } | ||||
|  | ||||
| func migrateLOGDB() error { | ||||
| 	var err error | ||||
| 	if err = LOG_DB.AutoMigrate(&Log{}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func setDBConns(db *gorm.DB) *sql.DB { | ||||
| 	if config.DebugSQLEnabled { | ||||
| 		db = db.Debug() | ||||
| 	} | ||||
|  | ||||
| 	sqlDB, err := db.DB() | ||||
| 	if err != nil { | ||||
| 		logger.FatalLog("failed to connect database: " + err.Error()) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) | ||||
| 	sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) | ||||
| 	sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) | ||||
| 	return sqlDB | ||||
| } | ||||
|  | ||||
| func closeDB(db *gorm.DB) error { | ||||
|   | ||||
| @@ -29,12 +29,30 @@ func stopReasonClaude2OpenAI(reason *string) string { | ||||
| 		return "stop" | ||||
| 	case "max_tokens": | ||||
| 		return "length" | ||||
| 	case "tool_use": | ||||
| 		return "tool_calls" | ||||
| 	default: | ||||
| 		return *reason | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 	claudeTools := make([]Tool, 0, len(textRequest.Tools)) | ||||
|  | ||||
| 	for _, tool := range textRequest.Tools { | ||||
| 		if params, ok := tool.Function.Parameters.(map[string]any); ok { | ||||
| 			claudeTools = append(claudeTools, Tool{ | ||||
| 				Name:        tool.Function.Name, | ||||
| 				Description: tool.Function.Description, | ||||
| 				InputSchema: InputSchema{ | ||||
| 					Type:       params["type"].(string), | ||||
| 					Properties: params["properties"], | ||||
| 					Required:   params["required"], | ||||
| 				}, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	claudeRequest := Request{ | ||||
| 		Model:       textRequest.Model, | ||||
| 		MaxTokens:   textRequest.MaxTokens, | ||||
| @@ -42,6 +60,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 		TopP:        textRequest.TopP, | ||||
| 		TopK:        textRequest.TopK, | ||||
| 		Stream:      textRequest.Stream, | ||||
| 		Tools:       claudeTools, | ||||
| 	} | ||||
| 	if len(claudeTools) > 0 { | ||||
| 		claudeToolChoice := struct { | ||||
| 			Type string `json:"type"` | ||||
| 			Name string `json:"name,omitempty"` | ||||
| 		}{Type: "auto"} // default value https://docs.anthropic.com/en/docs/build-with-claude/tool-use#controlling-claudes-output | ||||
| 		if choice, ok := textRequest.ToolChoice.(map[string]any); ok { | ||||
| 			if function, ok := choice["function"].(map[string]any); ok { | ||||
| 				claudeToolChoice.Type = "tool" | ||||
| 				claudeToolChoice.Name = function["name"].(string) | ||||
| 			} | ||||
| 		} else if toolChoiceType, ok := textRequest.ToolChoice.(string); ok { | ||||
| 			if toolChoiceType == "any" { | ||||
| 				claudeToolChoice.Type = toolChoiceType | ||||
| 			} | ||||
| 		} | ||||
| 		claudeRequest.ToolChoice = claudeToolChoice | ||||
| 	} | ||||
| 	if claudeRequest.MaxTokens == 0 { | ||||
| 		claudeRequest.MaxTokens = 4096 | ||||
| @@ -64,7 +100,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 		if message.IsStringContent() { | ||||
| 			content.Type = "text" | ||||
| 			content.Text = message.StringContent() | ||||
| 			if message.Role == "tool" { | ||||
| 				claudeMessage.Role = "user" | ||||
| 				content.Type = "tool_result" | ||||
| 				content.Content = content.Text | ||||
| 				content.Text = "" | ||||
| 				content.ToolUseId = message.ToolCallId | ||||
| 			} | ||||
| 			claudeMessage.Content = append(claudeMessage.Content, content) | ||||
| 			for i := range message.ToolCalls { | ||||
| 				inputParam := make(map[string]any) | ||||
| 				_ = json.Unmarshal([]byte(message.ToolCalls[i].Function.Arguments.(string)), &inputParam) | ||||
| 				claudeMessage.Content = append(claudeMessage.Content, Content{ | ||||
| 					Type:  "tool_use", | ||||
| 					Id:    message.ToolCalls[i].Id, | ||||
| 					Name:  message.ToolCalls[i].Function.Name, | ||||
| 					Input: inputParam, | ||||
| 				}) | ||||
| 			} | ||||
| 			claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) | ||||
| 			continue | ||||
| 		} | ||||
| @@ -97,16 +150,35 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo | ||||
| 	var response *Response | ||||
| 	var responseText string | ||||
| 	var stopReason string | ||||
| 	tools := make([]model.Tool, 0) | ||||
|  | ||||
| 	switch claudeResponse.Type { | ||||
| 	case "message_start": | ||||
| 		return nil, claudeResponse.Message | ||||
| 	case "content_block_start": | ||||
| 		if claudeResponse.ContentBlock != nil { | ||||
| 			responseText = claudeResponse.ContentBlock.Text | ||||
| 			if claudeResponse.ContentBlock.Type == "tool_use" { | ||||
| 				tools = append(tools, model.Tool{ | ||||
| 					Id:   claudeResponse.ContentBlock.Id, | ||||
| 					Type: "function", | ||||
| 					Function: model.Function{ | ||||
| 						Name:      claudeResponse.ContentBlock.Name, | ||||
| 						Arguments: "", | ||||
| 					}, | ||||
| 				}) | ||||
| 			} | ||||
| 		} | ||||
| 	case "content_block_delta": | ||||
| 		if claudeResponse.Delta != nil { | ||||
| 			responseText = claudeResponse.Delta.Text | ||||
| 			if claudeResponse.Delta.Type == "input_json_delta" { | ||||
| 				tools = append(tools, model.Tool{ | ||||
| 					Function: model.Function{ | ||||
| 						Arguments: claudeResponse.Delta.PartialJson, | ||||
| 					}, | ||||
| 				}) | ||||
| 			} | ||||
| 		} | ||||
| 	case "message_delta": | ||||
| 		if claudeResponse.Usage != nil { | ||||
| @@ -120,6 +192,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo | ||||
| 	} | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = responseText | ||||
| 	if len(tools) > 0 { | ||||
| 		choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ... | ||||
| 		choice.Delta.ToolCalls = tools | ||||
| 	} | ||||
| 	choice.Delta.Role = "assistant" | ||||
| 	finishReason := stopReasonClaude2OpenAI(&stopReason) | ||||
| 	if finishReason != "null" { | ||||
| @@ -136,12 +212,27 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { | ||||
| 	if len(claudeResponse.Content) > 0 { | ||||
| 		responseText = claudeResponse.Content[0].Text | ||||
| 	} | ||||
| 	tools := make([]model.Tool, 0) | ||||
| 	for _, v := range claudeResponse.Content { | ||||
| 		if v.Type == "tool_use" { | ||||
| 			args, _ := json.Marshal(v.Input) | ||||
| 			tools = append(tools, model.Tool{ | ||||
| 				Id:   v.Id, | ||||
| 				Type: "function", // compatible with other OpenAI derivative applications | ||||
| 				Function: model.Function{ | ||||
| 					Name:      v.Name, | ||||
| 					Arguments: string(args), | ||||
| 				}, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: model.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: responseText, | ||||
| 			Name:    nil, | ||||
| 			Role:      "assistant", | ||||
| 			Content:   responseText, | ||||
| 			Name:      nil, | ||||
| 			ToolCalls: tools, | ||||
| 		}, | ||||
| 		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), | ||||
| 	} | ||||
| @@ -176,6 +267,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 	var usage model.Usage | ||||
| 	var modelName string | ||||
| 	var id string | ||||
| 	var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice | ||||
|  | ||||
| 	for scanner.Scan() { | ||||
| 		data := scanner.Text() | ||||
| @@ -196,9 +288,20 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 		if meta != nil { | ||||
| 			usage.PromptTokens += meta.Usage.InputTokens | ||||
| 			usage.CompletionTokens += meta.Usage.OutputTokens | ||||
| 			modelName = meta.Model | ||||
| 			id = fmt.Sprintf("chatcmpl-%s", meta.Id) | ||||
| 			continue | ||||
| 			if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event. | ||||
| 				modelName = meta.Model | ||||
| 				id = fmt.Sprintf("chatcmpl-%s", meta.Id) | ||||
| 				continue | ||||
| 			} else { // finish_reason case | ||||
| 				if len(lastToolCallChoice.Delta.ToolCalls) > 0 { | ||||
| 					lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function | ||||
| 					if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments. | ||||
| 						lastArgs.Arguments = "{}" | ||||
| 						response.Choices[len(response.Choices)-1].Delta.Content = nil | ||||
| 						response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		if response == nil { | ||||
| 			continue | ||||
| @@ -207,6 +310,12 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 		response.Id = id | ||||
| 		response.Model = modelName | ||||
| 		response.Created = createdTime | ||||
|  | ||||
| 		for _, choice := range response.Choices { | ||||
| 			if len(choice.Delta.ToolCalls) > 0 { | ||||
| 				lastToolCallChoice = choice | ||||
| 			} | ||||
| 		} | ||||
| 		err = render.ObjectData(c, response) | ||||
| 		if err != nil { | ||||
| 			logger.SysError(err.Error()) | ||||
|   | ||||
| @@ -16,6 +16,12 @@ type Content struct { | ||||
| 	Type   string       `json:"type"` | ||||
| 	Text   string       `json:"text,omitempty"` | ||||
| 	Source *ImageSource `json:"source,omitempty"` | ||||
| 	// tool_calls | ||||
| 	Id        string `json:"id,omitempty"` | ||||
| 	Name      string `json:"name,omitempty"` | ||||
| 	Input     any    `json:"input,omitempty"` | ||||
| 	Content   string `json:"content,omitempty"` | ||||
| 	ToolUseId string `json:"tool_use_id,omitempty"` | ||||
| } | ||||
|  | ||||
| type Message struct { | ||||
| @@ -23,6 +29,18 @@ type Message struct { | ||||
| 	Content []Content `json:"content"` | ||||
| } | ||||
|  | ||||
| type Tool struct { | ||||
| 	Name        string      `json:"name"` | ||||
| 	Description string      `json:"description,omitempty"` | ||||
| 	InputSchema InputSchema `json:"input_schema"` | ||||
| } | ||||
|  | ||||
| type InputSchema struct { | ||||
| 	Type       string `json:"type"` | ||||
| 	Properties any    `json:"properties,omitempty"` | ||||
| 	Required   any    `json:"required,omitempty"` | ||||
| } | ||||
|  | ||||
| type Request struct { | ||||
| 	Model         string    `json:"model"` | ||||
| 	Messages      []Message `json:"messages"` | ||||
| @@ -33,6 +51,8 @@ type Request struct { | ||||
| 	Temperature   float64   `json:"temperature,omitempty"` | ||||
| 	TopP          float64   `json:"top_p,omitempty"` | ||||
| 	TopK          int       `json:"top_k,omitempty"` | ||||
| 	Tools         []Tool    `json:"tools,omitempty"` | ||||
| 	ToolChoice    any       `json:"tool_choice,omitempty"` | ||||
| 	//Metadata    `json:"metadata,omitempty"` | ||||
| } | ||||
|  | ||||
| @@ -61,6 +81,7 @@ type Response struct { | ||||
| type Delta struct { | ||||
| 	Type         string  `json:"type"` | ||||
| 	Text         string  `json:"text"` | ||||
| 	PartialJson  string  `json:"partial_json,omitempty"` | ||||
| 	StopReason   *string `json:"stop_reason"` | ||||
| 	StopSequence *string `json:"stop_sequence"` | ||||
| } | ||||
|   | ||||
| @@ -1,17 +1,16 @@ | ||||
| package aws | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/aws/aws-sdk-go-v2/aws" | ||||
| 	"github.com/aws/aws-sdk-go-v2/credentials" | ||||
| 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/aws/aws-sdk-go-v2/aws" | ||||
| 	"github.com/aws/aws-sdk-go-v2/credentials" | ||||
| 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/anthropic" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/aws/utils" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
| @@ -19,18 +18,52 @@ import ( | ||||
| var _ adaptor.Adaptor = new(Adaptor) | ||||
| 
 | ||||
| type Adaptor struct { | ||||
| 	meta      *meta.Meta | ||||
| 	awsClient *bedrockruntime.Client | ||||
| 	awsAdapter utils.AwsAdapter | ||||
| 
 | ||||
| 	Meta      *meta.Meta | ||||
| 	AwsClient *bedrockruntime.Client | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| 	a.meta = meta | ||||
| 	a.awsClient = bedrockruntime.New(bedrockruntime.Options{ | ||||
| 	a.Meta = meta | ||||
| 	a.AwsClient = bedrockruntime.New(bedrockruntime.Options{ | ||||
| 		Region:      meta.Config.Region, | ||||
| 		Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")), | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 
 | ||||
| 	adaptor := GetAdaptor(request.Model) | ||||
| 	if adaptor == nil { | ||||
| 		return nil, errors.New("adaptor not found") | ||||
| 	} | ||||
| 
 | ||||
| 	a.awsAdapter = adaptor | ||||
| 	return adaptor.ConvertRequest(c, relayMode, request) | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if a.awsAdapter == nil { | ||||
| 		return nil, utils.WrapErr(errors.New("awsAdapter is nil")) | ||||
| 	} | ||||
| 	return a.awsAdapter.DoResponse(c, a.AwsClient, meta) | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) GetModelList() (models []string) { | ||||
| 	for model := range adaptors { | ||||
| 		models = append(models, model) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) GetChannelName() string { | ||||
| 	return "aws" | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	return "", nil | ||||
| } | ||||
| @@ -39,17 +72,6 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 
 | ||||
| 	claudeReq := anthropic.ConvertRequest(*request) | ||||
| 	c.Set(ctxkey.RequestModel, request.Model) | ||||
| 	c.Set(ctxkey.ConvertedRequest, claudeReq) | ||||
| 	return claudeReq, nil | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| @@ -60,23 +82,3 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		err, usage = StreamHandler(c, a.awsClient) | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, a.awsClient, meta.ActualModelName) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) GetModelList() (models []string) { | ||||
| 	for n := range awsModelIDMap { | ||||
| 		models = append(models, n) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) GetChannelName() string { | ||||
| 	return "aws" | ||||
| } | ||||
							
								
								
									
										37
									
								
								relay/adaptor/aws/claude/adapter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								relay/adaptor/aws/claude/adapter.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,37 @@ | ||||
| package aws | ||||
|  | ||||
| import ( | ||||
| 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/anthropic" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/aws/utils" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| var _ utils.AwsAdapter = new(Adaptor) | ||||
|  | ||||
| type Adaptor struct { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
|  | ||||
| 	claudeReq := anthropic.ConvertRequest(*request) | ||||
| 	c.Set(ctxkey.RequestModel, request.Model) | ||||
| 	c.Set(ctxkey.ConvertedRequest, claudeReq) | ||||
| 	return claudeReq, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		err, usage = StreamHandler(c, awsCli) | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, awsCli, meta.ActualModelName) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| @@ -5,7 +5,6 @@ import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 
 | ||||
| @@ -16,23 +15,17 @@ import ( | ||||
| 	"github.com/jinzhu/copier" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/anthropic" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/aws/utils" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
| 
 | ||||
| func wrapErr(err error) *relaymodel.ErrorWithStatusCode { | ||||
| 	return &relaymodel.ErrorWithStatusCode{ | ||||
| 		StatusCode: http.StatusInternalServerError, | ||||
| 		Error: relaymodel.Error{ | ||||
| 			Message: fmt.Sprintf("%s", err.Error()), | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html | ||||
| var awsModelIDMap = map[string]string{ | ||||
| var AwsModelIDMap = map[string]string{ | ||||
| 	"claude-instant-1.2":         "anthropic.claude-instant-v1", | ||||
| 	"claude-2.0":                 "anthropic.claude-v2", | ||||
| 	"claude-2.1":                 "anthropic.claude-v2:1", | ||||
| @@ -43,7 +36,7 @@ var awsModelIDMap = map[string]string{ | ||||
| } | ||||
| 
 | ||||
| func awsModelID(requestModel string) (string, error) { | ||||
| 	if awsModelID, ok := awsModelIDMap[requestModel]; ok { | ||||
| 	if awsModelID, ok := AwsModelIDMap[requestModel]; ok { | ||||
| 		return awsModelID, nil | ||||
| 	} | ||||
| 
 | ||||
| @@ -53,7 +46,7 @@ func awsModelID(requestModel string) (string, error) { | ||||
| func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { | ||||
| 	awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) | ||||
| 	if err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "awsModelID")), nil | ||||
| 		return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil | ||||
| 	} | ||||
| 
 | ||||
| 	awsReq := &bedrockruntime.InvokeModelInput{ | ||||
| @@ -64,30 +57,30 @@ func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (* | ||||
| 
 | ||||
| 	claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) | ||||
| 	if !ok { | ||||
| 		return wrapErr(errors.New("request not found")), nil | ||||
| 		return utils.WrapErr(errors.New("request not found")), nil | ||||
| 	} | ||||
| 	claudeReq := claudeReq_.(*anthropic.Request) | ||||
| 	awsClaudeReq := &Request{ | ||||
| 		AnthropicVersion: "bedrock-2023-05-31", | ||||
| 	} | ||||
| 	if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "copy request")), nil | ||||
| 		return utils.WrapErr(errors.Wrap(err, "copy request")), nil | ||||
| 	} | ||||
| 
 | ||||
| 	awsReq.Body, err = json.Marshal(awsClaudeReq) | ||||
| 	if err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "marshal request")), nil | ||||
| 		return utils.WrapErr(errors.Wrap(err, "marshal request")), nil | ||||
| 	} | ||||
| 
 | ||||
| 	awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) | ||||
| 	if err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "InvokeModel")), nil | ||||
| 		return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil | ||||
| 	} | ||||
| 
 | ||||
| 	claudeResponse := new(anthropic.Response) | ||||
| 	err = json.Unmarshal(awsResp.Body, claudeResponse) | ||||
| 	if err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "unmarshal response")), nil | ||||
| 		return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil | ||||
| 	} | ||||
| 
 | ||||
| 	openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse) | ||||
| @@ -107,7 +100,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E | ||||
| 	createdTime := helper.GetTimestamp() | ||||
| 	awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) | ||||
| 	if err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "awsModelID")), nil | ||||
| 		return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil | ||||
| 	} | ||||
| 
 | ||||
| 	awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ | ||||
| @@ -118,7 +111,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E | ||||
| 
 | ||||
| 	claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) | ||||
| 	if !ok { | ||||
| 		return wrapErr(errors.New("request not found")), nil | ||||
| 		return utils.WrapErr(errors.New("request not found")), nil | ||||
| 	} | ||||
| 	claudeReq := claudeReq_.(*anthropic.Request) | ||||
| 
 | ||||
| @@ -126,16 +119,16 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E | ||||
| 		AnthropicVersion: "bedrock-2023-05-31", | ||||
| 	} | ||||
| 	if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "copy request")), nil | ||||
| 		return utils.WrapErr(errors.Wrap(err, "copy request")), nil | ||||
| 	} | ||||
| 	awsReq.Body, err = json.Marshal(awsClaudeReq) | ||||
| 	if err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "marshal request")), nil | ||||
| 		return utils.WrapErr(errors.Wrap(err, "marshal request")), nil | ||||
| 	} | ||||
| 
 | ||||
| 	awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) | ||||
| 	if err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil | ||||
| 		return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil | ||||
| 	} | ||||
| 	stream := awsResp.GetStream() | ||||
| 	defer stream.Close() | ||||
| @@ -143,6 +136,8 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	var usage relaymodel.Usage | ||||
| 	var id string | ||||
| 	var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice | ||||
| 
 | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		event, ok := <-stream.Events() | ||||
| 		if !ok { | ||||
| @@ -163,8 +158,19 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E | ||||
| 			if meta != nil { | ||||
| 				usage.PromptTokens += meta.Usage.InputTokens | ||||
| 				usage.CompletionTokens += meta.Usage.OutputTokens | ||||
| 				id = fmt.Sprintf("chatcmpl-%s", meta.Id) | ||||
| 				return true | ||||
| 				if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event. | ||||
| 					id = fmt.Sprintf("chatcmpl-%s", meta.Id) | ||||
| 					return true | ||||
| 				} else { // finish_reason case | ||||
| 					if len(lastToolCallChoice.Delta.ToolCalls) > 0 { | ||||
| 						lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function | ||||
| 						if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments. | ||||
| 							lastArgs.Arguments = "{}" | ||||
| 							response.Choices[len(response.Choices)-1].Delta.Content = nil | ||||
| 							response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 			if response == nil { | ||||
| 				return true | ||||
| @@ -172,6 +178,12 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E | ||||
| 			response.Id = id | ||||
| 			response.Model = c.GetString(ctxkey.OriginalModel) | ||||
| 			response.Created = createdTime | ||||
| 
 | ||||
| 			for _, choice := range response.Choices { | ||||
| 				if len(choice.Delta.ToolCalls) > 0 { | ||||
| 					lastToolCallChoice = choice | ||||
| 				} | ||||
| 			} | ||||
| 			jsonStr, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| @@ -9,9 +9,12 @@ type Request struct { | ||||
| 	// AnthropicVersion should be "bedrock-2023-05-31" | ||||
| 	AnthropicVersion string              `json:"anthropic_version"` | ||||
| 	Messages         []anthropic.Message `json:"messages"` | ||||
| 	System           string              `json:"system,omitempty"` | ||||
| 	MaxTokens        int                 `json:"max_tokens,omitempty"` | ||||
| 	Temperature      float64             `json:"temperature,omitempty"` | ||||
| 	TopP             float64             `json:"top_p,omitempty"` | ||||
| 	TopK             int                 `json:"top_k,omitempty"` | ||||
| 	StopSequences    []string            `json:"stop_sequences,omitempty"` | ||||
| 	Tools            []anthropic.Tool    `json:"tools,omitempty"` | ||||
| 	ToolChoice       any                 `json:"tool_choice,omitempty"` | ||||
| } | ||||
							
								
								
									
										37
									
								
								relay/adaptor/aws/llama3/adapter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								relay/adaptor/aws/llama3/adapter.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,37 @@ | ||||
| package aws | ||||
|  | ||||
| import ( | ||||
| 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/aws/utils" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| var _ utils.AwsAdapter = new(Adaptor) | ||||
|  | ||||
| type Adaptor struct { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
|  | ||||
| 	llamaReq := ConvertRequest(*request) | ||||
| 	c.Set(ctxkey.RequestModel, request.Model) | ||||
| 	c.Set(ctxkey.ConvertedRequest, llamaReq) | ||||
| 	return llamaReq, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		err, usage = StreamHandler(c, awsCli) | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, awsCli, meta.ActualModelName) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
							
								
								
									
										231
									
								
								relay/adaptor/aws/llama3/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										231
									
								
								relay/adaptor/aws/llama3/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,231 @@ | ||||
| // Package aws provides the AWS adaptor for the relay service. | ||||
| package aws | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"text/template" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
|  | ||||
| 	"github.com/aws/aws-sdk-go-v2/aws" | ||||
| 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||||
| 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/aws/utils" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| // Only support llama-3-8b and llama-3-70b instruction models | ||||
| // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html | ||||
| var AwsModelIDMap = map[string]string{ | ||||
| 	"llama3-8b-8192":  "meta.llama3-8b-instruct-v1:0", | ||||
| 	"llama3-70b-8192": "meta.llama3-70b-instruct-v1:0", | ||||
| } | ||||
|  | ||||
| func awsModelID(requestModel string) (string, error) { | ||||
| 	if awsModelID, ok := AwsModelIDMap[requestModel]; ok { | ||||
| 		return awsModelID, nil | ||||
| 	} | ||||
|  | ||||
| 	return "", errors.Errorf("model %s not found", requestModel) | ||||
| } | ||||
|  | ||||
| // promptTemplate with range | ||||
| const promptTemplate = `<|begin_of_text|>{{range .Messages}}<|start_header_id|>{{.Role}}<|end_header_id|>{{.StringContent}}<|eot_id|>{{end}}<|start_header_id|>assistant<|end_header_id|> | ||||
| ` | ||||
|  | ||||
| var promptTpl = template.Must(template.New("llama3-chat").Parse(promptTemplate)) | ||||
|  | ||||
| func RenderPrompt(messages []relaymodel.Message) string { | ||||
| 	var buf bytes.Buffer | ||||
| 	err := promptTpl.Execute(&buf, struct{ Messages []relaymodel.Message }{messages}) | ||||
| 	if err != nil { | ||||
| 		logger.SysError("error rendering prompt messages: " + err.Error()) | ||||
| 	} | ||||
| 	return buf.String() | ||||
| } | ||||
|  | ||||
| func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request { | ||||
| 	llamaRequest := Request{ | ||||
| 		MaxGenLen:   textRequest.MaxTokens, | ||||
| 		Temperature: textRequest.Temperature, | ||||
| 		TopP:        textRequest.TopP, | ||||
| 	} | ||||
| 	if llamaRequest.MaxGenLen == 0 { | ||||
| 		llamaRequest.MaxGenLen = 2048 | ||||
| 	} | ||||
| 	prompt := RenderPrompt(textRequest.Messages) | ||||
| 	llamaRequest.Prompt = prompt | ||||
| 	return &llamaRequest | ||||
| } | ||||
|  | ||||
| func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { | ||||
| 	awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) | ||||
| 	if err != nil { | ||||
| 		return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil | ||||
| 	} | ||||
|  | ||||
| 	awsReq := &bedrockruntime.InvokeModelInput{ | ||||
| 		ModelId:     aws.String(awsModelId), | ||||
| 		Accept:      aws.String("application/json"), | ||||
| 		ContentType: aws.String("application/json"), | ||||
| 	} | ||||
|  | ||||
| 	llamaReq, ok := c.Get(ctxkey.ConvertedRequest) | ||||
| 	if !ok { | ||||
| 		return utils.WrapErr(errors.New("request not found")), nil | ||||
| 	} | ||||
|  | ||||
| 	awsReq.Body, err = json.Marshal(llamaReq) | ||||
| 	if err != nil { | ||||
| 		return utils.WrapErr(errors.Wrap(err, "marshal request")), nil | ||||
| 	} | ||||
|  | ||||
| 	awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) | ||||
| 	if err != nil { | ||||
| 		return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil | ||||
| 	} | ||||
|  | ||||
| 	var llamaResponse Response | ||||
| 	err = json.Unmarshal(awsResp.Body, &llamaResponse) | ||||
| 	if err != nil { | ||||
| 		return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil | ||||
| 	} | ||||
|  | ||||
| 	openaiResp := ResponseLlama2OpenAI(&llamaResponse) | ||||
| 	openaiResp.Model = modelName | ||||
| 	usage := relaymodel.Usage{ | ||||
| 		PromptTokens:     llamaResponse.PromptTokenCount, | ||||
| 		CompletionTokens: llamaResponse.GenerationTokenCount, | ||||
| 		TotalTokens:      llamaResponse.PromptTokenCount + llamaResponse.GenerationTokenCount, | ||||
| 	} | ||||
| 	openaiResp.Usage = usage | ||||
|  | ||||
| 	c.JSON(http.StatusOK, openaiResp) | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func ResponseLlama2OpenAI(llamaResponse *Response) *openai.TextResponse { | ||||
| 	var responseText string | ||||
| 	if len(llamaResponse.Generation) > 0 { | ||||
| 		responseText = llamaResponse.Generation | ||||
| 	} | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: relaymodel.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: responseText, | ||||
| 			Name:    nil, | ||||
| 		}, | ||||
| 		FinishReason: llamaResponse.StopReason, | ||||
| 	} | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", random.GetUUID()), | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { | ||||
| 	createdTime := helper.GetTimestamp() | ||||
| 	awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) | ||||
| 	if err != nil { | ||||
| 		return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil | ||||
| 	} | ||||
|  | ||||
| 	awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ | ||||
| 		ModelId:     aws.String(awsModelId), | ||||
| 		Accept:      aws.String("application/json"), | ||||
| 		ContentType: aws.String("application/json"), | ||||
| 	} | ||||
|  | ||||
| 	llamaReq, ok := c.Get(ctxkey.ConvertedRequest) | ||||
| 	if !ok { | ||||
| 		return utils.WrapErr(errors.New("request not found")), nil | ||||
| 	} | ||||
|  | ||||
| 	awsReq.Body, err = json.Marshal(llamaReq) | ||||
| 	if err != nil { | ||||
| 		return utils.WrapErr(errors.Wrap(err, "marshal request")), nil | ||||
| 	} | ||||
|  | ||||
| 	awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) | ||||
| 	if err != nil { | ||||
| 		return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil | ||||
| 	} | ||||
| 	stream := awsResp.GetStream() | ||||
| 	defer stream.Close() | ||||
|  | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	var usage relaymodel.Usage | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		event, ok := <-stream.Events() | ||||
| 		if !ok { | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
|  | ||||
| 		switch v := event.(type) { | ||||
| 		case *types.ResponseStreamMemberChunk: | ||||
| 			var llamaResp StreamResponse | ||||
| 			err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&llamaResp) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return false | ||||
| 			} | ||||
|  | ||||
| 			if llamaResp.PromptTokenCount > 0 { | ||||
| 				usage.PromptTokens = llamaResp.PromptTokenCount | ||||
| 			} | ||||
| 			if llamaResp.StopReason == "stop" { | ||||
| 				usage.CompletionTokens = llamaResp.GenerationTokenCount | ||||
| 				usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens | ||||
| 			} | ||||
| 			response := StreamResponseLlama2OpenAI(&llamaResp) | ||||
| 			response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID()) | ||||
| 			response.Model = c.GetString(ctxkey.OriginalModel) | ||||
| 			response.Created = createdTime | ||||
| 			jsonStr, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) | ||||
| 			return true | ||||
| 		case *types.UnknownUnionMember: | ||||
| 			fmt.Println("unknown tag:", v.Tag) | ||||
| 			return false | ||||
| 		default: | ||||
| 			fmt.Println("union is nil or unknown type") | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
|  | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func StreamResponseLlama2OpenAI(llamaResponse *StreamResponse) *openai.ChatCompletionsStreamResponse { | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = llamaResponse.Generation | ||||
| 	choice.Delta.Role = "assistant" | ||||
| 	finishReason := llamaResponse.StopReason | ||||
| 	if finishReason != "null" { | ||||
| 		choice.FinishReason = &finishReason | ||||
| 	} | ||||
| 	var openaiResponse openai.ChatCompletionsStreamResponse | ||||
| 	openaiResponse.Object = "chat.completion.chunk" | ||||
| 	openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} | ||||
| 	return &openaiResponse | ||||
| } | ||||
							
								
								
									
										45
									
								
								relay/adaptor/aws/llama3/main_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								relay/adaptor/aws/llama3/main_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,45 @@ | ||||
| package aws_test | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	aws "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | ||||
| func TestRenderPrompt(t *testing.T) { | ||||
| 	messages := []relaymodel.Message{ | ||||
| 		{ | ||||
| 			Role:    "user", | ||||
| 			Content: "What's your name?", | ||||
| 		}, | ||||
| 	} | ||||
| 	prompt := aws.RenderPrompt(messages) | ||||
| 	expected := `<|begin_of_text|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|> | ||||
| ` | ||||
| 	assert.Equal(t, expected, prompt) | ||||
|  | ||||
| 	messages = []relaymodel.Message{ | ||||
| 		{ | ||||
| 			Role:    "system", | ||||
| 			Content: "Your name is Kat. You are a detective.", | ||||
| 		}, | ||||
| 		{ | ||||
| 			Role:    "user", | ||||
| 			Content: "What's your name?", | ||||
| 		}, | ||||
| 		{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: "Kat", | ||||
| 		}, | ||||
| 		{ | ||||
| 			Role:    "user", | ||||
| 			Content: "What's your job?", | ||||
| 		}, | ||||
| 	} | ||||
| 	prompt = aws.RenderPrompt(messages) | ||||
| 	expected = `<|begin_of_text|><|start_header_id|>system<|end_header_id|>Your name is Kat. You are a detective.<|eot_id|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>Kat<|eot_id|><|start_header_id|>user<|end_header_id|>What's your job?<|eot_id|><|start_header_id|>assistant<|end_header_id|> | ||||
| ` | ||||
| 	assert.Equal(t, expected, prompt) | ||||
| } | ||||
							
								
								
									
										29
									
								
								relay/adaptor/aws/llama3/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								relay/adaptor/aws/llama3/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| package aws | ||||
|  | ||||
| // Request is the request to AWS Llama3 | ||||
| // | ||||
| // https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html | ||||
| type Request struct { | ||||
| 	Prompt      string  `json:"prompt"` | ||||
| 	MaxGenLen   int     `json:"max_gen_len,omitempty"` | ||||
| 	Temperature float64 `json:"temperature,omitempty"` | ||||
| 	TopP        float64 `json:"top_p,omitempty"` | ||||
| } | ||||
|  | ||||
| // Response is the response from AWS Llama3 | ||||
| // | ||||
| // https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html | ||||
| type Response struct { | ||||
| 	Generation           string `json:"generation"` | ||||
| 	PromptTokenCount     int    `json:"prompt_token_count"` | ||||
| 	GenerationTokenCount int    `json:"generation_token_count"` | ||||
| 	StopReason           string `json:"stop_reason"` | ||||
| } | ||||
|  | ||||
| // {'generation': 'Hi', 'prompt_token_count': 15, 'generation_token_count': 1, 'stop_reason': None} | ||||
| type StreamResponse struct { | ||||
| 	Generation           string `json:"generation"` | ||||
| 	PromptTokenCount     int    `json:"prompt_token_count"` | ||||
| 	GenerationTokenCount int    `json:"generation_token_count"` | ||||
| 	StopReason           string `json:"stop_reason"` | ||||
| } | ||||
							
								
								
									
										39
									
								
								relay/adaptor/aws/registry.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								relay/adaptor/aws/registry.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,39 @@ | ||||
| package aws | ||||
|  | ||||
| import ( | ||||
| 	claude "github.com/songquanpeng/one-api/relay/adaptor/aws/claude" | ||||
| 	llama3 "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/aws/utils" | ||||
| ) | ||||
|  | ||||
| type AwsModelType int | ||||
|  | ||||
| const ( | ||||
| 	AwsClaude AwsModelType = iota + 1 | ||||
| 	AwsLlama3 | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	adaptors = map[string]AwsModelType{} | ||||
| ) | ||||
|  | ||||
| func init() { | ||||
| 	for model := range claude.AwsModelIDMap { | ||||
| 		adaptors[model] = AwsClaude | ||||
| 	} | ||||
| 	for model := range llama3.AwsModelIDMap { | ||||
| 		adaptors[model] = AwsLlama3 | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetAdaptor(model string) utils.AwsAdapter { | ||||
| 	adaptorType := adaptors[model] | ||||
| 	switch adaptorType { | ||||
| 	case AwsClaude: | ||||
| 		return &claude.Adaptor{} | ||||
| 	case AwsLlama3: | ||||
| 		return &llama3.Adaptor{} | ||||
| 	default: | ||||
| 		return nil | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										51
									
								
								relay/adaptor/aws/utils/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								relay/adaptor/aws/utils/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,51 @@ | ||||
| package utils | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/aws/aws-sdk-go-v2/aws" | ||||
| 	"github.com/aws/aws-sdk-go-v2/credentials" | ||||
| 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| type AwsAdapter interface { | ||||
| 	ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) | ||||
| 	DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) | ||||
| } | ||||
|  | ||||
| type Adaptor struct { | ||||
| 	Meta      *meta.Meta | ||||
| 	AwsClient *bedrockruntime.Client | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| 	a.Meta = meta | ||||
| 	a.AwsClient = bedrockruntime.New(bedrockruntime.Options{ | ||||
| 		Region:      meta.Config.Region, | ||||
| 		Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")), | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	return "", nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	return request, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
							
								
								
									
										16
									
								
								relay/adaptor/aws/utils/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								relay/adaptor/aws/utils/utils.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | ||||
| package utils | ||||
|  | ||||
| import ( | ||||
| 	"net/http" | ||||
|  | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| func WrapErr(err error) *relaymodel.ErrorWithStatusCode { | ||||
| 	return &relaymodel.ErrorWithStatusCode{ | ||||
| 		StatusCode: http.StatusInternalServerError, | ||||
| 		Error: relaymodel.Error{ | ||||
| 			Message: err.Error(), | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
| @@ -10,6 +10,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| @@ -28,7 +29,14 @@ func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil | ||||
| 	switch meta.Mode { | ||||
| 	case relaymode.ChatCompletions: | ||||
| 		return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", meta.BaseURL, meta.Config.UserID), nil | ||||
| 	case relaymode.Embeddings: | ||||
| 		return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", meta.BaseURL, meta.Config.UserID), nil | ||||
| 	default: | ||||
| 		return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| @@ -41,7 +49,14 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	return ConvertRequest(*request), nil | ||||
| 	switch relayMode { | ||||
| 	case relaymode.Completions: | ||||
| 		return ConvertCompletionsRequest(*request), nil | ||||
| 	case relaymode.ChatCompletions, relaymode.Embeddings: | ||||
| 		return request, nil | ||||
| 	default: | ||||
| 		return nil, errors.New("not implemented") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
|   | ||||
| @@ -3,11 +3,13 @@ package cloudflare | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| @@ -16,57 +18,23 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 	var promptBuilder strings.Builder | ||||
| 	for _, message := range textRequest.Messages { | ||||
| 		promptBuilder.WriteString(message.StringContent()) | ||||
| 		promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息 | ||||
| 	} | ||||
|  | ||||
| func ConvertCompletionsRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 	p, _ := textRequest.Prompt.(string) | ||||
| 	return &Request{ | ||||
| 		Prompt:      p, | ||||
| 		MaxTokens:   textRequest.MaxTokens, | ||||
| 		Prompt:      promptBuilder.String(), | ||||
| 		Stream:      textRequest.Stream, | ||||
| 		Temperature: textRequest.Temperature, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse { | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: model.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: cloudflareResponse.Result.Response, | ||||
| 		}, | ||||
| 		FinishReason: "stop", | ||||
| 	} | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai.ChatCompletionsStreamResponse { | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = cloudflareResponse.Response | ||||
| 	choice.Delta.Role = "assistant" | ||||
| 	openaiResponse := openai.ChatCompletionsStreamResponse{ | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 	} | ||||
| 	return &openaiResponse | ||||
| } | ||||
|  | ||||
| func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(bufio.ScanLines) | ||||
|  | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	id := helper.GetResponseID(c) | ||||
| 	responseModel := c.GetString("original_model") | ||||
| 	responseModel := c.GetString(ctxkey.OriginalModel) | ||||
| 	var responseText string | ||||
|  | ||||
| 	for scanner.Scan() { | ||||
| @@ -77,22 +45,22 @@ func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelN | ||||
| 		data = strings.TrimPrefix(data, "data: ") | ||||
| 		data = strings.TrimSuffix(data, "\r") | ||||
|  | ||||
| 		var cloudflareResponse StreamResponse | ||||
| 		err := json.Unmarshal([]byte(data), &cloudflareResponse) | ||||
| 		if data == "[DONE]" { | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		var response openai.ChatCompletionsStreamResponse | ||||
| 		err := json.Unmarshal([]byte(data), &response) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		response := StreamResponseCloudflare2OpenAI(&cloudflareResponse) | ||||
| 		if response == nil { | ||||
| 			continue | ||||
| 		for _, v := range response.Choices { | ||||
| 			v.Delta.Role = "assistant" | ||||
| 			responseText += v.Delta.StringContent() | ||||
| 		} | ||||
|  | ||||
| 		responseText += cloudflareResponse.Response | ||||
| 		response.Id = id | ||||
| 		response.Model = responseModel | ||||
|  | ||||
| 		response.Model = modelName | ||||
| 		err = render.ObjectData(c, response) | ||||
| 		if err != nil { | ||||
| 			logger.SysError(err.Error()) | ||||
| @@ -123,22 +91,25 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	var cloudflareResponse Response | ||||
| 	err = json.Unmarshal(responseBody, &cloudflareResponse) | ||||
| 	var response openai.TextResponse | ||||
| 	err = json.Unmarshal(responseBody, &response) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	fullTextResponse := ResponseCloudflare2OpenAI(&cloudflareResponse) | ||||
| 	fullTextResponse.Model = modelName | ||||
| 	usage := openai.ResponseText2Usage(cloudflareResponse.Result.Response, modelName, promptTokens) | ||||
| 	fullTextResponse.Usage = *usage | ||||
| 	fullTextResponse.Id = helper.GetResponseID(c) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	response.Model = modelName | ||||
| 	var responseText string | ||||
| 	for _, v := range response.Choices { | ||||
| 		responseText += v.Message.Content.(string) | ||||
| 	} | ||||
| 	usage := openai.ResponseText2Usage(responseText, modelName, promptTokens) | ||||
| 	response.Usage = *usage | ||||
| 	response.Id = helper.GetResponseID(c) | ||||
| 	jsonResponse, err := json.Marshal(response) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	_, _ = c.Writer.Write(jsonResponse) | ||||
| 	return nil, usage | ||||
| } | ||||
|   | ||||
| @@ -1,25 +1,13 @@ | ||||
| package cloudflare | ||||
|  | ||||
| import "github.com/songquanpeng/one-api/relay/model" | ||||
|  | ||||
| type Request struct { | ||||
| 	Lora        string  `json:"lora,omitempty"` | ||||
| 	MaxTokens   int     `json:"max_tokens,omitempty"` | ||||
| 	Prompt      string  `json:"prompt,omitempty"` | ||||
| 	Raw         bool    `json:"raw,omitempty"` | ||||
| 	Stream      bool    `json:"stream,omitempty"` | ||||
| 	Temperature float64 `json:"temperature,omitempty"` | ||||
| } | ||||
|  | ||||
| type Result struct { | ||||
| 	Response string `json:"response"` | ||||
| } | ||||
|  | ||||
| type Response struct { | ||||
| 	Result   Result   `json:"result"` | ||||
| 	Success  bool     `json:"success"` | ||||
| 	Errors   []string `json:"errors"` | ||||
| 	Messages []string `json:"messages"` | ||||
| } | ||||
|  | ||||
| type StreamResponse struct { | ||||
| 	Response string `json:"response"` | ||||
| 	Messages    []model.Message `json:"messages,omitempty"` | ||||
| 	Lora        string          `json:"lora,omitempty"` | ||||
| 	MaxTokens   int             `json:"max_tokens,omitempty"` | ||||
| 	Prompt      string          `json:"prompt,omitempty"` | ||||
| 	Raw         bool            `json:"raw,omitempty"` | ||||
| 	Stream      bool            `json:"stream,omitempty"` | ||||
| 	Temperature float64         `json:"temperature,omitempty"` | ||||
| } | ||||
|   | ||||
							
								
								
									
										19
									
								
								relay/adaptor/novita/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								relay/adaptor/novita/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| package novita | ||||
|  | ||||
| // https://novita.ai/llm-api | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"meta-llama/llama-3-8b-instruct", | ||||
| 	"meta-llama/llama-3-70b-instruct", | ||||
| 	"nousresearch/hermes-2-pro-llama-3-8b", | ||||
| 	"nousresearch/nous-hermes-llama2-13b", | ||||
| 	"mistralai/mistral-7b-instruct", | ||||
| 	"cognitivecomputations/dolphin-mixtral-8x22b", | ||||
| 	"sao10k/l3-70b-euryale-v2.1", | ||||
| 	"sophosympatheia/midnight-rose-70b", | ||||
| 	"gryphe/mythomax-l2-13b", | ||||
| 	"Nous-Hermes-2-Mixtral-8x7B-DPO", | ||||
| 	"lzlv_70b", | ||||
| 	"teknium/openhermes-2.5-mistral-7b", | ||||
| 	"microsoft/wizardlm-2-8x22b", | ||||
| } | ||||
							
								
								
									
										15
									
								
								relay/adaptor/novita/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								relay/adaptor/novita/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | ||||
| package novita | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| ) | ||||
|  | ||||
| func GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	if meta.Mode == relaymode.ChatCompletions { | ||||
| 		return fmt.Sprintf("%s/chat/completions", meta.BaseURL), nil | ||||
| 	} | ||||
| 	return "", fmt.Errorf("unsupported relay mode %d for novita", meta.Mode) | ||||
| } | ||||
| @@ -3,17 +3,19 @@ package openai | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/doubao" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/minimax" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/novita" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| @@ -48,6 +50,8 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 		return minimax.GetRequestURL(meta) | ||||
| 	case channeltype.Doubao: | ||||
| 		return doubao.GetRequestURL(meta) | ||||
| 	case channeltype.Novita: | ||||
| 		return novita.GetRequestURL(meta) | ||||
| 	default: | ||||
| 		return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil | ||||
| 	} | ||||
|   | ||||
| @@ -10,6 +10,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/minimax" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/mistral" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/moonshot" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/novita" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/stepfun" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/togetherai" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| @@ -28,6 +29,7 @@ var CompatibleChannels = []int{ | ||||
| 	channeltype.StepFun, | ||||
| 	channeltype.DeepSeek, | ||||
| 	channeltype.TogetherAI, | ||||
| 	channeltype.Novita, | ||||
| } | ||||
|  | ||||
| func GetCompatibleChannelMeta(channelType int) (string, []string) { | ||||
| @@ -56,6 +58,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) { | ||||
| 		return "together.ai", togetherai.ModelList | ||||
| 	case channeltype.Doubao: | ||||
| 		return "doubao", doubao.ModelList | ||||
| 	case channeltype.Novita: | ||||
| 		return "novita", novita.ModelList | ||||
| 	default: | ||||
| 		return "openai", ModelList | ||||
| 	} | ||||
|   | ||||
| @@ -4,11 +4,12 @@ import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/conv" | ||||
| @@ -31,6 +32,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | ||||
|  | ||||
| 	common.SetEventStreamHeaders(c) | ||||
|  | ||||
| 	doneRendered := false | ||||
| 	for scanner.Scan() { | ||||
| 		data := scanner.Text() | ||||
| 		if len(data) < dataPrefixLength { // ignore blank line or wrong format | ||||
| @@ -41,6 +43,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | ||||
| 		} | ||||
| 		if strings.HasPrefix(data[dataPrefixLength:], done) { | ||||
| 			render.StringData(c, data) | ||||
| 			doneRendered = true | ||||
| 			continue | ||||
| 		} | ||||
| 		switch relayMode { | ||||
| @@ -81,7 +84,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | ||||
| 		logger.SysError("error reading stream: " + err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	render.Done(c) | ||||
| 	if !doneRendered { | ||||
| 		render.Done(c) | ||||
| 	} | ||||
|  | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -6,4 +6,5 @@ var ModelList = []string{ | ||||
| 	"SparkDesk-v2.1", | ||||
| 	"SparkDesk-v3.1", | ||||
| 	"SparkDesk-v3.5", | ||||
| 	"SparkDesk-v4.0", | ||||
| } | ||||
|   | ||||
| @@ -44,7 +44,7 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string | ||||
| 	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens | ||||
| 	xunfeiRequest.Payload.Message.Text = messages | ||||
|  | ||||
| 	if strings.HasPrefix(domain, "generalv3") { | ||||
| 	if strings.HasPrefix(domain, "generalv3") || domain == "4.0Ultra" { | ||||
| 		functions := make([]model.Function, len(request.Tools)) | ||||
| 		for i, tool := range request.Tools { | ||||
| 			functions[i] = tool.Function | ||||
| @@ -290,6 +290,8 @@ func apiVersion2domain(apiVersion string) string { | ||||
| 		return "generalv3" | ||||
| 	case "v3.5": | ||||
| 		return "generalv3.5" | ||||
| 	case "v4.0": | ||||
| 		return "4.0Ultra" | ||||
| 	} | ||||
| 	return "general" + apiVersion | ||||
| } | ||||
|   | ||||
| @@ -2,6 +2,7 @@ package ratio | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| @@ -125,6 +126,7 @@ var ModelRatio = map[string]float64{ | ||||
| 	"SparkDesk-v2.1":            1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v3.1":            1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v3.5":            1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v4.0":            1.2858, // ¥0.018 / 1k tokens | ||||
| 	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens | ||||
| 	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens | ||||
| 	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens | ||||
| @@ -168,6 +170,9 @@ var ModelRatio = map[string]float64{ | ||||
| 	"step-1v-32k": 0.024 * RMB, | ||||
| 	"step-1-32k":  0.024 * RMB, | ||||
| 	"step-1-200k": 0.15 * RMB, | ||||
| 	// aws llama3 https://aws.amazon.com/cn/bedrock/pricing/ | ||||
| 	"llama3-8b-8192(33)":  0.0003 / 0.002,  // $0.0003 / 1K tokens | ||||
| 	"llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens | ||||
| 	// https://cohere.com/pricing | ||||
| 	"command":               0.5, | ||||
| 	"command-nightly":       0.5, | ||||
| @@ -184,7 +189,11 @@ var ModelRatio = map[string]float64{ | ||||
| 	"deepl-ja": 25.0 / 1000 * USD, | ||||
| } | ||||
|  | ||||
| var CompletionRatio = map[string]float64{} | ||||
| var CompletionRatio = map[string]float64{ | ||||
| 	// aws llama3 | ||||
| 	"llama3-8b-8192(33)":  0.0006 / 0.0003, | ||||
| 	"llama3-70b-8192(33)": 0.0035 / 0.00265, | ||||
| } | ||||
|  | ||||
| var DefaultModelRatio map[string]float64 | ||||
| var DefaultCompletionRatio map[string]float64 | ||||
| @@ -233,22 +242,28 @@ func UpdateModelRatioByJSONString(jsonStr string) error { | ||||
| 	return json.Unmarshal([]byte(jsonStr), &ModelRatio) | ||||
| } | ||||
|  | ||||
| func GetModelRatio(name string) float64 { | ||||
| func GetModelRatio(name string, channelType int) float64 { | ||||
| 	if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { | ||||
| 		name = strings.TrimSuffix(name, "-internet") | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") { | ||||
| 		name = strings.TrimSuffix(name, "-internet") | ||||
| 	} | ||||
| 	ratio, ok := ModelRatio[name] | ||||
| 	if !ok { | ||||
| 		ratio, ok = DefaultModelRatio[name] | ||||
| 	model := fmt.Sprintf("%s(%d)", name, channelType) | ||||
| 	if ratio, ok := ModelRatio[model]; ok { | ||||
| 		return ratio | ||||
| 	} | ||||
| 	if !ok { | ||||
| 		logger.SysError("model ratio not found: " + name) | ||||
| 		return 30 | ||||
| 	if ratio, ok := DefaultModelRatio[model]; ok { | ||||
| 		return ratio | ||||
| 	} | ||||
| 	return ratio | ||||
| 	if ratio, ok := ModelRatio[name]; ok { | ||||
| 		return ratio | ||||
| 	} | ||||
| 	if ratio, ok := DefaultModelRatio[name]; ok { | ||||
| 		return ratio | ||||
| 	} | ||||
| 	logger.SysError("model ratio not found: " + name) | ||||
| 	return 30 | ||||
| } | ||||
|  | ||||
| func CompletionRatio2JSONString() string { | ||||
| @@ -264,7 +279,17 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { | ||||
| 	return json.Unmarshal([]byte(jsonStr), &CompletionRatio) | ||||
| } | ||||
|  | ||||
| func GetCompletionRatio(name string) float64 { | ||||
| func GetCompletionRatio(name string, channelType int) float64 { | ||||
| 	if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { | ||||
| 		name = strings.TrimSuffix(name, "-internet") | ||||
| 	} | ||||
| 	model := fmt.Sprintf("%s(%d)", name, channelType) | ||||
| 	if ratio, ok := CompletionRatio[model]; ok { | ||||
| 		return ratio | ||||
| 	} | ||||
| 	if ratio, ok := DefaultCompletionRatio[model]; ok { | ||||
| 		return ratio | ||||
| 	} | ||||
| 	if ratio, ok := CompletionRatio[name]; ok { | ||||
| 		return ratio | ||||
| 	} | ||||
|   | ||||
| @@ -42,5 +42,6 @@ const ( | ||||
| 	DeepL | ||||
| 	TogetherAI | ||||
| 	Doubao | ||||
| 	Novita | ||||
| 	Dummy | ||||
| ) | ||||
|   | ||||
| @@ -42,6 +42,7 @@ var ChannelBaseURLs = []string{ | ||||
| 	"https://api-free.deepl.com",                // 38 | ||||
| 	"https://api.together.xyz",                  // 39 | ||||
| 	"https://ark.cn-beijing.volces.com",         // 40 | ||||
| 	"https://api.novita.ai/v3/openai",           // 41 | ||||
| } | ||||
|  | ||||
| func init() { | ||||
|   | ||||
| @@ -7,6 +7,10 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/client" | ||||
| @@ -21,9 +25,6 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { | ||||
| @@ -53,7 +54,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	modelRatio := billingratio.GetModelRatio(audioModel) | ||||
| 	modelRatio := billingratio.GetModelRatio(audioModel, channelType) | ||||
| 	groupRatio := billingratio.GetGroupRatio(group) | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	var quota int64 | ||||
|   | ||||
| @@ -4,6 +4,10 @@ import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"math" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| @@ -16,9 +20,6 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| 	"math" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) { | ||||
| @@ -40,78 +41,6 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener | ||||
| 	return textRequest, nil | ||||
| } | ||||
|  | ||||
| func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { | ||||
| 	imageRequest := &relaymodel.ImageRequest{} | ||||
| 	err := common.UnmarshalBodyReusable(c, imageRequest) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if imageRequest.N == 0 { | ||||
| 		imageRequest.N = 1 | ||||
| 	} | ||||
| 	if imageRequest.Size == "" { | ||||
| 		imageRequest.Size = "1024x1024" | ||||
| 	} | ||||
| 	if imageRequest.Model == "" { | ||||
| 		imageRequest.Model = "dall-e-2" | ||||
| 	} | ||||
| 	return imageRequest, nil | ||||
| } | ||||
|  | ||||
| func isValidImageSize(model string, size string) bool { | ||||
| 	if model == "cogview-3" { | ||||
| 		return true | ||||
| 	} | ||||
| 	_, ok := billingratio.ImageSizeRatios[model][size] | ||||
| 	return ok | ||||
| } | ||||
|  | ||||
| func getImageSizeRatio(model string, size string) float64 { | ||||
| 	ratio, ok := billingratio.ImageSizeRatios[model][size] | ||||
| 	if !ok { | ||||
| 		return 1 | ||||
| 	} | ||||
| 	return ratio | ||||
| } | ||||
|  | ||||
| func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { | ||||
| 	// model validation | ||||
| 	hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size) | ||||
| 	if !hasValidSize { | ||||
| 		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) | ||||
| 	} | ||||
| 	// check prompt length | ||||
| 	if imageRequest.Prompt == "" { | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | ||||
| 	} | ||||
| 	if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] { | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) | ||||
| 	} | ||||
| 	// Number of generated images validation | ||||
| 	if !isWithinRange(imageRequest.Model, imageRequest.N) { | ||||
| 		// channel not azure | ||||
| 		if meta.ChannelType != channeltype.Azure { | ||||
| 			return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { | ||||
| 	if imageRequest == nil { | ||||
| 		return 0, errors.New("imageRequest is nil") | ||||
| 	} | ||||
| 	imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) | ||||
| 	if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { | ||||
| 		if imageRequest.Size == "1024x1024" { | ||||
| 			imageCostRatio *= 2 | ||||
| 		} else { | ||||
| 			imageCostRatio *= 1.5 | ||||
| 		} | ||||
| 	} | ||||
| 	return imageCostRatio, nil | ||||
| } | ||||
|  | ||||
| func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { | ||||
| 	switch relayMode { | ||||
| 	case relaymode.ChatCompletions: | ||||
| @@ -167,7 +96,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M | ||||
| 		return | ||||
| 	} | ||||
| 	var quota int64 | ||||
| 	completionRatio := billingratio.GetCompletionRatio(textRequest.Model) | ||||
| 	completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType) | ||||
| 	promptTokens := usage.PromptTokens | ||||
| 	completionTokens := usage.CompletionTokens | ||||
| 	quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) | ||||
|   | ||||
| @@ -6,7 +6,11 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| @@ -16,17 +20,86 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| func isWithinRange(element string, value int) bool { | ||||
| 	if _, ok := billingratio.ImageGenerationAmounts[element]; !ok { | ||||
| 		return false | ||||
| func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { | ||||
| 	imageRequest := &relaymodel.ImageRequest{} | ||||
| 	err := common.UnmarshalBodyReusable(c, imageRequest) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	min := billingratio.ImageGenerationAmounts[element][0] | ||||
| 	max := billingratio.ImageGenerationAmounts[element][1] | ||||
| 	return value >= min && value <= max | ||||
| 	if imageRequest.N == 0 { | ||||
| 		imageRequest.N = 1 | ||||
| 	} | ||||
| 	if imageRequest.Size == "" { | ||||
| 		imageRequest.Size = "1024x1024" | ||||
| 	} | ||||
| 	if imageRequest.Model == "" { | ||||
| 		imageRequest.Model = "dall-e-2" | ||||
| 	} | ||||
| 	return imageRequest, nil | ||||
| } | ||||
|  | ||||
| func isValidImageSize(model string, size string) bool { | ||||
| 	if model == "cogview-3" || billingratio.ImageSizeRatios[model] == nil { | ||||
| 		return true | ||||
| 	} | ||||
| 	_, ok := billingratio.ImageSizeRatios[model][size] | ||||
| 	return ok | ||||
| } | ||||
|  | ||||
| func isValidImagePromptLength(model string, promptLength int) bool { | ||||
| 	maxPromptLength, ok := billingratio.ImagePromptLengthLimitations[model] | ||||
| 	return !ok || promptLength <= maxPromptLength | ||||
| } | ||||
|  | ||||
| func isWithinRange(element string, value int) bool { | ||||
| 	amounts, ok := billingratio.ImageGenerationAmounts[element] | ||||
| 	return !ok || (value >= amounts[0] && value <= amounts[1]) | ||||
| } | ||||
|  | ||||
| func getImageSizeRatio(model string, size string) float64 { | ||||
| 	if ratio, ok := billingratio.ImageSizeRatios[model][size]; ok { | ||||
| 		return ratio | ||||
| 	} | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { | ||||
| 	// check prompt length | ||||
| 	if imageRequest.Prompt == "" { | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// model validation | ||||
| 	if !isValidImageSize(imageRequest.Model, imageRequest.Size) { | ||||
| 		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	if !isValidImagePromptLength(imageRequest.Model, len(imageRequest.Prompt)) { | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// Number of generated images validation | ||||
| 	if !isWithinRange(imageRequest.Model, imageRequest.N) { | ||||
| 		return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { | ||||
| 	if imageRequest == nil { | ||||
| 		return 0, errors.New("imageRequest is nil") | ||||
| 	} | ||||
| 	imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) | ||||
| 	if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { | ||||
| 		if imageRequest.Size == "1024x1024" { | ||||
| 			imageCostRatio *= 2 | ||||
| 		} else { | ||||
| 			imageCostRatio *= 1.5 | ||||
| 		} | ||||
| 	} | ||||
| 	return imageCostRatio, nil | ||||
| } | ||||
|  | ||||
| func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { | ||||
| @@ -94,7 +167,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	} | ||||
|  | ||||
| 	modelRatio := billingratio.GetModelRatio(imageModel) | ||||
| 	modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType) | ||||
| 	groupRatio := billingratio.GetGroupRatio(meta.Group) | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) | ||||
|   | ||||
| @@ -4,6 +4,9 @@ import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay" | ||||
| @@ -14,8 +17,6 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { | ||||
| @@ -35,7 +36,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { | ||||
| 	textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping) | ||||
| 	meta.ActualModelName = textRequest.Model | ||||
| 	// get model ratio & group ratio | ||||
| 	modelRatio := billingratio.GetModelRatio(textRequest.Model) | ||||
| 	modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) | ||||
| 	groupRatio := billingratio.GetGroupRatio(meta.Group) | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	// pre-consume quota | ||||
|   | ||||
| @@ -1,10 +1,11 @@ | ||||
| package model | ||||
|  | ||||
| type Message struct { | ||||
| 	Role      string  `json:"role,omitempty"` | ||||
| 	Content   any     `json:"content,omitempty"` | ||||
| 	Name      *string `json:"name,omitempty"` | ||||
| 	ToolCalls []Tool  `json:"tool_calls,omitempty"` | ||||
| 	Role       string  `json:"role,omitempty"` | ||||
| 	Content    any     `json:"content,omitempty"` | ||||
| 	Name       *string `json:"name,omitempty"` | ||||
| 	ToolCalls  []Tool  `json:"tool_calls,omitempty"` | ||||
| 	ToolCallId string  `json:"tool_call_id,omitempty"` | ||||
| } | ||||
|  | ||||
| func (m Message) IsStringContent() bool { | ||||
|   | ||||
| @@ -2,13 +2,13 @@ package model | ||||
|  | ||||
| type Tool struct { | ||||
| 	Id       string   `json:"id,omitempty"` | ||||
| 	Type     string   `json:"type"` | ||||
| 	Type     string   `json:"type,omitempty"` // when splicing claude tools stream messages, it is empty | ||||
| 	Function Function `json:"function"` | ||||
| } | ||||
|  | ||||
| type Function struct { | ||||
| 	Description string `json:"description,omitempty"` | ||||
| 	Name        string `json:"name"` | ||||
| 	Name        string `json:"name,omitempty"`       // when splicing claude tools stream messages, it is empty | ||||
| 	Parameters  any    `json:"parameters,omitempty"` // request | ||||
| 	Arguments   any    `json:"arguments,omitempty"`  // response | ||||
| } | ||||
|   | ||||
| @@ -47,7 +47,7 @@ const PersonalSetting = () => { | ||||
|   const [countdown, setCountdown] = useState(30); | ||||
|   const [affLink, setAffLink] = useState(''); | ||||
|   const [systemToken, setSystemToken] = useState(''); | ||||
|   // const [models, setModels] = useState([]); | ||||
|   const [models, setModels] = useState([]); | ||||
|   const [openTransfer, setOpenTransfer] = useState(false); | ||||
|   const [transferAmount, setTransferAmount] = useState(0); | ||||
|  | ||||
| @@ -72,7 +72,7 @@ const PersonalSetting = () => { | ||||
|         console.log(userState); | ||||
|       } | ||||
|     ); | ||||
|     // loadModels().then(); | ||||
|     loadModels().then(); | ||||
|     getAffLink().then(); | ||||
|     setTransferAmount(getQuotaPerUnit()); | ||||
|   }, []); | ||||
| @@ -127,16 +127,16 @@ const PersonalSetting = () => { | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   // const loadModels = async () => { | ||||
|   //   let res = await API.get(`/api/user/models`); | ||||
|   //   const { success, message, data } = res.data; | ||||
|   //   if (success) { | ||||
|   //     setModels(data); | ||||
|   //     console.log(data); | ||||
|   //   } else { | ||||
|   //     showError(message); | ||||
|   //   } | ||||
|   // }; | ||||
|   const loadModels = async () => { | ||||
|     let res = await API.get(`/api/user/available_models`); | ||||
|     const { success, message, data } = res.data; | ||||
|     if (success) { | ||||
|       setModels(data); | ||||
|       console.log(data); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const handleAffLinkClick = async (e) => { | ||||
|     e.target.select(); | ||||
| @@ -344,7 +344,7 @@ const PersonalSetting = () => { | ||||
|               } | ||||
|             > | ||||
|               <Typography.Title heading={6}>调用信息</Typography.Title> | ||||
|               {/* <Typography.Title heading={6}>可用模型</Typography.Title> | ||||
|               <p>可用模型(可点击复制)</p> | ||||
|               <div style={{ marginTop: 10 }}> | ||||
|                 <Space wrap> | ||||
|                   {models.map((model) => ( | ||||
| @@ -355,7 +355,7 @@ const PersonalSetting = () => { | ||||
|                     </Tag> | ||||
|                   ))} | ||||
|                 </Space> | ||||
|               </div> */} | ||||
|               </div> | ||||
|             </Card> | ||||
|             {/* <Card | ||||
|               footer={ | ||||
|   | ||||
| @@ -78,7 +78,7 @@ const EditChannel = (props) => { | ||||
|                     localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; | ||||
|                     break; | ||||
|                 case 18: | ||||
|                     localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5']; | ||||
|                     localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0']; | ||||
|                     break; | ||||
|                 case 19: | ||||
|                     localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; | ||||
|   | ||||
| @@ -13,7 +13,7 @@ export const CHANNEL_OPTIONS = { | ||||
|   }, | ||||
|   33: { | ||||
|     key: 33, | ||||
|     text: 'AWS Claude', | ||||
|     text: 'AWS', | ||||
|     value: 33, | ||||
|     color: 'primary' | ||||
|   }, | ||||
| @@ -161,6 +161,12 @@ export const CHANNEL_OPTIONS = { | ||||
|     value: 39, | ||||
|     color: 'primary' | ||||
|   }, | ||||
|   41: { | ||||
|     key: 41, | ||||
|     text: 'Novita', | ||||
|     value: 41, | ||||
|     color: 'purple' | ||||
|   }, | ||||
|   8: { | ||||
|     key: 8, | ||||
|     text: '自定义渠道', | ||||
|   | ||||
| @@ -91,7 +91,7 @@ const typeConfig = { | ||||
|       other: '版本号' | ||||
|     }, | ||||
|     input: { | ||||
|       models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5'] | ||||
|       models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0'] | ||||
|     }, | ||||
|     prompt: { | ||||
|       key: '按照如下格式输入:APPID|APISecret|APIKey', | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; | ||||
| import { Button, Dropdown, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; | ||||
| import { Link } from 'react-router-dom'; | ||||
| import { | ||||
|   API, | ||||
| @@ -70,13 +70,33 @@ const ChannelsTable = () => { | ||||
|     const res = await API.get(`/api/channel/?p=${startIdx}`); | ||||
|     const { success, message, data } = res.data; | ||||
|     if (success) { | ||||
|       if (startIdx === 0) { | ||||
|         setChannels(data); | ||||
|       } else { | ||||
|         let newChannels = [...channels]; | ||||
|         newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); | ||||
|         setChannels(newChannels); | ||||
|       } | ||||
|         let localChannels = data.map((channel) => { | ||||
|             if (channel.models === '') { | ||||
|                 channel.models = []; | ||||
|                 channel.test_model = ""; | ||||
|             } else { | ||||
|                 channel.models = channel.models.split(','); | ||||
|                 if (channel.models.length > 0) { | ||||
|                     channel.test_model = channel.models[0]; | ||||
|                 } | ||||
|                 channel.model_options = channel.models.map((model) => { | ||||
|                     return { | ||||
|                         key: model, | ||||
|                         text: model, | ||||
|                         value: model, | ||||
|                     } | ||||
|                 }) | ||||
|                 console.log('channel', channel) | ||||
|             } | ||||
|             return channel; | ||||
|         }); | ||||
|         if (startIdx === 0) { | ||||
|             setChannels(localChannels); | ||||
|         } else { | ||||
|             let newChannels = [...channels]; | ||||
|             newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...localChannels); | ||||
|             setChannels(newChannels); | ||||
|         } | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
| @@ -225,19 +245,31 @@ const ChannelsTable = () => { | ||||
|     setSearching(false); | ||||
|   }; | ||||
|  | ||||
|   const testChannel = async (id, name, idx) => { | ||||
|     const res = await API.get(`/api/channel/test/${id}/`); | ||||
|     const { success, message, time } = res.data; | ||||
|   const switchTestModel = async (idx, model) => { | ||||
|     let newChannels = [...channels]; | ||||
|     let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; | ||||
|     newChannels[realIdx].test_model = model; | ||||
|     setChannels(newChannels); | ||||
|   }; | ||||
|  | ||||
|   const testChannel = async (id, name, idx, m) => { | ||||
|     const res = await API.get(`/api/channel/test/${id}?model=${m}`); | ||||
|     const { success, message, time, model } = res.data; | ||||
|     if (success) { | ||||
|       let newChannels = [...channels]; | ||||
|       let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; | ||||
|       newChannels[realIdx].response_time = time * 1000; | ||||
|       newChannels[realIdx].test_time = Date.now() / 1000; | ||||
|       setChannels(newChannels); | ||||
|       showInfo(`渠道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); | ||||
|       showInfo(`渠道 ${name} 测试成功,模型 ${model},耗时 ${time.toFixed(2)} 秒。`); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
|     let newChannels = [...channels]; | ||||
|     let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; | ||||
|     newChannels[realIdx].response_time = time * 1000; | ||||
|     newChannels[realIdx].test_time = Date.now() / 1000; | ||||
|     setChannels(newChannels); | ||||
|   }; | ||||
|  | ||||
|   const testChannels = async (scope) => { | ||||
| @@ -405,6 +437,7 @@ const ChannelsTable = () => { | ||||
|             > | ||||
|               优先级 | ||||
|             </Table.HeaderCell> | ||||
|             <Table.HeaderCell>测试模型</Table.HeaderCell> | ||||
|             <Table.HeaderCell>操作</Table.HeaderCell> | ||||
|           </Table.Row> | ||||
|         </Table.Header> | ||||
| @@ -459,13 +492,24 @@ const ChannelsTable = () => { | ||||
|                       basic | ||||
|                     /> | ||||
|                   </Table.Cell> | ||||
|                   <Table.Cell> | ||||
|                     <Dropdown | ||||
|                       placeholder='请选择测试模型' | ||||
|                       selection | ||||
|                       options={channel.model_options} | ||||
|                       defaultValue={channel.test_model} | ||||
|                       onChange={(event, data) => { | ||||
|                         switchTestModel(idx, data.value); | ||||
|                       }} | ||||
|                     /> | ||||
|                   </Table.Cell> | ||||
|                   <Table.Cell> | ||||
|                     <div> | ||||
|                       <Button | ||||
|                         size={'small'} | ||||
|                         positive | ||||
|                         onClick={() => { | ||||
|                           testChannel(channel.id, channel.name, idx); | ||||
|                           testChannel(channel.id, channel.name, idx, channel.test_model); | ||||
|                         }} | ||||
|                       > | ||||
|                         测试 | ||||
|   | ||||
| @@ -1,11 +1,12 @@ | ||||
| export const CHANNEL_OPTIONS = [ | ||||
|     {key: 1, text: 'OpenAI', value: 1, color: 'green'}, | ||||
|     {key: 14, text: 'Anthropic Claude', value: 14, color: 'black'}, | ||||
|     {key: 33, text: 'AWS Claude', value: 33, color: 'black'}, | ||||
|     {key: 33, text: 'AWS', value: 33, color: 'black'}, | ||||
|     {key: 3, text: 'Azure OpenAI', value: 3, color: 'olive'}, | ||||
|     {key: 11, text: 'Google PaLM2', value: 11, color: 'orange'}, | ||||
|     {key: 24, text: 'Google Gemini', value: 24, color: 'orange'}, | ||||
|     {key: 28, text: 'Mistral AI', value: 28, color: 'orange'}, | ||||
|     {key: 41, text: 'Novita', value: 41, color: 'purple'}, | ||||
|     {key: 40, text: '字节跳动豆包', value: 40, color: 'blue'}, | ||||
|     {key: 15, text: '百度文心千帆', value: 15, color: 'blue'}, | ||||
|     {key: 17, text: '阿里通义千问', value: 17, color: 'orange'}, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user