mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-31 22:03:41 +08:00 
			
		
		
		
	Compare commits
	
		
			26 Commits
		
	
	
		
			v0.6.7-alp
			...
			v0.6.8-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 0fc07ea558 | ||
|  | 1ce1e529ee | ||
|  | d936817de9 | ||
|  | fecaece71b | ||
|  | c135d74f13 | ||
|  | d0369b114f | ||
|  | b21b3b5b46 | ||
|  | ae1cd29f94 | ||
|  | f25aaf7752 | ||
|  | b70a07e814 | ||
|  | 34cb147a74 | ||
|  | 8cc1ee6360 | ||
|  | 5a58426859 | ||
|  | 254b9777c0 | ||
|  | 114c44c6e7 | ||
|  | a3c7e15aed | ||
|  | 3777517f64 | ||
|  | 9fc5f427dc | ||
|  | 864a467886 | ||
|  | ed78b5340b | ||
|  | fee69e7c20 | ||
|  | 9d23a44dbf | ||
|  | 6e4cfb20d5 | ||
|  | ff196b75a7 | ||
|  | 279caf82dc | ||
|  | b1520b308b | 
							
								
								
									
										3
									
								
								.env.example
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								.env.example
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| PORT=3000 | ||||
| DEBUG=false | ||||
| HTTPS_PROXY=http://localhost:7890 | ||||
							
								
								
									
										47
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | ||||
| name: CI | ||||
|  | ||||
| # This setup assumes that you run the unit tests with code coverage in the same | ||||
| # workflow that will also print the coverage report as comment to the pull request.  | ||||
| # Therefore, you need to trigger this workflow when a pull request is (re)opened or | ||||
| # when new code is pushed to the branch of the pull request. In addition, you also | ||||
| # need to trigger this workflow when new code is pushed to the main branch because  | ||||
| # we need to upload the code coverage results as artifact for the main branch as | ||||
| # well since it will be the baseline code coverage. | ||||
| #  | ||||
| # We do not want to trigger the workflow for pushes to *any* branch because this | ||||
| # would trigger our jobs twice on pull requests (once from "push" event and once | ||||
| # from "pull_request->synchronize") | ||||
| on: | ||||
|   pull_request: | ||||
|     types: [opened, reopened, synchronize] | ||||
|   push: | ||||
|     branches: | ||||
|       - 'main' | ||||
|  | ||||
| jobs: | ||||
|   unit_tests: | ||||
|     name: "Unit tests" | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - name: Checkout repository | ||||
|         uses: actions/checkout@v4 | ||||
|  | ||||
|       - name: Setup Go | ||||
|         uses: actions/setup-go@v4 | ||||
|         with: | ||||
|           go-version: ^1.22 | ||||
|  | ||||
|       # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a  | ||||
|       # coverage profile to a file. You will need the name of the file (e.g. "coverage.txt") | ||||
|       # in the next step as well as the next job. | ||||
|       - name: Test | ||||
|         run: go test -cover -coverprofile=coverage.txt ./... | ||||
|       - uses: codecov/codecov-action@v4 | ||||
|         with: | ||||
|           token: ${{ secrets.CODECOV_TOKEN }} | ||||
|  | ||||
|   commit_lint: | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - uses: actions/checkout@v3 | ||||
|       - uses: wagoid/commitlint-github-action@v6 | ||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -9,3 +9,4 @@ logs | ||||
| data | ||||
| /web/node_modules | ||||
| cmd.md | ||||
| .env | ||||
| @@ -287,7 +287,9 @@ If the channel ID is not provided, load balancing will be used to distribute the | ||||
|     + Double-check that your interface address and API Key are correct. | ||||
|  | ||||
| ## Related Projects | ||||
| [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM | ||||
| * [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM | ||||
| * [VChart](https://github.com/VisActor/VChart):  More than just a cross-platform charting library, but also an expressive data storyteller. | ||||
| * [VMind](https://github.com/VisActor/VMind):  Not just automatic, but also fantastic. Open-source solution for intelligent visualization. | ||||
|  | ||||
| ## Note | ||||
| This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes. | ||||
|   | ||||
| @@ -340,6 +340,7 @@ graph LR | ||||
| 不加的话将会使用负载均衡的方式使用多个渠道。 | ||||
|  | ||||
| ### 环境变量 | ||||
| > One API 支持从 `.env` 文件中读取环境变量,请参照 `.env.example` 文件,使用时请将其重命名为 `.env`。 | ||||
| 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 | ||||
|    + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||
|    + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 | ||||
| @@ -447,6 +448,8 @@ https://openai.justsong.cn | ||||
| ## 相关项目 | ||||
| * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 | ||||
| * [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web):  一键拥有你自己的跨平台 ChatGPT 应用 | ||||
| * [VChart](https://github.com/VisActor/VChart):  不只是开箱即用的多端图表库,更是生动灵活的数据故事讲述者。 | ||||
| * [VMind](https://github.com/VisActor/VMind):  不仅自动,还很智能。开源智能可视化解决方案。 | ||||
|  | ||||
| ## 注意 | ||||
|  | ||||
|   | ||||
| @@ -19,4 +19,5 @@ const ( | ||||
| 	TokenName         = "token_name" | ||||
| 	BaseURL           = "base_url" | ||||
| 	AvailableModels   = "available_models" | ||||
| 	KeyRequestBody    = "key_request_body" | ||||
| ) | ||||
|   | ||||
| @@ -4,14 +4,13 @@ import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"io" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| const KeyRequestBody = "key_request_body" | ||||
|  | ||||
| func GetRequestBody(c *gin.Context) ([]byte, error) { | ||||
| 	requestBody, _ := c.Get(KeyRequestBody) | ||||
| 	requestBody, _ := c.Get(ctxkey.KeyRequestBody) | ||||
| 	if requestBody != nil { | ||||
| 		return requestBody.([]byte), nil | ||||
| 	} | ||||
| @@ -20,7 +19,7 @@ func GetRequestBody(c *gin.Context) ([]byte, error) { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	_ = c.Request.Body.Close() | ||||
| 	c.Set(KeyRequestBody, requestBody) | ||||
| 	c.Set(ctxkey.KeyRequestBody, requestBody) | ||||
| 	return requestBody.([]byte), nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -2,6 +2,7 @@ package image_test | ||||
|  | ||||
| import ( | ||||
| 	"encoding/base64" | ||||
| 	"github.com/songquanpeng/one-api/common/client" | ||||
| 	"image" | ||||
| 	_ "image/gif" | ||||
| 	_ "image/jpeg" | ||||
| @@ -44,6 +45,11 @@ var ( | ||||
| 	} | ||||
| ) | ||||
|  | ||||
| func TestMain(m *testing.M) { | ||||
| 	client.Init() | ||||
| 	m.Run() | ||||
| } | ||||
|  | ||||
| func TestDecode(t *testing.T) { | ||||
| 	// Bytes read: varies sometimes | ||||
| 	// jpeg: 1063892 | ||||
|   | ||||
| @@ -24,7 +24,7 @@ func printHelp() { | ||||
| 	fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]") | ||||
| } | ||||
|  | ||||
| func init() { | ||||
| func Init() { | ||||
| 	flag.Parse() | ||||
|  | ||||
| 	if *PrintVersion { | ||||
|   | ||||
							
								
								
									
										29
									
								
								common/render/render.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								common/render/render.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| package render | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func StringData(c *gin.Context, str string) { | ||||
| 	str = strings.TrimPrefix(str, "data: ") | ||||
| 	str = strings.TrimSuffix(str, "\r") | ||||
| 	c.Render(-1, common.CustomEvent{Data: "data: " + str}) | ||||
| 	c.Writer.Flush() | ||||
| } | ||||
|  | ||||
| func ObjectData(c *gin.Context, object interface{}) error { | ||||
| 	jsonData, err := json.Marshal(object) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error marshalling object: %w", err) | ||||
| 	} | ||||
| 	StringData(c, string(jsonData)) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func Done(c *gin.Context) { | ||||
| 	StringData(c, "[DONE]") | ||||
| } | ||||
| @@ -48,7 +48,7 @@ func Relay(c *gin.Context) { | ||||
| 		logger.Debugf(ctx, "request body: %s", string(requestBody)) | ||||
| 	} | ||||
| 	channelId := c.GetInt(ctxkey.ChannelId) | ||||
| 	userId := c.GetInt("id") | ||||
| 	userId := c.GetInt(ctxkey.Id) | ||||
| 	bizErr := relayHelper(c, relayMode) | ||||
| 	if bizErr == nil { | ||||
| 		monitor.Emit(channelId, true) | ||||
|   | ||||
							
								
								
									
										5
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								go.mod
									
									
									
									
									
								
							| @@ -18,12 +18,13 @@ require ( | ||||
| 	github.com/google/uuid v1.6.0 | ||||
| 	github.com/gorilla/websocket v1.5.1 | ||||
| 	github.com/jinzhu/copier v0.4.0 | ||||
| 	github.com/joho/godotenv v1.5.1 | ||||
| 	github.com/pkg/errors v0.9.1 | ||||
| 	github.com/pkoukk/tiktoken-go v0.1.7 | ||||
| 	github.com/smartystreets/goconvey v1.8.1 | ||||
| 	github.com/stretchr/testify v1.9.0 | ||||
| 	golang.org/x/crypto v0.23.0 | ||||
| 	golang.org/x/image v0.16.0 | ||||
| 	golang.org/x/image v0.18.0 | ||||
| 	gorm.io/driver/mysql v1.5.6 | ||||
| 	gorm.io/driver/postgres v1.5.7 | ||||
| 	gorm.io/driver/sqlite v1.5.5 | ||||
| @@ -79,7 +80,7 @@ require ( | ||||
| 	golang.org/x/net v0.25.0 // indirect | ||||
| 	golang.org/x/sync v0.7.0 // indirect | ||||
| 	golang.org/x/sys v0.20.0 // indirect | ||||
| 	golang.org/x/text v0.15.0 // indirect | ||||
| 	golang.org/x/text v0.16.0 // indirect | ||||
| 	google.golang.org/protobuf v1.34.1 // indirect | ||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||
| ) | ||||
|   | ||||
							
								
								
									
										59
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										59
									
								
								go.sum
									
									
									
									
									
								
							| @@ -1,40 +1,25 @@ | ||||
| filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= | ||||
| filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= | ||||
| github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA= | ||||
| github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= | ||||
| github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo= | ||||
| github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= | ||||
| github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= | ||||
| github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg= | ||||
| github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs= | ||||
| github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo= | ||||
| github.com/aws/aws-sdk-go-v2/credentials v1.17.15 h1:YDexlvDRCA8ems2T5IP1xkMtOZ1uLJOCJdTr0igs5zo= | ||||
| github.com/aws/aws-sdk-go-v2/credentials v1.17.15/go.mod h1:vxHggqW6hFNaeNC0WyXS3VdyjcV0a4KMUY4dKJ96buU= | ||||
| github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg= | ||||
| github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I= | ||||
| github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 h1:lf/8VTF2cM+N4SLzaYJERKEWAXq8MOMpZfU6wEPWsPk= | ||||
| github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7/go.mod h1:4SjkU7QiqK2M9oozyMzfZ/23LmUY+h3oFqhdeP5OMiI= | ||||
| github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0= | ||||
| github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc= | ||||
| github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 h1:4OYVp0705xu8yjdyoWix0r9wPIRXnIzzOoUpQVHIJ/g= | ||||
| github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7/go.mod h1:vd7ESTEvI76T2Na050gODNmNU7+OyKrIKroYTu4ABiI= | ||||
| github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU= | ||||
| github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg= | ||||
| github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3 h1:Fihjyd6DeNjcawBEGLH9dkIEUi6AdhucDKPE9nJ4QiY= | ||||
| github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3/go.mod h1:opvUj3ismqSCxYc+m4WIjPL0ewZGtvp0ess7cKvBPOQ= | ||||
| github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= | ||||
| github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= | ||||
| github.com/bytedance/sonic v1.11.5 h1:G00FYjjqll5iQ1PYXynbg/hyzqBqavH8Mo9/oTopd9k= | ||||
| github.com/bytedance/sonic v1.11.5/go.mod h1:X2PC2giUdj/Cv2lliWFLk6c/DUQok5rViJSemeB0wDw= | ||||
| github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= | ||||
| github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= | ||||
| github.com/bytedance/sonic/loader v0.1.0/go.mod h1:UmRT+IRTGKz/DAkzcEGzyVqQFJ7H9BqwBO3pm9H/+HY= | ||||
| github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= | ||||
| github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= | ||||
| github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= | ||||
| github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= | ||||
| github.com/cloudwego/base64x v0.1.3 h1:b5J/l8xolB7dyDTTmhJP2oTs5LdrjyrUFuNxdfq5hAg= | ||||
| github.com/cloudwego/base64x v0.1.3/go.mod h1:1+1K5BUHIQzyapgpF7LwvOGAEDicKtt1umPV+aN8pi8= | ||||
| github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= | ||||
| github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= | ||||
| github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= | ||||
| @@ -51,26 +36,16 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos | ||||
| github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= | ||||
| github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= | ||||
| github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= | ||||
| github.com/gin-contrib/cors v1.7.1 h1:s9SIppU/rk8enVvkzwiC2VK3UZ/0NNGsWfUKvV55rqs= | ||||
| github.com/gin-contrib/cors v1.7.1/go.mod h1:n/Zj7B4xyrgk/cX1WCX2dkzFfaNm/xJb6oIUk7WTtps= | ||||
| github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw= | ||||
| github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E= | ||||
| github.com/gin-contrib/gzip v1.0.0 h1:UKN586Po/92IDX6ie5CWLgMI81obiIp5nSP85T3wlTk= | ||||
| github.com/gin-contrib/gzip v1.0.0/go.mod h1:CtG7tQrPB3vIBo6Gat9FVUsis+1emjvQqd66ME5TdnE= | ||||
| github.com/gin-contrib/gzip v1.0.1 h1:HQ8ENHODeLY7a4g1Au/46Z92bdGFl74OhxcZble9WJE= | ||||
| github.com/gin-contrib/gzip v1.0.1/go.mod h1:njt428fdUNRvjuJf16tZMYZ2Yl+WQB53X5wmhDwXvC4= | ||||
| github.com/gin-contrib/sessions v1.0.0 h1:r5GLta4Oy5xo9rAwMHx8B4wLpeRGHMdz9NafzJAdP8Y= | ||||
| github.com/gin-contrib/sessions v1.0.0/go.mod h1:DN0f4bvpqMQElDdi+gNGScrP2QEI04IErRyMFyorUOI= | ||||
| github.com/gin-contrib/sessions v1.0.1 h1:3hsJyNs7v7N8OtelFmYXFrulAf6zSR7nW/putcPEHxI= | ||||
| github.com/gin-contrib/sessions v1.0.1/go.mod h1:ouxSFM24/OgIud5MJYQJLpy6AwxQ5EYO9yLhbtObGkM= | ||||
| github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= | ||||
| github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= | ||||
| github.com/gin-contrib/static v1.1.1 h1:XEvBd4DDLG1HBlyPBQU1XO8NlTpw6mgdqcPteetYA5k= | ||||
| github.com/gin-contrib/static v1.1.1/go.mod h1:yRGmar7+JYvbMLRPIi4H5TVVSBwULfT9vetnVD0IO74= | ||||
| github.com/gin-contrib/static v1.1.2 h1:c3kT4bFkUJn2aoRU3s6XnMjJT8J6nNWJkR0NglqmlZ4= | ||||
| github.com/gin-contrib/static v1.1.2/go.mod h1:Fw90ozjHCmZBWbgrsqrDvO28YbhKEKzKp8GixhR4yLw= | ||||
| github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= | ||||
| github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= | ||||
| github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= | ||||
| github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= | ||||
| github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= | ||||
| @@ -78,8 +53,6 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o | ||||
| github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= | ||||
| github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= | ||||
| github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= | ||||
| github.com/go-playground/validator/v10 v10.19.0 h1:ol+5Fu+cSq9JD7SoSqe04GMI92cbn0+wvQ3bZ8b/AU4= | ||||
| github.com/go-playground/validator/v10 v10.19.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= | ||||
| github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= | ||||
| github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= | ||||
| github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= | ||||
| @@ -87,8 +60,6 @@ github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq | ||||
| github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= | ||||
| github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= | ||||
| github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= | ||||
| github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= | ||||
| github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= | ||||
| github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= | ||||
| github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= | ||||
| github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= | ||||
| @@ -122,6 +93,8 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD | ||||
| github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= | ||||
| github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= | ||||
| github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||||
| github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= | ||||
| github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= | ||||
| github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= | ||||
| github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= | ||||
| github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= | ||||
| @@ -147,14 +120,10 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY | ||||
| github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= | ||||
| github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= | ||||
| github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= | ||||
| github.com/pelletier/go-toml/v2 v2.2.1 h1:9TA9+T8+8CUCO2+WYnDLCgrYi9+omqKXyjDtosvtEhg= | ||||
| github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= | ||||
| github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= | ||||
| github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= | ||||
| github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= | ||||
| github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= | ||||
| github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw= | ||||
| github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= | ||||
| github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw= | ||||
| github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= | ||||
| github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | ||||
| @@ -181,37 +150,23 @@ github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2 | ||||
| github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= | ||||
| github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= | ||||
| golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | ||||
| golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc= | ||||
| golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= | ||||
| golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= | ||||
| golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= | ||||
| golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= | ||||
| golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= | ||||
| golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= | ||||
| golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= | ||||
| golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= | ||||
| golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= | ||||
| golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw= | ||||
| golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs= | ||||
| golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= | ||||
| golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= | ||||
| golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= | ||||
| golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= | ||||
| golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= | ||||
| golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= | ||||
| golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= | ||||
| golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= | ||||
| golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= | ||||
| golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||||
| golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= | ||||
| golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||||
| golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= | ||||
| golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= | ||||
| golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= | ||||
| golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= | ||||
| golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= | ||||
| golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= | ||||
| golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= | ||||
| google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= | ||||
| google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= | ||||
| google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= | ||||
| google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= | ||||
| gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||
| @@ -228,8 +183,6 @@ gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4c | ||||
| gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E= | ||||
| gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE= | ||||
| gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= | ||||
| gorm.io/gorm v1.25.9 h1:wct0gxZIELDk8+ZqF/MVnHLkA1rvYlBWUMv2EdsK1g8= | ||||
| gorm.io/gorm v1.25.9/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= | ||||
| gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= | ||||
| gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= | ||||
| nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= | ||||
|   | ||||
							
								
								
									
										3
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								main.go
									
									
									
									
									
								
							| @@ -6,6 +6,7 @@ import ( | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-contrib/sessions/cookie" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	_ "github.com/joho/godotenv/autoload" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/client" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| @@ -23,6 +24,7 @@ import ( | ||||
| var buildFS embed.FS | ||||
|  | ||||
| func main() { | ||||
| 	common.Init() | ||||
| 	logger.SetupLogger() | ||||
| 	logger.SysLogf("One API %s started", common.Version) | ||||
| 	if os.Getenv("GIN_MODE") != "debug" { | ||||
| @@ -113,6 +115,7 @@ func main() { | ||||
| 	if port == "" { | ||||
| 		port = strconv.Itoa(*common.Port) | ||||
| 	} | ||||
| 	logger.SysLogf("server started on http://localhost:%s", port) | ||||
| 	err = server.Run(":" + port) | ||||
| 	if err != nil { | ||||
| 		logger.FatalLog("failed to start HTTP server: " + err.Error()) | ||||
|   | ||||
| @@ -4,6 +4,12 @@ import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| @@ -12,10 +18,6 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 | ||||
| @@ -89,6 +91,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena | ||||
|  | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	var usage model.Usage | ||||
| 	var documents []LibraryDocument | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| @@ -102,60 +105,48 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 5 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:5] != "data:" { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[5:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
|  | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	var documents []LibraryDocument | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var AIProxyLibraryResponse LibraryStreamResponse | ||||
| 			err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			if len(AIProxyLibraryResponse.Documents) != 0 { | ||||
| 				documents = AIProxyLibraryResponse.Documents | ||||
| 			} | ||||
| 			response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			response := documentsAIProxyLibrary(documents) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
|  | ||||
| 	for scanner.Scan() { | ||||
| 		data := scanner.Text() | ||||
| 		if len(data) < 5 || data[:5] != "data:" { | ||||
| 			continue | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 		data = data[5:] | ||||
|  | ||||
| 		var AIProxyLibraryResponse LibraryStreamResponse | ||||
| 		err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 			continue | ||||
| 		} | ||||
| 		if len(AIProxyLibraryResponse.Documents) != 0 { | ||||
| 			documents = AIProxyLibraryResponse.Documents | ||||
| 		} | ||||
| 		response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) | ||||
| 		err = render.ObjectData(c, response) | ||||
| 		if err != nil { | ||||
| 			logger.SysError(err.Error()) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := scanner.Err(); err != nil { | ||||
| 		logger.SysError("error reading stream: " + err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	response := documentsAIProxyLibrary(documents) | ||||
| 	err := render.ObjectData(c, response) | ||||
| 	if err != nil { | ||||
| 		logger.SysError(err.Error()) | ||||
| 	} | ||||
| 	render.Done(c) | ||||
|  | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -3,15 +3,17 @@ package ali | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r | ||||
| @@ -181,56 +183,43 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 5 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:5] != "data:" { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[5:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
|  | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	//lastResponseText := "" | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var aliResponse ChatResponse | ||||
| 			err := json.Unmarshal([]byte(data), &aliResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			if aliResponse.Usage.OutputTokens != 0 { | ||||
| 				usage.PromptTokens = aliResponse.Usage.InputTokens | ||||
| 				usage.CompletionTokens = aliResponse.Usage.OutputTokens | ||||
| 				usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens | ||||
| 			} | ||||
| 			response := streamResponseAli2OpenAI(&aliResponse) | ||||
| 			if response == nil { | ||||
| 				return true | ||||
| 			} | ||||
| 			//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) | ||||
| 			//lastResponseText = aliResponse.Output.Text | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
|  | ||||
| 	for scanner.Scan() { | ||||
| 		data := scanner.Text() | ||||
| 		if len(data) < 5 || data[:5] != "data:" { | ||||
| 			continue | ||||
| 		} | ||||
| 	}) | ||||
| 		data = data[5:] | ||||
|  | ||||
| 		var aliResponse ChatResponse | ||||
| 		err := json.Unmarshal([]byte(data), &aliResponse) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 			continue | ||||
| 		} | ||||
| 		if aliResponse.Usage.OutputTokens != 0 { | ||||
| 			usage.PromptTokens = aliResponse.Usage.InputTokens | ||||
| 			usage.CompletionTokens = aliResponse.Usage.OutputTokens | ||||
| 			usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens | ||||
| 		} | ||||
| 		response := streamResponseAli2OpenAI(&aliResponse) | ||||
| 		if response == nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		err = render.ObjectData(c, response) | ||||
| 		if err != nil { | ||||
| 			logger.SysError(err.Error()) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := scanner.Err(); err != nil { | ||||
| 		logger.SysError("error reading stream: " + err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	render.Done(c) | ||||
|  | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
|   | ||||
| @@ -5,4 +5,5 @@ var ModelList = []string{ | ||||
| 	"claude-3-haiku-20240307", | ||||
| 	"claude-3-sonnet-20240229", | ||||
| 	"claude-3-opus-20240229", | ||||
| 	"claude-3-5-sonnet-20240620", | ||||
| } | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| @@ -28,12 +29,30 @@ func stopReasonClaude2OpenAI(reason *string) string { | ||||
| 		return "stop" | ||||
| 	case "max_tokens": | ||||
| 		return "length" | ||||
| 	case "tool_use": | ||||
| 		return "tool_calls" | ||||
| 	default: | ||||
| 		return *reason | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 	claudeTools := make([]Tool, 0, len(textRequest.Tools)) | ||||
|  | ||||
| 	for _, tool := range textRequest.Tools { | ||||
| 		if params, ok := tool.Function.Parameters.(map[string]any); ok { | ||||
| 			claudeTools = append(claudeTools, Tool{ | ||||
| 				Name:        tool.Function.Name, | ||||
| 				Description: tool.Function.Description, | ||||
| 				InputSchema: InputSchema{ | ||||
| 					Type:       params["type"].(string), | ||||
| 					Properties: params["properties"], | ||||
| 					Required:   params["required"], | ||||
| 				}, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	claudeRequest := Request{ | ||||
| 		Model:       textRequest.Model, | ||||
| 		MaxTokens:   textRequest.MaxTokens, | ||||
| @@ -41,6 +60,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 		TopP:        textRequest.TopP, | ||||
| 		TopK:        textRequest.TopK, | ||||
| 		Stream:      textRequest.Stream, | ||||
| 		Tools:       claudeTools, | ||||
| 	} | ||||
| 	if len(claudeTools) > 0 { | ||||
| 		claudeToolChoice := struct { | ||||
| 			Type string `json:"type"` | ||||
| 			Name string `json:"name,omitempty"` | ||||
| 		}{Type: "auto"} // default value https://docs.anthropic.com/en/docs/build-with-claude/tool-use#controlling-claudes-output | ||||
| 		if choice, ok := textRequest.ToolChoice.(map[string]any); ok { | ||||
| 			if function, ok := choice["function"].(map[string]any); ok { | ||||
| 				claudeToolChoice.Type = "tool" | ||||
| 				claudeToolChoice.Name = function["name"].(string) | ||||
| 			} | ||||
| 		} else if toolChoiceType, ok := textRequest.ToolChoice.(string); ok { | ||||
| 			if toolChoiceType == "any" { | ||||
| 				claudeToolChoice.Type = toolChoiceType | ||||
| 			} | ||||
| 		} | ||||
| 		claudeRequest.ToolChoice = claudeToolChoice | ||||
| 	} | ||||
| 	if claudeRequest.MaxTokens == 0 { | ||||
| 		claudeRequest.MaxTokens = 4096 | ||||
| @@ -63,7 +100,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 		if message.IsStringContent() { | ||||
| 			content.Type = "text" | ||||
| 			content.Text = message.StringContent() | ||||
| 			if message.Role == "tool" { | ||||
| 				claudeMessage.Role = "user" | ||||
| 				content.Type = "tool_result" | ||||
| 				content.Content = content.Text | ||||
| 				content.Text = "" | ||||
| 				content.ToolUseId = message.ToolCallId | ||||
| 			} | ||||
| 			claudeMessage.Content = append(claudeMessage.Content, content) | ||||
| 			for i := range message.ToolCalls { | ||||
| 				inputParam := make(map[string]any) | ||||
| 				_ = json.Unmarshal([]byte(message.ToolCalls[i].Function.Arguments.(string)), &inputParam) | ||||
| 				claudeMessage.Content = append(claudeMessage.Content, Content{ | ||||
| 					Type:  "tool_use", | ||||
| 					Id:    message.ToolCalls[i].Id, | ||||
| 					Name:  message.ToolCalls[i].Function.Name, | ||||
| 					Input: inputParam, | ||||
| 				}) | ||||
| 			} | ||||
| 			claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) | ||||
| 			continue | ||||
| 		} | ||||
| @@ -96,16 +150,35 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo | ||||
| 	var response *Response | ||||
| 	var responseText string | ||||
| 	var stopReason string | ||||
| 	tools := make([]model.Tool, 0) | ||||
|  | ||||
| 	switch claudeResponse.Type { | ||||
| 	case "message_start": | ||||
| 		return nil, claudeResponse.Message | ||||
| 	case "content_block_start": | ||||
| 		if claudeResponse.ContentBlock != nil { | ||||
| 			responseText = claudeResponse.ContentBlock.Text | ||||
| 			if claudeResponse.ContentBlock.Type == "tool_use" { | ||||
| 				tools = append(tools, model.Tool{ | ||||
| 					Id:   claudeResponse.ContentBlock.Id, | ||||
| 					Type: "function", | ||||
| 					Function: model.Function{ | ||||
| 						Name:      claudeResponse.ContentBlock.Name, | ||||
| 						Arguments: "", | ||||
| 					}, | ||||
| 				}) | ||||
| 			} | ||||
| 		} | ||||
| 	case "content_block_delta": | ||||
| 		if claudeResponse.Delta != nil { | ||||
| 			responseText = claudeResponse.Delta.Text | ||||
| 			if claudeResponse.Delta.Type == "input_json_delta" { | ||||
| 				tools = append(tools, model.Tool{ | ||||
| 					Function: model.Function{ | ||||
| 						Arguments: claudeResponse.Delta.PartialJson, | ||||
| 					}, | ||||
| 				}) | ||||
| 			} | ||||
| 		} | ||||
| 	case "message_delta": | ||||
| 		if claudeResponse.Usage != nil { | ||||
| @@ -119,6 +192,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo | ||||
| 	} | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = responseText | ||||
| 	if len(tools) > 0 { | ||||
| 		choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ... | ||||
| 		choice.Delta.ToolCalls = tools | ||||
| 	} | ||||
| 	choice.Delta.Role = "assistant" | ||||
| 	finishReason := stopReasonClaude2OpenAI(&stopReason) | ||||
| 	if finishReason != "null" { | ||||
| @@ -135,12 +212,27 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { | ||||
| 	if len(claudeResponse.Content) > 0 { | ||||
| 		responseText = claudeResponse.Content[0].Text | ||||
| 	} | ||||
| 	tools := make([]model.Tool, 0) | ||||
| 	for _, v := range claudeResponse.Content { | ||||
| 		if v.Type == "tool_use" { | ||||
| 			args, _ := json.Marshal(v.Input) | ||||
| 			tools = append(tools, model.Tool{ | ||||
| 				Id:   v.Id, | ||||
| 				Type: "function", // compatible with other OpenAI derivative applications | ||||
| 				Function: model.Function{ | ||||
| 					Name:      v.Name, | ||||
| 					Arguments: string(args), | ||||
| 				}, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: model.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: responseText, | ||||
| 			Name:    nil, | ||||
| 			Role:      "assistant", | ||||
| 			Content:   responseText, | ||||
| 			Name:      nil, | ||||
| 			ToolCalls: tools, | ||||
| 		}, | ||||
| 		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), | ||||
| 	} | ||||
| @@ -169,64 +261,77 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 6 { | ||||
| 				continue | ||||
| 			} | ||||
| 			if !strings.HasPrefix(data, "data:") { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = strings.TrimPrefix(data, "data:") | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
|  | ||||
| 	common.SetEventStreamHeaders(c) | ||||
|  | ||||
| 	var usage model.Usage | ||||
| 	var modelName string | ||||
| 	var id string | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			// some implementations may add \r at the end of data | ||||
| 			data = strings.TrimSpace(data) | ||||
| 			var claudeResponse StreamResponse | ||||
| 			err := json.Unmarshal([]byte(data), &claudeResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			response, meta := StreamResponseClaude2OpenAI(&claudeResponse) | ||||
| 			if meta != nil { | ||||
| 				usage.PromptTokens += meta.Usage.InputTokens | ||||
| 				usage.CompletionTokens += meta.Usage.OutputTokens | ||||
| 	var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice | ||||
|  | ||||
| 	for scanner.Scan() { | ||||
| 		data := scanner.Text() | ||||
| 		if len(data) < 6 || !strings.HasPrefix(data, "data:") { | ||||
| 			continue | ||||
| 		} | ||||
| 		data = strings.TrimPrefix(data, "data:") | ||||
| 		data = strings.TrimSpace(data) | ||||
|  | ||||
| 		var claudeResponse StreamResponse | ||||
| 		err := json.Unmarshal([]byte(data), &claudeResponse) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		response, meta := StreamResponseClaude2OpenAI(&claudeResponse) | ||||
| 		if meta != nil { | ||||
| 			usage.PromptTokens += meta.Usage.InputTokens | ||||
| 			usage.CompletionTokens += meta.Usage.OutputTokens | ||||
| 			if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event. | ||||
| 				modelName = meta.Model | ||||
| 				id = fmt.Sprintf("chatcmpl-%s", meta.Id) | ||||
| 				return true | ||||
| 				continue | ||||
| 			} else { // finish_reason case | ||||
| 				if len(lastToolCallChoice.Delta.ToolCalls) > 0 { | ||||
| 					lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function | ||||
| 					if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments. | ||||
| 						lastArgs.Arguments = "{}" | ||||
| 						response.Choices[len(response.Choices)-1].Delta.Content = nil | ||||
| 						response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 			if response == nil { | ||||
| 				return true | ||||
| 			} | ||||
| 			response.Id = id | ||||
| 			response.Model = modelName | ||||
| 			response.Created = createdTime | ||||
| 			jsonStr, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	_ = resp.Body.Close() | ||||
| 		if response == nil { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		response.Id = id | ||||
| 		response.Model = modelName | ||||
| 		response.Created = createdTime | ||||
|  | ||||
| 		for _, choice := range response.Choices { | ||||
| 			if len(choice.Delta.ToolCalls) > 0 { | ||||
| 				lastToolCallChoice = choice | ||||
| 			} | ||||
| 		} | ||||
| 		err = render.ObjectData(c, response) | ||||
| 		if err != nil { | ||||
| 			logger.SysError(err.Error()) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := scanner.Err(); err != nil { | ||||
| 		logger.SysError("error reading stream: " + err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	render.Done(c) | ||||
|  | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -16,6 +16,12 @@ type Content struct { | ||||
| 	Type   string       `json:"type"` | ||||
| 	Text   string       `json:"text,omitempty"` | ||||
| 	Source *ImageSource `json:"source,omitempty"` | ||||
| 	// tool_calls | ||||
| 	Id        string `json:"id,omitempty"` | ||||
| 	Name      string `json:"name,omitempty"` | ||||
| 	Input     any    `json:"input,omitempty"` | ||||
| 	Content   string `json:"content,omitempty"` | ||||
| 	ToolUseId string `json:"tool_use_id,omitempty"` | ||||
| } | ||||
|  | ||||
| type Message struct { | ||||
| @@ -23,6 +29,18 @@ type Message struct { | ||||
| 	Content []Content `json:"content"` | ||||
| } | ||||
|  | ||||
| type Tool struct { | ||||
| 	Name        string      `json:"name"` | ||||
| 	Description string      `json:"description,omitempty"` | ||||
| 	InputSchema InputSchema `json:"input_schema"` | ||||
| } | ||||
|  | ||||
| type InputSchema struct { | ||||
| 	Type       string `json:"type"` | ||||
| 	Properties any    `json:"properties,omitempty"` | ||||
| 	Required   any    `json:"required,omitempty"` | ||||
| } | ||||
|  | ||||
| type Request struct { | ||||
| 	Model         string    `json:"model"` | ||||
| 	Messages      []Message `json:"messages"` | ||||
| @@ -33,6 +51,8 @@ type Request struct { | ||||
| 	Temperature   float64   `json:"temperature,omitempty"` | ||||
| 	TopP          float64   `json:"top_p,omitempty"` | ||||
| 	TopK          int       `json:"top_k,omitempty"` | ||||
| 	Tools         []Tool    `json:"tools,omitempty"` | ||||
| 	ToolChoice    any       `json:"tool_choice,omitempty"` | ||||
| 	//Metadata    `json:"metadata,omitempty"` | ||||
| } | ||||
|  | ||||
| @@ -61,6 +81,7 @@ type Response struct { | ||||
| type Delta struct { | ||||
| 	Type         string  `json:"type"` | ||||
| 	Text         string  `json:"text"` | ||||
| 	PartialJson  string  `json:"partial_json,omitempty"` | ||||
| 	StopReason   *string `json:"stop_reason"` | ||||
| 	StopSequence *string `json:"stop_sequence"` | ||||
| } | ||||
|   | ||||
| @@ -6,6 +6,7 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
|  | ||||
| @@ -33,12 +34,13 @@ func wrapErr(err error) *relaymodel.ErrorWithStatusCode { | ||||
|  | ||||
| // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html | ||||
| var awsModelIDMap = map[string]string{ | ||||
| 	"claude-instant-1.2":       "anthropic.claude-instant-v1", | ||||
| 	"claude-2.0":               "anthropic.claude-v2", | ||||
| 	"claude-2.1":               "anthropic.claude-v2:1", | ||||
| 	"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", | ||||
| 	"claude-3-opus-20240229":   "anthropic.claude-3-opus-20240229-v1:0", | ||||
| 	"claude-3-haiku-20240307":  "anthropic.claude-3-haiku-20240307-v1:0", | ||||
| 	"claude-instant-1.2":         "anthropic.claude-instant-v1", | ||||
| 	"claude-2.0":                 "anthropic.claude-v2", | ||||
| 	"claude-2.1":                 "anthropic.claude-v2:1", | ||||
| 	"claude-3-sonnet-20240229":   "anthropic.claude-3-sonnet-20240229-v1:0", | ||||
| 	"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", | ||||
| 	"claude-3-opus-20240229":     "anthropic.claude-3-opus-20240229-v1:0", | ||||
| 	"claude-3-haiku-20240307":    "anthropic.claude-3-haiku-20240307-v1:0", | ||||
| } | ||||
|  | ||||
| func awsModelID(requestModel string) (string, error) { | ||||
| @@ -142,6 +144,8 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	var usage relaymodel.Usage | ||||
| 	var id string | ||||
| 	var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice | ||||
|  | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		event, ok := <-stream.Events() | ||||
| 		if !ok { | ||||
| @@ -162,8 +166,19 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E | ||||
| 			if meta != nil { | ||||
| 				usage.PromptTokens += meta.Usage.InputTokens | ||||
| 				usage.CompletionTokens += meta.Usage.OutputTokens | ||||
| 				id = fmt.Sprintf("chatcmpl-%s", meta.Id) | ||||
| 				return true | ||||
| 				if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event. | ||||
| 					id = fmt.Sprintf("chatcmpl-%s", meta.Id) | ||||
| 					return true | ||||
| 				} else { // finish_reason case | ||||
| 					if len(lastToolCallChoice.Delta.ToolCalls) > 0 { | ||||
| 						lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function | ||||
| 						if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments. | ||||
| 							lastArgs.Arguments = "{}" | ||||
| 							response.Choices[len(response.Choices)-1].Delta.Content = nil | ||||
| 							response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 			if response == nil { | ||||
| 				return true | ||||
| @@ -171,6 +186,12 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E | ||||
| 			response.Id = id | ||||
| 			response.Model = c.GetString(ctxkey.OriginalModel) | ||||
| 			response.Created = createdTime | ||||
|  | ||||
| 			for _, choice := range response.Choices { | ||||
| 				if len(choice.Delta.ToolCalls) > 0 { | ||||
| 					lastToolCallChoice = choice | ||||
| 				} | ||||
| 			} | ||||
| 			jsonStr, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
|   | ||||
| @@ -9,9 +9,12 @@ type Request struct { | ||||
| 	// AnthropicVersion should be "bedrock-2023-05-31" | ||||
| 	AnthropicVersion string              `json:"anthropic_version"` | ||||
| 	Messages         []anthropic.Message `json:"messages"` | ||||
| 	System           string              `json:"system,omitempty"` | ||||
| 	MaxTokens        int                 `json:"max_tokens,omitempty"` | ||||
| 	Temperature      float64             `json:"temperature,omitempty"` | ||||
| 	TopP             float64             `json:"top_p,omitempty"` | ||||
| 	TopK             int                 `json:"top_k,omitempty"` | ||||
| 	StopSequences    []string            `json:"stop_sequences,omitempty"` | ||||
| 	Tools            []anthropic.Tool    `json:"tools,omitempty"` | ||||
| 	ToolChoice       any                 `json:"tool_choice,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -5,6 +5,13 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/client" | ||||
| @@ -12,11 +19,6 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 | ||||
| @@ -137,59 +139,41 @@ func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.Embeddin | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	var usage model.Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 6 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[6:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	scanner.Split(bufio.ScanLines) | ||||
|  | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var baiduResponse ChatStreamResponse | ||||
| 			err := json.Unmarshal([]byte(data), &baiduResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			if baiduResponse.Usage.TotalTokens != 0 { | ||||
| 				usage.TotalTokens = baiduResponse.Usage.TotalTokens | ||||
| 				usage.PromptTokens = baiduResponse.Usage.PromptTokens | ||||
| 				usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens | ||||
| 			} | ||||
| 			response := streamResponseBaidu2OpenAI(&baiduResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
|  | ||||
| 	for scanner.Scan() { | ||||
| 		data := scanner.Text() | ||||
| 		if len(data) < 6 { | ||||
| 			continue | ||||
| 		} | ||||
| 	}) | ||||
| 		data = data[6:] | ||||
|  | ||||
| 		var baiduResponse ChatStreamResponse | ||||
| 		err := json.Unmarshal([]byte(data), &baiduResponse) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 			continue | ||||
| 		} | ||||
| 		if baiduResponse.Usage.TotalTokens != 0 { | ||||
| 			usage.TotalTokens = baiduResponse.Usage.TotalTokens | ||||
| 			usage.PromptTokens = baiduResponse.Usage.PromptTokens | ||||
| 			usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens | ||||
| 		} | ||||
| 		response := streamResponseBaidu2OpenAI(&baiduResponse) | ||||
| 		err = render.ObjectData(c, response) | ||||
| 		if err != nil { | ||||
| 			logger.SysError(err.Error()) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := scanner.Err(); err != nil { | ||||
| 		logger.SysError("error reading stream: " + err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	render.Done(c) | ||||
|  | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
|   | ||||
| @@ -2,8 +2,8 @@ package cloudflare | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| @@ -17,21 +17,20 @@ import ( | ||||
| ) | ||||
|  | ||||
| func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
|     var promptBuilder strings.Builder | ||||
|     for _, message := range textRequest.Messages { | ||||
|         promptBuilder.WriteString(message.StringContent()) | ||||
|         promptBuilder.WriteString("\n")  // 添加换行符来分隔每个消息 | ||||
|     } | ||||
| 	var promptBuilder strings.Builder | ||||
| 	for _, message := range textRequest.Messages { | ||||
| 		promptBuilder.WriteString(message.StringContent()) | ||||
| 		promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息 | ||||
| 	} | ||||
|  | ||||
|     return &Request{ | ||||
|         MaxTokens:   textRequest.MaxTokens, | ||||
|         Prompt:      promptBuilder.String(), | ||||
|         Stream:      textRequest.Stream, | ||||
|         Temperature: textRequest.Temperature, | ||||
|     } | ||||
| 	return &Request{ | ||||
| 		MaxTokens:   textRequest.MaxTokens, | ||||
| 		Prompt:      promptBuilder.String(), | ||||
| 		Stream:      textRequest.Stream, | ||||
| 		Temperature: textRequest.Temperature, | ||||
| 	} | ||||
| } | ||||
|  | ||||
|  | ||||
| func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse { | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| @@ -63,67 +62,54 @@ func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai | ||||
|  | ||||
| func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := bytes.IndexByte(data, '\n'); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	scanner.Split(bufio.ScanLines) | ||||
|  | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < len("data: ") { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = strings.TrimPrefix(data, "data: ") | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	id := helper.GetResponseID(c) | ||||
| 	responseModel := c.GetString("original_model") | ||||
| 	var responseText string | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			// some implementations may add \r at the end of data | ||||
| 			data = strings.TrimSuffix(data, "\r") | ||||
| 			var cloudflareResponse StreamResponse | ||||
| 			err := json.Unmarshal([]byte(data), &cloudflareResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			response := StreamResponseCloudflare2OpenAI(&cloudflareResponse) | ||||
| 			if response == nil { | ||||
| 				return true | ||||
| 			} | ||||
| 			responseText += cloudflareResponse.Response | ||||
| 			response.Id = id | ||||
| 			response.Model = responseModel | ||||
| 			jsonStr, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
|  | ||||
| 	for scanner.Scan() { | ||||
| 		data := scanner.Text() | ||||
| 		if len(data) < len("data: ") { | ||||
| 			continue | ||||
| 		} | ||||
| 	}) | ||||
| 	_ = resp.Body.Close() | ||||
| 		data = strings.TrimPrefix(data, "data: ") | ||||
| 		data = strings.TrimSuffix(data, "\r") | ||||
|  | ||||
| 		var cloudflareResponse StreamResponse | ||||
| 		err := json.Unmarshal([]byte(data), &cloudflareResponse) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		response := StreamResponseCloudflare2OpenAI(&cloudflareResponse) | ||||
| 		if response == nil { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		responseText += cloudflareResponse.Response | ||||
| 		response.Id = id | ||||
| 		response.Model = responseModel | ||||
|  | ||||
| 		err = render.ObjectData(c, response) | ||||
| 		if err != nil { | ||||
| 			logger.SysError(err.Error()) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := scanner.Err(); err != nil { | ||||
| 		logger.SysError("error reading stream: " + err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	render.Done(c) | ||||
|  | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	usage := openai.ResponseText2Usage(responseText, responseModel, promptTokens) | ||||
| 	return nil, usage | ||||
| } | ||||
|   | ||||
| @@ -2,9 +2,9 @@ package cohere | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| @@ -134,66 +134,53 @@ func ResponseCohere2OpenAI(cohereResponse *Response) *openai.TextResponse { | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	createdTime := helper.GetTimestamp() | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := bytes.IndexByte(data, '\n'); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	scanner.Split(bufio.ScanLines) | ||||
|  | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	var usage model.Usage | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			// some implementations may add \r at the end of data | ||||
| 			data = strings.TrimSuffix(data, "\r") | ||||
| 			var cohereResponse StreamResponse | ||||
| 			err := json.Unmarshal([]byte(data), &cohereResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			response, meta := StreamResponseCohere2OpenAI(&cohereResponse) | ||||
| 			if meta != nil { | ||||
| 				usage.PromptTokens += meta.Meta.Tokens.InputTokens | ||||
| 				usage.CompletionTokens += meta.Meta.Tokens.OutputTokens | ||||
| 				return true | ||||
| 			} | ||||
| 			if response == nil { | ||||
| 				return true | ||||
| 			} | ||||
| 			response.Id = fmt.Sprintf("chatcmpl-%d", createdTime) | ||||
| 			response.Model = c.GetString("original_model") | ||||
| 			response.Created = createdTime | ||||
| 			jsonStr, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
|  | ||||
| 	for scanner.Scan() { | ||||
| 		data := scanner.Text() | ||||
| 		data = strings.TrimSuffix(data, "\r") | ||||
|  | ||||
| 		var cohereResponse StreamResponse | ||||
| 		err := json.Unmarshal([]byte(data), &cohereResponse) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 			continue | ||||
| 		} | ||||
| 	}) | ||||
| 	_ = resp.Body.Close() | ||||
|  | ||||
| 		response, meta := StreamResponseCohere2OpenAI(&cohereResponse) | ||||
| 		if meta != nil { | ||||
| 			usage.PromptTokens += meta.Meta.Tokens.InputTokens | ||||
| 			usage.CompletionTokens += meta.Meta.Tokens.OutputTokens | ||||
| 			continue | ||||
| 		} | ||||
| 		if response == nil { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		response.Id = fmt.Sprintf("chatcmpl-%d", createdTime) | ||||
| 		response.Model = c.GetString("original_model") | ||||
| 		response.Created = createdTime | ||||
|  | ||||
| 		err = render.ObjectData(c, response) | ||||
| 		if err != nil { | ||||
| 			logger.SysError(err.Error()) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := scanner.Err(); err != nil { | ||||
| 		logger.SysError("error reading stream: " + err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	render.Done(c) | ||||
|  | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -4,6 +4,11 @@ import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/conv" | ||||
| @@ -12,9 +17,6 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://www.coze.com/open | ||||
| @@ -109,69 +111,54 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 	var responseText string | ||||
| 	createdTime := helper.GetTimestamp() | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 5 { | ||||
| 				continue | ||||
| 			} | ||||
| 			if !strings.HasPrefix(data, "data:") { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = strings.TrimPrefix(data, "data:") | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	scanner.Split(bufio.ScanLines) | ||||
|  | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	var modelName string | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			// some implementations may add \r at the end of data | ||||
| 			data = strings.TrimSuffix(data, "\r") | ||||
| 			var cozeResponse StreamResponse | ||||
| 			err := json.Unmarshal([]byte(data), &cozeResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			response, _ := StreamResponseCoze2OpenAI(&cozeResponse) | ||||
| 			if response == nil { | ||||
| 				return true | ||||
| 			} | ||||
| 			for _, choice := range response.Choices { | ||||
| 				responseText += conv.AsString(choice.Delta.Content) | ||||
| 			} | ||||
| 			response.Model = modelName | ||||
| 			response.Created = createdTime | ||||
| 			jsonStr, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
|  | ||||
| 	for scanner.Scan() { | ||||
| 		data := scanner.Text() | ||||
| 		if len(data) < 5 || !strings.HasPrefix(data, "data:") { | ||||
| 			continue | ||||
| 		} | ||||
| 	}) | ||||
| 	_ = resp.Body.Close() | ||||
| 		data = strings.TrimPrefix(data, "data:") | ||||
| 		data = strings.TrimSuffix(data, "\r") | ||||
|  | ||||
| 		var cozeResponse StreamResponse | ||||
| 		err := json.Unmarshal([]byte(data), &cozeResponse) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		response, _ := StreamResponseCoze2OpenAI(&cozeResponse) | ||||
| 		if response == nil { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		for _, choice := range response.Choices { | ||||
| 			responseText += conv.AsString(choice.Delta.Content) | ||||
| 		} | ||||
| 		response.Model = modelName | ||||
| 		response.Created = createdTime | ||||
|  | ||||
| 		err = render.ObjectData(c, response) | ||||
| 		if err != nil { | ||||
| 			logger.SysError(err.Error()) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := scanner.Err(); err != nil { | ||||
| 		logger.SysError("error reading stream: " + err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	render.Done(c) | ||||
|  | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	return nil, &responseText | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| @@ -245,8 +246,10 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = geminiResponse.GetResponseText() | ||||
| 	choice.FinishReason = &constant.StopFinishReason | ||||
| 	//choice.FinishReason = &constant.StopFinishReason | ||||
| 	var response openai.ChatCompletionsStreamResponse | ||||
| 	response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID()) | ||||
| 	response.Created = helper.GetTimestamp() | ||||
| 	response.Object = "chat.completion.chunk" | ||||
| 	response.Model = "gemini" | ||||
| 	response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} | ||||
| @@ -273,64 +276,50 @@ func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.Embeddi | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { | ||||
| 	responseText := "" | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			data = strings.TrimSpace(data) | ||||
| 			if !strings.HasPrefix(data, "data: ") { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = strings.TrimPrefix(data, "data: ") | ||||
| 			data = strings.TrimSuffix(data, "\"") | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	scanner.Split(bufio.ScanLines) | ||||
|  | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var geminiResponse ChatResponse | ||||
| 			err := json.Unmarshal([]byte(data), &geminiResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			response := streamResponseGeminiChat2OpenAI(&geminiResponse) | ||||
| 			if response == nil { | ||||
| 				return true | ||||
| 			} | ||||
| 			responseText += response.Choices[0].Delta.StringContent() | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
|  | ||||
| 	for scanner.Scan() { | ||||
| 		data := scanner.Text() | ||||
| 		data = strings.TrimSpace(data) | ||||
| 		if !strings.HasPrefix(data, "data: ") { | ||||
| 			continue | ||||
| 		} | ||||
| 	}) | ||||
| 		data = strings.TrimPrefix(data, "data: ") | ||||
| 		data = strings.TrimSuffix(data, "\"") | ||||
|  | ||||
| 		var geminiResponse ChatResponse | ||||
| 		err := json.Unmarshal([]byte(data), &geminiResponse) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		response := streamResponseGeminiChat2OpenAI(&geminiResponse) | ||||
| 		if response == nil { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		responseText += response.Choices[0].Delta.StringContent() | ||||
|  | ||||
| 		err = render.ObjectData(c, response) | ||||
| 		if err != nil { | ||||
| 			logger.SysError(err.Error()) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := scanner.Err(); err != nil { | ||||
| 		logger.SysError("error reading stream: " + err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	render.Done(c) | ||||
|  | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
|  | ||||
| 	return nil, responseText | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -5,12 +5,14 @@ import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/image" | ||||
| @@ -105,54 +107,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "}\n"); i >= 0 { | ||||
| 			return i + 2, data[0:i], nil | ||||
| 			return i + 2, data[0 : i+1], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := strings.TrimPrefix(scanner.Text(), "}") | ||||
| 			dataChan <- data + "}" | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
|  | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var ollamaResponse ChatResponse | ||||
| 			err := json.Unmarshal([]byte(data), &ollamaResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			if ollamaResponse.EvalCount != 0 { | ||||
| 				usage.PromptTokens = ollamaResponse.PromptEvalCount | ||||
| 				usage.CompletionTokens = ollamaResponse.EvalCount | ||||
| 				usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount | ||||
| 			} | ||||
| 			response := streamResponseOllama2OpenAI(&ollamaResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
|  | ||||
| 	for scanner.Scan() { | ||||
| 		data := strings.TrimPrefix(scanner.Text(), "}") | ||||
| 		data = data + "}" | ||||
|  | ||||
| 		var ollamaResponse ChatResponse | ||||
| 		err := json.Unmarshal([]byte(data), &ollamaResponse) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 			continue | ||||
| 		} | ||||
| 	}) | ||||
|  | ||||
| 		if ollamaResponse.EvalCount != 0 { | ||||
| 			usage.PromptTokens = ollamaResponse.PromptEvalCount | ||||
| 			usage.CompletionTokens = ollamaResponse.EvalCount | ||||
| 			usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount | ||||
| 		} | ||||
|  | ||||
| 		response := streamResponseOllama2OpenAI(&ollamaResponse) | ||||
| 		err = render.ObjectData(c, response) | ||||
| 		if err != nil { | ||||
| 			logger.SysError(err.Error()) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := scanner.Err(); err != nil { | ||||
| 		logger.SysError("error reading stream: " + err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	render.Done(c) | ||||
|  | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -4,15 +4,17 @@ import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/conv" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -24,88 +26,68 @@ const ( | ||||
| func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { | ||||
| 	responseText := "" | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	scanner.Split(bufio.ScanLines) | ||||
| 	var usage *model.Usage | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < dataPrefixLength { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done { | ||||
| 				continue | ||||
| 			} | ||||
| 			if strings.HasPrefix(data[dataPrefixLength:], done) { | ||||
| 				dataChan <- data | ||||
| 				continue | ||||
| 			} | ||||
| 			switch relayMode { | ||||
| 			case relaymode.ChatCompletions: | ||||
| 				var streamResponse ChatCompletionsStreamResponse | ||||
| 				err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) | ||||
| 				if err != nil { | ||||
| 					logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 					dataChan <- data // if error happened, pass the data to client | ||||
| 					continue         // just ignore the error | ||||
| 				} | ||||
| 				if len(streamResponse.Choices) == 0 { | ||||
| 					// but for empty choice, we should not pass it to client, this is for azure | ||||
| 					continue // just ignore empty choice | ||||
| 				} | ||||
| 				dataChan <- data | ||||
| 				for _, choice := range streamResponse.Choices { | ||||
| 					responseText += conv.AsString(choice.Delta.Content) | ||||
| 				} | ||||
| 				if streamResponse.Usage != nil { | ||||
| 					usage = streamResponse.Usage | ||||
| 				} | ||||
| 			case relaymode.Completions: | ||||
| 				dataChan <- data | ||||
| 				var streamResponse CompletionsStreamResponse | ||||
| 				err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) | ||||
| 				if err != nil { | ||||
| 					logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 					continue | ||||
| 				} | ||||
| 				for _, choice := range streamResponse.Choices { | ||||
| 					responseText += choice.Text | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
|  | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			if strings.HasPrefix(data, "data: [DONE]") { | ||||
| 				data = data[:12] | ||||
| 			} | ||||
| 			// some implementations may add \r at the end of data | ||||
| 			data = strings.TrimSuffix(data, "\r") | ||||
| 			c.Render(-1, common.CustomEvent{Data: data}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			return false | ||||
|  | ||||
| 	for scanner.Scan() { | ||||
| 		data := scanner.Text() | ||||
| 		if len(data) < dataPrefixLength { // ignore blank line or wrong format | ||||
| 			continue | ||||
| 		} | ||||
| 	}) | ||||
| 		if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done { | ||||
| 			continue | ||||
| 		} | ||||
| 		if strings.HasPrefix(data[dataPrefixLength:], done) { | ||||
| 			render.StringData(c, data) | ||||
| 			continue | ||||
| 		} | ||||
| 		switch relayMode { | ||||
| 		case relaymode.ChatCompletions: | ||||
| 			var streamResponse ChatCompletionsStreamResponse | ||||
| 			err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				render.StringData(c, data) // if error happened, pass the data to client | ||||
| 				continue                   // just ignore the error | ||||
| 			} | ||||
| 			if len(streamResponse.Choices) == 0 { | ||||
| 				// but for empty choice, we should not pass it to client, this is for azure | ||||
| 				continue // just ignore empty choice | ||||
| 			} | ||||
| 			render.StringData(c, data) | ||||
| 			for _, choice := range streamResponse.Choices { | ||||
| 				responseText += conv.AsString(choice.Delta.Content) | ||||
| 			} | ||||
| 			if streamResponse.Usage != nil { | ||||
| 				usage = streamResponse.Usage | ||||
| 			} | ||||
| 		case relaymode.Completions: | ||||
| 			render.StringData(c, data) | ||||
| 			var streamResponse CompletionsStreamResponse | ||||
| 			err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				continue | ||||
| 			} | ||||
| 			for _, choice := range streamResponse.Choices { | ||||
| 				responseText += choice.Text | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := scanner.Err(); err != nil { | ||||
| 		logger.SysError("error reading stream: " + err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	render.Done(c) | ||||
|  | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil | ||||
| 	} | ||||
|  | ||||
| 	return nil, responseText, usage | ||||
| } | ||||
|  | ||||
| @@ -149,7 +131,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st | ||||
| 		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	if textResponse.Usage.TotalTokens == 0 { | ||||
| 	if textResponse.Usage.TotalTokens == 0 || (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) { | ||||
| 		completionTokens := 0 | ||||
| 		for _, choice := range textResponse.Choices { | ||||
| 			completionTokens += CountTokenText(choice.Message.StringContent(), modelName) | ||||
|   | ||||
| @@ -3,6 +3,10 @@ package palm | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| @@ -11,8 +15,6 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body | ||||
| @@ -77,58 +79,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 	responseText := "" | ||||
| 	responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID()) | ||||
| 	createdTime := helper.GetTimestamp() | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		responseBody, err := io.ReadAll(resp.Body) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error reading stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error closing stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		var palmResponse ChatResponse | ||||
| 		err = json.Unmarshal(responseBody, &palmResponse) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) | ||||
| 		fullTextResponse.Id = responseId | ||||
| 		fullTextResponse.Created = createdTime | ||||
| 		if len(palmResponse.Candidates) > 0 { | ||||
| 			responseText = palmResponse.Candidates[0].Content | ||||
| 		} | ||||
| 		jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		dataChan <- string(jsonResponse) | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
|  | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + data}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
|  | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		logger.SysError("error reading stream response: " + err.Error()) | ||||
| 		err := resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
|  | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
|  | ||||
| 	var palmResponse ChatResponse | ||||
| 	err = json.Unmarshal(responseBody, &palmResponse) | ||||
| 	if err != nil { | ||||
| 		logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
|  | ||||
| 	fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) | ||||
| 	fullTextResponse.Id = responseId | ||||
| 	fullTextResponse.Created = createdTime | ||||
| 	if len(palmResponse.Candidates) > 0 { | ||||
| 		responseText = palmResponse.Candidates[0].Content | ||||
| 	} | ||||
|  | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
|  | ||||
| 	err = render.ObjectData(c, string(jsonResponse)) | ||||
| 	if err != nil { | ||||
| 		logger.SysError(err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	render.Done(c) | ||||
|  | ||||
| 	return nil, responseText | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -2,35 +2,43 @@ package tencent | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://cloud.tencent.com/document/api/1729/101837 | ||||
|  | ||||
| type Adaptor struct { | ||||
| 	Sign string | ||||
| 	Sign      string | ||||
| 	Action    string | ||||
| 	Version   string | ||||
| 	Timestamp int64 | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
|  | ||||
| 	a.Action = "ChatCompletions" | ||||
| 	a.Version = "2023-09-01" | ||||
| 	a.Timestamp = helper.GetTimestamp() | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	return fmt.Sprintf("%s/hyllm/v1/chat/completions", meta.BaseURL), nil | ||||
| 	return meta.BaseURL + "/", nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| 	adaptor.SetupCommonRequestHeader(c, req, meta) | ||||
| 	req.Header.Set("Authorization", a.Sign) | ||||
| 	req.Header.Set("X-TC-Action", meta.ActualModelName) | ||||
| 	req.Header.Set("X-TC-Action", a.Action) | ||||
| 	req.Header.Set("X-TC-Version", a.Version) | ||||
| 	req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10)) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @@ -40,15 +48,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | ||||
| 	} | ||||
| 	apiKey := c.Request.Header.Get("Authorization") | ||||
| 	apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 	appId, secretId, secretKey, err := ParseConfig(apiKey) | ||||
| 	_, secretId, secretKey, err := ParseConfig(apiKey) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	tencentRequest := ConvertRequest(*request) | ||||
| 	tencentRequest.AppId = appId | ||||
| 	tencentRequest.SecretId = secretId | ||||
| 	// we have to calculate the sign here | ||||
| 	a.Sign = GetSign(*tencentRequest, secretKey) | ||||
| 	a.Sign = GetSign(*tencentRequest, a, secretId, secretKey) | ||||
| 	return tencentRequest, nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -1,7 +1,8 @@ | ||||
| package tencent | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"ChatPro", | ||||
| 	"ChatStd", | ||||
| 	"hunyuan", | ||||
| 	"hunyuan-lite", | ||||
| 	"hunyuan-standard", | ||||
| 	"hunyuan-standard-256K", | ||||
| 	"hunyuan-pro", | ||||
| } | ||||
|   | ||||
| @@ -3,11 +3,18 @@ package tencent | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"crypto/hmac" | ||||
| 	"crypto/sha1" | ||||
| 	"encoding/base64" | ||||
| 	"crypto/sha256" | ||||
| 	"encoding/hex" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/conv" | ||||
| @@ -17,36 +24,23 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://cloud.tencent.com/document/product/1729/97732 | ||||
|  | ||||
| func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 	messages := make([]Message, 0, len(request.Messages)) | ||||
| 	messages := make([]*Message, 0, len(request.Messages)) | ||||
| 	for i := 0; i < len(request.Messages); i++ { | ||||
| 		message := request.Messages[i] | ||||
| 		messages = append(messages, Message{ | ||||
| 		messages = append(messages, &Message{ | ||||
| 			Content: message.StringContent(), | ||||
| 			Role:    message.Role, | ||||
| 		}) | ||||
| 	} | ||||
| 	stream := 0 | ||||
| 	if request.Stream { | ||||
| 		stream = 1 | ||||
| 	} | ||||
| 	return &ChatRequest{ | ||||
| 		Timestamp:   helper.GetTimestamp(), | ||||
| 		Expired:     helper.GetTimestamp() + 24*60*60, | ||||
| 		QueryID:     random.GetUUID(), | ||||
| 		Temperature: request.Temperature, | ||||
| 		TopP:        request.TopP, | ||||
| 		Stream:      stream, | ||||
| 		Model:       &request.Model, | ||||
| 		Stream:      &request.Stream, | ||||
| 		Messages:    messages, | ||||
| 		TopP:        &request.TopP, | ||||
| 		Temperature: &request.Temperature, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -54,7 +48,11 @@ func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Usage:   response.Usage, | ||||
| 		Usage: model.Usage{ | ||||
| 			PromptTokens:     response.Usage.PromptTokens, | ||||
| 			CompletionTokens: response.Usage.CompletionTokens, | ||||
| 			TotalTokens:      response.Usage.TotalTokens, | ||||
| 		}, | ||||
| 	} | ||||
| 	if len(response.Choices) > 0 { | ||||
| 		choice := openai.TextResponseChoice{ | ||||
| @@ -91,69 +89,52 @@ func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCom | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { | ||||
| 	var responseText string | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 5 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:5] != "data:" { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[5:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	scanner.Split(bufio.ScanLines) | ||||
|  | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var TencentResponse ChatResponse | ||||
| 			err := json.Unmarshal([]byte(data), &TencentResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			response := streamResponseTencent2OpenAI(&TencentResponse) | ||||
| 			if len(response.Choices) != 0 { | ||||
| 				responseText += conv.AsString(response.Choices[0].Delta.Content) | ||||
| 			} | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
|  | ||||
| 	for scanner.Scan() { | ||||
| 		data := scanner.Text() | ||||
| 		if len(data) < 5 || !strings.HasPrefix(data, "data:") { | ||||
| 			continue | ||||
| 		} | ||||
| 	}) | ||||
| 		data = strings.TrimPrefix(data, "data:") | ||||
|  | ||||
| 		var tencentResponse ChatResponse | ||||
| 		err := json.Unmarshal([]byte(data), &tencentResponse) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		response := streamResponseTencent2OpenAI(&tencentResponse) | ||||
| 		if len(response.Choices) != 0 { | ||||
| 			responseText += conv.AsString(response.Choices[0].Delta.Content) | ||||
| 		} | ||||
|  | ||||
| 		err = render.ObjectData(c, response) | ||||
| 		if err != nil { | ||||
| 			logger.SysError(err.Error()) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := scanner.Err(); err != nil { | ||||
| 		logger.SysError("error reading stream: " + err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	render.Done(c) | ||||
|  | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
|  | ||||
| 	return nil, responseText | ||||
| } | ||||
|  | ||||
| func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	var TencentResponse ChatResponse | ||||
| 	var responseP ChatResponseP | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| @@ -162,10 +143,11 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &TencentResponse) | ||||
| 	err = json.Unmarshal(responseBody, &responseP) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	TencentResponse = responseP.Response | ||||
| 	if TencentResponse.Error.Code != 0 { | ||||
| 		return &model.ErrorWithStatusCode{ | ||||
| 			Error: model.Error{ | ||||
| @@ -202,29 +184,62 @@ func ParseConfig(config string) (appId int64, secretId string, secretKey string, | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func GetSign(req ChatRequest, secretKey string) string { | ||||
| 	params := make([]string, 0) | ||||
| 	params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) | ||||
| 	params = append(params, "secret_id="+req.SecretId) | ||||
| 	params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10)) | ||||
| 	params = append(params, "query_id="+req.QueryID) | ||||
| 	params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64)) | ||||
| 	params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64)) | ||||
| 	params = append(params, "stream="+strconv.Itoa(req.Stream)) | ||||
| 	params = append(params, "expired="+strconv.FormatInt(req.Expired, 10)) | ||||
|  | ||||
| 	var messageStr string | ||||
| 	for _, msg := range req.Messages { | ||||
| 		messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content) | ||||
| 	} | ||||
| 	messageStr = strings.TrimSuffix(messageStr, ",") | ||||
| 	params = append(params, "messages=["+messageStr+"]") | ||||
|  | ||||
| 	sort.Strings(params) | ||||
| 	url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") | ||||
| 	mac := hmac.New(sha1.New, []byte(secretKey)) | ||||
| 	signURL := url | ||||
| 	mac.Write([]byte(signURL)) | ||||
| 	sign := mac.Sum([]byte(nil)) | ||||
| 	return base64.StdEncoding.EncodeToString(sign) | ||||
| func sha256hex(s string) string { | ||||
| 	b := sha256.Sum256([]byte(s)) | ||||
| 	return hex.EncodeToString(b[:]) | ||||
| } | ||||
|  | ||||
| func hmacSha256(s, key string) string { | ||||
| 	hashed := hmac.New(sha256.New, []byte(key)) | ||||
| 	hashed.Write([]byte(s)) | ||||
| 	return string(hashed.Sum(nil)) | ||||
| } | ||||
|  | ||||
| func GetSign(req ChatRequest, adaptor *Adaptor, secId, secKey string) string { | ||||
| 	// build canonical request string | ||||
| 	host := "hunyuan.tencentcloudapi.com" | ||||
| 	httpRequestMethod := "POST" | ||||
| 	canonicalURI := "/" | ||||
| 	canonicalQueryString := "" | ||||
| 	canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n", | ||||
| 		"application/json", host, strings.ToLower(adaptor.Action)) | ||||
| 	signedHeaders := "content-type;host;x-tc-action" | ||||
| 	payload, _ := json.Marshal(req) | ||||
| 	hashedRequestPayload := sha256hex(string(payload)) | ||||
| 	canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", | ||||
| 		httpRequestMethod, | ||||
| 		canonicalURI, | ||||
| 		canonicalQueryString, | ||||
| 		canonicalHeaders, | ||||
| 		signedHeaders, | ||||
| 		hashedRequestPayload) | ||||
| 	// build string to sign | ||||
| 	algorithm := "TC3-HMAC-SHA256" | ||||
| 	requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10) | ||||
| 	timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64) | ||||
| 	t := time.Unix(timestamp, 0).UTC() | ||||
| 	// must be the format 2006-01-02, ref to package time for more info | ||||
| 	date := t.Format("2006-01-02") | ||||
| 	credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan") | ||||
| 	hashedCanonicalRequest := sha256hex(canonicalRequest) | ||||
| 	string2sign := fmt.Sprintf("%s\n%s\n%s\n%s", | ||||
| 		algorithm, | ||||
| 		requestTimestamp, | ||||
| 		credentialScope, | ||||
| 		hashedCanonicalRequest) | ||||
|  | ||||
| 	// sign string | ||||
| 	secretDate := hmacSha256(date, "TC3"+secKey) | ||||
| 	secretService := hmacSha256("hunyuan", secretDate) | ||||
| 	secretKey := hmacSha256("tc3_request", secretService) | ||||
| 	signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey))) | ||||
|  | ||||
| 	// build authorization | ||||
| 	authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", | ||||
| 		algorithm, | ||||
| 		secId, | ||||
| 		credentialScope, | ||||
| 		signedHeaders, | ||||
| 		signature) | ||||
| 	return authorization | ||||
| } | ||||
|   | ||||
| @@ -1,63 +1,75 @@ | ||||
| package tencent | ||||
|  | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| type Message struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| 	Role    string `json:"Role"` | ||||
| 	Content string `json:"Content"` | ||||
| } | ||||
|  | ||||
| type ChatRequest struct { | ||||
| 	AppId    int64  `json:"app_id"`    // 腾讯云账号的 APPID | ||||
| 	SecretId string `json:"secret_id"` // 官网 SecretId | ||||
| 	// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 | ||||
| 	// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 | ||||
| 	Timestamp int64 `json:"timestamp"` | ||||
| 	// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, | ||||
| 	// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 | ||||
| 	Expired int64  `json:"expired"` | ||||
| 	QueryID string `json:"query_id"` //请求 Id,用于问题排查 | ||||
| 	// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 | ||||
| 	// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 | ||||
| 	// 建议该参数和 top_p 只设置1个,不要同时更改 top_p | ||||
| 	Temperature float64 `json:"temperature"` | ||||
| 	// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 | ||||
| 	// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 | ||||
| 	// 建议该参数和 temperature 只设置1个,不要同时更改 | ||||
| 	TopP float64 `json:"top_p"` | ||||
| 	// Stream 0:同步,1:流式 (默认,协议:SSE) | ||||
| 	// 同步请求超时:60s,如果内容较长建议使用流式 | ||||
| 	Stream int `json:"stream"` | ||||
| 	// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 | ||||
| 	// 输入 content 总数最大支持 3000 token。 | ||||
| 	Messages []Message `json:"messages"` | ||||
| 	// 模型名称,可选值包括 hunyuan-lite、hunyuan-standard、hunyuan-standard-256K、hunyuan-pro。 | ||||
| 	// 各模型介绍请阅读 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 中的说明。 | ||||
| 	// | ||||
| 	// 注意: | ||||
| 	// 不同的模型计费不同,请根据 [购买指南](https://cloud.tencent.com/document/product/1729/97731) 按需调用。 | ||||
| 	Model *string `json:"Model"` | ||||
| 	// 聊天上下文信息。 | ||||
| 	// 说明: | ||||
| 	// 1. 长度最多为 40,按对话时间从旧到新在数组中排列。 | ||||
| 	// 2. Message.Role 可选值:system、user、assistant。 | ||||
| 	// 其中,system 角色可选,如存在则必须位于列表的最开始。user 和 assistant 需交替出现(一问一答),以 user 提问开始和结束,且 Content 不能为空。Role 的顺序示例:[system(可选) user assistant user assistant user ...]。 | ||||
| 	// 3. Messages 中 Content 总长度不能超过模型输入长度上限(可参考 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 文档),超过则会截断最前面的内容,只保留尾部内容。 | ||||
| 	Messages []*Message `json:"Messages"` | ||||
| 	// 流式调用开关。 | ||||
| 	// 说明: | ||||
| 	// 1. 未传值时默认为非流式调用(false)。 | ||||
| 	// 2. 流式调用时以 SSE 协议增量返回结果(返回值取 Choices[n].Delta 中的值,需要拼接增量数据才能获得完整结果)。 | ||||
| 	// 3. 非流式调用时: | ||||
| 	// 调用方式与普通 HTTP 请求无异。 | ||||
| 	// 接口响应耗时较长,**如需更低时延建议设置为 true**。 | ||||
| 	// 只返回一次最终结果(返回值取 Choices[n].Message 中的值)。 | ||||
| 	// | ||||
| 	// 注意: | ||||
| 	// 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。 | ||||
| 	Stream *bool `json:"Stream"` | ||||
| 	// 说明: | ||||
| 	// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。 | ||||
| 	// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。 | ||||
| 	// 3. 非必要不建议使用,不合理的取值会影响效果。 | ||||
| 	TopP *float64 `json:"TopP"` | ||||
| 	// 说明: | ||||
| 	// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。 | ||||
| 	// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。 | ||||
| 	// 3. 非必要不建议使用,不合理的取值会影响效果。 | ||||
| 	Temperature *float64 `json:"Temperature"` | ||||
| } | ||||
|  | ||||
| type Error struct { | ||||
| 	Code    int    `json:"code"` | ||||
| 	Message string `json:"message"` | ||||
| 	Code    int    `json:"Code"` | ||||
| 	Message string `json:"Message"` | ||||
| } | ||||
|  | ||||
| type Usage struct { | ||||
| 	InputTokens  int `json:"input_tokens"` | ||||
| 	OutputTokens int `json:"output_tokens"` | ||||
| 	TotalTokens  int `json:"total_tokens"` | ||||
| 	PromptTokens     int `json:"PromptTokens"` | ||||
| 	CompletionTokens int `json:"CompletionTokens"` | ||||
| 	TotalTokens      int `json:"TotalTokens"` | ||||
| } | ||||
|  | ||||
| type ResponseChoices struct { | ||||
| 	FinishReason string  `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 | ||||
| 	Messages     Message `json:"messages,omitempty"`      // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 | ||||
| 	Delta        Message `json:"delta,omitempty"`         // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 | ||||
| 	FinishReason string  `json:"FinishReason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 | ||||
| 	Messages     Message `json:"Message,omitempty"`      // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 | ||||
| 	Delta        Message `json:"Delta,omitempty"`        // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 | ||||
| } | ||||
|  | ||||
| type ChatResponse struct { | ||||
| 	Choices []ResponseChoices `json:"choices,omitempty"` // 结果 | ||||
| 	Created string            `json:"created,omitempty"` // unix 时间戳的字符串 | ||||
| 	Id      string            `json:"id,omitempty"`      // 会话 id | ||||
| 	Usage   model.Usage       `json:"usage,omitempty"`   // token 数量 | ||||
| 	Error   Error             `json:"error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值 | ||||
| 	Note    string            `json:"note,omitempty"`    // 注释 | ||||
| 	ReqID   string            `json:"req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 | ||||
| 	Choices []ResponseChoices `json:"Choices,omitempty"` // 结果 | ||||
| 	Created int64             `json:"Created,omitempty"` // unix 时间戳的字符串 | ||||
| 	Id      string            `json:"Id,omitempty"`      // 会话 id | ||||
| 	Usage   Usage             `json:"Usage,omitempty"`   // token 数量 | ||||
| 	Error   Error             `json:"Error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值 | ||||
| 	Note    string            `json:"Note,omitempty"`    // 注释 | ||||
| 	ReqID   string            `json:"Req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 | ||||
| } | ||||
|  | ||||
| type ChatResponseP struct { | ||||
| 	Response ChatResponse `json:"Response,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -6,4 +6,5 @@ var ModelList = []string{ | ||||
| 	"SparkDesk-v2.1", | ||||
| 	"SparkDesk-v3.1", | ||||
| 	"SparkDesk-v3.5", | ||||
| 	"SparkDesk-v4.0", | ||||
| } | ||||
|   | ||||
| @@ -44,9 +44,13 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string | ||||
| 	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens | ||||
| 	xunfeiRequest.Payload.Message.Text = messages | ||||
|  | ||||
| 	if strings.HasPrefix(domain, "generalv3") { | ||||
| 	if strings.HasPrefix(domain, "generalv3") || domain == "4.0Ultra" { | ||||
| 		functions := make([]model.Function, len(request.Tools)) | ||||
| 		for i, tool := range request.Tools { | ||||
| 			functions[i] = tool.Function | ||||
| 		} | ||||
| 		xunfeiRequest.Payload.Functions = &Functions{ | ||||
| 			Text: request.Tools, | ||||
| 			Text: functions, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @@ -286,6 +290,8 @@ func apiVersion2domain(apiVersion string) string { | ||||
| 		return "generalv3" | ||||
| 	case "v3.5": | ||||
| 		return "generalv3.5" | ||||
| 	case "v4.0": | ||||
| 		return "4.0Ultra" | ||||
| 	} | ||||
| 	return "general" + apiVersion | ||||
| } | ||||
|   | ||||
| @@ -10,7 +10,7 @@ type Message struct { | ||||
| } | ||||
|  | ||||
| type Functions struct { | ||||
| 	Text []model.Tool `json:"text,omitempty"` | ||||
| 	Text []model.Function `json:"text,omitempty"` | ||||
| } | ||||
|  | ||||
| type ChatRequest struct { | ||||
|   | ||||
| @@ -3,6 +3,13 @@ package zhipu | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/golang-jwt/jwt" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| @@ -11,11 +18,6 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // https://open.bigmodel.cn/doc/api#chatglm_std | ||||
| @@ -155,66 +157,55 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	metaChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			lines := strings.Split(data, "\n") | ||||
| 			for i, line := range lines { | ||||
| 				if len(line) < 5 { | ||||
|  | ||||
| 	common.SetEventStreamHeaders(c) | ||||
|  | ||||
| 	for scanner.Scan() { | ||||
| 		data := scanner.Text() | ||||
| 		lines := strings.Split(data, "\n") | ||||
| 		for i, line := range lines { | ||||
| 			if len(line) < 5 { | ||||
| 				continue | ||||
| 			} | ||||
| 			if strings.HasPrefix(line, "data:") { | ||||
| 				dataSegment := line[5:] | ||||
| 				if i != len(lines)-1 { | ||||
| 					dataSegment += "\n" | ||||
| 				} | ||||
| 				response := streamResponseZhipu2OpenAI(dataSegment) | ||||
| 				err := render.ObjectData(c, response) | ||||
| 				if err != nil { | ||||
| 					logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				} | ||||
| 			} else if strings.HasPrefix(line, "meta:") { | ||||
| 				metaSegment := line[5:] | ||||
| 				var zhipuResponse StreamMetaResponse | ||||
| 				err := json.Unmarshal([]byte(metaSegment), &zhipuResponse) | ||||
| 				if err != nil { | ||||
| 					logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 					continue | ||||
| 				} | ||||
| 				if line[:5] == "data:" { | ||||
| 					dataChan <- line[5:] | ||||
| 					if i != len(lines)-1 { | ||||
| 						dataChan <- "\n" | ||||
| 					} | ||||
| 				} else if line[:5] == "meta:" { | ||||
| 					metaChan <- line[5:] | ||||
| 				response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) | ||||
| 				err = render.ObjectData(c, response) | ||||
| 				if err != nil { | ||||
| 					logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				} | ||||
| 				usage = zhipuUsage | ||||
| 			} | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			response := streamResponseZhipu2OpenAI(data) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case data := <-metaChan: | ||||
| 			var zhipuResponse StreamMetaResponse | ||||
| 			err := json.Unmarshal([]byte(data), &zhipuResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			usage = zhipuUsage | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	} | ||||
|  | ||||
| 	if err := scanner.Err(); err != nil { | ||||
| 		logger.SysError("error reading stream: " + err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	render.Done(c) | ||||
|  | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	return nil, usage | ||||
| } | ||||
|  | ||||
|   | ||||
							
								
								
									
										16
									
								
								relay/adaptor_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								relay/adaptor_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | ||||
| package relay | ||||
|  | ||||
| import ( | ||||
| 	. "github.com/smartystreets/goconvey/convey" | ||||
| 	"github.com/songquanpeng/one-api/relay/apitype" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestGetAdaptor(t *testing.T) { | ||||
| 	Convey("get adaptor", t, func() { | ||||
| 		for i := 0; i < apitype.Dummy; i++ { | ||||
| 			a := GetAdaptor(i) | ||||
| 			So(a, ShouldNotBeNil) | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
| @@ -70,12 +70,13 @@ var ModelRatio = map[string]float64{ | ||||
| 	"dall-e-2":                0.02 * USD, // $0.016 - $0.020 / image | ||||
| 	"dall-e-3":                0.04 * USD, // $0.040 - $0.120 / image | ||||
| 	// https://www.anthropic.com/api#pricing | ||||
| 	"claude-instant-1.2":       0.8 / 1000 * USD, | ||||
| 	"claude-2.0":               8.0 / 1000 * USD, | ||||
| 	"claude-2.1":               8.0 / 1000 * USD, | ||||
| 	"claude-3-haiku-20240307":  0.25 / 1000 * USD, | ||||
| 	"claude-3-sonnet-20240229": 3.0 / 1000 * USD, | ||||
| 	"claude-3-opus-20240229":   15.0 / 1000 * USD, | ||||
| 	"claude-instant-1.2":         0.8 / 1000 * USD, | ||||
| 	"claude-2.0":                 8.0 / 1000 * USD, | ||||
| 	"claude-2.1":                 8.0 / 1000 * USD, | ||||
| 	"claude-3-haiku-20240307":    0.25 / 1000 * USD, | ||||
| 	"claude-3-sonnet-20240229":   3.0 / 1000 * USD, | ||||
| 	"claude-3-5-sonnet-20240620": 3.0 / 1000 * USD, | ||||
| 	"claude-3-opus-20240229":     15.0 / 1000 * USD, | ||||
| 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 | ||||
| 	"ERNIE-4.0-8K":       0.120 * RMB, | ||||
| 	"ERNIE-3.5-8K":       0.012 * RMB, | ||||
| @@ -124,6 +125,7 @@ var ModelRatio = map[string]float64{ | ||||
| 	"SparkDesk-v2.1":            1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v3.1":            1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v3.5":            1.2858, // ¥0.018 / 1k tokens | ||||
| 	"SparkDesk-v4.0":            1.2858, // ¥0.018 / 1k tokens | ||||
| 	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens | ||||
| 	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens | ||||
| 	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens | ||||
|   | ||||
| @@ -24,7 +24,7 @@ var ChannelBaseURLs = []string{ | ||||
| 	"https://openrouter.ai/api",                 // 20 | ||||
| 	"https://api.aiproxy.io",                    // 21 | ||||
| 	"https://fastgpt.run/api/openapi",           // 22 | ||||
| 	"https://hunyuan.cloud.tencent.com",         // 23 | ||||
| 	"https://hunyuan.tencentcloudapi.com",       // 23 | ||||
| 	"https://generativelanguage.googleapis.com", // 24 | ||||
| 	"https://api.moonshot.cn",                   // 25 | ||||
| 	"https://api.baichuan-ai.com",               // 26 | ||||
|   | ||||
							
								
								
									
										12
									
								
								relay/channeltype/url_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								relay/channeltype/url_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,12 @@ | ||||
| package channeltype | ||||
|  | ||||
| import ( | ||||
| 	. "github.com/smartystreets/goconvey/convey" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestChannelBaseURLs(t *testing.T) { | ||||
| 	Convey("channel base urls", t, func() { | ||||
| 		So(len(ChannelBaseURLs), ShouldEqual, Dummy) | ||||
| 	}) | ||||
| } | ||||
| @@ -40,78 +40,6 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener | ||||
| 	return textRequest, nil | ||||
| } | ||||
|  | ||||
| func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { | ||||
| 	imageRequest := &relaymodel.ImageRequest{} | ||||
| 	err := common.UnmarshalBodyReusable(c, imageRequest) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if imageRequest.N == 0 { | ||||
| 		imageRequest.N = 1 | ||||
| 	} | ||||
| 	if imageRequest.Size == "" { | ||||
| 		imageRequest.Size = "1024x1024" | ||||
| 	} | ||||
| 	if imageRequest.Model == "" { | ||||
| 		imageRequest.Model = "dall-e-2" | ||||
| 	} | ||||
| 	return imageRequest, nil | ||||
| } | ||||
|  | ||||
| func isValidImageSize(model string, size string) bool { | ||||
| 	if model == "cogview-3" { | ||||
| 		return true | ||||
| 	} | ||||
| 	_, ok := billingratio.ImageSizeRatios[model][size] | ||||
| 	return ok | ||||
| } | ||||
|  | ||||
| func getImageSizeRatio(model string, size string) float64 { | ||||
| 	ratio, ok := billingratio.ImageSizeRatios[model][size] | ||||
| 	if !ok { | ||||
| 		return 1 | ||||
| 	} | ||||
| 	return ratio | ||||
| } | ||||
|  | ||||
| func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { | ||||
| 	// model validation | ||||
| 	hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size) | ||||
| 	if !hasValidSize { | ||||
| 		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) | ||||
| 	} | ||||
| 	// check prompt length | ||||
| 	if imageRequest.Prompt == "" { | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | ||||
| 	} | ||||
| 	if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] { | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) | ||||
| 	} | ||||
| 	// Number of generated images validation | ||||
| 	if !isWithinRange(imageRequest.Model, imageRequest.N) { | ||||
| 		// channel not azure | ||||
| 		if meta.ChannelType != channeltype.Azure { | ||||
| 			return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { | ||||
| 	if imageRequest == nil { | ||||
| 		return 0, errors.New("imageRequest is nil") | ||||
| 	} | ||||
| 	imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) | ||||
| 	if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { | ||||
| 		if imageRequest.Size == "1024x1024" { | ||||
| 			imageCostRatio *= 2 | ||||
| 		} else { | ||||
| 			imageCostRatio *= 1.5 | ||||
| 		} | ||||
| 	} | ||||
| 	return imageCostRatio, nil | ||||
| } | ||||
|  | ||||
| func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { | ||||
| 	switch relayMode { | ||||
| 	case relaymode.ChatCompletions: | ||||
|   | ||||
| @@ -7,6 +7,7 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| @@ -20,13 +21,84 @@ import ( | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| func isWithinRange(element string, value int) bool { | ||||
| 	if _, ok := billingratio.ImageGenerationAmounts[element]; !ok { | ||||
| 		return false | ||||
| func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { | ||||
| 	imageRequest := &relaymodel.ImageRequest{} | ||||
| 	err := common.UnmarshalBodyReusable(c, imageRequest) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	min := billingratio.ImageGenerationAmounts[element][0] | ||||
| 	max := billingratio.ImageGenerationAmounts[element][1] | ||||
| 	return value >= min && value <= max | ||||
| 	if imageRequest.N == 0 { | ||||
| 		imageRequest.N = 1 | ||||
| 	} | ||||
| 	if imageRequest.Size == "" { | ||||
| 		imageRequest.Size = "1024x1024" | ||||
| 	} | ||||
| 	if imageRequest.Model == "" { | ||||
| 		imageRequest.Model = "dall-e-2" | ||||
| 	} | ||||
| 	return imageRequest, nil | ||||
| } | ||||
|  | ||||
| func isValidImageSize(model string, size string) bool { | ||||
| 	if model == "cogview-3" || billingratio.ImageSizeRatios[model] == nil { | ||||
| 		return true | ||||
| 	} | ||||
| 	_, ok := billingratio.ImageSizeRatios[model][size] | ||||
| 	return ok | ||||
| } | ||||
|  | ||||
| func isValidImagePromptLength(model string, promptLength int) bool { | ||||
| 	maxPromptLength, ok := billingratio.ImagePromptLengthLimitations[model] | ||||
| 	return !ok || promptLength <= maxPromptLength | ||||
| } | ||||
|  | ||||
| func isWithinRange(element string, value int) bool { | ||||
| 	amounts, ok := billingratio.ImageGenerationAmounts[element] | ||||
| 	return !ok || (value >= amounts[0] && value <= amounts[1]) | ||||
| } | ||||
|  | ||||
| func getImageSizeRatio(model string, size string) float64 { | ||||
| 	if ratio, ok := billingratio.ImageSizeRatios[model][size]; ok { | ||||
| 		return ratio | ||||
| 	} | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { | ||||
| 	// check prompt length | ||||
| 	if imageRequest.Prompt == "" { | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// model validation | ||||
| 	if !isValidImageSize(imageRequest.Model, imageRequest.Size) { | ||||
| 		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	if !isValidImagePromptLength(imageRequest.Model, len(imageRequest.Prompt)) { | ||||
| 		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// Number of generated images validation | ||||
| 	if !isWithinRange(imageRequest.Model, imageRequest.N) { | ||||
| 		return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { | ||||
| 	if imageRequest == nil { | ||||
| 		return 0, errors.New("imageRequest is nil") | ||||
| 	} | ||||
| 	imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) | ||||
| 	if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { | ||||
| 		if imageRequest.Size == "1024x1024" { | ||||
| 			imageCostRatio *= 2 | ||||
| 		} else { | ||||
| 			imageCostRatio *= 1.5 | ||||
| 		} | ||||
| 	} | ||||
| 	return imageCostRatio, nil | ||||
| } | ||||
|  | ||||
| func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { | ||||
|   | ||||
| @@ -1,10 +1,11 @@ | ||||
| package model | ||||
|  | ||||
| type Message struct { | ||||
| 	Role      string  `json:"role,omitempty"` | ||||
| 	Content   any     `json:"content,omitempty"` | ||||
| 	Name      *string `json:"name,omitempty"` | ||||
| 	ToolCalls []Tool  `json:"tool_calls,omitempty"` | ||||
| 	Role       string  `json:"role,omitempty"` | ||||
| 	Content    any     `json:"content,omitempty"` | ||||
| 	Name       *string `json:"name,omitempty"` | ||||
| 	ToolCalls  []Tool  `json:"tool_calls,omitempty"` | ||||
| 	ToolCallId string  `json:"tool_call_id,omitempty"` | ||||
| } | ||||
|  | ||||
| func (m Message) IsStringContent() bool { | ||||
|   | ||||
| @@ -2,13 +2,13 @@ package model | ||||
|  | ||||
| type Tool struct { | ||||
| 	Id       string   `json:"id,omitempty"` | ||||
| 	Type     string   `json:"type"` | ||||
| 	Type     string   `json:"type,omitempty"` // when splicing claude tools stream messages, it is empty | ||||
| 	Function Function `json:"function"` | ||||
| } | ||||
|  | ||||
| type Function struct { | ||||
| 	Description string `json:"description,omitempty"` | ||||
| 	Name        string `json:"name"` | ||||
| 	Name        string `json:"name,omitempty"`       // when splicing claude tools stream messages, it is empty | ||||
| 	Parameters  any    `json:"parameters,omitempty"` // request | ||||
| 	Arguments   any    `json:"arguments,omitempty"`  // response | ||||
| } | ||||
|   | ||||
| @@ -63,7 +63,7 @@ const EditChannel = (props) => { | ||||
|             let localModels = []; | ||||
|             switch (value) { | ||||
|                 case 14: | ||||
|                     localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"]; | ||||
|                     localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20240620"]; | ||||
|                     break; | ||||
|                 case 11: | ||||
|                     localModels = ['PaLM-2']; | ||||
| @@ -78,7 +78,7 @@ const EditChannel = (props) => { | ||||
|                     localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; | ||||
|                     break; | ||||
|                 case 18: | ||||
|                     localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5']; | ||||
|                     localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0']; | ||||
|                     break; | ||||
|                 case 19: | ||||
|                     localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; | ||||
|   | ||||
| @@ -91,7 +91,7 @@ const typeConfig = { | ||||
|       other: '版本号' | ||||
|     }, | ||||
|     input: { | ||||
|       models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5'] | ||||
|       models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0'] | ||||
|     }, | ||||
|     prompt: { | ||||
|       key: '按照如下格式输入:APPID|APISecret|APIKey', | ||||
|   | ||||
		Reference in New Issue
	
	Block a user