mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	Compare commits
	
		
			26 Commits
		
	
	
		
			v0.6.7-alp
			...
			v0.6.8-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					ec6ad24810 | ||
| 
						 | 
					c4fe57c165 | ||
| 
						 | 
					274fcf3d76 | ||
| 
						 | 
					0fc07ea558 | ||
| 
						 | 
					1ce1e529ee | ||
| 
						 | 
					d936817de9 | ||
| 
						 | 
					fecaece71b | ||
| 
						 | 
					c135d74f13 | ||
| 
						 | 
					d0369b114f | ||
| 
						 | 
					b21b3b5b46 | ||
| 
						 | 
					ae1cd29f94 | ||
| 
						 | 
					f25aaf7752 | ||
| 
						 | 
					b70a07e814 | ||
| 
						 | 
					34cb147a74 | ||
| 
						 | 
					8cc1ee6360 | ||
| 
						 | 
					5a58426859 | ||
| 
						 | 
					254b9777c0 | ||
| 
						 | 
					114c44c6e7 | ||
| 
						 | 
					a3c7e15aed | ||
| 
						 | 
					3777517f64 | ||
| 
						 | 
					9fc5f427dc | ||
| 
						 | 
					864a467886 | ||
| 
						 | 
					ed78b5340b | ||
| 
						 | 
					fee69e7c20 | ||
| 
						 | 
					9d23a44dbf | ||
| 
						 | 
					6e4cfb20d5 | 
							
								
								
									
										3
									
								
								.env.example
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								.env.example
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
			
		||||
PORT=3000
 | 
			
		||||
DEBUG=false
 | 
			
		||||
HTTPS_PROXY=http://localhost:7890
 | 
			
		||||
							
								
								
									
										47
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,47 @@
 | 
			
		||||
name: CI
 | 
			
		||||
 | 
			
		||||
# This setup assumes that you run the unit tests with code coverage in the same
 | 
			
		||||
# workflow that will also print the coverage report as comment to the pull request. 
 | 
			
		||||
# Therefore, you need to trigger this workflow when a pull request is (re)opened or
 | 
			
		||||
# when new code is pushed to the branch of the pull request. In addition, you also
 | 
			
		||||
# need to trigger this workflow when new code is pushed to the main branch because 
 | 
			
		||||
# we need to upload the code coverage results as artifact for the main branch as
 | 
			
		||||
# well since it will be the baseline code coverage.
 | 
			
		||||
# 
 | 
			
		||||
# We do not want to trigger the workflow for pushes to *any* branch because this
 | 
			
		||||
# would trigger our jobs twice on pull requests (once from "push" event and once
 | 
			
		||||
# from "pull_request->synchronize")
 | 
			
		||||
on:
 | 
			
		||||
  pull_request:
 | 
			
		||||
    types: [opened, reopened, synchronize]
 | 
			
		||||
  push:
 | 
			
		||||
    branches:
 | 
			
		||||
      - 'main'
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  unit_tests:
 | 
			
		||||
    name: "Unit tests"
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Checkout repository
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
 | 
			
		||||
      - name: Setup Go
 | 
			
		||||
        uses: actions/setup-go@v4
 | 
			
		||||
        with:
 | 
			
		||||
          go-version: ^1.22
 | 
			
		||||
 | 
			
		||||
      # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a 
 | 
			
		||||
      # coverage profile to a file. You will need the name of the file (e.g. "coverage.txt")
 | 
			
		||||
      # in the next step as well as the next job.
 | 
			
		||||
      - name: Test
 | 
			
		||||
        run: go test -cover -coverprofile=coverage.txt ./...
 | 
			
		||||
      - uses: codecov/codecov-action@v4
 | 
			
		||||
        with:
 | 
			
		||||
          token: ${{ secrets.CODECOV_TOKEN }}
 | 
			
		||||
 | 
			
		||||
  commit_lint:
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
    steps:
 | 
			
		||||
      - uses: actions/checkout@v3
 | 
			
		||||
      - uses: wagoid/commitlint-github-action@v6
 | 
			
		||||
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -8,4 +8,5 @@ build
 | 
			
		||||
logs
 | 
			
		||||
data
 | 
			
		||||
/web/node_modules
 | 
			
		||||
cmd.md
 | 
			
		||||
cmd.md
 | 
			
		||||
.env
 | 
			
		||||
							
								
								
									
										10
									
								
								README.en.md
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								README.en.md
									
									
									
									
									
								
							@@ -101,7 +101,7 @@ Nginx reference configuration:
 | 
			
		||||
```
 | 
			
		||||
server{
 | 
			
		||||
   server_name openai.justsong.cn;  # Modify your domain name accordingly
 | 
			
		||||
   
 | 
			
		||||
 | 
			
		||||
   location / {
 | 
			
		||||
          client_max_body_size  64m;
 | 
			
		||||
          proxy_http_version 1.1;
 | 
			
		||||
@@ -132,12 +132,12 @@ The initial account username is `root` and password is `123456`.
 | 
			
		||||
1. Download the executable file from [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) or compile from source:
 | 
			
		||||
   ```shell
 | 
			
		||||
   git clone https://github.com/songquanpeng/one-api.git
 | 
			
		||||
   
 | 
			
		||||
 | 
			
		||||
   # Build the frontend
 | 
			
		||||
   cd one-api/web/default
 | 
			
		||||
   npm install
 | 
			
		||||
   npm run build
 | 
			
		||||
   
 | 
			
		||||
 | 
			
		||||
   # Build the backend
 | 
			
		||||
   cd ../..
 | 
			
		||||
   go mod download
 | 
			
		||||
@@ -287,7 +287,9 @@ If the channel ID is not provided, load balancing will be used to distribute the
 | 
			
		||||
    + Double-check that your interface address and API Key are correct.
 | 
			
		||||
 | 
			
		||||
## Related Projects
 | 
			
		||||
[FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM
 | 
			
		||||
* [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM
 | 
			
		||||
* [VChart](https://github.com/VisActor/VChart):  More than just a cross-platform charting library, but also an expressive data storyteller.
 | 
			
		||||
* [VMind](https://github.com/VisActor/VMind):  Not just automatic, but also fantastic. Open-source solution for intelligent visualization.
 | 
			
		||||
 | 
			
		||||
## Note
 | 
			
		||||
This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes.
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										13
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								README.md
									
									
									
									
									
								
							@@ -53,7 +53,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
 | 
			
		||||
> [!NOTE]
 | 
			
		||||
> 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
 | 
			
		||||
> 
 | 
			
		||||
>
 | 
			
		||||
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
 | 
			
		||||
 | 
			
		||||
> [!WARNING]
 | 
			
		||||
@@ -144,7 +144,7 @@ Nginx 的参考配置:
 | 
			
		||||
```
 | 
			
		||||
server{
 | 
			
		||||
   server_name openai.justsong.cn;  # 请根据实际情况修改你的域名
 | 
			
		||||
   
 | 
			
		||||
 | 
			
		||||
   location / {
 | 
			
		||||
          client_max_body_size  64m;
 | 
			
		||||
          proxy_http_version 1.1;
 | 
			
		||||
@@ -189,12 +189,12 @@ docker-compose ps
 | 
			
		||||
1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译:
 | 
			
		||||
   ```shell
 | 
			
		||||
   git clone https://github.com/songquanpeng/one-api.git
 | 
			
		||||
   
 | 
			
		||||
 | 
			
		||||
   # 构建前端
 | 
			
		||||
   cd one-api/web/default
 | 
			
		||||
   npm install
 | 
			
		||||
   npm run build
 | 
			
		||||
   
 | 
			
		||||
 | 
			
		||||
   # 构建后端
 | 
			
		||||
   cd ../..
 | 
			
		||||
   go mod download
 | 
			
		||||
@@ -321,7 +321,7 @@ Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashbo
 | 
			
		||||
例如对于 OpenAI 的官方库:
 | 
			
		||||
```bash
 | 
			
		||||
OPENAI_API_KEY="sk-xxxxxx"
 | 
			
		||||
OPENAI_API_BASE="https://<HOST>:<PORT>/v1" 
 | 
			
		||||
OPENAI_API_BASE="https://<HOST>:<PORT>/v1"
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```mermaid
 | 
			
		||||
@@ -340,6 +340,7 @@ graph LR
 | 
			
		||||
不加的话将会使用负载均衡的方式使用多个渠道。
 | 
			
		||||
 | 
			
		||||
### 环境变量
 | 
			
		||||
> One API 支持从 `.env` 文件中读取环境变量,请参照 `.env.example` 文件,使用时请将其重命名为 `.env`。
 | 
			
		||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
 | 
			
		||||
   + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
 | 
			
		||||
   + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。
 | 
			
		||||
@@ -447,6 +448,8 @@ https://openai.justsong.cn
 | 
			
		||||
## 相关项目
 | 
			
		||||
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
 | 
			
		||||
* [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web):  一键拥有你自己的跨平台 ChatGPT 应用
 | 
			
		||||
* [VChart](https://github.com/VisActor/VChart):  不只是开箱即用的多端图表库,更是生动灵活的数据故事讲述者。
 | 
			
		||||
* [VMind](https://github.com/VisActor/VMind):  不仅自动,还很智能。开源智能可视化解决方案。
 | 
			
		||||
 | 
			
		||||
## 注意
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -145,6 +145,9 @@ var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN")
 | 
			
		||||
 | 
			
		||||
var GeminiVersion = env.String("GEMINI_VERSION", "v1")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)
 | 
			
		||||
 | 
			
		||||
var RelayProxy = env.String("RELAY_PROXY", "")
 | 
			
		||||
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
 | 
			
		||||
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)
 | 
			
		||||
 
 | 
			
		||||
@@ -19,4 +19,5 @@ const (
 | 
			
		||||
	TokenName         = "token_name"
 | 
			
		||||
	BaseURL           = "base_url"
 | 
			
		||||
	AvailableModels   = "available_models"
 | 
			
		||||
	KeyRequestBody    = "key_request_body"
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -4,14 +4,13 @@ import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const KeyRequestBody = "key_request_body"
 | 
			
		||||
 | 
			
		||||
func GetRequestBody(c *gin.Context) ([]byte, error) {
 | 
			
		||||
	requestBody, _ := c.Get(KeyRequestBody)
 | 
			
		||||
	requestBody, _ := c.Get(ctxkey.KeyRequestBody)
 | 
			
		||||
	if requestBody != nil {
 | 
			
		||||
		return requestBody.([]byte), nil
 | 
			
		||||
	}
 | 
			
		||||
@@ -20,7 +19,7 @@ func GetRequestBody(c *gin.Context) ([]byte, error) {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	_ = c.Request.Body.Close()
 | 
			
		||||
	c.Set(KeyRequestBody, requestBody)
 | 
			
		||||
	c.Set(ctxkey.KeyRequestBody, requestBody)
 | 
			
		||||
	return requestBody.([]byte), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package image_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/client"
 | 
			
		||||
	"image"
 | 
			
		||||
	_ "image/gif"
 | 
			
		||||
	_ "image/jpeg"
 | 
			
		||||
@@ -44,6 +45,11 @@ var (
 | 
			
		||||
	}
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestMain(m *testing.M) {
 | 
			
		||||
	client.Init()
 | 
			
		||||
	m.Run()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestDecode(t *testing.T) {
 | 
			
		||||
	// Bytes read: varies sometimes
 | 
			
		||||
	// jpeg: 1063892
 | 
			
		||||
 
 | 
			
		||||
@@ -24,7 +24,7 @@ func printHelp() {
 | 
			
		||||
	fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
func Init() {
 | 
			
		||||
	flag.Parse()
 | 
			
		||||
 | 
			
		||||
	if *PrintVersion {
 | 
			
		||||
 
 | 
			
		||||
@@ -27,7 +27,12 @@ var setupLogOnce sync.Once
 | 
			
		||||
func SetupLogger() {
 | 
			
		||||
	setupLogOnce.Do(func() {
 | 
			
		||||
		if LogDir != "" {
 | 
			
		||||
			logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
 | 
			
		||||
			var logPath string
 | 
			
		||||
			if config.OnlyOneLogFile {
 | 
			
		||||
				logPath = filepath.Join(LogDir, "oneapi.log")
 | 
			
		||||
			} else {
 | 
			
		||||
				logPath = filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
 | 
			
		||||
			}
 | 
			
		||||
			fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				log.Fatal("failed to open log file")
 | 
			
		||||
 
 | 
			
		||||
@@ -6,11 +6,16 @@ import (
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/smtp"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func shouldAuth() bool {
 | 
			
		||||
	return config.SMTPAccount != "" || config.SMTPToken != ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SendEmail(subject string, receiver string, content string) error {
 | 
			
		||||
	if receiver == "" {
 | 
			
		||||
		return fmt.Errorf("receiver is empty")
 | 
			
		||||
@@ -41,16 +46,24 @@ func SendEmail(subject string, receiver string, content string) error {
 | 
			
		||||
		"Date: %s\r\n"+
 | 
			
		||||
		"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
 | 
			
		||||
		receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
 | 
			
		||||
 | 
			
		||||
	auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer)
 | 
			
		||||
	addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)
 | 
			
		||||
	to := strings.Split(receiver, ";")
 | 
			
		||||
 | 
			
		||||
	if config.SMTPPort == 465 {
 | 
			
		||||
		tlsConfig := &tls.Config{
 | 
			
		||||
			InsecureSkipVerify: true,
 | 
			
		||||
			ServerName:         config.SMTPServer,
 | 
			
		||||
	if config.SMTPPort == 465 || !shouldAuth() {
 | 
			
		||||
		// need advanced client
 | 
			
		||||
		var conn net.Conn
 | 
			
		||||
		var err error
 | 
			
		||||
		if config.SMTPPort == 465 {
 | 
			
		||||
			tlsConfig := &tls.Config{
 | 
			
		||||
				InsecureSkipVerify: true,
 | 
			
		||||
				ServerName:         config.SMTPServer,
 | 
			
		||||
			}
 | 
			
		||||
			conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
 | 
			
		||||
		} else {
 | 
			
		||||
			conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort))
 | 
			
		||||
		}
 | 
			
		||||
		conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
@@ -59,8 +72,10 @@ func SendEmail(subject string, receiver string, content string) error {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		defer client.Close()
 | 
			
		||||
		if err = client.Auth(auth); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		if shouldAuth() {
 | 
			
		||||
			if err = client.Auth(auth); err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if err = client.Mail(config.SMTPFrom); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										29
									
								
								common/render/render.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								common/render/render.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,29 @@
 | 
			
		||||
package render
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func StringData(c *gin.Context, str string) {
 | 
			
		||||
	str = strings.TrimPrefix(str, "data: ")
 | 
			
		||||
	str = strings.TrimSuffix(str, "\r")
 | 
			
		||||
	c.Render(-1, common.CustomEvent{Data: "data: " + str})
 | 
			
		||||
	c.Writer.Flush()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ObjectData(c *gin.Context, object interface{}) error {
 | 
			
		||||
	jsonData, err := json.Marshal(object)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error marshalling object: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	StringData(c, string(jsonData))
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Done(c *gin.Context) {
 | 
			
		||||
	StringData(c, "[DONE]")
 | 
			
		||||
}
 | 
			
		||||
@@ -48,7 +48,7 @@ func Relay(c *gin.Context) {
 | 
			
		||||
		logger.Debugf(ctx, "request body: %s", string(requestBody))
 | 
			
		||||
	}
 | 
			
		||||
	channelId := c.GetInt(ctxkey.ChannelId)
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	userId := c.GetInt(ctxkey.Id)
 | 
			
		||||
	bizErr := relayHelper(c, relayMode)
 | 
			
		||||
	if bizErr == nil {
 | 
			
		||||
		monitor.Emit(channelId, true)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										5
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								go.mod
									
									
									
									
									
								
							@@ -18,12 +18,13 @@ require (
 | 
			
		||||
	github.com/google/uuid v1.6.0
 | 
			
		||||
	github.com/gorilla/websocket v1.5.1
 | 
			
		||||
	github.com/jinzhu/copier v0.4.0
 | 
			
		||||
	github.com/joho/godotenv v1.5.1
 | 
			
		||||
	github.com/pkg/errors v0.9.1
 | 
			
		||||
	github.com/pkoukk/tiktoken-go v0.1.7
 | 
			
		||||
	github.com/smartystreets/goconvey v1.8.1
 | 
			
		||||
	github.com/stretchr/testify v1.9.0
 | 
			
		||||
	golang.org/x/crypto v0.23.0
 | 
			
		||||
	golang.org/x/image v0.16.0
 | 
			
		||||
	golang.org/x/image v0.18.0
 | 
			
		||||
	gorm.io/driver/mysql v1.5.6
 | 
			
		||||
	gorm.io/driver/postgres v1.5.7
 | 
			
		||||
	gorm.io/driver/sqlite v1.5.5
 | 
			
		||||
@@ -79,7 +80,7 @@ require (
 | 
			
		||||
	golang.org/x/net v0.25.0 // indirect
 | 
			
		||||
	golang.org/x/sync v0.7.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.20.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.15.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.16.0 // indirect
 | 
			
		||||
	google.golang.org/protobuf v1.34.1 // indirect
 | 
			
		||||
	gopkg.in/yaml.v3 v3.0.1 // indirect
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										59
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										59
									
								
								go.sum
									
									
									
									
									
								
							@@ -1,40 +1,25 @@
 | 
			
		||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
 | 
			
		||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.15 h1:YDexlvDRCA8ems2T5IP1xkMtOZ1uLJOCJdTr0igs5zo=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.15/go.mod h1:vxHggqW6hFNaeNC0WyXS3VdyjcV0a4KMUY4dKJ96buU=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 h1:lf/8VTF2cM+N4SLzaYJERKEWAXq8MOMpZfU6wEPWsPk=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7/go.mod h1:4SjkU7QiqK2M9oozyMzfZ/23LmUY+h3oFqhdeP5OMiI=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 h1:4OYVp0705xu8yjdyoWix0r9wPIRXnIzzOoUpQVHIJ/g=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7/go.mod h1:vd7ESTEvI76T2Na050gODNmNU7+OyKrIKroYTu4ABiI=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3 h1:Fihjyd6DeNjcawBEGLH9dkIEUi6AdhucDKPE9nJ4QiY=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3/go.mod h1:opvUj3ismqSCxYc+m4WIjPL0ewZGtvp0ess7cKvBPOQ=
 | 
			
		||||
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
 | 
			
		||||
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
 | 
			
		||||
github.com/bytedance/sonic v1.11.5 h1:G00FYjjqll5iQ1PYXynbg/hyzqBqavH8Mo9/oTopd9k=
 | 
			
		||||
github.com/bytedance/sonic v1.11.5/go.mod h1:X2PC2giUdj/Cv2lliWFLk6c/DUQok5rViJSemeB0wDw=
 | 
			
		||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
 | 
			
		||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
 | 
			
		||||
github.com/bytedance/sonic/loader v0.1.0/go.mod h1:UmRT+IRTGKz/DAkzcEGzyVqQFJ7H9BqwBO3pm9H/+HY=
 | 
			
		||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
 | 
			
		||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
 | 
			
		||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
 | 
			
		||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
 | 
			
		||||
github.com/cloudwego/base64x v0.1.3 h1:b5J/l8xolB7dyDTTmhJP2oTs5LdrjyrUFuNxdfq5hAg=
 | 
			
		||||
github.com/cloudwego/base64x v0.1.3/go.mod h1:1+1K5BUHIQzyapgpF7LwvOGAEDicKtt1umPV+aN8pi8=
 | 
			
		||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
 | 
			
		||||
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
 | 
			
		||||
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
 | 
			
		||||
@@ -51,26 +36,16 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos
 | 
			
		||||
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
 | 
			
		||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
 | 
			
		||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
 | 
			
		||||
github.com/gin-contrib/cors v1.7.1 h1:s9SIppU/rk8enVvkzwiC2VK3UZ/0NNGsWfUKvV55rqs=
 | 
			
		||||
github.com/gin-contrib/cors v1.7.1/go.mod h1:n/Zj7B4xyrgk/cX1WCX2dkzFfaNm/xJb6oIUk7WTtps=
 | 
			
		||||
github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw=
 | 
			
		||||
github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E=
 | 
			
		||||
github.com/gin-contrib/gzip v1.0.0 h1:UKN586Po/92IDX6ie5CWLgMI81obiIp5nSP85T3wlTk=
 | 
			
		||||
github.com/gin-contrib/gzip v1.0.0/go.mod h1:CtG7tQrPB3vIBo6Gat9FVUsis+1emjvQqd66ME5TdnE=
 | 
			
		||||
github.com/gin-contrib/gzip v1.0.1 h1:HQ8ENHODeLY7a4g1Au/46Z92bdGFl74OhxcZble9WJE=
 | 
			
		||||
github.com/gin-contrib/gzip v1.0.1/go.mod h1:njt428fdUNRvjuJf16tZMYZ2Yl+WQB53X5wmhDwXvC4=
 | 
			
		||||
github.com/gin-contrib/sessions v1.0.0 h1:r5GLta4Oy5xo9rAwMHx8B4wLpeRGHMdz9NafzJAdP8Y=
 | 
			
		||||
github.com/gin-contrib/sessions v1.0.0/go.mod h1:DN0f4bvpqMQElDdi+gNGScrP2QEI04IErRyMFyorUOI=
 | 
			
		||||
github.com/gin-contrib/sessions v1.0.1 h1:3hsJyNs7v7N8OtelFmYXFrulAf6zSR7nW/putcPEHxI=
 | 
			
		||||
github.com/gin-contrib/sessions v1.0.1/go.mod h1:ouxSFM24/OgIud5MJYQJLpy6AwxQ5EYO9yLhbtObGkM=
 | 
			
		||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
 | 
			
		||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
 | 
			
		||||
github.com/gin-contrib/static v1.1.1 h1:XEvBd4DDLG1HBlyPBQU1XO8NlTpw6mgdqcPteetYA5k=
 | 
			
		||||
github.com/gin-contrib/static v1.1.1/go.mod h1:yRGmar7+JYvbMLRPIi4H5TVVSBwULfT9vetnVD0IO74=
 | 
			
		||||
github.com/gin-contrib/static v1.1.2 h1:c3kT4bFkUJn2aoRU3s6XnMjJT8J6nNWJkR0NglqmlZ4=
 | 
			
		||||
github.com/gin-contrib/static v1.1.2/go.mod h1:Fw90ozjHCmZBWbgrsqrDvO28YbhKEKzKp8GixhR4yLw=
 | 
			
		||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
 | 
			
		||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
 | 
			
		||||
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
 | 
			
		||||
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
 | 
			
		||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
 | 
			
		||||
@@ -78,8 +53,6 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o
 | 
			
		||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
 | 
			
		||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
 | 
			
		||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
 | 
			
		||||
github.com/go-playground/validator/v10 v10.19.0 h1:ol+5Fu+cSq9JD7SoSqe04GMI92cbn0+wvQ3bZ8b/AU4=
 | 
			
		||||
github.com/go-playground/validator/v10 v10.19.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
 | 
			
		||||
github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
 | 
			
		||||
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
 | 
			
		||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
 | 
			
		||||
@@ -87,8 +60,6 @@ github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq
 | 
			
		||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
 | 
			
		||||
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
 | 
			
		||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
 | 
			
		||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
 | 
			
		||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
 | 
			
		||||
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
 | 
			
		||||
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
 | 
			
		||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
 | 
			
		||||
@@ -122,6 +93,8 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
 | 
			
		||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
 | 
			
		||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
 | 
			
		||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
 | 
			
		||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
 | 
			
		||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
 | 
			
		||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
 | 
			
		||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
 | 
			
		||||
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
 | 
			
		||||
@@ -147,14 +120,10 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY
 | 
			
		||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
 | 
			
		||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
 | 
			
		||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
 | 
			
		||||
github.com/pelletier/go-toml/v2 v2.2.1 h1:9TA9+T8+8CUCO2+WYnDLCgrYi9+omqKXyjDtosvtEhg=
 | 
			
		||||
github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
 | 
			
		||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
 | 
			
		||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
 | 
			
		||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
 | 
			
		||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
 | 
			
		||||
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
 | 
			
		||||
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
 | 
			
		||||
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
 | 
			
		||||
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
 | 
			
		||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 | 
			
		||||
@@ -181,37 +150,23 @@ github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2
 | 
			
		||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
 | 
			
		||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
 | 
			
		||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
 | 
			
		||||
golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc=
 | 
			
		||||
golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
 | 
			
		||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
 | 
			
		||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
 | 
			
		||||
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
 | 
			
		||||
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
 | 
			
		||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
 | 
			
		||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
 | 
			
		||||
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
 | 
			
		||||
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
 | 
			
		||||
golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw=
 | 
			
		||||
golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs=
 | 
			
		||||
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
 | 
			
		||||
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
 | 
			
		||||
golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
 | 
			
		||||
golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
 | 
			
		||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
 | 
			
		||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
 | 
			
		||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
 | 
			
		||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
 | 
			
		||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
 | 
			
		||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 | 
			
		||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
 | 
			
		||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 | 
			
		||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
 | 
			
		||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
 | 
			
		||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
 | 
			
		||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
 | 
			
		||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
 | 
			
		||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
 | 
			
		||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
 | 
			
		||||
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
 | 
			
		||||
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
 | 
			
		||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
 | 
			
		||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
 | 
			
		||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 | 
			
		||||
@@ -228,8 +183,6 @@ gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4c
 | 
			
		||||
gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E=
 | 
			
		||||
gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE=
 | 
			
		||||
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
 | 
			
		||||
gorm.io/gorm v1.25.9 h1:wct0gxZIELDk8+ZqF/MVnHLkA1rvYlBWUMv2EdsK1g8=
 | 
			
		||||
gorm.io/gorm v1.25.9/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
 | 
			
		||||
gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s=
 | 
			
		||||
gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
 | 
			
		||||
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										25
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								main.go
									
									
									
									
									
								
							@@ -6,6 +6,7 @@ import (
 | 
			
		||||
	"github.com/gin-contrib/sessions"
 | 
			
		||||
	"github.com/gin-contrib/sessions/cookie"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	_ "github.com/joho/godotenv/autoload"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/client"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
@@ -23,29 +24,22 @@ import (
 | 
			
		||||
var buildFS embed.FS
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	common.Init()
 | 
			
		||||
	logger.SetupLogger()
 | 
			
		||||
	logger.SysLogf("One API %s started", common.Version)
 | 
			
		||||
	if os.Getenv("GIN_MODE") != "debug" {
 | 
			
		||||
 | 
			
		||||
	if os.Getenv("GIN_MODE") != gin.DebugMode {
 | 
			
		||||
		gin.SetMode(gin.ReleaseMode)
 | 
			
		||||
	}
 | 
			
		||||
	if config.DebugEnabled {
 | 
			
		||||
		logger.SysLog("running in debug mode")
 | 
			
		||||
	}
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	// Initialize SQL Database
 | 
			
		||||
	model.DB, err = model.InitDB("SQL_DSN")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.FatalLog("failed to initialize database: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	if os.Getenv("LOG_SQL_DSN") != "" {
 | 
			
		||||
		logger.SysLog("using secondary database for table logs")
 | 
			
		||||
		model.LOG_DB, err = model.InitDB("LOG_SQL_DSN")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.FatalLog("failed to initialize secondary database: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		model.LOG_DB = model.DB
 | 
			
		||||
	}
 | 
			
		||||
	model.InitDB()
 | 
			
		||||
	model.InitLogDB()
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
	err = model.CreateRootAccountIfNeed()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.FatalLog("database init error: " + err.Error())
 | 
			
		||||
@@ -113,6 +107,7 @@ func main() {
 | 
			
		||||
	if port == "" {
 | 
			
		||||
		port = strconv.Itoa(*common.Port)
 | 
			
		||||
	}
 | 
			
		||||
	logger.SysLogf("server started on http://localhost:%s", port)
 | 
			
		||||
	err = server.Run(":" + port)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.FatalLog("failed to start HTTP server: " + err.Error())
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										219
									
								
								model/main.go
									
									
									
									
									
								
							
							
						
						
									
										219
									
								
								model/main.go
									
									
									
									
									
								
							@@ -1,6 +1,7 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
@@ -60,90 +61,156 @@ func CreateRootAccountIfNeed() error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func chooseDB(envName string) (*gorm.DB, error) {
 | 
			
		||||
	if os.Getenv(envName) != "" {
 | 
			
		||||
		dsn := os.Getenv(envName)
 | 
			
		||||
		if strings.HasPrefix(dsn, "postgres://") {
 | 
			
		||||
			// Use PostgreSQL
 | 
			
		||||
			logger.SysLog("using PostgreSQL as database")
 | 
			
		||||
			common.UsingPostgreSQL = true
 | 
			
		||||
			return gorm.Open(postgres.New(postgres.Config{
 | 
			
		||||
				DSN:                  dsn,
 | 
			
		||||
				PreferSimpleProtocol: true, // disables implicit prepared statement usage
 | 
			
		||||
			}), &gorm.Config{
 | 
			
		||||
				PrepareStmt: true, // precompile SQL
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	dsn := os.Getenv(envName)
 | 
			
		||||
 | 
			
		||||
	switch {
 | 
			
		||||
	case strings.HasPrefix(dsn, "postgres://"):
 | 
			
		||||
		// Use PostgreSQL
 | 
			
		||||
		return openPostgreSQL(dsn)
 | 
			
		||||
	case dsn != "":
 | 
			
		||||
		// Use MySQL
 | 
			
		||||
		logger.SysLog("using MySQL as database")
 | 
			
		||||
		common.UsingMySQL = true
 | 
			
		||||
		return gorm.Open(mysql.Open(dsn), &gorm.Config{
 | 
			
		||||
			PrepareStmt: true, // precompile SQL
 | 
			
		||||
		})
 | 
			
		||||
		return openMySQL(dsn)
 | 
			
		||||
	default:
 | 
			
		||||
		// Use SQLite
 | 
			
		||||
		return openSQLite()
 | 
			
		||||
	}
 | 
			
		||||
	// Use SQLite
 | 
			
		||||
	logger.SysLog("SQL_DSN not set, using SQLite as database")
 | 
			
		||||
	common.UsingSQLite = true
 | 
			
		||||
	config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout)
 | 
			
		||||
	return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func openPostgreSQL(dsn string) (*gorm.DB, error) {
 | 
			
		||||
	logger.SysLog("using PostgreSQL as database")
 | 
			
		||||
	common.UsingPostgreSQL = true
 | 
			
		||||
	return gorm.Open(postgres.New(postgres.Config{
 | 
			
		||||
		DSN:                  dsn,
 | 
			
		||||
		PreferSimpleProtocol: true, // disables implicit prepared statement usage
 | 
			
		||||
	}), &gorm.Config{
 | 
			
		||||
		PrepareStmt: true, // precompile SQL
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InitDB(envName string) (db *gorm.DB, err error) {
 | 
			
		||||
	db, err = chooseDB(envName)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		if config.DebugSQLEnabled {
 | 
			
		||||
			db = db.Debug()
 | 
			
		||||
		}
 | 
			
		||||
		sqlDB, err := db.DB()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
 | 
			
		||||
		sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
 | 
			
		||||
		sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
 | 
			
		||||
func openMySQL(dsn string) (*gorm.DB, error) {
 | 
			
		||||
	logger.SysLog("using MySQL as database")
 | 
			
		||||
	common.UsingMySQL = true
 | 
			
		||||
	return gorm.Open(mysql.Open(dsn), &gorm.Config{
 | 
			
		||||
		PrepareStmt: true, // precompile SQL
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		if !config.IsMasterNode {
 | 
			
		||||
			return db, err
 | 
			
		||||
		}
 | 
			
		||||
		if common.UsingMySQL {
 | 
			
		||||
			_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
 | 
			
		||||
		}
 | 
			
		||||
		logger.SysLog("database migration started")
 | 
			
		||||
		err = db.AutoMigrate(&Channel{})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		err = db.AutoMigrate(&Token{})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		err = db.AutoMigrate(&User{})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		err = db.AutoMigrate(&Option{})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		err = db.AutoMigrate(&Redemption{})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		err = db.AutoMigrate(&Ability{})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		err = db.AutoMigrate(&Log{})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		logger.SysLog("database migrated")
 | 
			
		||||
		return db, err
 | 
			
		||||
	} else {
 | 
			
		||||
		logger.FatalLog(err)
 | 
			
		||||
func openSQLite() (*gorm.DB, error) {
 | 
			
		||||
	logger.SysLog("SQL_DSN not set, using SQLite as database")
 | 
			
		||||
	common.UsingSQLite = true
 | 
			
		||||
	dsn := fmt.Sprintf("%s?_busy_timeout=%d", common.SQLitePath, common.SQLiteBusyTimeout)
 | 
			
		||||
	return gorm.Open(sqlite.Open(dsn), &gorm.Config{
 | 
			
		||||
		PrepareStmt: true, // precompile SQL
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InitDB() {
 | 
			
		||||
	var err error
 | 
			
		||||
	DB, err = chooseDB("SQL_DSN")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.FatalLog("failed to initialize database: " + err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return db, err
 | 
			
		||||
 | 
			
		||||
	sqlDB := setDBConns(DB)
 | 
			
		||||
 | 
			
		||||
	if !config.IsMasterNode {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if common.UsingMySQL {
 | 
			
		||||
		_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.SysLog("database migration started")
 | 
			
		||||
	if err = migrateDB(); err != nil {
 | 
			
		||||
		logger.FatalLog("failed to migrate database: " + err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	logger.SysLog("database migrated")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func migrateDB() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	if err = DB.AutoMigrate(&Channel{}); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if err = DB.AutoMigrate(&Token{}); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if err = DB.AutoMigrate(&User{}); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if err = DB.AutoMigrate(&Option{}); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if err = DB.AutoMigrate(&Redemption{}); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if err = DB.AutoMigrate(&Ability{}); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if err = DB.AutoMigrate(&Log{}); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if err = DB.AutoMigrate(&Channel{}); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InitLogDB() {
 | 
			
		||||
	if os.Getenv("LOG_SQL_DSN") == "" {
 | 
			
		||||
		LOG_DB = DB
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.SysLog("using secondary database for table logs")
 | 
			
		||||
	var err error
 | 
			
		||||
	LOG_DB, err = chooseDB("LOG_SQL_DSN")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.FatalLog("failed to initialize secondary database: " + err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	setDBConns(LOG_DB)
 | 
			
		||||
 | 
			
		||||
	if !config.IsMasterNode {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.SysLog("secondary database migration started")
 | 
			
		||||
	err = migrateLOGDB()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.FatalLog("failed to migrate secondary database: " + err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	logger.SysLog("secondary database migrated")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func migrateLOGDB() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func setDBConns(db *gorm.DB) *sql.DB {
 | 
			
		||||
	if config.DebugSQLEnabled {
 | 
			
		||||
		db = db.Debug()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sqlDB, err := db.DB()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.FatalLog("failed to connect database: " + err.Error())
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
 | 
			
		||||
	sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
 | 
			
		||||
	sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
 | 
			
		||||
	return sqlDB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func closeDB(db *gorm.DB) error {
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,12 @@ import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
@@ -12,10 +18,6 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/constant"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
 | 
			
		||||
@@ -89,6 +91,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
 | 
			
		||||
 | 
			
		||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	var usage model.Usage
 | 
			
		||||
	var documents []LibraryDocument
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
@@ -102,60 +105,48 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 5 { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if data[:5] != "data:" {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = data[5:]
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	var documents []LibraryDocument
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			var AIProxyLibraryResponse LibraryStreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			if len(AIProxyLibraryResponse.Documents) != 0 {
 | 
			
		||||
				documents = AIProxyLibraryResponse.Documents
 | 
			
		||||
			}
 | 
			
		||||
			response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			response := documentsAIProxyLibrary(documents)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if len(data) < 5 || data[:5] != "data:" {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
		data = data[5:]
 | 
			
		||||
 | 
			
		||||
		var AIProxyLibraryResponse LibraryStreamResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if len(AIProxyLibraryResponse.Documents) != 0 {
 | 
			
		||||
			documents = AIProxyLibraryResponse.Documents
 | 
			
		||||
		}
 | 
			
		||||
		response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	response := documentsAIProxyLibrary(documents)
 | 
			
		||||
	err := render.ObjectData(c, response)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysError(err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,15 +3,17 @@ package ali
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
 | 
			
		||||
@@ -181,56 +183,43 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 5 { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if data[:5] != "data:" {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = data[5:]
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	//lastResponseText := ""
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			var aliResponse ChatResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &aliResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			if aliResponse.Usage.OutputTokens != 0 {
 | 
			
		||||
				usage.PromptTokens = aliResponse.Usage.InputTokens
 | 
			
		||||
				usage.CompletionTokens = aliResponse.Usage.OutputTokens
 | 
			
		||||
				usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
 | 
			
		||||
			}
 | 
			
		||||
			response := streamResponseAli2OpenAI(&aliResponse)
 | 
			
		||||
			if response == nil {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
 | 
			
		||||
			//lastResponseText = aliResponse.Output.Text
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if len(data) < 5 || data[:5] != "data:" {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
		data = data[5:]
 | 
			
		||||
 | 
			
		||||
		var aliResponse ChatResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &aliResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if aliResponse.Usage.OutputTokens != 0 {
 | 
			
		||||
			usage.PromptTokens = aliResponse.Usage.InputTokens
 | 
			
		||||
			usage.CompletionTokens = aliResponse.Usage.OutputTokens
 | 
			
		||||
			usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
 | 
			
		||||
		}
 | 
			
		||||
		response := streamResponseAli2OpenAI(&aliResponse)
 | 
			
		||||
		if response == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
 
 | 
			
		||||
@@ -5,4 +5,5 @@ var ModelList = []string{
 | 
			
		||||
	"claude-3-haiku-20240307",
 | 
			
		||||
	"claude-3-sonnet-20240229",
 | 
			
		||||
	"claude-3-opus-20240229",
 | 
			
		||||
	"claude-3-5-sonnet-20240620",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -28,12 +29,30 @@ func stopReasonClaude2OpenAI(reason *string) string {
 | 
			
		||||
		return "stop"
 | 
			
		||||
	case "max_tokens":
 | 
			
		||||
		return "length"
 | 
			
		||||
	case "tool_use":
 | 
			
		||||
		return "tool_calls"
 | 
			
		||||
	default:
 | 
			
		||||
		return *reason
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
 | 
			
		||||
	claudeTools := make([]Tool, 0, len(textRequest.Tools))
 | 
			
		||||
 | 
			
		||||
	for _, tool := range textRequest.Tools {
 | 
			
		||||
		if params, ok := tool.Function.Parameters.(map[string]any); ok {
 | 
			
		||||
			claudeTools = append(claudeTools, Tool{
 | 
			
		||||
				Name:        tool.Function.Name,
 | 
			
		||||
				Description: tool.Function.Description,
 | 
			
		||||
				InputSchema: InputSchema{
 | 
			
		||||
					Type:       params["type"].(string),
 | 
			
		||||
					Properties: params["properties"],
 | 
			
		||||
					Required:   params["required"],
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	claudeRequest := Request{
 | 
			
		||||
		Model:       textRequest.Model,
 | 
			
		||||
		MaxTokens:   textRequest.MaxTokens,
 | 
			
		||||
@@ -41,6 +60,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
 | 
			
		||||
		TopP:        textRequest.TopP,
 | 
			
		||||
		TopK:        textRequest.TopK,
 | 
			
		||||
		Stream:      textRequest.Stream,
 | 
			
		||||
		Tools:       claudeTools,
 | 
			
		||||
	}
 | 
			
		||||
	if len(claudeTools) > 0 {
 | 
			
		||||
		claudeToolChoice := struct {
 | 
			
		||||
			Type string `json:"type"`
 | 
			
		||||
			Name string `json:"name,omitempty"`
 | 
			
		||||
		}{Type: "auto"} // default value https://docs.anthropic.com/en/docs/build-with-claude/tool-use#controlling-claudes-output
 | 
			
		||||
		if choice, ok := textRequest.ToolChoice.(map[string]any); ok {
 | 
			
		||||
			if function, ok := choice["function"].(map[string]any); ok {
 | 
			
		||||
				claudeToolChoice.Type = "tool"
 | 
			
		||||
				claudeToolChoice.Name = function["name"].(string)
 | 
			
		||||
			}
 | 
			
		||||
		} else if toolChoiceType, ok := textRequest.ToolChoice.(string); ok {
 | 
			
		||||
			if toolChoiceType == "any" {
 | 
			
		||||
				claudeToolChoice.Type = toolChoiceType
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		claudeRequest.ToolChoice = claudeToolChoice
 | 
			
		||||
	}
 | 
			
		||||
	if claudeRequest.MaxTokens == 0 {
 | 
			
		||||
		claudeRequest.MaxTokens = 4096
 | 
			
		||||
@@ -63,7 +100,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
 | 
			
		||||
		if message.IsStringContent() {
 | 
			
		||||
			content.Type = "text"
 | 
			
		||||
			content.Text = message.StringContent()
 | 
			
		||||
			if message.Role == "tool" {
 | 
			
		||||
				claudeMessage.Role = "user"
 | 
			
		||||
				content.Type = "tool_result"
 | 
			
		||||
				content.Content = content.Text
 | 
			
		||||
				content.Text = ""
 | 
			
		||||
				content.ToolUseId = message.ToolCallId
 | 
			
		||||
			}
 | 
			
		||||
			claudeMessage.Content = append(claudeMessage.Content, content)
 | 
			
		||||
			for i := range message.ToolCalls {
 | 
			
		||||
				inputParam := make(map[string]any)
 | 
			
		||||
				_ = json.Unmarshal([]byte(message.ToolCalls[i].Function.Arguments.(string)), &inputParam)
 | 
			
		||||
				claudeMessage.Content = append(claudeMessage.Content, Content{
 | 
			
		||||
					Type:  "tool_use",
 | 
			
		||||
					Id:    message.ToolCalls[i].Id,
 | 
			
		||||
					Name:  message.ToolCalls[i].Function.Name,
 | 
			
		||||
					Input: inputParam,
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
			claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
@@ -96,16 +150,35 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
 | 
			
		||||
	var response *Response
 | 
			
		||||
	var responseText string
 | 
			
		||||
	var stopReason string
 | 
			
		||||
	tools := make([]model.Tool, 0)
 | 
			
		||||
 | 
			
		||||
	switch claudeResponse.Type {
 | 
			
		||||
	case "message_start":
 | 
			
		||||
		return nil, claudeResponse.Message
 | 
			
		||||
	case "content_block_start":
 | 
			
		||||
		if claudeResponse.ContentBlock != nil {
 | 
			
		||||
			responseText = claudeResponse.ContentBlock.Text
 | 
			
		||||
			if claudeResponse.ContentBlock.Type == "tool_use" {
 | 
			
		||||
				tools = append(tools, model.Tool{
 | 
			
		||||
					Id:   claudeResponse.ContentBlock.Id,
 | 
			
		||||
					Type: "function",
 | 
			
		||||
					Function: model.Function{
 | 
			
		||||
						Name:      claudeResponse.ContentBlock.Name,
 | 
			
		||||
						Arguments: "",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	case "content_block_delta":
 | 
			
		||||
		if claudeResponse.Delta != nil {
 | 
			
		||||
			responseText = claudeResponse.Delta.Text
 | 
			
		||||
			if claudeResponse.Delta.Type == "input_json_delta" {
 | 
			
		||||
				tools = append(tools, model.Tool{
 | 
			
		||||
					Function: model.Function{
 | 
			
		||||
						Arguments: claudeResponse.Delta.PartialJson,
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	case "message_delta":
 | 
			
		||||
		if claudeResponse.Usage != nil {
 | 
			
		||||
@@ -119,6 +192,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
 | 
			
		||||
	}
 | 
			
		||||
	var choice openai.ChatCompletionsStreamResponseChoice
 | 
			
		||||
	choice.Delta.Content = responseText
 | 
			
		||||
	if len(tools) > 0 {
 | 
			
		||||
		choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
 | 
			
		||||
		choice.Delta.ToolCalls = tools
 | 
			
		||||
	}
 | 
			
		||||
	choice.Delta.Role = "assistant"
 | 
			
		||||
	finishReason := stopReasonClaude2OpenAI(&stopReason)
 | 
			
		||||
	if finishReason != "null" {
 | 
			
		||||
@@ -135,12 +212,27 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
 | 
			
		||||
	if len(claudeResponse.Content) > 0 {
 | 
			
		||||
		responseText = claudeResponse.Content[0].Text
 | 
			
		||||
	}
 | 
			
		||||
	tools := make([]model.Tool, 0)
 | 
			
		||||
	for _, v := range claudeResponse.Content {
 | 
			
		||||
		if v.Type == "tool_use" {
 | 
			
		||||
			args, _ := json.Marshal(v.Input)
 | 
			
		||||
			tools = append(tools, model.Tool{
 | 
			
		||||
				Id:   v.Id,
 | 
			
		||||
				Type: "function", // compatible with other OpenAI derivative applications
 | 
			
		||||
				Function: model.Function{
 | 
			
		||||
					Name:      v.Name,
 | 
			
		||||
					Arguments: string(args),
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	choice := openai.TextResponseChoice{
 | 
			
		||||
		Index: 0,
 | 
			
		||||
		Message: model.Message{
 | 
			
		||||
			Role:    "assistant",
 | 
			
		||||
			Content: responseText,
 | 
			
		||||
			Name:    nil,
 | 
			
		||||
			Role:      "assistant",
 | 
			
		||||
			Content:   responseText,
 | 
			
		||||
			Name:      nil,
 | 
			
		||||
			ToolCalls: tools,
 | 
			
		||||
		},
 | 
			
		||||
		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
 | 
			
		||||
	}
 | 
			
		||||
@@ -169,64 +261,77 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 6 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if !strings.HasPrefix(data, "data:") {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = strings.TrimPrefix(data, "data:")
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
 | 
			
		||||
	var usage model.Usage
 | 
			
		||||
	var modelName string
 | 
			
		||||
	var id string
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			// some implementations may add \r at the end of data
 | 
			
		||||
			data = strings.TrimSpace(data)
 | 
			
		||||
			var claudeResponse StreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &claudeResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response, meta := StreamResponseClaude2OpenAI(&claudeResponse)
 | 
			
		||||
			if meta != nil {
 | 
			
		||||
				usage.PromptTokens += meta.Usage.InputTokens
 | 
			
		||||
				usage.CompletionTokens += meta.Usage.OutputTokens
 | 
			
		||||
	var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if len(data) < 6 || !strings.HasPrefix(data, "data:") {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		data = strings.TrimPrefix(data, "data:")
 | 
			
		||||
		data = strings.TrimSpace(data)
 | 
			
		||||
 | 
			
		||||
		var claudeResponse StreamResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &claudeResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		response, meta := StreamResponseClaude2OpenAI(&claudeResponse)
 | 
			
		||||
		if meta != nil {
 | 
			
		||||
			usage.PromptTokens += meta.Usage.InputTokens
 | 
			
		||||
			usage.CompletionTokens += meta.Usage.OutputTokens
 | 
			
		||||
			if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
 | 
			
		||||
				modelName = meta.Model
 | 
			
		||||
				id = fmt.Sprintf("chatcmpl-%s", meta.Id)
 | 
			
		||||
				return true
 | 
			
		||||
				continue
 | 
			
		||||
			} else { // finish_reason case
 | 
			
		||||
				if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
 | 
			
		||||
					lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
 | 
			
		||||
					if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
 | 
			
		||||
						lastArgs.Arguments = "{}"
 | 
			
		||||
						response.Choices[len(response.Choices)-1].Delta.Content = nil
 | 
			
		||||
						response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if response == nil {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response.Id = id
 | 
			
		||||
			response.Model = modelName
 | 
			
		||||
			response.Created = createdTime
 | 
			
		||||
			jsonStr, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	_ = resp.Body.Close()
 | 
			
		||||
		if response == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		response.Id = id
 | 
			
		||||
		response.Model = modelName
 | 
			
		||||
		response.Created = createdTime
 | 
			
		||||
 | 
			
		||||
		for _, choice := range response.Choices {
 | 
			
		||||
			if len(choice.Delta.ToolCalls) > 0 {
 | 
			
		||||
				lastToolCallChoice = choice
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -16,6 +16,12 @@ type Content struct {
 | 
			
		||||
	Type   string       `json:"type"`
 | 
			
		||||
	Text   string       `json:"text,omitempty"`
 | 
			
		||||
	Source *ImageSource `json:"source,omitempty"`
 | 
			
		||||
	// tool_calls
 | 
			
		||||
	Id        string `json:"id,omitempty"`
 | 
			
		||||
	Name      string `json:"name,omitempty"`
 | 
			
		||||
	Input     any    `json:"input,omitempty"`
 | 
			
		||||
	Content   string `json:"content,omitempty"`
 | 
			
		||||
	ToolUseId string `json:"tool_use_id,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Message struct {
 | 
			
		||||
@@ -23,6 +29,18 @@ type Message struct {
 | 
			
		||||
	Content []Content `json:"content"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Tool struct {
 | 
			
		||||
	Name        string      `json:"name"`
 | 
			
		||||
	Description string      `json:"description,omitempty"`
 | 
			
		||||
	InputSchema InputSchema `json:"input_schema"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type InputSchema struct {
 | 
			
		||||
	Type       string `json:"type"`
 | 
			
		||||
	Properties any    `json:"properties,omitempty"`
 | 
			
		||||
	Required   any    `json:"required,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Request struct {
 | 
			
		||||
	Model         string    `json:"model"`
 | 
			
		||||
	Messages      []Message `json:"messages"`
 | 
			
		||||
@@ -33,6 +51,8 @@ type Request struct {
 | 
			
		||||
	Temperature   float64   `json:"temperature,omitempty"`
 | 
			
		||||
	TopP          float64   `json:"top_p,omitempty"`
 | 
			
		||||
	TopK          int       `json:"top_k,omitempty"`
 | 
			
		||||
	Tools         []Tool    `json:"tools,omitempty"`
 | 
			
		||||
	ToolChoice    any       `json:"tool_choice,omitempty"`
 | 
			
		||||
	//Metadata    `json:"metadata,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -61,6 +81,7 @@ type Response struct {
 | 
			
		||||
type Delta struct {
 | 
			
		||||
	Type         string  `json:"type"`
 | 
			
		||||
	Text         string  `json:"text"`
 | 
			
		||||
	PartialJson  string  `json:"partial_json,omitempty"`
 | 
			
		||||
	StopReason   *string `json:"stop_reason"`
 | 
			
		||||
	StopSequence *string `json:"stop_sequence"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,7 @@ import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
@@ -33,12 +34,13 @@ func wrapErr(err error) *relaymodel.ErrorWithStatusCode {
 | 
			
		||||
 | 
			
		||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
 | 
			
		||||
var awsModelIDMap = map[string]string{
 | 
			
		||||
	"claude-instant-1.2":       "anthropic.claude-instant-v1",
 | 
			
		||||
	"claude-2.0":               "anthropic.claude-v2",
 | 
			
		||||
	"claude-2.1":               "anthropic.claude-v2:1",
 | 
			
		||||
	"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
 | 
			
		||||
	"claude-3-opus-20240229":   "anthropic.claude-3-opus-20240229-v1:0",
 | 
			
		||||
	"claude-3-haiku-20240307":  "anthropic.claude-3-haiku-20240307-v1:0",
 | 
			
		||||
	"claude-instant-1.2":         "anthropic.claude-instant-v1",
 | 
			
		||||
	"claude-2.0":                 "anthropic.claude-v2",
 | 
			
		||||
	"claude-2.1":                 "anthropic.claude-v2:1",
 | 
			
		||||
	"claude-3-sonnet-20240229":   "anthropic.claude-3-sonnet-20240229-v1:0",
 | 
			
		||||
	"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
 | 
			
		||||
	"claude-3-opus-20240229":     "anthropic.claude-3-opus-20240229-v1:0",
 | 
			
		||||
	"claude-3-haiku-20240307":    "anthropic.claude-3-haiku-20240307-v1:0",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func awsModelID(requestModel string) (string, error) {
 | 
			
		||||
@@ -142,6 +144,8 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
	var usage relaymodel.Usage
 | 
			
		||||
	var id string
 | 
			
		||||
	var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
 | 
			
		||||
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		event, ok := <-stream.Events()
 | 
			
		||||
		if !ok {
 | 
			
		||||
@@ -162,8 +166,19 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
 | 
			
		||||
			if meta != nil {
 | 
			
		||||
				usage.PromptTokens += meta.Usage.InputTokens
 | 
			
		||||
				usage.CompletionTokens += meta.Usage.OutputTokens
 | 
			
		||||
				id = fmt.Sprintf("chatcmpl-%s", meta.Id)
 | 
			
		||||
				return true
 | 
			
		||||
				if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
 | 
			
		||||
					id = fmt.Sprintf("chatcmpl-%s", meta.Id)
 | 
			
		||||
					return true
 | 
			
		||||
				} else { // finish_reason case
 | 
			
		||||
					if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
 | 
			
		||||
						lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
 | 
			
		||||
						if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
 | 
			
		||||
							lastArgs.Arguments = "{}"
 | 
			
		||||
							response.Choices[len(response.Choices)-1].Delta.Content = nil
 | 
			
		||||
							response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if response == nil {
 | 
			
		||||
				return true
 | 
			
		||||
@@ -171,6 +186,12 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
 | 
			
		||||
			response.Id = id
 | 
			
		||||
			response.Model = c.GetString(ctxkey.OriginalModel)
 | 
			
		||||
			response.Created = createdTime
 | 
			
		||||
 | 
			
		||||
			for _, choice := range response.Choices {
 | 
			
		||||
				if len(choice.Delta.ToolCalls) > 0 {
 | 
			
		||||
					lastToolCallChoice = choice
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			jsonStr, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
 
 | 
			
		||||
@@ -9,9 +9,12 @@ type Request struct {
 | 
			
		||||
	// AnthropicVersion should be "bedrock-2023-05-31"
 | 
			
		||||
	AnthropicVersion string              `json:"anthropic_version"`
 | 
			
		||||
	Messages         []anthropic.Message `json:"messages"`
 | 
			
		||||
	System           string              `json:"system,omitempty"`
 | 
			
		||||
	MaxTokens        int                 `json:"max_tokens,omitempty"`
 | 
			
		||||
	Temperature      float64             `json:"temperature,omitempty"`
 | 
			
		||||
	TopP             float64             `json:"top_p,omitempty"`
 | 
			
		||||
	TopK             int                 `json:"top_k,omitempty"`
 | 
			
		||||
	StopSequences    []string            `json:"stop_sequences,omitempty"`
 | 
			
		||||
	Tools            []anthropic.Tool    `json:"tools,omitempty"`
 | 
			
		||||
	ToolChoice       any                 `json:"tool_choice,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,13 @@ import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/client"
 | 
			
		||||
@@ -12,11 +19,6 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/constant"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
 | 
			
		||||
@@ -137,59 +139,41 @@ func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.Embeddin
 | 
			
		||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	var usage model.Usage
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 6 { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = data[6:]
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			var baiduResponse ChatStreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &baiduResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			if baiduResponse.Usage.TotalTokens != 0 {
 | 
			
		||||
				usage.TotalTokens = baiduResponse.Usage.TotalTokens
 | 
			
		||||
				usage.PromptTokens = baiduResponse.Usage.PromptTokens
 | 
			
		||||
				usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
 | 
			
		||||
			}
 | 
			
		||||
			response := streamResponseBaidu2OpenAI(&baiduResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if len(data) < 6 {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
		data = data[6:]
 | 
			
		||||
 | 
			
		||||
		var baiduResponse ChatStreamResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &baiduResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if baiduResponse.Usage.TotalTokens != 0 {
 | 
			
		||||
			usage.TotalTokens = baiduResponse.Usage.TotalTokens
 | 
			
		||||
			usage.PromptTokens = baiduResponse.Usage.PromptTokens
 | 
			
		||||
			usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
 | 
			
		||||
		}
 | 
			
		||||
		response := streamResponseBaidu2OpenAI(&baiduResponse)
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
 
 | 
			
		||||
@@ -2,8 +2,8 @@ package cloudflare
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -17,21 +17,20 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
 | 
			
		||||
    var promptBuilder strings.Builder
 | 
			
		||||
    for _, message := range textRequest.Messages {
 | 
			
		||||
        promptBuilder.WriteString(message.StringContent())
 | 
			
		||||
        promptBuilder.WriteString("\n")  // 添加换行符来分隔每个消息
 | 
			
		||||
    }
 | 
			
		||||
	var promptBuilder strings.Builder
 | 
			
		||||
	for _, message := range textRequest.Messages {
 | 
			
		||||
		promptBuilder.WriteString(message.StringContent())
 | 
			
		||||
		promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
    return &Request{
 | 
			
		||||
        MaxTokens:   textRequest.MaxTokens,
 | 
			
		||||
        Prompt:      promptBuilder.String(),
 | 
			
		||||
        Stream:      textRequest.Stream,
 | 
			
		||||
        Temperature: textRequest.Temperature,
 | 
			
		||||
    }
 | 
			
		||||
	return &Request{
 | 
			
		||||
		MaxTokens:   textRequest.MaxTokens,
 | 
			
		||||
		Prompt:      promptBuilder.String(),
 | 
			
		||||
		Stream:      textRequest.Stream,
 | 
			
		||||
		Temperature: textRequest.Temperature,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse {
 | 
			
		||||
	choice := openai.TextResponseChoice{
 | 
			
		||||
		Index: 0,
 | 
			
		||||
@@ -63,67 +62,54 @@ func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai
 | 
			
		||||
 | 
			
		||||
func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := bytes.IndexByte(data, '\n'); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < len("data: ") {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = strings.TrimPrefix(data, "data: ")
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	id := helper.GetResponseID(c)
 | 
			
		||||
	responseModel := c.GetString("original_model")
 | 
			
		||||
	var responseText string
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			// some implementations may add \r at the end of data
 | 
			
		||||
			data = strings.TrimSuffix(data, "\r")
 | 
			
		||||
			var cloudflareResponse StreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &cloudflareResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response := StreamResponseCloudflare2OpenAI(&cloudflareResponse)
 | 
			
		||||
			if response == nil {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			responseText += cloudflareResponse.Response
 | 
			
		||||
			response.Id = id
 | 
			
		||||
			response.Model = responseModel
 | 
			
		||||
			jsonStr, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if len(data) < len("data: ") {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	_ = resp.Body.Close()
 | 
			
		||||
		data = strings.TrimPrefix(data, "data: ")
 | 
			
		||||
		data = strings.TrimSuffix(data, "\r")
 | 
			
		||||
 | 
			
		||||
		var cloudflareResponse StreamResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &cloudflareResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		response := StreamResponseCloudflare2OpenAI(&cloudflareResponse)
 | 
			
		||||
		if response == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		responseText += cloudflareResponse.Response
 | 
			
		||||
		response.Id = id
 | 
			
		||||
		response.Model = responseModel
 | 
			
		||||
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	usage := openai.ResponseText2Usage(responseText, responseModel, promptTokens)
 | 
			
		||||
	return nil, usage
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,9 +2,9 @@ package cohere
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -134,66 +134,53 @@ func ResponseCohere2OpenAI(cohereResponse *Response) *openai.TextResponse {
 | 
			
		||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	createdTime := helper.GetTimestamp()
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := bytes.IndexByte(data, '\n'); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	var usage model.Usage
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			// some implementations may add \r at the end of data
 | 
			
		||||
			data = strings.TrimSuffix(data, "\r")
 | 
			
		||||
			var cohereResponse StreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &cohereResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response, meta := StreamResponseCohere2OpenAI(&cohereResponse)
 | 
			
		||||
			if meta != nil {
 | 
			
		||||
				usage.PromptTokens += meta.Meta.Tokens.InputTokens
 | 
			
		||||
				usage.CompletionTokens += meta.Meta.Tokens.OutputTokens
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			if response == nil {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response.Id = fmt.Sprintf("chatcmpl-%d", createdTime)
 | 
			
		||||
			response.Model = c.GetString("original_model")
 | 
			
		||||
			response.Created = createdTime
 | 
			
		||||
			jsonStr, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		data = strings.TrimSuffix(data, "\r")
 | 
			
		||||
 | 
			
		||||
		var cohereResponse StreamResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &cohereResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	_ = resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
		response, meta := StreamResponseCohere2OpenAI(&cohereResponse)
 | 
			
		||||
		if meta != nil {
 | 
			
		||||
			usage.PromptTokens += meta.Meta.Tokens.InputTokens
 | 
			
		||||
			usage.CompletionTokens += meta.Meta.Tokens.OutputTokens
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if response == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		response.Id = fmt.Sprintf("chatcmpl-%d", createdTime)
 | 
			
		||||
		response.Model = c.GetString("original_model")
 | 
			
		||||
		response.Created = createdTime
 | 
			
		||||
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,11 @@ import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/conv"
 | 
			
		||||
@@ -12,9 +17,6 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://www.coze.com/open
 | 
			
		||||
@@ -109,69 +111,54 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
	var responseText string
 | 
			
		||||
	createdTime := helper.GetTimestamp()
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 5 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if !strings.HasPrefix(data, "data:") {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = strings.TrimPrefix(data, "data:")
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	var modelName string
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			// some implementations may add \r at the end of data
 | 
			
		||||
			data = strings.TrimSuffix(data, "\r")
 | 
			
		||||
			var cozeResponse StreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &cozeResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response, _ := StreamResponseCoze2OpenAI(&cozeResponse)
 | 
			
		||||
			if response == nil {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			for _, choice := range response.Choices {
 | 
			
		||||
				responseText += conv.AsString(choice.Delta.Content)
 | 
			
		||||
			}
 | 
			
		||||
			response.Model = modelName
 | 
			
		||||
			response.Created = createdTime
 | 
			
		||||
			jsonStr, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if len(data) < 5 || !strings.HasPrefix(data, "data:") {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	_ = resp.Body.Close()
 | 
			
		||||
		data = strings.TrimPrefix(data, "data:")
 | 
			
		||||
		data = strings.TrimSuffix(data, "\r")
 | 
			
		||||
 | 
			
		||||
		var cozeResponse StreamResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &cozeResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		response, _ := StreamResponseCoze2OpenAI(&cozeResponse)
 | 
			
		||||
		if response == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, choice := range response.Choices {
 | 
			
		||||
			responseText += conv.AsString(choice.Delta.Content)
 | 
			
		||||
		}
 | 
			
		||||
		response.Model = modelName
 | 
			
		||||
		response.Created = createdTime
 | 
			
		||||
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, &responseText
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -275,64 +276,50 @@ func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.Embeddi
 | 
			
		||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
 | 
			
		||||
	responseText := ""
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			data = strings.TrimSpace(data)
 | 
			
		||||
			if !strings.HasPrefix(data, "data: ") {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = strings.TrimPrefix(data, "data: ")
 | 
			
		||||
			data = strings.TrimSuffix(data, "\"")
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			var geminiResponse ChatResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &geminiResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response := streamResponseGeminiChat2OpenAI(&geminiResponse)
 | 
			
		||||
			if response == nil {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			responseText += response.Choices[0].Delta.StringContent()
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		data = strings.TrimSpace(data)
 | 
			
		||||
		if !strings.HasPrefix(data, "data: ") {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
		data = strings.TrimPrefix(data, "data: ")
 | 
			
		||||
		data = strings.TrimSuffix(data, "\"")
 | 
			
		||||
 | 
			
		||||
		var geminiResponse ChatResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &geminiResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		response := streamResponseGeminiChat2OpenAI(&geminiResponse)
 | 
			
		||||
		if response == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		responseText += response.Choices[0].Delta.StringContent()
 | 
			
		||||
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, responseText
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -5,12 +5,14 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/random"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/random"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/image"
 | 
			
		||||
@@ -105,54 +107,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "}\n"); i >= 0 {
 | 
			
		||||
			return i + 2, data[0:i], nil
 | 
			
		||||
			return i + 2, data[0 : i+1], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := strings.TrimPrefix(scanner.Text(), "}")
 | 
			
		||||
			dataChan <- data + "}"
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			var ollamaResponse ChatResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &ollamaResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			if ollamaResponse.EvalCount != 0 {
 | 
			
		||||
				usage.PromptTokens = ollamaResponse.PromptEvalCount
 | 
			
		||||
				usage.CompletionTokens = ollamaResponse.EvalCount
 | 
			
		||||
				usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount
 | 
			
		||||
			}
 | 
			
		||||
			response := streamResponseOllama2OpenAI(&ollamaResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := strings.TrimPrefix(scanner.Text(), "}")
 | 
			
		||||
		data = data + "}"
 | 
			
		||||
 | 
			
		||||
		var ollamaResponse ChatResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &ollamaResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
		if ollamaResponse.EvalCount != 0 {
 | 
			
		||||
			usage.PromptTokens = ollamaResponse.PromptEvalCount
 | 
			
		||||
			usage.CompletionTokens = ollamaResponse.EvalCount
 | 
			
		||||
			usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		response := streamResponseOllama2OpenAI(&ollamaResponse)
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -4,15 +4,17 @@ import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/conv"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/relaymode"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
@@ -24,88 +26,68 @@ const (
 | 
			
		||||
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
 | 
			
		||||
	responseText := ""
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
	var usage *model.Usage
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < dataPrefixLength { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(data[dataPrefixLength:], done) {
 | 
			
		||||
				dataChan <- data
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			switch relayMode {
 | 
			
		||||
			case relaymode.ChatCompletions:
 | 
			
		||||
				var streamResponse ChatCompletionsStreamResponse
 | 
			
		||||
				err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
					dataChan <- data // if error happened, pass the data to client
 | 
			
		||||
					continue         // just ignore the error
 | 
			
		||||
				}
 | 
			
		||||
				if len(streamResponse.Choices) == 0 {
 | 
			
		||||
					// but for empty choice, we should not pass it to client, this is for azure
 | 
			
		||||
					continue // just ignore empty choice
 | 
			
		||||
				}
 | 
			
		||||
				dataChan <- data
 | 
			
		||||
				for _, choice := range streamResponse.Choices {
 | 
			
		||||
					responseText += conv.AsString(choice.Delta.Content)
 | 
			
		||||
				}
 | 
			
		||||
				if streamResponse.Usage != nil {
 | 
			
		||||
					usage = streamResponse.Usage
 | 
			
		||||
				}
 | 
			
		||||
			case relaymode.Completions:
 | 
			
		||||
				dataChan <- data
 | 
			
		||||
				var streamResponse CompletionsStreamResponse
 | 
			
		||||
				err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				for _, choice := range streamResponse.Choices {
 | 
			
		||||
					responseText += choice.Text
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			if strings.HasPrefix(data, "data: [DONE]") {
 | 
			
		||||
				data = data[:12]
 | 
			
		||||
			}
 | 
			
		||||
			// some implementations may add \r at the end of data
 | 
			
		||||
			data = strings.TrimSuffix(data, "\r")
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: data})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if len(data) < dataPrefixLength { // ignore blank line or wrong format
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
		if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if strings.HasPrefix(data[dataPrefixLength:], done) {
 | 
			
		||||
			render.StringData(c, data)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		switch relayMode {
 | 
			
		||||
		case relaymode.ChatCompletions:
 | 
			
		||||
			var streamResponse ChatCompletionsStreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				render.StringData(c, data) // if error happened, pass the data to client
 | 
			
		||||
				continue                   // just ignore the error
 | 
			
		||||
			}
 | 
			
		||||
			if len(streamResponse.Choices) == 0 {
 | 
			
		||||
				// but for empty choice, we should not pass it to client, this is for azure
 | 
			
		||||
				continue // just ignore empty choice
 | 
			
		||||
			}
 | 
			
		||||
			render.StringData(c, data)
 | 
			
		||||
			for _, choice := range streamResponse.Choices {
 | 
			
		||||
				responseText += conv.AsString(choice.Delta.Content)
 | 
			
		||||
			}
 | 
			
		||||
			if streamResponse.Usage != nil {
 | 
			
		||||
				usage = streamResponse.Usage
 | 
			
		||||
			}
 | 
			
		||||
		case relaymode.Completions:
 | 
			
		||||
			render.StringData(c, data)
 | 
			
		||||
			var streamResponse CompletionsStreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			for _, choice := range streamResponse.Choices {
 | 
			
		||||
				responseText += choice.Text
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, responseText, usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -149,7 +131,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
 | 
			
		||||
		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if textResponse.Usage.TotalTokens == 0 {
 | 
			
		||||
	if textResponse.Usage.TotalTokens == 0 || (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) {
 | 
			
		||||
		completionTokens := 0
 | 
			
		||||
		for _, choice := range textResponse.Choices {
 | 
			
		||||
			completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,10 @@ package palm
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
@@ -11,8 +15,6 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/constant"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
 | 
			
		||||
@@ -77,58 +79,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
	responseText := ""
 | 
			
		||||
	responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID())
 | 
			
		||||
	createdTime := helper.GetTimestamp()
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error reading stream response: " + err.Error())
 | 
			
		||||
			stopChan <- true
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		err = resp.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error closing stream response: " + err.Error())
 | 
			
		||||
			stopChan <- true
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		var palmResponse ChatResponse
 | 
			
		||||
		err = json.Unmarshal(responseBody, &palmResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			stopChan <- true
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
 | 
			
		||||
		fullTextResponse.Id = responseId
 | 
			
		||||
		fullTextResponse.Created = createdTime
 | 
			
		||||
		if len(palmResponse.Candidates) > 0 {
 | 
			
		||||
			responseText = palmResponse.Candidates[0].Content
 | 
			
		||||
		}
 | 
			
		||||
		jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
			stopChan <- true
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		dataChan <- string(jsonResponse)
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + data})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysError("error reading stream response: " + err.Error())
 | 
			
		||||
		err := resp.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var palmResponse ChatResponse
 | 
			
		||||
	err = json.Unmarshal(responseBody, &palmResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
 | 
			
		||||
	fullTextResponse.Id = responseId
 | 
			
		||||
	fullTextResponse.Created = createdTime
 | 
			
		||||
	if len(palmResponse.Candidates) > 0 {
 | 
			
		||||
		responseText = palmResponse.Candidates[0].Content
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = render.ObjectData(c, string(jsonResponse))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysError(err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	return nil, responseText
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,13 @@ import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/conv"
 | 
			
		||||
@@ -17,11 +24,6 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/constant"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
@@ -87,64 +89,46 @@ func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCom
 | 
			
		||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
 | 
			
		||||
	var responseText string
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 5 { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if data[:5] != "data:" {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = data[5:]
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			var TencentResponse ChatResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &TencentResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response := streamResponseTencent2OpenAI(&TencentResponse)
 | 
			
		||||
			if len(response.Choices) != 0 {
 | 
			
		||||
				responseText += conv.AsString(response.Choices[0].Delta.Content)
 | 
			
		||||
			}
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if len(data) < 5 || !strings.HasPrefix(data, "data:") {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
		data = strings.TrimPrefix(data, "data:")
 | 
			
		||||
 | 
			
		||||
		var tencentResponse ChatResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &tencentResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		response := streamResponseTencent2OpenAI(&tencentResponse)
 | 
			
		||||
		if len(response.Choices) != 0 {
 | 
			
		||||
			responseText += conv.AsString(response.Choices[0].Delta.Content)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, responseText
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -6,4 +6,5 @@ var ModelList = []string{
 | 
			
		||||
	"SparkDesk-v2.1",
 | 
			
		||||
	"SparkDesk-v3.1",
 | 
			
		||||
	"SparkDesk-v3.5",
 | 
			
		||||
	"SparkDesk-v4.0",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -44,7 +44,7 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string
 | 
			
		||||
	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
 | 
			
		||||
	xunfeiRequest.Payload.Message.Text = messages
 | 
			
		||||
 | 
			
		||||
	if strings.HasPrefix(domain, "generalv3") {
 | 
			
		||||
	if strings.HasPrefix(domain, "generalv3") || domain == "4.0Ultra" {
 | 
			
		||||
		functions := make([]model.Function, len(request.Tools))
 | 
			
		||||
		for i, tool := range request.Tools {
 | 
			
		||||
			functions[i] = tool.Function
 | 
			
		||||
@@ -290,6 +290,8 @@ func apiVersion2domain(apiVersion string) string {
 | 
			
		||||
		return "generalv3"
 | 
			
		||||
	case "v3.5":
 | 
			
		||||
		return "generalv3.5"
 | 
			
		||||
	case "v4.0":
 | 
			
		||||
		return "4.0Ultra"
 | 
			
		||||
	}
 | 
			
		||||
	return "general" + apiVersion
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,13 @@ package zhipu
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/golang-jwt/jwt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
@@ -11,11 +18,6 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/constant"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://open.bigmodel.cn/doc/api#chatglm_std
 | 
			
		||||
@@ -155,66 +157,55 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	metaChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			lines := strings.Split(data, "\n")
 | 
			
		||||
			for i, line := range lines {
 | 
			
		||||
				if len(line) < 5 {
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		lines := strings.Split(data, "\n")
 | 
			
		||||
		for i, line := range lines {
 | 
			
		||||
			if len(line) < 5 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(line, "data:") {
 | 
			
		||||
				dataSegment := line[5:]
 | 
			
		||||
				if i != len(lines)-1 {
 | 
			
		||||
					dataSegment += "\n"
 | 
			
		||||
				}
 | 
			
		||||
				response := streamResponseZhipu2OpenAI(dataSegment)
 | 
			
		||||
				err := render.ObjectData(c, response)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			} else if strings.HasPrefix(line, "meta:") {
 | 
			
		||||
				metaSegment := line[5:]
 | 
			
		||||
				var zhipuResponse StreamMetaResponse
 | 
			
		||||
				err := json.Unmarshal([]byte(metaSegment), &zhipuResponse)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				if line[:5] == "data:" {
 | 
			
		||||
					dataChan <- line[5:]
 | 
			
		||||
					if i != len(lines)-1 {
 | 
			
		||||
						dataChan <- "\n"
 | 
			
		||||
					}
 | 
			
		||||
				} else if line[:5] == "meta:" {
 | 
			
		||||
					metaChan <- line[5:]
 | 
			
		||||
				response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
 | 
			
		||||
				err = render.ObjectData(c, response)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
				usage = zhipuUsage
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			response := streamResponseZhipu2OpenAI(data)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case data := <-metaChan:
 | 
			
		||||
			var zhipuResponse StreamMetaResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &zhipuResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			usage = zhipuUsage
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										16
									
								
								relay/adaptor_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								relay/adaptor_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,16 @@
 | 
			
		||||
package relay
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	. "github.com/smartystreets/goconvey/convey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/apitype"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestGetAdaptor(t *testing.T) {
 | 
			
		||||
	Convey("get adaptor", t, func() {
 | 
			
		||||
		for i := 0; i < apitype.Dummy; i++ {
 | 
			
		||||
			a := GetAdaptor(i)
 | 
			
		||||
			So(a, ShouldNotBeNil)
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
@@ -70,12 +70,13 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"dall-e-2":                0.02 * USD, // $0.016 - $0.020 / image
 | 
			
		||||
	"dall-e-3":                0.04 * USD, // $0.040 - $0.120 / image
 | 
			
		||||
	// https://www.anthropic.com/api#pricing
 | 
			
		||||
	"claude-instant-1.2":       0.8 / 1000 * USD,
 | 
			
		||||
	"claude-2.0":               8.0 / 1000 * USD,
 | 
			
		||||
	"claude-2.1":               8.0 / 1000 * USD,
 | 
			
		||||
	"claude-3-haiku-20240307":  0.25 / 1000 * USD,
 | 
			
		||||
	"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
 | 
			
		||||
	"claude-3-opus-20240229":   15.0 / 1000 * USD,
 | 
			
		||||
	"claude-instant-1.2":         0.8 / 1000 * USD,
 | 
			
		||||
	"claude-2.0":                 8.0 / 1000 * USD,
 | 
			
		||||
	"claude-2.1":                 8.0 / 1000 * USD,
 | 
			
		||||
	"claude-3-haiku-20240307":    0.25 / 1000 * USD,
 | 
			
		||||
	"claude-3-sonnet-20240229":   3.0 / 1000 * USD,
 | 
			
		||||
	"claude-3-5-sonnet-20240620": 3.0 / 1000 * USD,
 | 
			
		||||
	"claude-3-opus-20240229":     15.0 / 1000 * USD,
 | 
			
		||||
	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
 | 
			
		||||
	"ERNIE-4.0-8K":       0.120 * RMB,
 | 
			
		||||
	"ERNIE-3.5-8K":       0.012 * RMB,
 | 
			
		||||
@@ -124,6 +125,7 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"SparkDesk-v2.1":            1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v3.1":            1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v3.5":            1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v4.0":            1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										12
									
								
								relay/channeltype/url_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								relay/channeltype/url_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,12 @@
 | 
			
		||||
package channeltype
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	. "github.com/smartystreets/goconvey/convey"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestChannelBaseURLs(t *testing.T) {
 | 
			
		||||
	Convey("channel base urls", t, func() {
 | 
			
		||||
		So(len(ChannelBaseURLs), ShouldEqual, Dummy)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
@@ -40,78 +40,6 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
 | 
			
		||||
	return textRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
 | 
			
		||||
	imageRequest := &relaymodel.ImageRequest{}
 | 
			
		||||
	err := common.UnmarshalBodyReusable(c, imageRequest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if imageRequest.N == 0 {
 | 
			
		||||
		imageRequest.N = 1
 | 
			
		||||
	}
 | 
			
		||||
	if imageRequest.Size == "" {
 | 
			
		||||
		imageRequest.Size = "1024x1024"
 | 
			
		||||
	}
 | 
			
		||||
	if imageRequest.Model == "" {
 | 
			
		||||
		imageRequest.Model = "dall-e-2"
 | 
			
		||||
	}
 | 
			
		||||
	return imageRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isValidImageSize(model string, size string) bool {
 | 
			
		||||
	if model == "cogview-3" {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	_, ok := billingratio.ImageSizeRatios[model][size]
 | 
			
		||||
	return ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getImageSizeRatio(model string, size string) float64 {
 | 
			
		||||
	ratio, ok := billingratio.ImageSizeRatios[model][size]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return 1
 | 
			
		||||
	}
 | 
			
		||||
	return ratio
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
 | 
			
		||||
	// model validation
 | 
			
		||||
	hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size)
 | 
			
		||||
	if !hasValidSize {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
	// check prompt length
 | 
			
		||||
	if imageRequest.Prompt == "" {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
	if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
	// Number of generated images validation
 | 
			
		||||
	if !isWithinRange(imageRequest.Model, imageRequest.N) {
 | 
			
		||||
		// channel not azure
 | 
			
		||||
		if meta.ChannelType != channeltype.Azure {
 | 
			
		||||
			return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
 | 
			
		||||
	if imageRequest == nil {
 | 
			
		||||
		return 0, errors.New("imageRequest is nil")
 | 
			
		||||
	}
 | 
			
		||||
	imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
 | 
			
		||||
	if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
 | 
			
		||||
		if imageRequest.Size == "1024x1024" {
 | 
			
		||||
			imageCostRatio *= 2
 | 
			
		||||
		} else {
 | 
			
		||||
			imageCostRatio *= 1.5
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return imageCostRatio, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
 | 
			
		||||
	switch relayMode {
 | 
			
		||||
	case relaymode.ChatCompletions:
 | 
			
		||||
 
 | 
			
		||||
@@ -7,6 +7,7 @@ import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
@@ -20,13 +21,84 @@ import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func isWithinRange(element string, value int) bool {
 | 
			
		||||
	if _, ok := billingratio.ImageGenerationAmounts[element]; !ok {
 | 
			
		||||
		return false
 | 
			
		||||
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
 | 
			
		||||
	imageRequest := &relaymodel.ImageRequest{}
 | 
			
		||||
	err := common.UnmarshalBodyReusable(c, imageRequest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	min := billingratio.ImageGenerationAmounts[element][0]
 | 
			
		||||
	max := billingratio.ImageGenerationAmounts[element][1]
 | 
			
		||||
	return value >= min && value <= max
 | 
			
		||||
	if imageRequest.N == 0 {
 | 
			
		||||
		imageRequest.N = 1
 | 
			
		||||
	}
 | 
			
		||||
	if imageRequest.Size == "" {
 | 
			
		||||
		imageRequest.Size = "1024x1024"
 | 
			
		||||
	}
 | 
			
		||||
	if imageRequest.Model == "" {
 | 
			
		||||
		imageRequest.Model = "dall-e-2"
 | 
			
		||||
	}
 | 
			
		||||
	return imageRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isValidImageSize(model string, size string) bool {
 | 
			
		||||
	if model == "cogview-3" || billingratio.ImageSizeRatios[model] == nil {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	_, ok := billingratio.ImageSizeRatios[model][size]
 | 
			
		||||
	return ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isValidImagePromptLength(model string, promptLength int) bool {
 | 
			
		||||
	maxPromptLength, ok := billingratio.ImagePromptLengthLimitations[model]
 | 
			
		||||
	return !ok || promptLength <= maxPromptLength
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isWithinRange(element string, value int) bool {
 | 
			
		||||
	amounts, ok := billingratio.ImageGenerationAmounts[element]
 | 
			
		||||
	return !ok || (value >= amounts[0] && value <= amounts[1])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getImageSizeRatio(model string, size string) float64 {
 | 
			
		||||
	if ratio, ok := billingratio.ImageSizeRatios[model][size]; ok {
 | 
			
		||||
		return ratio
 | 
			
		||||
	}
 | 
			
		||||
	return 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
 | 
			
		||||
	// check prompt length
 | 
			
		||||
	if imageRequest.Prompt == "" {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// model validation
 | 
			
		||||
	if !isValidImageSize(imageRequest.Model, imageRequest.Size) {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !isValidImagePromptLength(imageRequest.Model, len(imageRequest.Prompt)) {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Number of generated images validation
 | 
			
		||||
	if !isWithinRange(imageRequest.Model, imageRequest.N) {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
 | 
			
		||||
	if imageRequest == nil {
 | 
			
		||||
		return 0, errors.New("imageRequest is nil")
 | 
			
		||||
	}
 | 
			
		||||
	imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
 | 
			
		||||
	if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
 | 
			
		||||
		if imageRequest.Size == "1024x1024" {
 | 
			
		||||
			imageCostRatio *= 2
 | 
			
		||||
		} else {
 | 
			
		||||
			imageCostRatio *= 1.5
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return imageCostRatio, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,11 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
type Message struct {
 | 
			
		||||
	Role      string  `json:"role,omitempty"`
 | 
			
		||||
	Content   any     `json:"content,omitempty"`
 | 
			
		||||
	Name      *string `json:"name,omitempty"`
 | 
			
		||||
	ToolCalls []Tool  `json:"tool_calls,omitempty"`
 | 
			
		||||
	Role       string  `json:"role,omitempty"`
 | 
			
		||||
	Content    any     `json:"content,omitempty"`
 | 
			
		||||
	Name       *string `json:"name,omitempty"`
 | 
			
		||||
	ToolCalls  []Tool  `json:"tool_calls,omitempty"`
 | 
			
		||||
	ToolCallId string  `json:"tool_call_id,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Message) IsStringContent() bool {
 | 
			
		||||
 
 | 
			
		||||
@@ -2,13 +2,13 @@ package model
 | 
			
		||||
 | 
			
		||||
type Tool struct {
 | 
			
		||||
	Id       string   `json:"id,omitempty"`
 | 
			
		||||
	Type     string   `json:"type"`
 | 
			
		||||
	Type     string   `json:"type,omitempty"` // when splicing claude tools stream messages, it is empty
 | 
			
		||||
	Function Function `json:"function"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Function struct {
 | 
			
		||||
	Description string `json:"description,omitempty"`
 | 
			
		||||
	Name        string `json:"name"`
 | 
			
		||||
	Name        string `json:"name,omitempty"`       // when splicing claude tools stream messages, it is empty
 | 
			
		||||
	Parameters  any    `json:"parameters,omitempty"` // request
 | 
			
		||||
	Arguments   any    `json:"arguments,omitempty"`  // response
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -63,7 +63,7 @@ const EditChannel = (props) => {
 | 
			
		||||
            let localModels = [];
 | 
			
		||||
            switch (value) {
 | 
			
		||||
                case 14:
 | 
			
		||||
                    localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"];
 | 
			
		||||
                    localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20240620"];
 | 
			
		||||
                    break;
 | 
			
		||||
                case 11:
 | 
			
		||||
                    localModels = ['PaLM-2'];
 | 
			
		||||
@@ -78,7 +78,7 @@ const EditChannel = (props) => {
 | 
			
		||||
                    localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
 | 
			
		||||
                    break;
 | 
			
		||||
                case 18:
 | 
			
		||||
                    localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5'];
 | 
			
		||||
                    localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0'];
 | 
			
		||||
                    break;
 | 
			
		||||
                case 19:
 | 
			
		||||
                    localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];
 | 
			
		||||
 
 | 
			
		||||
@@ -91,7 +91,7 @@ const typeConfig = {
 | 
			
		||||
      other: '版本号'
 | 
			
		||||
    },
 | 
			
		||||
    input: {
 | 
			
		||||
      models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5']
 | 
			
		||||
      models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0']
 | 
			
		||||
    },
 | 
			
		||||
    prompt: {
 | 
			
		||||
      key: '按照如下格式输入:APPID|APISecret|APIKey',
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user