Compare commits

..

25 Commits

Author SHA1 Message Date
quzard
e87ad1f402 chore: remove -0613 suffix for Azure (#163) 2023-06-14 16:33:03 +08:00
JustSong
07cccdc8c0 docs: update issue template 2023-06-14 15:13:05 +08:00
JustSong
f71f01662c docs: update issue template 2023-06-14 15:03:51 +08:00
JustSong
54d7a1c2e8 docs: update issue template 2023-06-14 15:02:36 +08:00
JustSong
f426f31bd7 docs: update issue template 2023-06-14 14:59:24 +08:00
JustSong
2930577cd6 docs: update issue template 2023-06-14 14:51:48 +08:00
JustSong
e09512177a docs: add issue templates 2023-06-14 14:48:31 +08:00
JustSong
d6dbaff3c2 fix: fix file not committed 2023-06-14 12:52:56 +08:00
JustSong
7f9577a386 feat: now one channel can belong to multiple groups (close #153) 2023-06-14 12:14:08 +08:00
JustSong
38668e7331 chore: update gpt3.5 completion ratio 2023-06-14 09:41:06 +08:00
JustSong
323f3d263a feat: add new released models 2023-06-14 09:12:14 +08:00
JustSong
0c34ed4c61 docs: update README 2023-06-13 17:45:01 +08:00
JustSong
7c7eb6b7ec fix: now the input field can be array type now (close #149) 2023-06-12 16:11:57 +08:00
JustSong
8b2ef666ef fix: fix OpenAI-SB balance not correct 2023-06-12 09:40:49 +08:00
JustSong
955d5f8707 fix: fix group list not correct (close #147) 2023-06-12 09:11:48 +08:00
quzard
47ca449e32 feat: add support for updating balance of channel typpe OpenAI-SB (#146, close #125)
* Add support for updating channel balance in OpenAISB

* fix: handel error

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-06-11 21:04:41 +08:00
JustSong
39481eb6c0 chore: add trailing slash for API calling 2023-06-11 16:33:40 +08:00
JustSong
69153e7231 docs: update README 2023-06-11 12:37:15 +08:00
JustSong
cdef10cad8 docs: update README 2023-06-11 11:11:47 +08:00
JustSong
077853416d chore: record ratio detail in log 2023-06-11 11:11:19 +08:00
JustSong
596446dba4 feat: able to set group ratio now (close #62, close #142) 2023-06-11 11:08:16 +08:00
JustSong
9d0bec83df chore: update prompt 2023-06-11 09:55:50 +08:00
JustSong
f97a9ce597 fix: correct OpenAI error code's type 2023-06-11 09:49:57 +08:00
JustSong
4339f45f74 feat: support /v1/moderations now (close #117) 2023-06-11 09:37:36 +08:00
JustSong
e398e0756b docs: update README 2023-06-10 20:43:32 +08:00
22 changed files with 414 additions and 70 deletions

23
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View File

@@ -0,0 +1,23 @@
---
name: 报告问题
about: 使用简练详细的语言描述你遇到的问题
title: ''
labels: bug
assignees: ''
---
**例行检查**
+ [ ] 我已确认目前没有类似 issue
+ [ ] 我已确认我已升级到最新版本
+ [ ] 我理解并愿意跟进此 issue协助测试和提供反馈
+ [ ] 我理解并认可上述内容,并理解项目维护者精力有限,不遵循规则的 issue 可能会被无视或直接关闭
**问题描述**
**复现步骤**
**预期结果**
**相关截图**
如果没有的话,请删除此节。

11
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1,11 @@
blank_issues_enabled: false
contact_links:
- name: 项目群聊
url: https://openai.justsong.cn/
about: 演示站首页有官方群聊信息
- name: 赞赏支持
url: https://iamazing.cn/page/reward
about: 请作者喝杯咖啡,以激励作者持续开发
- name: 付费部署或定制功能
url: https://openai.justsong.cn/
about: 加群后联系群主

View File

@@ -0,0 +1,18 @@
---
name: 功能请求
about: 使用简练详细的语言描述希望加入的新功能
title: ''
labels: enhancement
assignees: ''
---
**例行检查**
+ [ ] 我已确认目前没有类似 issue
+ [ ] 我已确认我已升级到最新版本
+ [ ] 我理解并愿意跟进此 issue协助测试和提供反馈
+ [ ] 我理解并认可上述内容,并理解项目维护者精力有限,不遵循规则的 issue 可能会被无视或直接关闭
**功能描述**
**应用场景**

View File

@@ -66,18 +66,19 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
5. 支持**令牌管理**,设置令牌的过期时间和使用次数。 5. 支持**令牌管理**,设置令牌的过期时间和使用次数。
6. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 6. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。
7. 支持**通道管理**,批量创建通道。 7. 支持**通道管理**,批量创建通道。
8. 支持**用户分组**以及**渠道分组**。 8. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率
9. 支持渠道**设置模型列表**。 9. 支持渠道**设置模型列表**。
10. 支持发布公告,设置充值链接,设置新用户初始额度 10. 支持**查看额度明细**
11. 支持丰富的**自定义**设置, 11. 支持发布公告,设置充值链接,设置新用户初始额度。
12. 支持丰富的**自定义**设置,
1. 支持自定义系统名称logo 以及页脚。 1. 支持自定义系统名称logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
12. 支持通过系统访问令牌访问管理 API。 13. 支持通过系统访问令牌访问管理 API。
13. 支持用户管理,支持**多种用户登录注册方式** 14. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册以及通过邮箱进行密码重置。 + 邮箱登录注册以及通过邮箱进行密码重置。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。 + [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
14. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。 15. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
## 部署 ## 部署
### 基于 Docker 进行部署 ### 基于 Docker 进行部署
@@ -116,6 +117,8 @@ sudo certbot --nginx
sudo service nginx restart sudo service nginx restart
``` ```
初始账号用户名为 `root`,密码为 `123456`
### 手动部署 ### 手动部署
1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译:
```shell ```shell
@@ -200,3 +203,11 @@ https://openai.justsong.cn
+ 令牌额度仅供用户设置最大使用量,用户可自由设置。 + 令牌额度仅供用户设置最大使用量,用户可自由设置。
2. 宝塔部署后访问出现空白页面? 2. 宝塔部署后访问出现空白页面?
+ 自动配置的问题,详见[#97](https://github.com/songquanpeng/one-api/issues/97)。 + 自动配置的问题,详见[#97](https://github.com/songquanpeng/one-api/issues/97)。
3. 提示无可用渠道?
+ 请检查的用户分组和渠道分组设置。
+ 以及渠道的模型设置。
## 注意
本项目为开源项目,请在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及法律法规的情况下使用,不得用于非法用途。
本项目依据 MIT 协议开源,请以某种方式保留 One API 的版权信息。

31
common/group-ratio.go Normal file
View File

@@ -0,0 +1,31 @@
package common
import "encoding/json"
var GroupRatio = map[string]float64{
"default": 1,
"vip": 1,
"svip": 1,
}
func GroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(GroupRatio)
if err != nil {
SysError("Error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateGroupRatioByJSONString(jsonStr string) error {
GroupRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &GroupRatio)
}
func GetGroupRatio(name string) float64 {
ratio, ok := GroupRatio[name]
if !ok {
SysError("Group ratio not found: " + name)
return 1
}
return ratio
}

View File

@@ -2,16 +2,23 @@ package common
import "encoding/json" import "encoding/json"
// ModelRatio
// https://platform.openai.com/docs/models/model-endpoint-compatibility // https://platform.openai.com/docs/models/model-endpoint-compatibility
// https://openai.com/pricing // https://openai.com/pricing
// TODO: when a new api is enabled, check the pricing here // TODO: when a new api is enabled, check the pricing here
// 1 === $0.002 / 1K tokens
var ModelRatio = map[string]float64{ var ModelRatio = map[string]float64{
"gpt-4": 15, "gpt-4": 15,
"gpt-4-0314": 15, "gpt-4-0314": 15,
"gpt-4-0613": 15,
"gpt-4-32k": 30, "gpt-4-32k": 30,
"gpt-4-32k-0314": 30, "gpt-4-32k-0314": 30,
"gpt-3.5-turbo": 1, // $0.002 / 1K tokens "gpt-4-32k-0613": 30,
"gpt-3.5-turbo-0301": 1, "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5,
"text-ada-001": 0.2, "text-ada-001": 0.2,
"text-babbage-001": 0.25, "text-babbage-001": 0.25,
"text-curie-001": 1, "text-curie-001": 1,
@@ -26,8 +33,8 @@ var ModelRatio = map[string]float64{
"ada": 10, "ada": 10,
"text-embedding-ada-002": 0.2, "text-embedding-ada-002": 0.2,
"text-search-ada-doc-001": 10, "text-search-ada-doc-001": 10,
"text-moderation-stable": 10, "text-moderation-stable": 0.1,
"text-moderation-latest": 10, "text-moderation-latest": 0.1,
} }
func ModelRatio2JSONString() string { func ModelRatio2JSONString() string {
@@ -39,6 +46,7 @@ func ModelRatio2JSONString() string {
} }
func UpdateModelRatioByJSONString(jsonStr string) error { func UpdateModelRatioByJSONString(jsonStr string) error {
ModelRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &ModelRatio) return json.Unmarshal([]byte(jsonStr), &ModelRatio)
} }

View File

@@ -37,6 +37,58 @@ type OpenAIUsageResponse struct {
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
} }
type OpenAISBUsageResponse struct {
Msg string `json:"msg"`
Data *struct {
Credit string `json:"credit"`
} `json:"data"`
}
func GetResponseBody(method, url string, channel *model.Channel) ([]byte, error) {
client := &http.Client{}
req, err := http.NewRequest(method, url, nil)
if err != nil {
return nil, err
}
auth := fmt.Sprintf("Bearer %s", channel.Key)
req.Header.Add("Authorization", auth)
res, err := client.Do(req)
if err != nil {
return nil, err
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
err = res.Body.Close()
if err != nil {
return nil, err
}
return body, nil
}
func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
body, err := GetResponseBody("GET", url, channel)
if err != nil {
return 0, err
}
response := OpenAISBUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
if response.Data == nil {
return 0, errors.New(response.Msg)
}
balance, err := strconv.ParseFloat(response.Data.Credit, 64)
if err != nil {
return 0, err
}
channel.UpdateBalance(balance)
return balance, nil
}
func updateChannelBalance(channel *model.Channel) (float64, error) { func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type] baseURL := common.ChannelBaseURLs[channel.Type]
switch channel.Type { switch channel.Type {
@@ -48,27 +100,14 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return 0, errors.New("尚未实现") return 0, errors.New("尚未实现")
case common.ChannelTypeCustom: case common.ChannelTypeCustom:
baseURL = channel.BaseURL baseURL = channel.BaseURL
case common.ChannelTypeOpenAISB:
return updateChannelOpenAISBBalance(channel)
default: default:
return 0, errors.New("尚未实现") return 0, errors.New("尚未实现")
} }
url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL) url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
client := &http.Client{} body, err := GetResponseBody("GET", url, channel)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return 0, err
}
auth := fmt.Sprintf("Bearer %s", channel.Key)
req.Header.Add("Authorization", auth)
res, err := client.Do(req)
if err != nil {
return 0, err
}
body, err := io.ReadAll(res.Body)
if err != nil {
return 0, err
}
err = res.Body.Close()
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -84,20 +123,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
startDate = now.AddDate(0, 0, -100).Format("2006-01-02") startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
} }
url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate) url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
req, err = http.NewRequest("GET", url, nil) body, err = GetResponseBody("GET", url, channel)
if err != nil {
return 0, err
}
req.Header.Add("Authorization", auth)
res, err = client.Do(req)
if err != nil {
return 0, err
}
body, err = io.ReadAll(res.Body)
if err != nil {
return 0, err
}
err = res.Body.Close()
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@@ -59,7 +59,7 @@ func testChannel(channel *model.Channel, request *ChatRequest) error {
return err return err
} }
if response.Usage.CompletionTokens == 0 { if response.Usage.CompletionTokens == 0 {
return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message)) return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
} }
return nil return nil
} }

19
controller/group.go Normal file
View File

@@ -0,0 +1,19 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
)
func GetGroups(c *gin.Context) {
groupNames := make([]string, 0)
for groupName, _ := range common.GroupRatio {
groupNames = append(groupNames, groupName)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": groupNames,
})
}

View File

@@ -71,6 +71,33 @@ func init() {
Root: "gpt-3.5-turbo-0301", Root: "gpt-3.5-turbo-0301",
Parent: nil, Parent: nil,
}, },
{
Id: "gpt-3.5-turbo-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0613",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-16k",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-16k",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-16k-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-16k-0613",
Parent: nil,
},
{ {
Id: "gpt-4", Id: "gpt-4",
Object: "model", Object: "model",
@@ -89,6 +116,15 @@ func init() {
Root: "gpt-4-0314", Root: "gpt-4-0314",
Parent: nil, Parent: nil,
}, },
{
Id: "gpt-4-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0613",
Parent: nil,
},
{ {
Id: "gpt-4-32k", Id: "gpt-4-32k",
Object: "model", Object: "model",
@@ -107,6 +143,15 @@ func init() {
Root: "gpt-4-32k-0314", Root: "gpt-4-32k-0314",
Parent: nil, Parent: nil,
}, },
{
Id: "gpt-4-32k-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k-0613",
Parent: nil,
},
{ {
Id: "text-embedding-ada-002", Id: "text-embedding-ada-002",
Object: "model", Object: "model",
@@ -161,6 +206,24 @@ func init() {
Root: "text-ada-001", Root: "text-ada-001",
Parent: nil, Parent: nil,
}, },
{
Id: "text-moderation-latest",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-moderation-latest",
Parent: nil,
},
{
Id: "text-moderation-stable",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-moderation-stable",
Parent: nil,
},
} }
openAIModelsMap = make(map[string]OpenAIModels) openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels { for _, model := range openAIModels {

View File

@@ -58,6 +58,20 @@ func countTokenMessages(messages []Message, model string) int {
return tokenNum return tokenNum
} }
func countTokenInput(input any, model string) int {
switch input.(type) {
case string:
return countTokenText(input.(string), model)
case []string:
text := ""
for _, s := range input.([]string) {
text += s
}
return countTokenText(text, model)
}
return 0
}
func countTokenText(text string, model string) int { func countTokenText(text string, model string) int {
tokenEncoder := getTokenEncoder(model) tokenEncoder := getTokenEncoder(model)
token := tokenEncoder.Encode(text, nil, nil) token := tokenEncoder.Encode(text, nil, nil)

View File

@@ -24,6 +24,7 @@ const (
RelayModeChatCompletions RelayModeChatCompletions
RelayModeCompletions RelayModeCompletions
RelayModeEmbeddings RelayModeEmbeddings
RelayModeModeration
) )
// https://platform.openai.com/docs/api-reference/chat // https://platform.openai.com/docs/api-reference/chat
@@ -37,6 +38,7 @@ type GeneralOpenAIRequest struct {
Temperature float64 `json:"temperature"` Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"` TopP float64 `json:"top_p"`
N int `json:"n"` N int `json:"n"`
Input any `json:"input"`
} }
type ChatRequest struct { type ChatRequest struct {
@@ -63,7 +65,7 @@ type OpenAIError struct {
Message string `json:"message"` Message string `json:"message"`
Type string `json:"type"` Type string `json:"type"`
Param string `json:"param"` Param string `json:"param"`
Code string `json:"code"` Code any `json:"code"`
} }
type OpenAIErrorWithStatusCode struct { type OpenAIErrorWithStatusCode struct {
@@ -100,11 +102,13 @@ func Relay(c *gin.Context) {
relayMode = RelayModeCompletions relayMode = RelayModeCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
relayMode = RelayModeModeration
} }
err := relayHelper(c, relayMode) err := relayHelper(c, relayMode)
if err != nil { if err != nil {
if err.StatusCode == http.StatusTooManyRequests { if err.StatusCode == http.StatusTooManyRequests {
err.OpenAIError.Message = "负载已,请稍后再试,或升级账户以提升服务质量。" err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
} }
c.JSON(err.StatusCode, gin.H{ c.JSON(err.StatusCode, gin.H{
"error": err.OpenAIError, "error": err.OpenAIError,
@@ -136,6 +140,7 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
consumeQuota := c.GetBool("consume_quota") consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
var textRequest GeneralOpenAIRequest var textRequest GeneralOpenAIRequest
if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
err := common.UnmarshalBodyReusable(c, &textRequest) err := common.UnmarshalBodyReusable(c, &textRequest)
@@ -143,6 +148,9 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
} }
} }
if relayMode == RelayModeModeration && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
baseURL := common.ChannelBaseURLs[channelType] baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String() requestURL := c.Request.URL.String()
if channelType == common.ChannelTypeCustom { if channelType == common.ChannelTypeCustom {
@@ -169,6 +177,7 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
// https://github.com/songquanpeng/one-api/issues/67 // https://github.com/songquanpeng/one-api/issues/67
model_ = strings.TrimSuffix(model_, "-0301") model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314") model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613")
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
} else if channelType == common.ChannelTypePaLM { } else if channelType == common.ChannelTypePaLM {
err := relayPaLM(textRequest, c) err := relayPaLM(textRequest, c)
@@ -180,12 +189,16 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
case RelayModeCompletions: case RelayModeCompletions:
promptTokens = countTokenText(textRequest.Prompt, textRequest.Model) promptTokens = countTokenText(textRequest.Prompt, textRequest.Model)
case RelayModeModeration:
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
} }
preConsumedTokens := common.PreConsumedQuota preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 { if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens preConsumedTokens = promptTokens + textRequest.MaxTokens
} }
ratio := common.GetModelRatio(textRequest.Model) modelRatio := common.GetModelRatio(textRequest.Model)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio) preConsumedQuota := int(float64(preConsumedTokens) * ratio)
if consumeQuota { if consumeQuota {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
@@ -227,25 +240,27 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
defer func() { defer func() {
if consumeQuota { if consumeQuota {
quota := 0 quota := 0
usingGPT4 := strings.HasPrefix(textRequest.Model, "gpt-4") completionRatio := 1.34 // default for gpt-3
completionRatio := 1 if strings.HasPrefix(textRequest.Model, "gpt-4") {
if usingGPT4 {
completionRatio = 2 completionRatio = 2
} }
if isStream { if isStream {
responseTokens := countTokenText(streamResponseText, textRequest.Model) responseTokens := countTokenText(streamResponseText, textRequest.Model)
quota = promptTokens + responseTokens*completionRatio quota = promptTokens + int(float64(responseTokens)*completionRatio)
} else { } else {
quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio quota = textResponse.Usage.PromptTokens + int(float64(textResponse.Usage.CompletionTokens)*completionRatio)
} }
quota = int(float64(quota) * ratio) quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
}
quotaDelta := quota - preConsumedQuota quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta) err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil { if err != nil {
common.SysError("Error consuming token remain quota: " + err.Error()) common.SysError("Error consuming token remain quota: " + err.Error())
} }
userId := c.GetInt("id") userId := c.GetInt("id")
model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("使用模型 %s 消耗 %d 点额度", textRequest.Model, quota)) model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("使用模型 %s 消耗 %d 点额度(模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", textRequest.Model, quota, modelRatio, groupRatio, completionRatio))
} }
}() }()

View File

@@ -7,6 +7,7 @@ import (
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"strings"
) )
type ModelRequest struct { type ModelRequest struct {
@@ -15,6 +16,9 @@ type ModelRequest struct {
func Distribute() func(c *gin.Context) { func Distribute() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
userId := c.GetInt("id")
userGroup, _ := model.GetUserGroup(userId)
c.Set("group", userGroup)
var channel *model.Channel var channel *model.Channel
channelId, ok := c.Get("channelId") channelId, ok := c.Get("channelId")
if ok { if ok {
@@ -64,8 +68,11 @@ func Distribute() func(c *gin.Context) {
c.Abort() c.Abort()
return return
} }
userId := c.GetInt("id") if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
userGroup, _ := model.GetUserGroup(userId) if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model) channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil { if err != nil {
c.JSON(200, gin.H{ c.JSON(200, gin.H{

View File

@@ -30,15 +30,18 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
func (channel *Channel) AddAbilities() error { func (channel *Channel) AddAbilities() error {
models_ := strings.Split(channel.Models, ",") models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",")
abilities := make([]Ability, 0, len(models_)) abilities := make([]Ability, 0, len(models_))
for _, model := range models_ { for _, model := range models_ {
ability := Ability{ for _, group := range groups_ {
Group: channel.Group, ability := Ability{
Model: model, Group: group,
ChannelId: channel.Id, Model: model,
Enabled: channel.Status == common.ChannelStatusEnabled, ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
}
abilities = append(abilities, ability)
} }
abilities = append(abilities, ability)
} }
return DB.Create(&abilities).Error return DB.Create(&abilities).Error
} }

View File

@@ -58,6 +58,7 @@ func InitOptionMap() {
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMapRWMutex.Unlock() common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase() loadOptionsFromDatabase()
@@ -177,6 +178,8 @@ func updateOptionMap(key string, value string) (err error) {
common.PreConsumedQuota, _ = strconv.Atoi(value) common.PreConsumedQuota, _ = strconv.Atoi(value)
case "ModelRatio": case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value) err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
case "TopUpLink": case "TopUpLink":
common.TopUpLink = value common.TopUpLink = value
case "ChannelDisableThreshold": case "ChannelDisableThreshold":

View File

@@ -98,5 +98,10 @@ func SetApiRouter(router *gin.Engine) {
logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs) logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs) logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs)
logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs) logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs)
groupRoute := apiRouter.Group("/group")
groupRoute.Use(middleware.AdminAuth())
{
groupRoute.GET("/", controller.GetGroups)
}
} }
} }

View File

@@ -37,6 +37,6 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
relayV1Router.POST("/moderations", controller.RelayNotImplemented) relayV1Router.POST("/moderations", controller.Relay)
} }
} }

View File

@@ -27,6 +27,13 @@ function renderType(type) {
return <Label basic color={type2label[type].color}>{type2label[type].text}</Label>; return <Label basic color={type2label[type].color}>{type2label[type].text}</Label>;
} }
function renderBalance(type, balance) {
if (type === 5) {
return <span>¥{(balance / 10000).toFixed(2)}</span>
}
return <span>${balance.toFixed(2)}</span>
}
const ChannelsTable = () => { const ChannelsTable = () => {
const [channels, setChannels] = useState([]); const [channels, setChannels] = useState([]);
const [loading, setLoading] = useState(true); const [loading, setLoading] = useState(true);
@@ -336,7 +343,7 @@ const ChannelsTable = () => {
<Popup <Popup
content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'} content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'}
key={channel.id} key={channel.id}
trigger={<span>${channel.balance.toFixed(2)}</span>} trigger={renderBalance(channel.type, channel.balance)}
basic basic
/> />
</Table.Cell> </Table.Cell>

View File

@@ -30,6 +30,7 @@ const SystemSetting = () => {
QuotaRemindThreshold: 0, QuotaRemindThreshold: 0,
PreConsumedQuota: 0, PreConsumedQuota: 0,
ModelRatio: '', ModelRatio: '',
GroupRatio: '',
TopUpLink: '', TopUpLink: '',
AutomaticDisableChannelEnabled: '', AutomaticDisableChannelEnabled: '',
ChannelDisableThreshold: 0, ChannelDisableThreshold: 0,
@@ -101,6 +102,7 @@ const SystemSetting = () => {
name === 'QuotaRemindThreshold' || name === 'QuotaRemindThreshold' ||
name === 'PreConsumedQuota' || name === 'PreConsumedQuota' ||
name === 'ModelRatio' || name === 'ModelRatio' ||
name === 'GroupRatio' ||
name === 'TopUpLink' name === 'TopUpLink'
) { ) {
setInputs((inputs) => ({ ...inputs, [name]: value })); setInputs((inputs) => ({ ...inputs, [name]: value }));
@@ -131,6 +133,13 @@ const SystemSetting = () => {
} }
await updateOption('ModelRatio', inputs.ModelRatio); await updateOption('ModelRatio', inputs.ModelRatio);
} }
if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
if (!verifyJSON(inputs.GroupRatio)) {
showError('分组倍率不是合法的 JSON 字符串');
return;
}
await updateOption('GroupRatio', inputs.GroupRatio);
}
if (originInputs['TopUpLink'] !== inputs.TopUpLink) { if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
await updateOption('TopUpLink', inputs.TopUpLink); await updateOption('TopUpLink', inputs.TopUpLink);
} }
@@ -329,6 +338,17 @@ const SystemSetting = () => {
placeholder='为一个 JSON 文本,键为模型名称,值为倍率' placeholder='为一个 JSON 文本,键为模型名称,值为倍率'
/> />
</Form.Group> </Form.Group>
<Form.Group widths='equal'>
<Form.TextArea
label='分组倍率'
name='GroupRatio'
onChange={handleInputChange}
style={{ minHeight: 250, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete='new-password'
value={inputs.GroupRatio}
placeholder='为一个 JSON 文本,键为分组名称,值为倍率'
/>
</Form.Group>
<Form.Button onClick={submitOperationConfig}>保存运营设置</Form.Button> <Form.Button onClick={submitOperationConfig}>保存运营设置</Form.Button>
<Divider /> <Divider />
<Header as='h3'> <Header as='h3'>

View File

@@ -11,5 +11,16 @@ export function renderGroup(group) {
if (group === "") { if (group === "") {
return <Label>default</Label> return <Label>default</Label>
} }
return <Label>{group}</Label> let groups = group.split(",");
groups.sort();
return <>
{groups.map((group) => {
if (group === "vip" || group === "pro") {
return <Label color='yellow'>{group}</Label>
} else if (group === "svip" || group === "premium") {
return <Label color='red'>{group}</Label>
}
return <Label>{group}</Label>
})}
</>
} }

View File

@@ -15,12 +15,13 @@ const EditChannel = () => {
key: '', key: '',
base_url: '', base_url: '',
other: '', other: '',
group: 'default',
models: [], models: [],
groups: ['default']
}; };
const [batch, setBatch] = useState(false); const [batch, setBatch] = useState(false);
const [inputs, setInputs] = useState(originInputs); const [inputs, setInputs] = useState(originInputs);
const [modelOptions, setModelOptions] = useState([]); const [modelOptions, setModelOptions] = useState([]);
const [groupOptions, setGroupOptions] = useState([]);
const [basicModels, setBasicModels] = useState([]); const [basicModels, setBasicModels] = useState([]);
const [fullModels, setFullModels] = useState([]); const [fullModels, setFullModels] = useState([]);
const handleInputChange = (e, { name, value }) => { const handleInputChange = (e, { name, value }) => {
@@ -36,6 +37,11 @@ const EditChannel = () => {
} else { } else {
data.models = data.models.split(",") data.models = data.models.split(",")
} }
if (data.group === "") {
data.groups = []
} else {
data.groups = data.group.split(",")
}
setInputs(data); setInputs(data);
} else { } else {
showError(message); showError(message);
@@ -58,11 +64,25 @@ const EditChannel = () => {
} }
}; };
const fetchGroups = async () => {
try {
let res = await API.get(`/api/group/`);
setGroupOptions(res.data.data.map((group) => ({
key: group,
text: group,
value: group,
})));
} catch (error) {
showError(error.message);
}
};
useEffect(() => { useEffect(() => {
if (isEdit) { if (isEdit) {
loadChannel().then(); loadChannel().then();
} }
fetchModels().then(); fetchModels().then();
fetchGroups().then();
}, []); }, []);
const submit = async () => { const submit = async () => {
@@ -79,6 +99,7 @@ const EditChannel = () => {
} }
let res; let res;
localInputs.models = localInputs.models.join(",") localInputs.models = localInputs.models.join(",")
localInputs.group = localInputs.groups.join(",")
if (isEdit) { if (isEdit) {
res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) }); res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) });
} else { } else {
@@ -167,13 +188,19 @@ const EditChannel = () => {
/> />
</Form.Field> </Form.Field>
<Form.Field> <Form.Field>
<Form.Input <Form.Dropdown
label='分组' label='分组'
name='group' placeholder={'请选择分组'}
placeholder={'请输入分组'} name='groups'
fluid
multiple
selection
allowAdditions
additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'}
onChange={handleInputChange} onChange={handleInputChange}
value={inputs.group} value={inputs.groups}
autoComplete='new-password' autoComplete='new-password'
options={groupOptions}
/> />
</Form.Field> </Form.Field>
<Form.Field> <Form.Field>

View File

@@ -17,11 +17,24 @@ const EditUser = () => {
quota: 0, quota: 0,
group: 'default' group: 'default'
}); });
const [groupOptions, setGroupOptions] = useState([]);
const { username, display_name, password, github_id, wechat_id, email, quota, group } = const { username, display_name, password, github_id, wechat_id, email, quota, group } =
inputs; inputs;
const handleInputChange = (e, { name, value }) => { const handleInputChange = (e, { name, value }) => {
setInputs((inputs) => ({ ...inputs, [name]: value })); setInputs((inputs) => ({ ...inputs, [name]: value }));
}; };
const fetchGroups = async () => {
try {
let res = await API.get(`/api/group/`);
setGroupOptions(res.data.data.map((group) => ({
key: group,
text: group,
value: group,
})));
} catch (error) {
showError(error.message);
}
};
const loadUser = async () => { const loadUser = async () => {
let res = undefined; let res = undefined;
@@ -41,6 +54,9 @@ const EditUser = () => {
}; };
useEffect(() => { useEffect(() => {
loadUser().then(); loadUser().then();
if (userId) {
fetchGroups().then();
}
}, []); }, []);
const submit = async () => { const submit = async () => {
@@ -101,13 +117,19 @@ const EditUser = () => {
{ {
userId && <> userId && <>
<Form.Field> <Form.Field>
<Form.Input <Form.Dropdown
label='分组' label='分组'
placeholder={'请选择分组'}
name='group' name='group'
placeholder={'请输入用户分组'} fluid
search
selection
allowAdditions
additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'}
onChange={handleInputChange} onChange={handleInputChange}
value={group} value={inputs.group}
autoComplete='new-password' autoComplete='new-password'
options={groupOptions}
/> />
</Form.Field> </Form.Field>
<Form.Field> <Form.Field>