mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 07:43:41 +08:00 
			
		
		
		
	Compare commits
	
		
			12 Commits
		
	
	
		
			v0.3.4
			...
			v0.4.1-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					145bb14cb2 | ||
| 
						 | 
					8901f03864 | ||
| 
						 | 
					813bf0bd66 | ||
| 
						 | 
					45e9fd66e7 | ||
| 
						 | 
					e0d0674f81 | ||
| 
						 | 
					4b6adaec0b | ||
| 
						 | 
					9301b3fed3 | ||
| 
						 | 
					c6edb78ac9 | ||
| 
						 | 
					521ede2469 | ||
| 
						 | 
					2c53424db8 | ||
| 
						 | 
					2ad22e1425 | ||
| 
						 | 
					502515bbbd | 
							
								
								
									
										25
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								README.md
									
									
									
									
									
								
							@@ -40,14 +40,17 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
 | 
			
		||||
  <a href="https://openai.justsong.cn/">在线演示</a>
 | 
			
		||||
  ·
 | 
			
		||||
  <a href="https://github.com/songquanpeng/one-api#常见问题">常见问题</a>
 | 
			
		||||
  ·
 | 
			
		||||
  <a href="https://iamazing.cn/page/reward">赞赏支持</a>
 | 
			
		||||
</p>
 | 
			
		||||
 | 
			
		||||
> **Warning**:从 `v0.2` 版本升级到 `v0.3` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.2-v0.3.sql)。
 | 
			
		||||
> **Warning**:使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
 | 
			
		||||
 | 
			
		||||
> **Warning**:从 `v0.3` 版本升级到 `v0.4` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.3-v0.4.sql)。
 | 
			
		||||
 | 
			
		||||
## 功能
 | 
			
		||||
1. 支持多种 API 访问渠道,欢迎 PR 或提 issue 添加更多渠道:
 | 
			
		||||
   + [x] OpenAI 官方通道
 | 
			
		||||
   + [x] OpenAI 官方通道(支持配置代理)
 | 
			
		||||
   + [x] **Azure OpenAI API**
 | 
			
		||||
   + [x] [API2D](https://api2d.com/r/197971)
 | 
			
		||||
   + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
 | 
			
		||||
@@ -56,23 +59,25 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
 | 
			
		||||
   + [x] [OpenAI Max](https://openaimax.com)
 | 
			
		||||
   + [x] [OpenAI-SB](https://openai-sb.com)
 | 
			
		||||
   + [x] [CloseAI](https://console.openai-asia.com/r/2412)
 | 
			
		||||
   + [x] 自定义渠道:例如使用自行搭建的 OpenAI 代理
 | 
			
		||||
   + [x] 自定义渠道:例如各种未收录的第三方代理服务
 | 
			
		||||
2. 支持通过**负载均衡**的方式访问多个渠道。
 | 
			
		||||
3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
 | 
			
		||||
4. 支持**多机部署**,[详见此处](#多机部署)。
 | 
			
		||||
5. 支持**令牌管理**,设置令牌的过期时间和使用次数。
 | 
			
		||||
6. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。
 | 
			
		||||
7. 支持**通道管理**,批量创建通道。
 | 
			
		||||
8. 支持发布公告,设置充值链接,设置新用户初始额度。
 | 
			
		||||
9. 支持丰富的**自定义**设置,
 | 
			
		||||
   1. 支持自定义系统名称,logo 以及页脚。
 | 
			
		||||
   2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
 | 
			
		||||
10. 支持通过系统访问令牌访问管理 API。
 | 
			
		||||
11. 支持用户管理,支持**多种用户登录注册方式**:
 | 
			
		||||
8. 支持**用户分组**以及**渠道分组**。
 | 
			
		||||
9. 支持渠道**设置模型列表**。
 | 
			
		||||
10. 支持发布公告,设置充值链接,设置新用户初始额度。
 | 
			
		||||
11. 支持丰富的**自定义**设置,
 | 
			
		||||
    1. 支持自定义系统名称,logo 以及页脚。
 | 
			
		||||
    2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
 | 
			
		||||
12. 支持通过系统访问令牌访问管理 API。
 | 
			
		||||
13. 支持用户管理,支持**多种用户登录注册方式**:
 | 
			
		||||
    + 邮箱登录注册以及通过邮箱进行密码重置。
 | 
			
		||||
    + [GitHub 开放授权](https://github.com/settings/applications/new)。
 | 
			
		||||
    + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
 | 
			
		||||
12. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
 | 
			
		||||
14. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
 | 
			
		||||
 | 
			
		||||
## 部署
 | 
			
		||||
### 基于 Docker 进行部署
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										17
									
								
								bin/migration_v0.3-v0.4.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								bin/migration_v0.3-v0.4.sql
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,17 @@
 | 
			
		||||
INSERT INTO abilities (`group`, model, channel_id, enabled)
 | 
			
		||||
SELECT c.`group`, m.model, c.id, 1
 | 
			
		||||
FROM channels c
 | 
			
		||||
CROSS JOIN (
 | 
			
		||||
    SELECT 'gpt-3.5-turbo' AS model UNION ALL
 | 
			
		||||
    SELECT 'gpt-3.5-turbo-0301' AS model UNION ALL
 | 
			
		||||
    SELECT 'gpt-4' AS model UNION ALL
 | 
			
		||||
    SELECT 'gpt-4-0314' AS model
 | 
			
		||||
) AS m
 | 
			
		||||
WHERE c.status = 1
 | 
			
		||||
  AND NOT EXISTS (
 | 
			
		||||
    SELECT 1
 | 
			
		||||
    FROM abilities a
 | 
			
		||||
    WHERE a.`group` = c.`group`
 | 
			
		||||
      AND a.model = m.model
 | 
			
		||||
      AND a.channel_id = c.id
 | 
			
		||||
);
 | 
			
		||||
@@ -25,6 +25,7 @@ var OptionMap map[string]string
 | 
			
		||||
var OptionMapRWMutex sync.RWMutex
 | 
			
		||||
 | 
			
		||||
var ItemsPerPage = 10
 | 
			
		||||
var MaxRecentItems = 100
 | 
			
		||||
 | 
			
		||||
var PasswordLoginEnabled = true
 | 
			
		||||
var PasswordRegisterEnabled = true
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										26
									
								
								common/gin.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								common/gin.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,26 @@
 | 
			
		||||
package common
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func UnmarshalBodyReusable(c *gin.Context, v any) error {
 | 
			
		||||
	requestBody, err := io.ReadAll(c.Request.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	err = c.Request.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(requestBody, &v)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	// Reset request body
 | 
			
		||||
	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -10,7 +10,7 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"gpt-4-0314":              15,
 | 
			
		||||
	"gpt-4-32k":               30,
 | 
			
		||||
	"gpt-4-32k-0314":          30,
 | 
			
		||||
	"gpt-3.5-turbo":           1,
 | 
			
		||||
	"gpt-3.5-turbo":           1, // $0.002 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo-0301":      1,
 | 
			
		||||
	"text-ada-001":            0.2,
 | 
			
		||||
	"text-babbage-001":        0.25,
 | 
			
		||||
 
 | 
			
		||||
@@ -41,7 +41,9 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channel.Type]
 | 
			
		||||
	switch channel.Type {
 | 
			
		||||
	case common.ChannelTypeOpenAI:
 | 
			
		||||
		// do nothing
 | 
			
		||||
		if channel.BaseURL != "" {
 | 
			
		||||
			baseURL = channel.BaseURL
 | 
			
		||||
		}
 | 
			
		||||
	case common.ChannelTypeAzure:
 | 
			
		||||
		return 0, errors.New("尚未实现")
 | 
			
		||||
	case common.ChannelTypeCustom:
 | 
			
		||||
 
 | 
			
		||||
@@ -27,6 +27,8 @@ func testChannel(channel *model.Channel, request *ChatRequest) error {
 | 
			
		||||
	} else {
 | 
			
		||||
		if channel.Type == common.ChannelTypeCustom {
 | 
			
		||||
			requestURL = channel.BaseURL
 | 
			
		||||
		} else if channel.Type == common.ChannelTypeOpenAI && channel.BaseURL != "" {
 | 
			
		||||
			requestURL = channel.BaseURL
 | 
			
		||||
		}
 | 
			
		||||
		requestURL += "/v1/chat/completions"
 | 
			
		||||
	}
 | 
			
		||||
@@ -56,7 +58,7 @@ func testChannel(channel *model.Channel, request *ChatRequest) error {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if response.Error.Message != "" || response.Error.Code != "" {
 | 
			
		||||
	if response.Usage.CompletionTokens == 0 {
 | 
			
		||||
		return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										86
									
								
								controller/log.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								controller/log.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,86 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"strconv"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetAllLogs(c *gin.Context) {
 | 
			
		||||
	p, _ := strconv.Atoi(c.Query("p"))
 | 
			
		||||
	if p < 0 {
 | 
			
		||||
		p = 0
 | 
			
		||||
	}
 | 
			
		||||
	logType, _ := strconv.Atoi(c.Query("type"))
 | 
			
		||||
	logs, err := model.GetAllLogs(logType, p*common.ItemsPerPage, common.ItemsPerPage)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(200, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    logs,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUserLogs(c *gin.Context) {
 | 
			
		||||
	p, _ := strconv.Atoi(c.Query("p"))
 | 
			
		||||
	if p < 0 {
 | 
			
		||||
		p = 0
 | 
			
		||||
	}
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	logType, _ := strconv.Atoi(c.Query("type"))
 | 
			
		||||
	logs, err := model.GetUserLogs(userId, logType, p*common.ItemsPerPage, common.ItemsPerPage)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(200, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    logs,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SearchAllLogs(c *gin.Context) {
 | 
			
		||||
	keyword := c.Query("keyword")
 | 
			
		||||
	logs, err := model.SearchAllLogs(keyword)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(200, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    logs,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SearchUserLogs(c *gin.Context) {
 | 
			
		||||
	keyword := c.Query("keyword")
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	logs, err := model.SearchUserLogs(userId, keyword)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(200, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    logs,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
@@ -116,6 +116,51 @@ func init() {
 | 
			
		||||
			Root:       "text-embedding-ada-002",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "text-davinci-003",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "text-davinci-003",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "text-davinci-002",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "text-davinci-002",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "text-curie-001",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "text-curie-001",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "text-babbage-001",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "text-babbage-001",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "text-ada-001",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "text-ada-001",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	openAIModelsMap = make(map[string]OpenAIModels)
 | 
			
		||||
	for _, model := range openAIModels {
 | 
			
		||||
 
 | 
			
		||||
@@ -19,6 +19,13 @@ type Message struct {
 | 
			
		||||
	Name    *string `json:"name,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RelayModeUnknown = iota
 | 
			
		||||
	RelayModeChatCompletions
 | 
			
		||||
	RelayModeCompletions
 | 
			
		||||
	RelayModeEmbeddings
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://platform.openai.com/docs/api-reference/chat
 | 
			
		||||
 | 
			
		||||
type GeneralOpenAIRequest struct {
 | 
			
		||||
@@ -69,7 +76,7 @@ type TextResponse struct {
 | 
			
		||||
	Error OpenAIError `json:"error"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type StreamResponse struct {
 | 
			
		||||
type ChatCompletionsStreamResponse struct {
 | 
			
		||||
	Choices []struct {
 | 
			
		||||
		Delta struct {
 | 
			
		||||
			Content string `json:"content"`
 | 
			
		||||
@@ -78,8 +85,23 @@ type StreamResponse struct {
 | 
			
		||||
	} `json:"choices"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CompletionsStreamResponse struct {
 | 
			
		||||
	Choices []struct {
 | 
			
		||||
		Text         string `json:"text"`
 | 
			
		||||
		FinishReason string `json:"finish_reason"`
 | 
			
		||||
	} `json:"choices"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Relay(c *gin.Context) {
 | 
			
		||||
	err := relayHelper(c)
 | 
			
		||||
	relayMode := RelayModeUnknown
 | 
			
		||||
	if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
 | 
			
		||||
		relayMode = RelayModeChatCompletions
 | 
			
		||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
 | 
			
		||||
		relayMode = RelayModeCompletions
 | 
			
		||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
 | 
			
		||||
		relayMode = RelayModeEmbeddings
 | 
			
		||||
	}
 | 
			
		||||
	err := relayHelper(c, relayMode)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err.StatusCode == http.StatusTooManyRequests {
 | 
			
		||||
			err.OpenAIError.Message = "负载已满,请稍后再试,或升级账户以提升服务质量。"
 | 
			
		||||
@@ -110,31 +132,25 @@ func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatus
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
 | 
			
		||||
func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	consumeQuota := c.GetBool("consume_quota")
 | 
			
		||||
	var textRequest GeneralOpenAIRequest
 | 
			
		||||
	if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
 | 
			
		||||
		requestBody, err := io.ReadAll(c.Request.Body)
 | 
			
		||||
		err := common.UnmarshalBodyReusable(c, &textRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
		err = c.Request.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "close_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
		err = json.Unmarshal(requestBody, &textRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "unmarshal_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
		// Reset request body
 | 
			
		||||
		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
	}
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channelType]
 | 
			
		||||
	requestURL := c.Request.URL.String()
 | 
			
		||||
	if channelType == common.ChannelTypeCustom {
 | 
			
		||||
		baseURL = c.GetString("base_url")
 | 
			
		||||
	} else if channelType == common.ChannelTypeOpenAI {
 | 
			
		||||
		if c.GetString("base_url") != "" {
 | 
			
		||||
			baseURL = c.GetString("base_url")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 | 
			
		||||
	if channelType == common.ChannelTypeAzure {
 | 
			
		||||
@@ -158,8 +174,13 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		err := relayPaLM(textRequest, c)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)
 | 
			
		||||
	var promptTokens int
 | 
			
		||||
	switch relayMode {
 | 
			
		||||
	case RelayModeChatCompletions:
 | 
			
		||||
		promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
 | 
			
		||||
	case RelayModeCompletions:
 | 
			
		||||
		promptTokens = countTokenText(textRequest.Prompt, textRequest.Model)
 | 
			
		||||
	}
 | 
			
		||||
	preConsumedTokens := common.PreConsumedQuota
 | 
			
		||||
	if textRequest.MaxTokens != 0 {
 | 
			
		||||
		preConsumedTokens = promptTokens + textRequest.MaxTokens
 | 
			
		||||
@@ -255,14 +276,27 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
 | 
			
		||||
				dataChan <- data
 | 
			
		||||
				data = data[6:]
 | 
			
		||||
				if !strings.HasPrefix(data, "[DONE]") {
 | 
			
		||||
					var streamResponse StreamResponse
 | 
			
		||||
					err = json.Unmarshal([]byte(data), &streamResponse)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						common.SysError("Error unmarshalling stream response: " + err.Error())
 | 
			
		||||
						return
 | 
			
		||||
					}
 | 
			
		||||
					for _, choice := range streamResponse.Choices {
 | 
			
		||||
						streamResponseText += choice.Delta.Content
 | 
			
		||||
					switch relayMode {
 | 
			
		||||
					case RelayModeChatCompletions:
 | 
			
		||||
						var streamResponse ChatCompletionsStreamResponse
 | 
			
		||||
						err = json.Unmarshal([]byte(data), &streamResponse)
 | 
			
		||||
						if err != nil {
 | 
			
		||||
							common.SysError("Error unmarshalling stream response: " + err.Error())
 | 
			
		||||
							return
 | 
			
		||||
						}
 | 
			
		||||
						for _, choice := range streamResponse.Choices {
 | 
			
		||||
							streamResponseText += choice.Delta.Content
 | 
			
		||||
						}
 | 
			
		||||
					case RelayModeCompletions:
 | 
			
		||||
						var streamResponse CompletionsStreamResponse
 | 
			
		||||
						err = json.Unmarshal([]byte(data), &streamResponse)
 | 
			
		||||
						if err != nil {
 | 
			
		||||
							common.SysError("Error unmarshalling stream response: " + err.Error())
 | 
			
		||||
							return
 | 
			
		||||
						}
 | 
			
		||||
						for _, choice := range streamResponse.Choices {
 | 
			
		||||
							streamResponseText += choice.Text
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 
 | 
			
		||||
@@ -228,7 +228,7 @@ func GetUser(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	myRole := c.GetInt("role")
 | 
			
		||||
	if myRole <= user.Role {
 | 
			
		||||
	if myRole <= user.Role && myRole != common.RoleRootUser {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "无权获取同级或更高等级用户的信息",
 | 
			
		||||
@@ -326,14 +326,14 @@ func UpdateUser(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	myRole := c.GetInt("role")
 | 
			
		||||
	if myRole <= originUser.Role {
 | 
			
		||||
	if myRole <= originUser.Role && myRole != common.RoleRootUser {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "无权更新同权限等级或更高权限等级的用户信息",
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if myRole <= updatedUser.Role {
 | 
			
		||||
	if myRole <= updatedUser.Role && myRole != common.RoleRootUser {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "无权将其他用户权限等级提升到大于等于自己的权限等级",
 | 
			
		||||
 
 | 
			
		||||
@@ -9,6 +9,10 @@ import (
 | 
			
		||||
	"strconv"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ModelRequest struct {
 | 
			
		||||
	Model string `json:"model"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Distribute() func(c *gin.Context) {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		var channel *model.Channel
 | 
			
		||||
@@ -48,8 +52,21 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			// Select a channel for the user
 | 
			
		||||
			var err error
 | 
			
		||||
			channel, err = model.GetRandomChannel()
 | 
			
		||||
			var modelRequest ModelRequest
 | 
			
		||||
			err := common.UnmarshalBodyReusable(c, &modelRequest)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.JSON(200, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": "无效的请求",
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			userId := c.GetInt("id")
 | 
			
		||||
			userGroup, _ := model.GetUserGroup(userId)
 | 
			
		||||
			channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.JSON(200, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
@@ -65,11 +82,9 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
		c.Set("channel_id", channel.Id)
 | 
			
		||||
		c.Set("channel_name", channel.Name)
 | 
			
		||||
		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 | 
			
		||||
		if channel.Type == common.ChannelTypeCustom || channel.Type == common.ChannelTypeAzure {
 | 
			
		||||
			c.Set("base_url", channel.BaseURL)
 | 
			
		||||
			if channel.Type == common.ChannelTypeAzure {
 | 
			
		||||
				c.Set("api_version", channel.Other)
 | 
			
		||||
			}
 | 
			
		||||
		c.Set("base_url", channel.BaseURL)
 | 
			
		||||
		if channel.Type == common.ChannelTypeAzure {
 | 
			
		||||
			c.Set("api_version", channel.Other)
 | 
			
		||||
		}
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										69
									
								
								model/ability.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								model/ability.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,69 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Ability struct {
 | 
			
		||||
	Group     string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"`
 | 
			
		||||
	Model     string `json:"model" gorm:"primaryKey;autoIncrement:false"`
 | 
			
		||||
	ChannelId int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
 | 
			
		||||
	Enabled   bool   `json:"enabled"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 | 
			
		||||
	ability := Ability{}
 | 
			
		||||
	var err error = nil
 | 
			
		||||
	if common.UsingSQLite {
 | 
			
		||||
		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
 | 
			
		||||
	} else {
 | 
			
		||||
		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	channel := Channel{}
 | 
			
		||||
	err = DB.First(&channel, "id = ?", ability.ChannelId).Error
 | 
			
		||||
	return &channel, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (channel *Channel) AddAbilities() error {
 | 
			
		||||
	models_ := strings.Split(channel.Models, ",")
 | 
			
		||||
	abilities := make([]Ability, 0, len(models_))
 | 
			
		||||
	for _, model := range models_ {
 | 
			
		||||
		ability := Ability{
 | 
			
		||||
			Group:     channel.Group,
 | 
			
		||||
			Model:     model,
 | 
			
		||||
			ChannelId: channel.Id,
 | 
			
		||||
			Enabled:   channel.Status == common.ChannelStatusEnabled,
 | 
			
		||||
		}
 | 
			
		||||
		abilities = append(abilities, ability)
 | 
			
		||||
	}
 | 
			
		||||
	return DB.Create(&abilities).Error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (channel *Channel) DeleteAbilities() error {
 | 
			
		||||
	return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateAbilities updates abilities of this channel.
 | 
			
		||||
// Make sure the channel is completed before calling this function.
 | 
			
		||||
func (channel *Channel) UpdateAbilities() error {
 | 
			
		||||
	// A quick and dirty way to update abilities
 | 
			
		||||
	// First delete all abilities of this channel
 | 
			
		||||
	err := channel.DeleteAbilities()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	// Then add new abilities
 | 
			
		||||
	err = channel.AddAbilities()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateAbilityStatus(channelId int, status bool) error {
 | 
			
		||||
	return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
 | 
			
		||||
}
 | 
			
		||||
@@ -1,7 +1,6 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	_ "gorm.io/driver/sqlite"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -19,6 +18,8 @@ type Channel struct {
 | 
			
		||||
	Other              string  `json:"other"`
 | 
			
		||||
	Balance            float64 `json:"balance"` // in USD
 | 
			
		||||
	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"`
 | 
			
		||||
	Models             string  `json:"models"`
 | 
			
		||||
	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
 | 
			
		||||
@@ -49,13 +50,12 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetRandomChannel() (*Channel, error) {
 | 
			
		||||
	// TODO: consider weight
 | 
			
		||||
	channel := Channel{}
 | 
			
		||||
	var err error = nil
 | 
			
		||||
	if common.UsingSQLite {
 | 
			
		||||
		err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RANDOM()").Limit(1).First(&channel).Error
 | 
			
		||||
		err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
 | 
			
		||||
	} else {
 | 
			
		||||
		err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RAND()").Limit(1).First(&channel).Error
 | 
			
		||||
		err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
 | 
			
		||||
	}
 | 
			
		||||
	return &channel, err
 | 
			
		||||
}
 | 
			
		||||
@@ -63,18 +63,36 @@ func GetRandomChannel() (*Channel, error) {
 | 
			
		||||
func BatchInsertChannels(channels []Channel) error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Create(&channels).Error
 | 
			
		||||
	return err
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	for _, channel_ := range channels {
 | 
			
		||||
		err = channel_.AddAbilities()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (channel *Channel) Insert() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Create(channel).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	err = channel.AddAbilities()
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (channel *Channel) Update() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Model(channel).Updates(channel).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	DB.Model(channel).First(channel, "id = ?", channel.Id)
 | 
			
		||||
	err = channel.UpdateAbilities()
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -101,11 +119,19 @@ func (channel *Channel) UpdateBalance(balance float64) {
 | 
			
		||||
func (channel *Channel) Delete() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Delete(channel).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	err = channel.DeleteAbilities()
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateChannelStatusById(id int, status int) {
 | 
			
		||||
	err := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
 | 
			
		||||
	err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError("failed to update ability status: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError("failed to update channel status: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										44
									
								
								model/log.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								model/log.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,44 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import "one-api/common"
 | 
			
		||||
 | 
			
		||||
type Log struct {
 | 
			
		||||
	Id        int    `json:"id"`
 | 
			
		||||
	UserId    int    `json:"user_id" gorm:"index"`
 | 
			
		||||
	CreatedAt int64  `json:"created_at" gorm:"bigint"`
 | 
			
		||||
	Type      int    `json:"type" gorm:"index"`
 | 
			
		||||
	Content   string `json:"content"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RecordLog(userId int, logType int, content string) {
 | 
			
		||||
	log := &Log{
 | 
			
		||||
		UserId:    userId,
 | 
			
		||||
		CreatedAt: common.GetTimestamp(),
 | 
			
		||||
		Type:      logType,
 | 
			
		||||
		Content:   content,
 | 
			
		||||
	}
 | 
			
		||||
	err := DB.Create(log).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError("failed to record log: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAllLogs(logType int, startIdx int, num int) (logs []*Log, err error) {
 | 
			
		||||
	err = DB.Where("type = ?", logType).Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
 | 
			
		||||
	return logs, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUserLogs(userId int, logType int, startIdx int, num int) (logs []*Log, err error) {
 | 
			
		||||
	err = DB.Where("user_id = ? and type = ?", userId, logType).Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
 | 
			
		||||
	return logs, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SearchAllLogs(keyword string) (logs []*Log, err error) {
 | 
			
		||||
	err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
 | 
			
		||||
	return logs, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
 | 
			
		||||
	err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
 | 
			
		||||
	return logs, err
 | 
			
		||||
}
 | 
			
		||||
@@ -75,6 +75,10 @@ func InitDB() (err error) {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		err = db.AutoMigrate(&Ability{})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		err = createRootAccountIfNeed()
 | 
			
		||||
		return err
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,6 @@ package model
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	_ "gorm.io/driver/sqlite"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -84,7 +83,7 @@ func (redemption *Redemption) SelectUpdate() error {
 | 
			
		||||
// Update Make sure your token's fields is completed, because this will update non-zero values
 | 
			
		||||
func (redemption *Redemption) Update() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Model(redemption).Select("name", "status", "redeemed_time").Updates(redemption).Error
 | 
			
		||||
	err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time").Updates(redemption).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,7 +3,6 @@ package model
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	_ "gorm.io/driver/sqlite"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -22,6 +22,7 @@ type User struct {
 | 
			
		||||
	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database!
 | 
			
		||||
	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
 | 
			
		||||
	Quota            int    `json:"quota" gorm:"type:int;default:0"`
 | 
			
		||||
	Group            string `json:"group" gorm:"type:varchar(32);default:'default'"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetMaxUserId() int {
 | 
			
		||||
@@ -229,6 +230,11 @@ func GetUserEmail(id int) (email string, err error) {
 | 
			
		||||
	return email, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUserGroup(id int) (group string, err error) {
 | 
			
		||||
	err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
 | 
			
		||||
	return group, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IncreaseUserQuota(id int, quota int) (err error) {
 | 
			
		||||
	if quota < 0 {
 | 
			
		||||
		return errors.New("quota 不能为负数!")
 | 
			
		||||
 
 | 
			
		||||
@@ -63,6 +63,7 @@ func SetApiRouter(router *gin.Engine) {
 | 
			
		||||
		{
 | 
			
		||||
			channelRoute.GET("/", controller.GetAllChannels)
 | 
			
		||||
			channelRoute.GET("/search", controller.SearchChannels)
 | 
			
		||||
			channelRoute.GET("/models", controller.ListModels)
 | 
			
		||||
			channelRoute.GET("/:id", controller.GetChannel)
 | 
			
		||||
			channelRoute.GET("/test", controller.TestAllChannels)
 | 
			
		||||
			channelRoute.GET("/test/:id", controller.TestChannel)
 | 
			
		||||
@@ -92,5 +93,10 @@ func SetApiRouter(router *gin.Engine) {
 | 
			
		||||
			redemptionRoute.PUT("/", controller.UpdateRedemption)
 | 
			
		||||
			redemptionRoute.DELETE("/:id", controller.DeleteRedemption)
 | 
			
		||||
		}
 | 
			
		||||
		logRoute := apiRouter.Group("/log")
 | 
			
		||||
		logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
 | 
			
		||||
		logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
 | 
			
		||||
		logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs)
 | 
			
		||||
		logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -8,12 +8,16 @@ import (
 | 
			
		||||
 | 
			
		||||
func SetRelayRouter(router *gin.Engine) {
 | 
			
		||||
	// https://platform.openai.com/docs/api-reference/introduction
 | 
			
		||||
	modelsRouter := router.Group("/v1/models")
 | 
			
		||||
	modelsRouter.Use(middleware.TokenAuth())
 | 
			
		||||
	{
 | 
			
		||||
		modelsRouter.GET("/", controller.ListModels)
 | 
			
		||||
		modelsRouter.GET("/:model", controller.RetrieveModel)
 | 
			
		||||
	}
 | 
			
		||||
	relayV1Router := router.Group("/v1")
 | 
			
		||||
	relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
 | 
			
		||||
	{
 | 
			
		||||
		relayV1Router.GET("/models", controller.ListModels)
 | 
			
		||||
		relayV1Router.GET("/models/:model", controller.RetrieveModel)
 | 
			
		||||
		relayV1Router.POST("/completions", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.POST("/completions", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/chat/completions", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/edits", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.POST("/images/generations", controller.RelayNotImplemented)
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import { Link } from 'react-router-dom';
 | 
			
		||||
import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
 | 
			
		||||
 | 
			
		||||
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
 | 
			
		||||
import { renderGroup } from '../helpers/render';
 | 
			
		||||
 | 
			
		||||
function renderTimestamp(timestamp) {
 | 
			
		||||
  return (
 | 
			
		||||
@@ -264,6 +265,14 @@ const ChannelsTable = () => {
 | 
			
		||||
            >
 | 
			
		||||
              名称
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell
 | 
			
		||||
              style={{ cursor: 'pointer' }}
 | 
			
		||||
              onClick={() => {
 | 
			
		||||
                sortChannel('group');
 | 
			
		||||
              }}
 | 
			
		||||
            >
 | 
			
		||||
              分组
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell
 | 
			
		||||
              style={{ cursor: 'pointer' }}
 | 
			
		||||
              onClick={() => {
 | 
			
		||||
@@ -312,6 +321,7 @@ const ChannelsTable = () => {
 | 
			
		||||
                <Table.Row key={channel.id}>
 | 
			
		||||
                  <Table.Cell>{channel.id}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{channel.name ? channel.name : '无'}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{renderGroup(channel.group)}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{renderType(channel.type)}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{renderStatus(channel.status)}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
@@ -398,7 +408,7 @@ const ChannelsTable = () => {
 | 
			
		||||
 | 
			
		||||
        <Table.Footer>
 | 
			
		||||
          <Table.Row>
 | 
			
		||||
            <Table.HeaderCell colSpan='7'>
 | 
			
		||||
            <Table.HeaderCell colSpan='8'>
 | 
			
		||||
              <Button size='small' as={Link} to='/channel/add' loading={loading}>
 | 
			
		||||
                添加新的渠道
 | 
			
		||||
              </Button>
 | 
			
		||||
 
 | 
			
		||||
@@ -4,7 +4,7 @@ import { Link } from 'react-router-dom';
 | 
			
		||||
import { API, showError, showSuccess } from '../helpers';
 | 
			
		||||
 | 
			
		||||
import { ITEMS_PER_PAGE } from '../constants';
 | 
			
		||||
import { renderText } from '../helpers/render';
 | 
			
		||||
import { renderGroup, renderText } from '../helpers/render';
 | 
			
		||||
 | 
			
		||||
function renderRole(role) {
 | 
			
		||||
  switch (role) {
 | 
			
		||||
@@ -175,6 +175,14 @@ const UsersTable = () => {
 | 
			
		||||
            >
 | 
			
		||||
              用户名
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell
 | 
			
		||||
              style={{ cursor: 'pointer' }}
 | 
			
		||||
              onClick={() => {
 | 
			
		||||
                sortUser('group');
 | 
			
		||||
              }}
 | 
			
		||||
            >
 | 
			
		||||
              分组
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell
 | 
			
		||||
              style={{ cursor: 'pointer' }}
 | 
			
		||||
              onClick={() => {
 | 
			
		||||
@@ -231,6 +239,7 @@ const UsersTable = () => {
 | 
			
		||||
                      hoverable
 | 
			
		||||
                    />
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
                  <Table.Cell>{renderGroup(user.group)}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{user.email ? renderText(user.email, 30) : '无'}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{user.quota}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{renderRole(user.role)}</Table.Cell>
 | 
			
		||||
@@ -293,7 +302,6 @@ const UsersTable = () => {
 | 
			
		||||
                        size={'small'}
 | 
			
		||||
                        as={Link}
 | 
			
		||||
                        to={'/user/edit/' + user.id}
 | 
			
		||||
                        disabled={user.role === 100}
 | 
			
		||||
                      >
 | 
			
		||||
                        编辑
 | 
			
		||||
                      </Button>
 | 
			
		||||
@@ -306,7 +314,7 @@ const UsersTable = () => {
 | 
			
		||||
 | 
			
		||||
        <Table.Footer>
 | 
			
		||||
          <Table.Row>
 | 
			
		||||
            <Table.HeaderCell colSpan='7'>
 | 
			
		||||
            <Table.HeaderCell colSpan='8'>
 | 
			
		||||
              <Button size='small' as={Link} to='/user/add' loading={loading}>
 | 
			
		||||
                添加新的用户
 | 
			
		||||
              </Button>
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,15 @@
 | 
			
		||||
import { Label } from 'semantic-ui-react';
 | 
			
		||||
 | 
			
		||||
export function renderText(text, limit) {
 | 
			
		||||
  if (text.length > limit) {
 | 
			
		||||
    return text.slice(0, limit - 3) + '...';
 | 
			
		||||
  }
 | 
			
		||||
  return text;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export function renderGroup(group) {
 | 
			
		||||
  if (group === "") {
 | 
			
		||||
    return <Label>default</Label>
 | 
			
		||||
  }
 | 
			
		||||
  return <Label>{group}</Label>
 | 
			
		||||
}
 | 
			
		||||
@@ -1,7 +1,7 @@
 | 
			
		||||
import React, { useEffect, useState } from 'react';
 | 
			
		||||
import { Button, Form, Header, Message, Segment } from 'semantic-ui-react';
 | 
			
		||||
import { useParams } from 'react-router-dom';
 | 
			
		||||
import { API, showError, showSuccess } from '../../helpers';
 | 
			
		||||
import { API, showError, showInfo, showSuccess } from '../../helpers';
 | 
			
		||||
import { CHANNEL_OPTIONS } from '../../constants';
 | 
			
		||||
 | 
			
		||||
const EditChannel = () => {
 | 
			
		||||
@@ -14,12 +14,16 @@ const EditChannel = () => {
 | 
			
		||||
    type: 1,
 | 
			
		||||
    key: '',
 | 
			
		||||
    base_url: '',
 | 
			
		||||
    other: ''
 | 
			
		||||
    other: '',
 | 
			
		||||
    group: 'default',
 | 
			
		||||
    models: [],
 | 
			
		||||
  };
 | 
			
		||||
  const [batch, setBatch] = useState(false);
 | 
			
		||||
  const [inputs, setInputs] = useState(originInputs);
 | 
			
		||||
  const [modelOptions, setModelOptions] = useState([]);
 | 
			
		||||
  const [basicModels, setBasicModels] = useState([]);
 | 
			
		||||
  const [fullModels, setFullModels] = useState([]);
 | 
			
		||||
  const handleInputChange = (e, { name, value }) => {
 | 
			
		||||
    console.log(name, value);
 | 
			
		||||
    setInputs((inputs) => ({ ...inputs, [name]: value }));
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
@@ -27,21 +31,45 @@ const EditChannel = () => {
 | 
			
		||||
    let res = await API.get(`/api/channel/${channelId}`);
 | 
			
		||||
    const { success, message, data } = res.data;
 | 
			
		||||
    if (success) {
 | 
			
		||||
      data.password = '';
 | 
			
		||||
      if (data.models === "") {
 | 
			
		||||
        data.models = []
 | 
			
		||||
      } else {
 | 
			
		||||
        data.models = data.models.split(",")
 | 
			
		||||
      }
 | 
			
		||||
      setInputs(data);
 | 
			
		||||
    } else {
 | 
			
		||||
      showError(message);
 | 
			
		||||
    }
 | 
			
		||||
    setLoading(false);
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const fetchModels = async () => {
 | 
			
		||||
    try {
 | 
			
		||||
      let res = await API.get(`/api/channel/models`);
 | 
			
		||||
      setModelOptions(res.data.data.map((model) => ({
 | 
			
		||||
        key: model.id,
 | 
			
		||||
        text: model.id,
 | 
			
		||||
        value: model.id,
 | 
			
		||||
      })));
 | 
			
		||||
      setFullModels(res.data.data.map((model) => model.id));
 | 
			
		||||
      setBasicModels(res.data.data.filter((model) => !model.id.startsWith("gpt-4")).map((model) => model.id));
 | 
			
		||||
    } catch (error) {
 | 
			
		||||
      showError(error.message);
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  useEffect(() => {
 | 
			
		||||
    if (isEdit) {
 | 
			
		||||
      loadChannel().then();
 | 
			
		||||
    }
 | 
			
		||||
    fetchModels().then();
 | 
			
		||||
  }, []);
 | 
			
		||||
 | 
			
		||||
  const submit = async () => {
 | 
			
		||||
    if (!isEdit && (inputs.name === '' || inputs.key === '')) return;
 | 
			
		||||
    if (!isEdit && (inputs.name === '' || inputs.key === '')) {
 | 
			
		||||
      showInfo('请填写渠道名称和渠道密钥!');
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
    let localInputs = inputs;
 | 
			
		||||
    if (localInputs.base_url.endsWith('/')) {
 | 
			
		||||
      localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
 | 
			
		||||
@@ -50,6 +78,7 @@ const EditChannel = () => {
 | 
			
		||||
      localInputs.other = '2023-03-15-preview';
 | 
			
		||||
    }
 | 
			
		||||
    let res;
 | 
			
		||||
    localInputs.models = localInputs.models.join(",")
 | 
			
		||||
    if (isEdit) {
 | 
			
		||||
      res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) });
 | 
			
		||||
    } else {
 | 
			
		||||
@@ -137,6 +166,52 @@ const EditChannel = () => {
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Field>
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
            <Form.Input
 | 
			
		||||
              label='分组'
 | 
			
		||||
              name='group'
 | 
			
		||||
              placeholder={'请输入分组'}
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              value={inputs.group}
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Field>
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
            <Form.Dropdown
 | 
			
		||||
              label='模型'
 | 
			
		||||
              placeholder={'请选择该通道所支持的模型'}
 | 
			
		||||
              name='models'
 | 
			
		||||
              fluid
 | 
			
		||||
              multiple
 | 
			
		||||
              selection
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              value={inputs.models}
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
              options={modelOptions}
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Field>
 | 
			
		||||
          <div style={{ lineHeight: '40px', marginBottom: '12px'}}>
 | 
			
		||||
            <Button type={'button'} onClick={() => {
 | 
			
		||||
              handleInputChange(null, { name: 'models', value: basicModels });
 | 
			
		||||
            }}>填入基础模型</Button>
 | 
			
		||||
            <Button type={'button'} onClick={() => {
 | 
			
		||||
              handleInputChange(null, { name: 'models', value: fullModels });
 | 
			
		||||
            }}>填入所有模型</Button>
 | 
			
		||||
          </div>
 | 
			
		||||
          {
 | 
			
		||||
            inputs.type === 1 && (
 | 
			
		||||
              <Form.Field>
 | 
			
		||||
                <Form.Input
 | 
			
		||||
                  label='代理'
 | 
			
		||||
                  name='base_url'
 | 
			
		||||
                  placeholder={'请输入 OpenAI API 代理地址,如果不需要请留空,格式为:https://api.openai.com'}
 | 
			
		||||
                  onChange={handleInputChange}
 | 
			
		||||
                  value={inputs.base_url}
 | 
			
		||||
                  autoComplete='new-password'
 | 
			
		||||
                />
 | 
			
		||||
              </Form.Field>
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
          {
 | 
			
		||||
            batch ? <Form.Field>
 | 
			
		||||
              <Form.TextArea
 | 
			
		||||
 
 | 
			
		||||
@@ -15,8 +15,9 @@ const EditUser = () => {
 | 
			
		||||
    wechat_id: '',
 | 
			
		||||
    email: '',
 | 
			
		||||
    quota: 0,
 | 
			
		||||
    group: 'default'
 | 
			
		||||
  });
 | 
			
		||||
  const { username, display_name, password, github_id, wechat_id, email, quota } =
 | 
			
		||||
  const { username, display_name, password, github_id, wechat_id, email, quota, group } =
 | 
			
		||||
    inputs;
 | 
			
		||||
  const handleInputChange = (e, { name, value }) => {
 | 
			
		||||
    setInputs((inputs) => ({ ...inputs, [name]: value }));
 | 
			
		||||
@@ -98,7 +99,17 @@ const EditUser = () => {
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Field>
 | 
			
		||||
          {
 | 
			
		||||
            userId && (
 | 
			
		||||
            userId && <>
 | 
			
		||||
              <Form.Field>
 | 
			
		||||
                <Form.Input
 | 
			
		||||
                  label='分组'
 | 
			
		||||
                  name='group'
 | 
			
		||||
                  placeholder={'请输入用户分组'}
 | 
			
		||||
                  onChange={handleInputChange}
 | 
			
		||||
                  value={group}
 | 
			
		||||
                  autoComplete='new-password'
 | 
			
		||||
                />
 | 
			
		||||
              </Form.Field>
 | 
			
		||||
              <Form.Field>
 | 
			
		||||
                <Form.Input
 | 
			
		||||
                  label='剩余额度'
 | 
			
		||||
@@ -110,7 +121,7 @@ const EditUser = () => {
 | 
			
		||||
                  autoComplete='new-password'
 | 
			
		||||
                />
 | 
			
		||||
              </Form.Field>
 | 
			
		||||
            )
 | 
			
		||||
            </>
 | 
			
		||||
          }
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
            <Form.Input
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user