mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 07:43:41 +08:00 
			
		
		
		
	Compare commits
	
		
			26 Commits
		
	
	
		
			v0.5.0-alp
			...
			v0.5.2-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					c9d2e42a9e | ||
| 
						 | 
					3fca6ff534 | ||
| 
						 | 
					8cbbeb784f | ||
| 
						 | 
					ec88c0c240 | ||
| 
						 | 
					065147b440 | ||
| 
						 | 
					fe8f216dd9 | ||
| 
						 | 
					b7d0616ae0 | ||
| 
						 | 
					ce9c8024a6 | ||
| 
						 | 
					8a866078b2 | ||
| 
						 | 
					3e81d8af45 | ||
| 
						 | 
					b8cb86c2c1 | ||
| 
						 | 
					f45d586400 | ||
| 
						 | 
					50dec03ff3 | ||
| 
						 | 
					f31d400b6f | ||
| 
						 | 
					130e6bfd83 | ||
| 
						 | 
					d1335ebc01 | ||
| 
						 | 
					e92da7928b | ||
| 
						 | 
					d1b6f492b6 | ||
| 
						 | 
					b9f6461dd4 | ||
| 
						 | 
					0a39521a3d | ||
| 
						 | 
					c134604cee | ||
| 
						 | 
					929e43ef81 | ||
| 
						 | 
					dce8bbe1ca | ||
| 
						 | 
					bc2f48b1f2 | ||
| 
						 | 
					889af8b2db | ||
| 
						 | 
					4eea096654 | 
							
								
								
									
										16
									
								
								README.en.md
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								README.en.md
									
									
									
									
									
								
							@@ -57,15 +57,13 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use
 | 
			
		||||
> **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability.
 | 
			
		||||
 | 
			
		||||
## Features
 | 
			
		||||
1. Supports multiple API access channels:
 | 
			
		||||
    + [x] Official OpenAI channel (support proxy configuration)
 | 
			
		||||
    + [x] **Azure OpenAI API**
 | 
			
		||||
    + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)
 | 
			
		||||
    + [x] [OpenAI-SB](https://openai-sb.com)
 | 
			
		||||
    + [x] [API2D](https://api2d.com/r/197971)
 | 
			
		||||
    + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
 | 
			
		||||
    + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (invitation code: `OneAPI`)
 | 
			
		||||
    + [x] Custom channel: Various third-party proxy services not included in the list
 | 
			
		||||
1. Support for multiple large models:
 | 
			
		||||
   + [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
 | 
			
		||||
   + [x] [Anthropic Claude Series Models](https://anthropic.com)
 | 
			
		||||
   + [x] [Google PaLM2 Series Models](https://developers.generativeai.google)
 | 
			
		||||
   + [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
 | 
			
		||||
   + [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html)
 | 
			
		||||
   + [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn)
 | 
			
		||||
2. Supports access to multiple channels through **load balancing**.
 | 
			
		||||
3. Supports **stream mode** that enables typewriter-like effect through stream transmission.
 | 
			
		||||
4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details.
 | 
			
		||||
 
 | 
			
		||||
@@ -63,9 +63,10 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
   + [x] [Anthropic Claude 系列模型](https://anthropic.com)
 | 
			
		||||
   + [x] [Google PaLM2 系列模型](https://developers.generativeai.google)
 | 
			
		||||
   + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
 | 
			
		||||
   + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
 | 
			
		||||
   + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
 | 
			
		||||
   + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
 | 
			
		||||
2. 支持配置镜像以及众多第三方代理服务:
 | 
			
		||||
   + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)
 | 
			
		||||
   + [x] [OpenAI-SB](https://openai-sb.com)
 | 
			
		||||
   + [x] [API2D](https://api2d.com/r/197971)
 | 
			
		||||
   + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
 | 
			
		||||
@@ -93,7 +94,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
19. 支持通过系统访问令牌访问管理 API。
 | 
			
		||||
20. 支持 Cloudflare Turnstile 用户校验。
 | 
			
		||||
21. 支持用户管理,支持**多种用户登录注册方式**:
 | 
			
		||||
    + 邮箱登录注册以及通过邮箱进行密码重置。
 | 
			
		||||
    + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
 | 
			
		||||
    + [GitHub 开放授权](https://github.com/settings/applications/new)。
 | 
			
		||||
    + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -42,6 +42,19 @@ var WeChatAuthEnabled = false
 | 
			
		||||
var TurnstileCheckEnabled = false
 | 
			
		||||
var RegisterEnabled = true
 | 
			
		||||
 | 
			
		||||
var EmailDomainRestrictionEnabled = false
 | 
			
		||||
var EmailDomainWhitelist = []string{
 | 
			
		||||
	"gmail.com",
 | 
			
		||||
	"163.com",
 | 
			
		||||
	"126.com",
 | 
			
		||||
	"qq.com",
 | 
			
		||||
	"outlook.com",
 | 
			
		||||
	"hotmail.com",
 | 
			
		||||
	"icloud.com",
 | 
			
		||||
	"yahoo.com",
 | 
			
		||||
	"foxmail.com",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var LogConsumeEnabled = true
 | 
			
		||||
 | 
			
		||||
var SMTPServer = ""
 | 
			
		||||
@@ -77,6 +90,8 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
 | 
			
		||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
 | 
			
		||||
var RequestInterval = time.Duration(requestInterval) * time.Second
 | 
			
		||||
 | 
			
		||||
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RoleGuestUser  = 0
 | 
			
		||||
	RoleCommonUser = 1
 | 
			
		||||
@@ -154,24 +169,28 @@ const (
 | 
			
		||||
	ChannelTypeAnthropic = 14
 | 
			
		||||
	ChannelTypeBaidu     = 15
 | 
			
		||||
	ChannelTypeZhipu     = 16
 | 
			
		||||
	ChannelTypeAli       = 17
 | 
			
		||||
	ChannelTypeXunfei    = 18
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ChannelBaseURLs = []string{
 | 
			
		||||
	"",                              // 0
 | 
			
		||||
	"https://api.openai.com",        // 1
 | 
			
		||||
	"https://oa.api2d.net",          // 2
 | 
			
		||||
	"",                              // 3
 | 
			
		||||
	"https://api.closeai-proxy.xyz", // 4
 | 
			
		||||
	"https://api.openai-sb.com",     // 5
 | 
			
		||||
	"https://api.openaimax.com",     // 6
 | 
			
		||||
	"https://api.ohmygpt.com",       // 7
 | 
			
		||||
	"",                              // 8
 | 
			
		||||
	"https://api.caipacity.com",     // 9
 | 
			
		||||
	"https://api.aiproxy.io",        // 10
 | 
			
		||||
	"",                              // 11
 | 
			
		||||
	"https://api.api2gpt.com",       // 12
 | 
			
		||||
	"https://api.aigc2d.com",        // 13
 | 
			
		||||
	"https://api.anthropic.com",     // 14
 | 
			
		||||
	"https://aip.baidubce.com",      // 15
 | 
			
		||||
	"https://open.bigmodel.cn",      // 16
 | 
			
		||||
	"",                               // 0
 | 
			
		||||
	"https://api.openai.com",         // 1
 | 
			
		||||
	"https://oa.api2d.net",           // 2
 | 
			
		||||
	"",                               // 3
 | 
			
		||||
	"https://api.closeai-proxy.xyz",  // 4
 | 
			
		||||
	"https://api.openai-sb.com",      // 5
 | 
			
		||||
	"https://api.openaimax.com",      // 6
 | 
			
		||||
	"https://api.ohmygpt.com",        // 7
 | 
			
		||||
	"",                               // 8
 | 
			
		||||
	"https://api.caipacity.com",      // 9
 | 
			
		||||
	"https://api.aiproxy.io",         // 10
 | 
			
		||||
	"",                               // 11
 | 
			
		||||
	"https://api.api2gpt.com",        // 12
 | 
			
		||||
	"https://api.aigc2d.com",         // 13
 | 
			
		||||
	"https://api.anthropic.com",      // 14
 | 
			
		||||
	"https://aip.baidubce.com",       // 15
 | 
			
		||||
	"https://open.bigmodel.cn",       // 16
 | 
			
		||||
	"https://dashscope.aliyuncs.com", // 17
 | 
			
		||||
	"",                               // 18
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -42,10 +42,14 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"claude-2":                30,
 | 
			
		||||
	"ERNIE-Bot":               0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"ERNIE-Bot-turbo":         0.5715, // ¥0.008 / 1k tokens
 | 
			
		||||
	"Embedding-V1":            0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"PaLM-2":                  1,
 | 
			
		||||
	"chatglm_pro":             0.7143, // ¥0.01 / 1k tokens
 | 
			
		||||
	"chatglm_std":             0.3572, // ¥0.005 / 1k tokens
 | 
			
		||||
	"chatglm_lite":            0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"qwen-v1":                 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag
 | 
			
		||||
	"qwen-plus-v1":            0.5715, // Same as above
 | 
			
		||||
	"SparkDesk":               0.8572, // TBD
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ModelRatio2JSONString() string {
 | 
			
		||||
 
 | 
			
		||||
@@ -16,6 +16,16 @@ import (
 | 
			
		||||
 | 
			
		||||
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
 | 
			
		||||
	switch channel.Type {
 | 
			
		||||
	case common.ChannelTypePaLM:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case common.ChannelTypeAnthropic:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case common.ChannelTypeBaidu:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case common.ChannelTypeZhipu:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case common.ChannelTypeXunfei:
 | 
			
		||||
		return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
 | 
			
		||||
	case common.ChannelTypeAzure:
 | 
			
		||||
		request.Model = "gpt-35-turbo"
 | 
			
		||||
	default:
 | 
			
		||||
 
 | 
			
		||||
@@ -3,10 +3,12 @@ package controller
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetStatus(c *gin.Context) {
 | 
			
		||||
@@ -78,6 +80,22 @@ func SendEmailVerification(c *gin.Context) {
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if common.EmailDomainRestrictionEnabled {
 | 
			
		||||
		allowed := false
 | 
			
		||||
		for _, domain := range common.EmailDomainWhitelist {
 | 
			
		||||
			if strings.HasSuffix(email, "@"+domain) {
 | 
			
		||||
				allowed = true
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if !allowed {
 | 
			
		||||
			c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
				"success": false,
 | 
			
		||||
				"message": "管理员启用了邮箱域名白名单,您的邮箱地址的域名不在白名单中",
 | 
			
		||||
			})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if model.IsEmailAlreadyTaken(email) {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
 
 | 
			
		||||
@@ -288,6 +288,15 @@ func init() {
 | 
			
		||||
			Root:       "ERNIE-Bot-turbo",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "Embedding-V1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "baidu",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "Embedding-V1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "PaLM-2",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
@@ -324,6 +333,33 @@ func init() {
 | 
			
		||||
			Root:       "chatglm_lite",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "qwen-v1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "ali",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "qwen-v1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "qwen-plus-v1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "ali",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "qwen-plus-v1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "SparkDesk",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "xunfei",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "SparkDesk",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	openAIModelsMap = make(map[string]OpenAIModels)
 | 
			
		||||
	for _, model := range openAIModels {
 | 
			
		||||
 
 | 
			
		||||
@@ -2,11 +2,12 @@ package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetOptions(c *gin.Context) {
 | 
			
		||||
@@ -49,6 +50,14 @@ func UpdateOption(c *gin.Context) {
 | 
			
		||||
			})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	case "EmailDomainRestrictionEnabled":
 | 
			
		||||
		if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
 | 
			
		||||
			c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
				"success": false,
 | 
			
		||||
				"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
 | 
			
		||||
			})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	case "WeChatAuthEnabled":
 | 
			
		||||
		if option.Value == "true" && common.WeChatServerAddress == "" {
 | 
			
		||||
			c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										240
									
								
								controller/relay-ali.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										240
									
								
								controller/relay-ali.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,240 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
 | 
			
		||||
 | 
			
		||||
type AliMessage struct {
 | 
			
		||||
	User string `json:"user"`
 | 
			
		||||
	Bot  string `json:"bot"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliInput struct {
 | 
			
		||||
	Prompt  string       `json:"prompt"`
 | 
			
		||||
	History []AliMessage `json:"history"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliParameters struct {
 | 
			
		||||
	TopP         float64 `json:"top_p,omitempty"`
 | 
			
		||||
	TopK         int     `json:"top_k,omitempty"`
 | 
			
		||||
	Seed         uint64  `json:"seed,omitempty"`
 | 
			
		||||
	EnableSearch bool    `json:"enable_search,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliChatRequest struct {
 | 
			
		||||
	Model      string        `json:"model"`
 | 
			
		||||
	Input      AliInput      `json:"input"`
 | 
			
		||||
	Parameters AliParameters `json:"parameters,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliError struct {
 | 
			
		||||
	Code      string `json:"code"`
 | 
			
		||||
	Message   string `json:"message"`
 | 
			
		||||
	RequestId string `json:"request_id"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliUsage struct {
 | 
			
		||||
	InputTokens  int `json:"input_tokens"`
 | 
			
		||||
	OutputTokens int `json:"output_tokens"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliOutput struct {
 | 
			
		||||
	Text         string `json:"text"`
 | 
			
		||||
	FinishReason string `json:"finish_reason"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliChatResponse struct {
 | 
			
		||||
	Output AliOutput `json:"output"`
 | 
			
		||||
	Usage  AliUsage  `json:"usage"`
 | 
			
		||||
	AliError
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
 | 
			
		||||
	messages := make([]AliMessage, 0, len(request.Messages))
 | 
			
		||||
	prompt := ""
 | 
			
		||||
	for i := 0; i < len(request.Messages); i++ {
 | 
			
		||||
		message := request.Messages[i]
 | 
			
		||||
		if message.Role == "system" {
 | 
			
		||||
			messages = append(messages, AliMessage{
 | 
			
		||||
				User: message.Content,
 | 
			
		||||
				Bot:  "Okay",
 | 
			
		||||
			})
 | 
			
		||||
			continue
 | 
			
		||||
		} else {
 | 
			
		||||
			if i == len(request.Messages)-1 {
 | 
			
		||||
				prompt = message.Content
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
			messages = append(messages, AliMessage{
 | 
			
		||||
				User: message.Content,
 | 
			
		||||
				Bot:  request.Messages[i+1].Content,
 | 
			
		||||
			})
 | 
			
		||||
			i++
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return &AliChatRequest{
 | 
			
		||||
		Model: request.Model,
 | 
			
		||||
		Input: AliInput{
 | 
			
		||||
			Prompt:  prompt,
 | 
			
		||||
			History: messages,
 | 
			
		||||
		},
 | 
			
		||||
		//Parameters: AliParameters{  // ChatGPT's parameters are not compatible with Ali's
 | 
			
		||||
		//	TopP: request.TopP,
 | 
			
		||||
		//	TopK: 50,
 | 
			
		||||
		//	//Seed:         0,
 | 
			
		||||
		//	//EnableSearch: false,
 | 
			
		||||
		//},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
 | 
			
		||||
	choice := OpenAITextResponseChoice{
 | 
			
		||||
		Index: 0,
 | 
			
		||||
		Message: Message{
 | 
			
		||||
			Role:    "assistant",
 | 
			
		||||
			Content: response.Output.Text,
 | 
			
		||||
		},
 | 
			
		||||
		FinishReason: response.Output.FinishReason,
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := OpenAITextResponse{
 | 
			
		||||
		Id:      response.RequestId,
 | 
			
		||||
		Object:  "chat.completion",
 | 
			
		||||
		Created: common.GetTimestamp(),
 | 
			
		||||
		Choices: []OpenAITextResponseChoice{choice},
 | 
			
		||||
		Usage: Usage{
 | 
			
		||||
			PromptTokens:     response.Usage.InputTokens,
 | 
			
		||||
			CompletionTokens: response.Usage.OutputTokens,
 | 
			
		||||
			TotalTokens:      response.Usage.InputTokens + response.Usage.OutputTokens,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	return &fullTextResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
 | 
			
		||||
	var choice ChatCompletionsStreamResponseChoice
 | 
			
		||||
	choice.Delta.Content = aliResponse.Output.Text
 | 
			
		||||
	choice.FinishReason = aliResponse.Output.FinishReason
 | 
			
		||||
	response := ChatCompletionsStreamResponse{
 | 
			
		||||
		Id:      aliResponse.RequestId,
 | 
			
		||||
		Object:  "chat.completion.chunk",
 | 
			
		||||
		Created: common.GetTimestamp(),
 | 
			
		||||
		Model:   "ernie-bot",
 | 
			
		||||
		Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
			
		||||
	}
 | 
			
		||||
	return &response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var usage Usage
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 5 { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if data[:5] != "data:" {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = data[5:]
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
	c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
	c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
	c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
	lastResponseText := ""
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			var aliResponse AliChatResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &aliResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			usage.PromptTokens += aliResponse.Usage.InputTokens
 | 
			
		||||
			usage.CompletionTokens += aliResponse.Usage.OutputTokens
 | 
			
		||||
			usage.TotalTokens += aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
 | 
			
		||||
			response := streamResponseAli2OpenAI(&aliResponse)
 | 
			
		||||
			response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
 | 
			
		||||
			lastResponseText = aliResponse.Output.Text
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var aliResponse AliChatResponse
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &aliResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	if aliResponse.Code != "" {
 | 
			
		||||
		return &OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: OpenAIError{
 | 
			
		||||
				Message: aliResponse.Message,
 | 
			
		||||
				Type:    aliResponse.Code,
 | 
			
		||||
				Param:   aliResponse.RequestId,
 | 
			
		||||
				Code:    aliResponse.Code,
 | 
			
		||||
			},
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := responseAli2OpenAI(&aliResponse)
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
@@ -54,13 +54,43 @@ type BaiduChatStreamResponse struct {
 | 
			
		||||
	IsEnd      bool `json:"is_end"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type BaiduEmbeddingRequest struct {
 | 
			
		||||
	Input []string `json:"input"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type BaiduEmbeddingData struct {
 | 
			
		||||
	Object    string    `json:"object"`
 | 
			
		||||
	Embedding []float64 `json:"embedding"`
 | 
			
		||||
	Index     int       `json:"index"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type BaiduEmbeddingResponse struct {
 | 
			
		||||
	Id      string               `json:"id"`
 | 
			
		||||
	Object  string               `json:"object"`
 | 
			
		||||
	Created int64                `json:"created"`
 | 
			
		||||
	Data    []BaiduEmbeddingData `json:"data"`
 | 
			
		||||
	Usage   Usage                `json:"usage"`
 | 
			
		||||
	BaiduError
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
 | 
			
		||||
	messages := make([]BaiduMessage, 0, len(request.Messages))
 | 
			
		||||
	for _, message := range request.Messages {
 | 
			
		||||
		messages = append(messages, BaiduMessage{
 | 
			
		||||
			Role:    message.Role,
 | 
			
		||||
			Content: message.Content,
 | 
			
		||||
		})
 | 
			
		||||
		if message.Role == "system" {
 | 
			
		||||
			messages = append(messages, BaiduMessage{
 | 
			
		||||
				Role:    "user",
 | 
			
		||||
				Content: message.Content,
 | 
			
		||||
			})
 | 
			
		||||
			messages = append(messages, BaiduMessage{
 | 
			
		||||
				Role:    "assistant",
 | 
			
		||||
				Content: "Okay",
 | 
			
		||||
			})
 | 
			
		||||
		} else {
 | 
			
		||||
			messages = append(messages, BaiduMessage{
 | 
			
		||||
				Role:    message.Role,
 | 
			
		||||
				Content: message.Content,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return &BaiduChatRequest{
 | 
			
		||||
		Messages: messages,
 | 
			
		||||
@@ -101,6 +131,36 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom
 | 
			
		||||
	return &response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
 | 
			
		||||
	baiduEmbeddingRequest := BaiduEmbeddingRequest{
 | 
			
		||||
		Input: nil,
 | 
			
		||||
	}
 | 
			
		||||
	switch request.Input.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		baiduEmbeddingRequest.Input = []string{request.Input.(string)}
 | 
			
		||||
	case []string:
 | 
			
		||||
		baiduEmbeddingRequest.Input = request.Input.([]string)
 | 
			
		||||
	}
 | 
			
		||||
	return &baiduEmbeddingRequest
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
 | 
			
		||||
	openAIEmbeddingResponse := OpenAIEmbeddingResponse{
 | 
			
		||||
		Object: "list",
 | 
			
		||||
		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
 | 
			
		||||
		Model:  "baidu-embedding",
 | 
			
		||||
		Usage:  response.Usage,
 | 
			
		||||
	}
 | 
			
		||||
	for _, item := range response.Data {
 | 
			
		||||
		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
 | 
			
		||||
			Object:    item.Object,
 | 
			
		||||
			Index:     item.Index,
 | 
			
		||||
			Embedding: item.Embedding,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	return &openAIEmbeddingResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var usage Usage
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
@@ -201,3 +261,39 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var baiduResponse BaiduEmbeddingResponse
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &baiduResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	if baiduResponse.ErrorMsg != "" {
 | 
			
		||||
		return &OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: OpenAIError{
 | 
			
		||||
				Message: baiduResponse.ErrorMsg,
 | 
			
		||||
				Type:    "baidu_error",
 | 
			
		||||
				Param:   "",
 | 
			
		||||
				Code:    baiduResponse.ErrorCode,
 | 
			
		||||
			},
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -69,11 +69,11 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
 | 
			
		||||
			prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
 | 
			
		||||
		} else if message.Role == "assistant" {
 | 
			
		||||
			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
 | 
			
		||||
		} else {
 | 
			
		||||
			// ignore other roles
 | 
			
		||||
		} else if message.Role == "system" {
 | 
			
		||||
			prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
 | 
			
		||||
		}
 | 
			
		||||
		prompt += "\n\nAssistant:"
 | 
			
		||||
	}
 | 
			
		||||
	prompt += "\n\nAssistant:"
 | 
			
		||||
	claudeRequest.Prompt = prompt
 | 
			
		||||
	return &claudeRequest
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -34,6 +34,9 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 | 
			
		||||
			if len(data) < 6 { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if data[:6] != "data: " && data[:6] != "[DONE]" {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
			data = data[6:]
 | 
			
		||||
			if !strings.HasPrefix(data, "[DONE]") {
 | 
			
		||||
@@ -43,7 +46,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 | 
			
		||||
					err := json.Unmarshal([]byte(data), &streamResponse)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
						return
 | 
			
		||||
						continue // just ignore the error
 | 
			
		||||
					}
 | 
			
		||||
					for _, choice := range streamResponse.Choices {
 | 
			
		||||
						responseText += choice.Delta.Content
 | 
			
		||||
@@ -53,7 +56,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 | 
			
		||||
					err := json.Unmarshal([]byte(data), &streamResponse)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
						return
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
					for _, choice := range streamResponse.Choices {
 | 
			
		||||
						responseText += choice.Text
 | 
			
		||||
 
 | 
			
		||||
@@ -20,6 +20,8 @@ const (
 | 
			
		||||
	APITypePaLM
 | 
			
		||||
	APITypeBaidu
 | 
			
		||||
	APITypeZhipu
 | 
			
		||||
	APITypeAli
 | 
			
		||||
	APITypeXunfei
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var httpClient *http.Client
 | 
			
		||||
@@ -73,7 +75,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	// map model name
 | 
			
		||||
	modelMapping := c.GetString("model_mapping")
 | 
			
		||||
	isModelMapped := false
 | 
			
		||||
	if modelMapping != "" {
 | 
			
		||||
	if modelMapping != "" && modelMapping != "{}" {
 | 
			
		||||
		modelMap := make(map[string]string)
 | 
			
		||||
		err := json.Unmarshal([]byte(modelMapping), &modelMap)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@@ -85,14 +87,19 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	apiType := APITypeOpenAI
 | 
			
		||||
	if strings.HasPrefix(textRequest.Model, "claude") {
 | 
			
		||||
	switch channelType {
 | 
			
		||||
	case common.ChannelTypeAnthropic:
 | 
			
		||||
		apiType = APITypeClaude
 | 
			
		||||
	} else if strings.HasPrefix(textRequest.Model, "ERNIE") {
 | 
			
		||||
	case common.ChannelTypeBaidu:
 | 
			
		||||
		apiType = APITypeBaidu
 | 
			
		||||
	} else if strings.HasPrefix(textRequest.Model, "PaLM") {
 | 
			
		||||
	case common.ChannelTypePaLM:
 | 
			
		||||
		apiType = APITypePaLM
 | 
			
		||||
	} else if strings.HasPrefix(textRequest.Model, "chatglm_") {
 | 
			
		||||
	case common.ChannelTypeZhipu:
 | 
			
		||||
		apiType = APITypeZhipu
 | 
			
		||||
	case common.ChannelTypeAli:
 | 
			
		||||
		apiType = APITypeAli
 | 
			
		||||
	case common.ChannelTypeXunfei:
 | 
			
		||||
		apiType = APITypeXunfei
 | 
			
		||||
	}
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channelType]
 | 
			
		||||
	requestURL := c.Request.URL.String()
 | 
			
		||||
@@ -134,12 +141,17 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
 | 
			
		||||
		case "BLOOMZ-7B":
 | 
			
		||||
			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
 | 
			
		||||
		case "Embedding-V1":
 | 
			
		||||
			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
 | 
			
		||||
		}
 | 
			
		||||
		apiKey := c.Request.Header.Get("Authorization")
 | 
			
		||||
		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
			
		||||
		fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days
 | 
			
		||||
	case APITypePaLM:
 | 
			
		||||
		fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
 | 
			
		||||
		if baseURL != "" {
 | 
			
		||||
			fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
 | 
			
		||||
		}
 | 
			
		||||
		apiKey := c.Request.Header.Get("Authorization")
 | 
			
		||||
		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
			
		||||
		fullRequestURL += "?key=" + apiKey
 | 
			
		||||
@@ -149,6 +161,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			method = "sse-invoke"
 | 
			
		||||
		}
 | 
			
		||||
		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
 | 
			
		||||
	case APITypeAli:
 | 
			
		||||
		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
 | 
			
		||||
	}
 | 
			
		||||
	var promptTokens int
 | 
			
		||||
	var completionTokens int
 | 
			
		||||
@@ -202,12 +216,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	case APITypeBaidu:
 | 
			
		||||
		baiduRequest := requestOpenAI2Baidu(textRequest)
 | 
			
		||||
		jsonStr, err := json.Marshal(baiduRequest)
 | 
			
		||||
		var jsonData []byte
 | 
			
		||||
		var err error
 | 
			
		||||
		switch relayMode {
 | 
			
		||||
		case RelayModeEmbeddings:
 | 
			
		||||
			baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
 | 
			
		||||
			jsonData, err = json.Marshal(baiduEmbeddingRequest)
 | 
			
		||||
		default:
 | 
			
		||||
			baiduRequest := requestOpenAI2Baidu(textRequest)
 | 
			
		||||
			jsonData, err = json.Marshal(baiduRequest)
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonData)
 | 
			
		||||
	case APITypePaLM:
 | 
			
		||||
		palmRequest := requestOpenAI2PaLM(textRequest)
 | 
			
		||||
		jsonStr, err := json.Marshal(palmRequest)
 | 
			
		||||
@@ -222,49 +244,68 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	}
 | 
			
		||||
	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	apiKey := c.Request.Header.Get("Authorization")
 | 
			
		||||
	apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
			
		||||
	switch apiType {
 | 
			
		||||
	case APITypeOpenAI:
 | 
			
		||||
		if channelType == common.ChannelTypeAzure {
 | 
			
		||||
			req.Header.Set("api-key", apiKey)
 | 
			
		||||
		} else {
 | 
			
		||||
			req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
 | 
			
		||||
	case APITypeAli:
 | 
			
		||||
		aliRequest := requestOpenAI2Ali(textRequest)
 | 
			
		||||
		jsonStr, err := json.Marshal(aliRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeClaude:
 | 
			
		||||
		req.Header.Set("x-api-key", apiKey)
 | 
			
		||||
		anthropicVersion := c.Request.Header.Get("anthropic-version")
 | 
			
		||||
		if anthropicVersion == "" {
 | 
			
		||||
			anthropicVersion = "2023-06-01"
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var req *http.Request
 | 
			
		||||
	var resp *http.Response
 | 
			
		||||
	isStream := textRequest.Stream
 | 
			
		||||
 | 
			
		||||
	if apiType != APITypeXunfei { // cause xunfei use websocket
 | 
			
		||||
		req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		req.Header.Set("anthropic-version", anthropicVersion)
 | 
			
		||||
	case APITypeZhipu:
 | 
			
		||||
		token := getZhipuToken(apiKey)
 | 
			
		||||
		req.Header.Set("Authorization", token)
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
			
		||||
	//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
 | 
			
		||||
	resp, err := httpClient.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	err = req.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	err = c.Request.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		apiKey := c.Request.Header.Get("Authorization")
 | 
			
		||||
		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
			
		||||
		switch apiType {
 | 
			
		||||
		case APITypeOpenAI:
 | 
			
		||||
			if channelType == common.ChannelTypeAzure {
 | 
			
		||||
				req.Header.Set("api-key", apiKey)
 | 
			
		||||
			} else {
 | 
			
		||||
				req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
 | 
			
		||||
			}
 | 
			
		||||
		case APITypeClaude:
 | 
			
		||||
			req.Header.Set("x-api-key", apiKey)
 | 
			
		||||
			anthropicVersion := c.Request.Header.Get("anthropic-version")
 | 
			
		||||
			if anthropicVersion == "" {
 | 
			
		||||
				anthropicVersion = "2023-06-01"
 | 
			
		||||
			}
 | 
			
		||||
			req.Header.Set("anthropic-version", anthropicVersion)
 | 
			
		||||
		case APITypeZhipu:
 | 
			
		||||
			token := getZhipuToken(apiKey)
 | 
			
		||||
			req.Header.Set("Authorization", token)
 | 
			
		||||
		case APITypeAli:
 | 
			
		||||
			req.Header.Set("Authorization", "Bearer "+apiKey)
 | 
			
		||||
			if textRequest.Stream {
 | 
			
		||||
				req.Header.Set("X-DashScope-SSE", "enable")
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
		req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
			
		||||
		//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
 | 
			
		||||
		resp, err = httpClient.Do(req)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		err = req.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		err = c.Request.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		isStream = strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var textResponse TextResponse
 | 
			
		||||
	isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 | 
			
		||||
	var streamResponseText string
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if consumeQuota {
 | 
			
		||||
@@ -276,16 +317,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			if strings.HasPrefix(textRequest.Model, "gpt-4") {
 | 
			
		||||
				completionRatio = 2
 | 
			
		||||
			}
 | 
			
		||||
			if isStream && apiType != APITypeBaidu && apiType != APITypeZhipu {
 | 
			
		||||
				completionTokens = countTokenText(streamResponseText, textRequest.Model)
 | 
			
		||||
			} else {
 | 
			
		||||
				promptTokens = textResponse.Usage.PromptTokens
 | 
			
		||||
				completionTokens = textResponse.Usage.CompletionTokens
 | 
			
		||||
				if apiType == APITypeZhipu {
 | 
			
		||||
					// zhipu's API does not return prompt tokens & completion tokens
 | 
			
		||||
					promptTokens = textResponse.Usage.TotalTokens
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			promptTokens = textResponse.Usage.PromptTokens
 | 
			
		||||
			completionTokens = textResponse.Usage.CompletionTokens
 | 
			
		||||
 | 
			
		||||
			quota = promptTokens + int(float64(completionTokens)*completionRatio)
 | 
			
		||||
			quota = int(float64(quota) * ratio)
 | 
			
		||||
			if ratio != 0 && quota <= 0 {
 | 
			
		||||
@@ -323,7 +358,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			streamResponseText = responseText
 | 
			
		||||
			textResponse.Usage.PromptTokens = promptTokens
 | 
			
		||||
			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage := openaiHandler(c, resp, consumeQuota)
 | 
			
		||||
@@ -341,7 +377,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			streamResponseText = responseText
 | 
			
		||||
			textResponse.Usage.PromptTokens = promptTokens
 | 
			
		||||
			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
 | 
			
		||||
@@ -364,7 +401,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage := baiduHandler(c, resp)
 | 
			
		||||
			var err *OpenAIErrorWithStatusCode
 | 
			
		||||
			var usage *Usage
 | 
			
		||||
			switch relayMode {
 | 
			
		||||
			case RelayModeEmbeddings:
 | 
			
		||||
				err, usage = baiduEmbeddingHandler(c, resp)
 | 
			
		||||
			default:
 | 
			
		||||
				err, usage = baiduHandler(c, resp)
 | 
			
		||||
			}
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
@@ -379,7 +423,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			streamResponseText = responseText
 | 
			
		||||
			textResponse.Usage.PromptTokens = promptTokens
 | 
			
		||||
			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
 | 
			
		||||
@@ -400,6 +445,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			if usage != nil {
 | 
			
		||||
				textResponse.Usage = *usage
 | 
			
		||||
			}
 | 
			
		||||
			// zhipu's API does not return prompt tokens & completion tokens
 | 
			
		||||
			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage := zhipuHandler(c, resp)
 | 
			
		||||
@@ -409,8 +456,49 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			if usage != nil {
 | 
			
		||||
				textResponse.Usage = *usage
 | 
			
		||||
			}
 | 
			
		||||
			// zhipu's API does not return prompt tokens & completion tokens
 | 
			
		||||
			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeAli:
 | 
			
		||||
		if isStream {
 | 
			
		||||
			err, usage := aliStreamHandler(c, resp)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			if usage != nil {
 | 
			
		||||
				textResponse.Usage = *usage
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage := aliHandler(c, resp)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			if usage != nil {
 | 
			
		||||
				textResponse.Usage = *usage
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeXunfei:
 | 
			
		||||
		if isStream {
 | 
			
		||||
			auth := c.Request.Header.Get("Authorization")
 | 
			
		||||
			auth = strings.TrimPrefix(auth, "Bearer ")
 | 
			
		||||
			splits := strings.Split(auth, "|")
 | 
			
		||||
			if len(splits) != 3 {
 | 
			
		||||
				return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
 | 
			
		||||
			}
 | 
			
		||||
			err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			if usage != nil {
 | 
			
		||||
				textResponse.Usage = *usage
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										278
									
								
								controller/relay-xunfei.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										278
									
								
								controller/relay-xunfei.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,278 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/hmac"
 | 
			
		||||
	"crypto/sha256"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://console.xfyun.cn/services/cbm
 | 
			
		||||
// https://www.xfyun.cn/doc/spark/Web.html
 | 
			
		||||
 | 
			
		||||
type XunfeiMessage struct {
 | 
			
		||||
	Role    string `json:"role"`
 | 
			
		||||
	Content string `json:"content"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type XunfeiChatRequest struct {
 | 
			
		||||
	Header struct {
 | 
			
		||||
		AppId string `json:"app_id"`
 | 
			
		||||
	} `json:"header"`
 | 
			
		||||
	Parameter struct {
 | 
			
		||||
		Chat struct {
 | 
			
		||||
			Domain      string  `json:"domain,omitempty"`
 | 
			
		||||
			Temperature float64 `json:"temperature,omitempty"`
 | 
			
		||||
			TopK        int     `json:"top_k,omitempty"`
 | 
			
		||||
			MaxTokens   int     `json:"max_tokens,omitempty"`
 | 
			
		||||
			Auditing    bool    `json:"auditing,omitempty"`
 | 
			
		||||
		} `json:"chat"`
 | 
			
		||||
	} `json:"parameter"`
 | 
			
		||||
	Payload struct {
 | 
			
		||||
		Message struct {
 | 
			
		||||
			Text []XunfeiMessage `json:"text"`
 | 
			
		||||
		} `json:"message"`
 | 
			
		||||
	} `json:"payload"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type XunfeiChatResponseTextItem struct {
 | 
			
		||||
	Content string `json:"content"`
 | 
			
		||||
	Role    string `json:"role"`
 | 
			
		||||
	Index   int    `json:"index"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type XunfeiChatResponse struct {
 | 
			
		||||
	Header struct {
 | 
			
		||||
		Code    int    `json:"code"`
 | 
			
		||||
		Message string `json:"message"`
 | 
			
		||||
		Sid     string `json:"sid"`
 | 
			
		||||
		Status  int    `json:"status"`
 | 
			
		||||
	} `json:"header"`
 | 
			
		||||
	Payload struct {
 | 
			
		||||
		Choices struct {
 | 
			
		||||
			Status int                          `json:"status"`
 | 
			
		||||
			Seq    int                          `json:"seq"`
 | 
			
		||||
			Text   []XunfeiChatResponseTextItem `json:"text"`
 | 
			
		||||
		} `json:"choices"`
 | 
			
		||||
	} `json:"payload"`
 | 
			
		||||
	Usage struct {
 | 
			
		||||
		//Text struct {
 | 
			
		||||
		//	QuestionTokens   string `json:"question_tokens"`
 | 
			
		||||
		//	PromptTokens     string `json:"prompt_tokens"`
 | 
			
		||||
		//	CompletionTokens string `json:"completion_tokens"`
 | 
			
		||||
		//	TotalTokens      string `json:"total_tokens"`
 | 
			
		||||
		//} `json:"text"`
 | 
			
		||||
		Text Usage `json:"text"`
 | 
			
		||||
	} `json:"usage"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *XunfeiChatRequest {
 | 
			
		||||
	messages := make([]XunfeiMessage, 0, len(request.Messages))
 | 
			
		||||
	for _, message := range request.Messages {
 | 
			
		||||
		if message.Role == "system" {
 | 
			
		||||
			messages = append(messages, XunfeiMessage{
 | 
			
		||||
				Role:    "user",
 | 
			
		||||
				Content: message.Content,
 | 
			
		||||
			})
 | 
			
		||||
			messages = append(messages, XunfeiMessage{
 | 
			
		||||
				Role:    "assistant",
 | 
			
		||||
				Content: "Okay",
 | 
			
		||||
			})
 | 
			
		||||
		} else {
 | 
			
		||||
			messages = append(messages, XunfeiMessage{
 | 
			
		||||
				Role:    message.Role,
 | 
			
		||||
				Content: message.Content,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	xunfeiRequest := XunfeiChatRequest{}
 | 
			
		||||
	xunfeiRequest.Header.AppId = xunfeiAppId
 | 
			
		||||
	xunfeiRequest.Parameter.Chat.Domain = "general"
 | 
			
		||||
	xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
 | 
			
		||||
	xunfeiRequest.Parameter.Chat.TopK = request.N
 | 
			
		||||
	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
 | 
			
		||||
	xunfeiRequest.Payload.Message.Text = messages
 | 
			
		||||
	return &xunfeiRequest
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
 | 
			
		||||
	if len(response.Payload.Choices.Text) == 0 {
 | 
			
		||||
		response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
 | 
			
		||||
			{
 | 
			
		||||
				Content: "",
 | 
			
		||||
			},
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	choice := OpenAITextResponseChoice{
 | 
			
		||||
		Index: 0,
 | 
			
		||||
		Message: Message{
 | 
			
		||||
			Role:    "assistant",
 | 
			
		||||
			Content: response.Payload.Choices.Text[0].Content,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := OpenAITextResponse{
 | 
			
		||||
		Object:  "chat.completion",
 | 
			
		||||
		Created: common.GetTimestamp(),
 | 
			
		||||
		Choices: []OpenAITextResponseChoice{choice},
 | 
			
		||||
		Usage:   response.Usage.Text,
 | 
			
		||||
	}
 | 
			
		||||
	return &fullTextResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse {
 | 
			
		||||
	if len(xunfeiResponse.Payload.Choices.Text) == 0 {
 | 
			
		||||
		xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
 | 
			
		||||
			{
 | 
			
		||||
				Content: "",
 | 
			
		||||
			},
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	var choice ChatCompletionsStreamResponseChoice
 | 
			
		||||
	choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
 | 
			
		||||
	response := ChatCompletionsStreamResponse{
 | 
			
		||||
		Object:  "chat.completion.chunk",
 | 
			
		||||
		Created: common.GetTimestamp(),
 | 
			
		||||
		Model:   "SparkDesk",
 | 
			
		||||
		Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
			
		||||
	}
 | 
			
		||||
	return &response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
 | 
			
		||||
	HmacWithShaToBase64 := func(algorithm, data, key string) string {
 | 
			
		||||
		mac := hmac.New(sha256.New, []byte(key))
 | 
			
		||||
		mac.Write([]byte(data))
 | 
			
		||||
		encodeData := mac.Sum(nil)
 | 
			
		||||
		return base64.StdEncoding.EncodeToString(encodeData)
 | 
			
		||||
	}
 | 
			
		||||
	ul, err := url.Parse(hostUrl)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		fmt.Println(err)
 | 
			
		||||
	}
 | 
			
		||||
	date := time.Now().UTC().Format(time.RFC1123)
 | 
			
		||||
	signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
 | 
			
		||||
	sign := strings.Join(signString, "\n")
 | 
			
		||||
	sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
 | 
			
		||||
	authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
 | 
			
		||||
		"hmac-sha256", "host date request-line", sha)
 | 
			
		||||
	authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
 | 
			
		||||
	v := url.Values{}
 | 
			
		||||
	v.Add("host", ul.Host)
 | 
			
		||||
	v.Add("date", date)
 | 
			
		||||
	v.Add("authorization", authorization)
 | 
			
		||||
	callUrl := hostUrl + "?" + v.Encode()
 | 
			
		||||
	return callUrl
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var usage Usage
 | 
			
		||||
	d := websocket.Dialer{
 | 
			
		||||
		HandshakeTimeout: 5 * time.Second,
 | 
			
		||||
	}
 | 
			
		||||
	hostUrl := "wss://aichat.xf-yun.com/v1/chat"
 | 
			
		||||
	conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
 | 
			
		||||
	if err != nil || resp.StatusCode != 101 {
 | 
			
		||||
		return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	data := requestOpenAI2Xunfei(textRequest, appId)
 | 
			
		||||
	err = conn.WriteJSON(data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	dataChan := make(chan XunfeiChatResponse)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			_, msg, err := conn.ReadMessage()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error reading stream response: " + err.Error())
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
			var response XunfeiChatResponse
 | 
			
		||||
			err = json.Unmarshal(msg, &response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
			dataChan <- response
 | 
			
		||||
			if response.Payload.Choices.Status == 2 {
 | 
			
		||||
				err := conn.Close()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.SysError("error closing websocket connection: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
	c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
	c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
	c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case xunfeiResponse := <-dataChan:
 | 
			
		||||
			usage.PromptTokens += xunfeiResponse.Usage.Text.PromptTokens
 | 
			
		||||
			usage.CompletionTokens += xunfeiResponse.Usage.Text.CompletionTokens
 | 
			
		||||
			usage.TotalTokens += xunfeiResponse.Usage.Text.TotalTokens
 | 
			
		||||
			response := streamResponseXunfei2OpenAI(&xunfeiResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var xunfeiResponse XunfeiChatResponse
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &xunfeiResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	if xunfeiResponse.Header.Code != 0 {
 | 
			
		||||
		return &OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: OpenAIError{
 | 
			
		||||
				Message: xunfeiResponse.Header.Message,
 | 
			
		||||
				Type:    "xunfei_error",
 | 
			
		||||
				Param:   "",
 | 
			
		||||
				Code:    xunfeiResponse.Header.Code,
 | 
			
		||||
			},
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
@@ -111,10 +111,21 @@ func getZhipuToken(apikey string) string {
 | 
			
		||||
func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
 | 
			
		||||
	messages := make([]ZhipuMessage, 0, len(request.Messages))
 | 
			
		||||
	for _, message := range request.Messages {
 | 
			
		||||
		messages = append(messages, ZhipuMessage{
 | 
			
		||||
			Role:    message.Role,
 | 
			
		||||
			Content: message.Content,
 | 
			
		||||
		})
 | 
			
		||||
		if message.Role == "system" {
 | 
			
		||||
			messages = append(messages, ZhipuMessage{
 | 
			
		||||
				Role:    "system",
 | 
			
		||||
				Content: message.Content,
 | 
			
		||||
			})
 | 
			
		||||
			messages = append(messages, ZhipuMessage{
 | 
			
		||||
				Role:    "user",
 | 
			
		||||
				Content: "Okay",
 | 
			
		||||
			})
 | 
			
		||||
		} else {
 | 
			
		||||
			messages = append(messages, ZhipuMessage{
 | 
			
		||||
				Role:    message.Role,
 | 
			
		||||
				Content: message.Content,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return &ZhipuRequest{
 | 
			
		||||
		Prompt:      messages,
 | 
			
		||||
 
 | 
			
		||||
@@ -99,6 +99,19 @@ type OpenAITextResponse struct {
 | 
			
		||||
	Usage   `json:"usage"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAIEmbeddingResponseItem struct {
 | 
			
		||||
	Object    string    `json:"object"`
 | 
			
		||||
	Index     int       `json:"index"`
 | 
			
		||||
	Embedding []float64 `json:"embedding"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAIEmbeddingResponse struct {
 | 
			
		||||
	Object string                        `json:"object"`
 | 
			
		||||
	Data   []OpenAIEmbeddingResponseItem `json:"data"`
 | 
			
		||||
	Model  string                        `json:"model"`
 | 
			
		||||
	Usage  `json:"usage"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ImageResponse struct {
 | 
			
		||||
	Created int `json:"created"`
 | 
			
		||||
	Data    []struct {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										1
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								go.mod
									
									
									
									
									
								
							@@ -13,6 +13,7 @@ require (
 | 
			
		||||
	github.com/go-redis/redis/v8 v8.11.5
 | 
			
		||||
	github.com/golang-jwt/jwt v3.2.2+incompatible
 | 
			
		||||
	github.com/google/uuid v1.3.0
 | 
			
		||||
	github.com/gorilla/websocket v1.5.0
 | 
			
		||||
	github.com/pkoukk/tiktoken-go v0.1.1
 | 
			
		||||
	golang.org/x/crypto v0.9.0
 | 
			
		||||
	gorm.io/driver/mysql v1.4.3
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										2
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.sum
									
									
									
									
									
								
							@@ -67,6 +67,8 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC
 | 
			
		||||
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
 | 
			
		||||
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
 | 
			
		||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
 | 
			
		||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
 | 
			
		||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
 | 
			
		||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
 | 
			
		||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
 | 
			
		||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
 | 
			
		||||
 
 | 
			
		||||
@@ -503,5 +503,12 @@
 | 
			
		||||
  "请输入 AZURE_OPENAI_ENDPOINT": "Please enter AZURE_OPENAI_ENDPOINT",
 | 
			
		||||
  "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel",
 | 
			
		||||
  "Homepage URL 填": "Fill in the Homepage URL",
 | 
			
		||||
  "Authorization callback URL 填": "Fill in the Authorization callback URL"
 | 
			
		||||
  "Authorization callback URL 填": "Fill in the Authorization callback URL",
 | 
			
		||||
  "请为通道命名": "Please name the channel",
 | 
			
		||||
  "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:",
 | 
			
		||||
  "模型重定向": "Model redirection",
 | 
			
		||||
  "请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel",
 | 
			
		||||
  "注意,": "Note that, ",
 | 
			
		||||
  ",图片演示。": "related image demo.",
 | 
			
		||||
  "令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!"
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										1
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								main.go
									
									
									
									
									
								
							@@ -54,6 +54,7 @@ func main() {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
		common.SyncFrequency = frequency
 | 
			
		||||
		go model.SyncOptions(frequency)
 | 
			
		||||
		if common.RedisEnabled {
 | 
			
		||||
			go model.SyncChannelCache(frequency)
 | 
			
		||||
 
 | 
			
		||||
@@ -12,11 +12,11 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	TokenCacheSeconds         = 60 * 60
 | 
			
		||||
	UserId2GroupCacheSeconds  = 60 * 60
 | 
			
		||||
	UserId2QuotaCacheSeconds  = 10 * 60
 | 
			
		||||
	UserId2StatusCacheSeconds = 60 * 60
 | 
			
		||||
var (
 | 
			
		||||
	TokenCacheSeconds         = common.SyncFrequency
 | 
			
		||||
	UserId2GroupCacheSeconds  = common.SyncFrequency
 | 
			
		||||
	UserId2QuotaCacheSeconds  = common.SyncFrequency
 | 
			
		||||
	UserId2StatusCacheSeconds = common.SyncFrequency
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func CacheGetTokenByKey(key string) (*Token, error) {
 | 
			
		||||
@@ -35,7 +35,7 @@ func CacheGetTokenByKey(key string) (*Token, error) {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), TokenCacheSeconds*time.Second)
 | 
			
		||||
		err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError("Redis set token error: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
@@ -55,7 +55,7 @@ func CacheGetUserGroup(id int) (group string, err error) {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return "", err
 | 
			
		||||
		}
 | 
			
		||||
		err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, UserId2GroupCacheSeconds*time.Second)
 | 
			
		||||
		err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError("Redis set user group error: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
@@ -73,7 +73,7 @@ func CacheGetUserQuota(id int) (quota int, err error) {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return 0, err
 | 
			
		||||
		}
 | 
			
		||||
		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second)
 | 
			
		||||
		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError("Redis set user quota error: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
@@ -91,7 +91,7 @@ func CacheUpdateUserQuota(id int) error {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second)
 | 
			
		||||
	err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -106,7 +106,7 @@ func CacheIsUserEnabled(userId int) bool {
 | 
			
		||||
			status = common.UserStatusEnabled
 | 
			
		||||
		}
 | 
			
		||||
		enabled = fmt.Sprintf("%d", status)
 | 
			
		||||
		err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, UserId2StatusCacheSeconds*time.Second)
 | 
			
		||||
		err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError("Redis set user enabled error: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -39,6 +39,8 @@ func InitOptionMap() {
 | 
			
		||||
	common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
 | 
			
		||||
	common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
 | 
			
		||||
	common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
 | 
			
		||||
	common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
 | 
			
		||||
	common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
 | 
			
		||||
	common.OptionMap["SMTPServer"] = ""
 | 
			
		||||
	common.OptionMap["SMTPFrom"] = ""
 | 
			
		||||
	common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
 | 
			
		||||
@@ -141,6 +143,8 @@ func updateOptionMap(key string, value string) (err error) {
 | 
			
		||||
			common.TurnstileCheckEnabled = boolValue
 | 
			
		||||
		case "RegisterEnabled":
 | 
			
		||||
			common.RegisterEnabled = boolValue
 | 
			
		||||
		case "EmailDomainRestrictionEnabled":
 | 
			
		||||
			common.EmailDomainRestrictionEnabled = boolValue
 | 
			
		||||
		case "AutomaticDisableChannelEnabled":
 | 
			
		||||
			common.AutomaticDisableChannelEnabled = boolValue
 | 
			
		||||
		case "ApproximateTokenEnabled":
 | 
			
		||||
@@ -154,6 +158,8 @@ func updateOptionMap(key string, value string) (err error) {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	switch key {
 | 
			
		||||
	case "EmailDomainWhitelist":
 | 
			
		||||
		common.EmailDomainWhitelist = strings.Split(value, ",")
 | 
			
		||||
	case "SMTPServer":
 | 
			
		||||
		common.SMTPServer = value
 | 
			
		||||
	case "SMTPPort":
 | 
			
		||||
 
 | 
			
		||||
@@ -51,20 +51,21 @@ func Redeem(key string, userId int) (quota int, err error) {
 | 
			
		||||
	redemption := &Redemption{}
 | 
			
		||||
 | 
			
		||||
	err = DB.Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		err := DB.Where("`key` = ?", key).First(redemption).Error
 | 
			
		||||
		err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errors.New("无效的兑换码")
 | 
			
		||||
		}
 | 
			
		||||
		if redemption.Status != common.RedemptionCodeStatusEnabled {
 | 
			
		||||
			return errors.New("该兑换码已被使用")
 | 
			
		||||
		}
 | 
			
		||||
		err = DB.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
 | 
			
		||||
		err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		redemption.RedeemedTime = common.GetTimestamp()
 | 
			
		||||
		redemption.Status = common.RedemptionCodeStatusUsed
 | 
			
		||||
		return redemption.SelectUpdate()
 | 
			
		||||
		err = tx.Save(redemption).Error
 | 
			
		||||
		return err
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, errors.New("兑换失败," + err.Error())
 | 
			
		||||
 
 | 
			
		||||
@@ -12,7 +12,7 @@ func SetRelayRouter(router *gin.Engine) {
 | 
			
		||||
	modelsRouter := router.Group("/v1/models")
 | 
			
		||||
	modelsRouter.Use(middleware.TokenAuth())
 | 
			
		||||
	{
 | 
			
		||||
		modelsRouter.GET("/", controller.ListModels)
 | 
			
		||||
		modelsRouter.GET("", controller.ListModels)
 | 
			
		||||
		modelsRouter.GET("/:model", controller.RetrieveModel)
 | 
			
		||||
	}
 | 
			
		||||
	relayV1Router := router.Group("/v1")
 | 
			
		||||
 
 | 
			
		||||
@@ -363,9 +363,12 @@ const ChannelsTable = () => {
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
                    <Popup
 | 
			
		||||
                      content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'}
 | 
			
		||||
                      key={channel.id}
 | 
			
		||||
                      trigger={renderBalance(channel.type, channel.balance)}
 | 
			
		||||
                      trigger={<span onClick={() => {
 | 
			
		||||
                        updateChannelBalance(channel.id, channel.name, idx);
 | 
			
		||||
                      }} style={{ cursor: 'pointer' }}>
 | 
			
		||||
                      {renderBalance(channel.type, channel.balance)}
 | 
			
		||||
                    </span>}
 | 
			
		||||
                      content="点击更新"
 | 
			
		||||
                      basic
 | 
			
		||||
                    />
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
@@ -380,16 +383,16 @@ const ChannelsTable = () => {
 | 
			
		||||
                      >
 | 
			
		||||
                        测试
 | 
			
		||||
                      </Button>
 | 
			
		||||
                      <Button
 | 
			
		||||
                        size={'small'}
 | 
			
		||||
                        positive
 | 
			
		||||
                        loading={updatingBalance}
 | 
			
		||||
                        onClick={() => {
 | 
			
		||||
                          updateChannelBalance(channel.id, channel.name, idx);
 | 
			
		||||
                        }}
 | 
			
		||||
                      >
 | 
			
		||||
                        更新余额
 | 
			
		||||
                      </Button>
 | 
			
		||||
                      {/*<Button*/}
 | 
			
		||||
                      {/*  size={'small'}*/}
 | 
			
		||||
                      {/*  positive*/}
 | 
			
		||||
                      {/*  loading={updatingBalance}*/}
 | 
			
		||||
                      {/*  onClick={() => {*/}
 | 
			
		||||
                      {/*    updateChannelBalance(channel.id, channel.name, idx);*/}
 | 
			
		||||
                      {/*  }}*/}
 | 
			
		||||
                      {/*>*/}
 | 
			
		||||
                      {/*  更新余额*/}
 | 
			
		||||
                      {/*</Button>*/}
 | 
			
		||||
                      <Popup
 | 
			
		||||
                        trigger={
 | 
			
		||||
                          <Button size='small' negative>
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,6 @@
 | 
			
		||||
import React, { useEffect, useState } from 'react';
 | 
			
		||||
import { Divider, Form, Grid, Header, Message } from 'semantic-ui-react';
 | 
			
		||||
import { API, removeTrailingSlash, showError, verifyJSON } from '../helpers';
 | 
			
		||||
import { Button, Divider, Form, Grid, Header, Input, Message } from 'semantic-ui-react';
 | 
			
		||||
import { API, removeTrailingSlash, showError } from '../helpers';
 | 
			
		||||
 | 
			
		||||
const SystemSetting = () => {
 | 
			
		||||
  let [inputs, setInputs] = useState({
 | 
			
		||||
@@ -26,9 +26,13 @@ const SystemSetting = () => {
 | 
			
		||||
    TurnstileSiteKey: '',
 | 
			
		||||
    TurnstileSecretKey: '',
 | 
			
		||||
    RegisterEnabled: '',
 | 
			
		||||
    EmailDomainRestrictionEnabled: '',
 | 
			
		||||
    EmailDomainWhitelist: ''
 | 
			
		||||
  });
 | 
			
		||||
  const [originInputs, setOriginInputs] = useState({});
 | 
			
		||||
  let [loading, setLoading] = useState(false);
 | 
			
		||||
  const [EmailDomainWhitelist, setEmailDomainWhitelist] = useState([]);
 | 
			
		||||
  const [restrictedDomainInput, setRestrictedDomainInput] = useState('');
 | 
			
		||||
 | 
			
		||||
  const getOptions = async () => {
 | 
			
		||||
    const res = await API.get('/api/option/');
 | 
			
		||||
@@ -38,8 +42,15 @@ const SystemSetting = () => {
 | 
			
		||||
      data.forEach((item) => {
 | 
			
		||||
        newInputs[item.key] = item.value;
 | 
			
		||||
      });
 | 
			
		||||
      setInputs(newInputs);
 | 
			
		||||
      setInputs({
 | 
			
		||||
        ...newInputs,
 | 
			
		||||
        EmailDomainWhitelist: newInputs.EmailDomainWhitelist.split(',')
 | 
			
		||||
      });
 | 
			
		||||
      setOriginInputs(newInputs);
 | 
			
		||||
 | 
			
		||||
      setEmailDomainWhitelist(newInputs.EmailDomainWhitelist.split(',').map((item) => {
 | 
			
		||||
        return { key: item, text: item, value: item };
 | 
			
		||||
      }));
 | 
			
		||||
    } else {
 | 
			
		||||
      showError(message);
 | 
			
		||||
    }
 | 
			
		||||
@@ -58,6 +69,7 @@ const SystemSetting = () => {
 | 
			
		||||
      case 'GitHubOAuthEnabled':
 | 
			
		||||
      case 'WeChatAuthEnabled':
 | 
			
		||||
      case 'TurnstileCheckEnabled':
 | 
			
		||||
      case 'EmailDomainRestrictionEnabled':
 | 
			
		||||
      case 'RegisterEnabled':
 | 
			
		||||
        value = inputs[key] === 'true' ? 'false' : 'true';
 | 
			
		||||
        break;
 | 
			
		||||
@@ -70,7 +82,12 @@ const SystemSetting = () => {
 | 
			
		||||
    });
 | 
			
		||||
    const { success, message } = res.data;
 | 
			
		||||
    if (success) {
 | 
			
		||||
      setInputs((inputs) => ({ ...inputs, [key]: value }));
 | 
			
		||||
      if (key === 'EmailDomainWhitelist') {
 | 
			
		||||
        value = value.split(',');
 | 
			
		||||
      }
 | 
			
		||||
      setInputs((inputs) => ({
 | 
			
		||||
        ...inputs, [key]: value
 | 
			
		||||
      }));
 | 
			
		||||
    } else {
 | 
			
		||||
      showError(message);
 | 
			
		||||
    }
 | 
			
		||||
@@ -88,7 +105,8 @@ const SystemSetting = () => {
 | 
			
		||||
      name === 'WeChatServerToken' ||
 | 
			
		||||
      name === 'WeChatAccountQRCodeImageURL' ||
 | 
			
		||||
      name === 'TurnstileSiteKey' ||
 | 
			
		||||
      name === 'TurnstileSecretKey'
 | 
			
		||||
      name === 'TurnstileSecretKey' ||
 | 
			
		||||
      name === 'EmailDomainWhitelist'
 | 
			
		||||
    ) {
 | 
			
		||||
      setInputs((inputs) => ({ ...inputs, [name]: value }));
 | 
			
		||||
    } else {
 | 
			
		||||
@@ -125,6 +143,16 @@ const SystemSetting = () => {
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  const submitEmailDomainWhitelist = async () => {
 | 
			
		||||
    if (
 | 
			
		||||
      originInputs['EmailDomainWhitelist'] !== inputs.EmailDomainWhitelist.join(',') &&
 | 
			
		||||
      inputs.SMTPToken !== ''
 | 
			
		||||
    ) {
 | 
			
		||||
      await updateOption('EmailDomainWhitelist', inputs.EmailDomainWhitelist.join(','));
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const submitWeChat = async () => {
 | 
			
		||||
    if (originInputs['WeChatServerAddress'] !== inputs.WeChatServerAddress) {
 | 
			
		||||
      await updateOption(
 | 
			
		||||
@@ -173,6 +201,22 @@ const SystemSetting = () => {
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const submitNewRestrictedDomain = () => {
 | 
			
		||||
    const localDomainList = inputs.EmailDomainWhitelist;
 | 
			
		||||
    if (restrictedDomainInput !== '' && !localDomainList.includes(restrictedDomainInput)) {
 | 
			
		||||
      setRestrictedDomainInput('');
 | 
			
		||||
      setInputs({
 | 
			
		||||
        ...inputs,
 | 
			
		||||
        EmailDomainWhitelist: [...localDomainList, restrictedDomainInput],
 | 
			
		||||
      });
 | 
			
		||||
      setEmailDomainWhitelist([...EmailDomainWhitelist, {
 | 
			
		||||
        key: restrictedDomainInput,
 | 
			
		||||
        text: restrictedDomainInput,
 | 
			
		||||
        value: restrictedDomainInput,
 | 
			
		||||
      }]);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
    <Grid columns={1}>
 | 
			
		||||
      <Grid.Column>
 | 
			
		||||
@@ -239,6 +283,54 @@ const SystemSetting = () => {
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Divider />
 | 
			
		||||
          <Header as='h3'>
 | 
			
		||||
            配置邮箱域名白名单
 | 
			
		||||
            <Header.Subheader>用以防止恶意用户利用临时邮箱批量注册</Header.Subheader>
 | 
			
		||||
          </Header>
 | 
			
		||||
          <Form.Group widths={3}>
 | 
			
		||||
            <Form.Checkbox
 | 
			
		||||
              label='启用邮箱域名白名单'
 | 
			
		||||
              name='EmailDomainRestrictionEnabled'
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              checked={inputs.EmailDomainRestrictionEnabled === 'true'}
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Form.Group widths={2}>
 | 
			
		||||
            <Form.Dropdown
 | 
			
		||||
              label='允许的邮箱域名'
 | 
			
		||||
              placeholder='允许的邮箱域名'
 | 
			
		||||
              name='EmailDomainWhitelist'
 | 
			
		||||
              required
 | 
			
		||||
              fluid
 | 
			
		||||
              multiple
 | 
			
		||||
              selection
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              value={inputs.EmailDomainWhitelist}
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
              options={EmailDomainWhitelist}
 | 
			
		||||
            />
 | 
			
		||||
            <Form.Input
 | 
			
		||||
              label='添加新的允许的邮箱域名'
 | 
			
		||||
              action={
 | 
			
		||||
                <Button type='button' onClick={() => {
 | 
			
		||||
                  submitNewRestrictedDomain();
 | 
			
		||||
                }}>填入</Button>
 | 
			
		||||
              }
 | 
			
		||||
              onKeyDown={(e) => {
 | 
			
		||||
                if (e.key === 'Enter') {
 | 
			
		||||
                  submitNewRestrictedDomain();
 | 
			
		||||
                }
 | 
			
		||||
              }}
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
              placeholder='输入新的允许的邮箱域名'
 | 
			
		||||
              value={restrictedDomainInput}
 | 
			
		||||
              onChange={(e, { value }) => {
 | 
			
		||||
                setRestrictedDomainInput(value);
 | 
			
		||||
              }}
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Form.Button onClick={submitEmailDomainWhitelist}>保存邮箱域名白名单设置</Form.Button>
 | 
			
		||||
          <Divider />
 | 
			
		||||
          <Header as='h3'>
 | 
			
		||||
            配置 SMTP
 | 
			
		||||
            <Header.Subheader>用以支持系统的邮件发送</Header.Subheader>
 | 
			
		||||
@@ -284,7 +376,7 @@ const SystemSetting = () => {
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              type='password'
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
              value={inputs.SMTPToken}
 | 
			
		||||
              checked={inputs.RegisterEnabled === 'true'}
 | 
			
		||||
              placeholder='敏感信息不会发送到前端显示'
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
 
 | 
			
		||||
@@ -227,7 +227,7 @@ const UsersTable = () => {
 | 
			
		||||
                      content={user.email ? user.email : '未绑定邮箱地址'}
 | 
			
		||||
                      key={user.username}
 | 
			
		||||
                      header={user.display_name ? user.display_name : user.username}
 | 
			
		||||
                      trigger={<span>{renderText(user.username, 10)}</span>}
 | 
			
		||||
                      trigger={<span>{renderText(user.username, 15)}</span>}
 | 
			
		||||
                      hoverable
 | 
			
		||||
                    />
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,8 @@ export const CHANNEL_OPTIONS = [
 | 
			
		||||
  { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' },
 | 
			
		||||
  { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' },
 | 
			
		||||
  { key: 15, text: '百度文心千帆', value: 15, color: 'blue' },
 | 
			
		||||
  { key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
 | 
			
		||||
  { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
 | 
			
		||||
  { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
 | 
			
		||||
  { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
 | 
			
		||||
  { key: 2, text: '代理:API2D', value: 2, color: 'blue' },
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,5 @@
 | 
			
		||||
export const toastConstants = {
 | 
			
		||||
  SUCCESS_TIMEOUT: 500,
 | 
			
		||||
  SUCCESS_TIMEOUT: 1500,
 | 
			
		||||
  INFO_TIMEOUT: 3000,
 | 
			
		||||
  ERROR_TIMEOUT: 5000,
 | 
			
		||||
  WARNING_TIMEOUT: 10000,
 | 
			
		||||
 
 | 
			
		||||
@@ -35,6 +35,30 @@ const EditChannel = () => {
 | 
			
		||||
  const [customModel, setCustomModel] = useState('');
 | 
			
		||||
  const handleInputChange = (e, { name, value }) => {
 | 
			
		||||
    setInputs((inputs) => ({ ...inputs, [name]: value }));
 | 
			
		||||
    if (name === 'type' && inputs.models.length === 0) {
 | 
			
		||||
      let localModels = [];
 | 
			
		||||
      switch (value) {
 | 
			
		||||
        case 14:
 | 
			
		||||
          localModels = ['claude-instant-1', 'claude-2'];
 | 
			
		||||
          break;
 | 
			
		||||
        case 11:
 | 
			
		||||
          localModels = ['PaLM-2'];
 | 
			
		||||
          break;
 | 
			
		||||
        case 15:
 | 
			
		||||
          localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
 | 
			
		||||
          break;
 | 
			
		||||
        case 17:
 | 
			
		||||
          localModels = ['qwen-v1', 'qwen-plus-v1'];
 | 
			
		||||
          break;
 | 
			
		||||
        case 16:
 | 
			
		||||
          localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
 | 
			
		||||
          break;
 | 
			
		||||
        case 18:
 | 
			
		||||
          localModels = ['SparkDesk'];
 | 
			
		||||
          break;
 | 
			
		||||
      }
 | 
			
		||||
      setInputs((inputs) => ({ ...inputs, models: localModels }));
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const loadChannel = async () => {
 | 
			
		||||
@@ -132,7 +156,10 @@ const EditChannel = () => {
 | 
			
		||||
      localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
 | 
			
		||||
    }
 | 
			
		||||
    if (localInputs.type === 3 && localInputs.other === '') {
 | 
			
		||||
      localInputs.other = '2023-03-15-preview';
 | 
			
		||||
      localInputs.other = '2023-06-01-preview';
 | 
			
		||||
    }
 | 
			
		||||
    if (localInputs.model_mapping === '') {
 | 
			
		||||
      localInputs.model_mapping = '{}';
 | 
			
		||||
    }
 | 
			
		||||
    let res;
 | 
			
		||||
    localInputs.models = localInputs.models.join(',');
 | 
			
		||||
@@ -192,7 +219,7 @@ const EditChannel = () => {
 | 
			
		||||
                  <Form.Input
 | 
			
		||||
                    label='默认 API 版本'
 | 
			
		||||
                    name='other'
 | 
			
		||||
                    placeholder={'请输入默认 API 版本,例如:2023-03-15-preview,该配置可以被实际的请求查询参数所覆盖'}
 | 
			
		||||
                    placeholder={'请输入默认 API 版本,例如:2023-06-01-preview,该配置可以被实际的请求查询参数所覆盖'}
 | 
			
		||||
                    onChange={handleInputChange}
 | 
			
		||||
                    value={inputs.other}
 | 
			
		||||
                    autoComplete='new-password'
 | 
			
		||||
@@ -270,8 +297,8 @@ const EditChannel = () => {
 | 
			
		||||
            }}>清除所有模型</Button>
 | 
			
		||||
            <Input
 | 
			
		||||
              action={
 | 
			
		||||
                <Button type={'button'} onClick={()=>{
 | 
			
		||||
                  if (customModel.trim() === "") return;
 | 
			
		||||
                <Button type={'button'} onClick={() => {
 | 
			
		||||
                  if (customModel.trim() === '') return;
 | 
			
		||||
                  if (inputs.models.includes(customModel)) return;
 | 
			
		||||
                  let localModels = [...inputs.models];
 | 
			
		||||
                  localModels.push(customModel);
 | 
			
		||||
@@ -279,9 +306,9 @@ const EditChannel = () => {
 | 
			
		||||
                  localModelOptions.push({
 | 
			
		||||
                    key: customModel,
 | 
			
		||||
                    text: customModel,
 | 
			
		||||
                    value: customModel,
 | 
			
		||||
                    value: customModel
 | 
			
		||||
                  });
 | 
			
		||||
                  setModelOptions(modelOptions=>{
 | 
			
		||||
                  setModelOptions(modelOptions => {
 | 
			
		||||
                    return [...modelOptions, ...localModelOptions];
 | 
			
		||||
                  });
 | 
			
		||||
                  setCustomModel('');
 | 
			
		||||
@@ -297,7 +324,7 @@ const EditChannel = () => {
 | 
			
		||||
          </div>
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
            <Form.TextArea
 | 
			
		||||
              label='模型映射'
 | 
			
		||||
              label='模型重定向'
 | 
			
		||||
              placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
 | 
			
		||||
              name='model_mapping'
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
@@ -323,7 +350,7 @@ const EditChannel = () => {
 | 
			
		||||
                label='密钥'
 | 
			
		||||
                name='key'
 | 
			
		||||
                required
 | 
			
		||||
                placeholder={inputs.type === 15 ? "请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次" : '请输入渠道对应的鉴权密钥'}
 | 
			
		||||
                placeholder={inputs.type === 15 ? '请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')}
 | 
			
		||||
                onChange={handleInputChange}
 | 
			
		||||
                value={inputs.key}
 | 
			
		||||
                autoComplete='new-password'
 | 
			
		||||
@@ -354,7 +381,7 @@ const EditChannel = () => {
 | 
			
		||||
              </Form.Field>
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
          <Button type={isEdit ? "button" : "submit"} positive onClick={submit}>提交</Button>
 | 
			
		||||
          <Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button>
 | 
			
		||||
        </Form>
 | 
			
		||||
      </Segment>
 | 
			
		||||
    </>
 | 
			
		||||
 
 | 
			
		||||
@@ -83,7 +83,7 @@ const EditToken = () => {
 | 
			
		||||
      if (isEdit) {
 | 
			
		||||
        showSuccess('令牌更新成功!');
 | 
			
		||||
      } else {
 | 
			
		||||
        showSuccess('令牌创建成功!');
 | 
			
		||||
        showSuccess('令牌创建成功,请在列表页面点击复制获取令牌!');
 | 
			
		||||
        setInputs(originInputs);
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user