merge upstream

Signed-off-by: wozulong <>
This commit is contained in:
wozulong
2024-05-31 11:14:25 +08:00
32 changed files with 4523 additions and 369 deletions

214
LICENSE
View File

@@ -1,21 +1,201 @@
MIT License
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
Copyright (c) 2024 Calcium-Ion
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
1. Definitions.
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -57,11 +57,11 @@
2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**如果是plus版本选择**Midjourney Proxy Plus**
,模型请参考上方模型列表
3. 地址填写midjourney-proxy部署的地址例如http://localhost:8080
3. **代理**填写midjourney-proxy部署的地址例如http://localhost:8080
4. 密钥填写midjourney-proxy的密钥如果没有设置密钥可以随便填
### 对接上游new api
1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型请参考上方模型列表
2. 地址填写上游new api的地址例如http://localhost:3000
2. **代理**填写上游new api的地址例如http://localhost:3000
3. 密钥填写上游new api的密钥

View File

@@ -56,17 +56,28 @@
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
6. [零一万物](https://platform.lingyiwanwu.com/)
7. 自定义渠道,支持填入完整调用地址
您可以在渠道中添加自定义模型gpt-4-gizmo-*或g-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。
## 渠道重试
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,建议开启缓存功能。
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
### 缓存设置方法
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true``false`,未设置则默认为 `false`
+ 例子:`MEMORY_CACHE_ENABLED=true`
### 为什么有的时候没有重试
这些错误码不会重试400504524
### 我想让400也重试
`渠道->编辑`中,将`状态码复写`改为
```json
{
"400": "500"
}
```
可以实现400错误转为500错误从而重试
## 部署
@@ -88,6 +99,9 @@ docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(宝塔的服务器地址:宝塔数据库端口)/宝塔数据库名称" -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest
# 注意数据库要开启远程访问并且只允许服务器IP访问
```
### 默认账号密码
默认账号root 密码123456
## Midjourney接口设置文档
[对接文档](Midjourney.md)

View File

@@ -51,6 +51,9 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
if err != nil {
return
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return
}
defer resp.Body.Close()
buffer := bytes.NewBuffer(nil)
_, err = buffer.ReadFrom(resp.Body)
@@ -70,6 +73,11 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
}
defer response.Body.Close()
if response.StatusCode != 200 {
err = errors.New(fmt.Sprintf("fail to get image from url: %s", response.Status))
return image.Config{}, "", err
}
var readData []byte
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))

View File

@@ -20,7 +20,7 @@ const (
// 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens
var DefaultModelRatio = map[string]float64{
var defaultModelRatio = map[string]float64{
//"midjourney": 50,
"gpt-4-gizmo-*": 15,
"g-*": 15,
@@ -150,7 +150,7 @@ var DefaultModelRatio = map[string]float64{
"llama-3-sonar-large-32k-online": 1 / 1000 * USD,
}
var DefaultModelPrice = map[string]float64{
var defaultModelPrice = map[string]float64{
"dall-e-2": 0.02,
"dall-e-3": 0.04,
"gpt-4-gizmo-*": 0.1,
@@ -176,15 +176,16 @@ var modelPrice map[string]float64 = nil
var modelRatio map[string]float64 = nil
var CompletionRatio map[string]float64 = nil
var DefaultCompletionRatio = map[string]float64{
var defaultCompletionRatio = map[string]float64{
"gpt-4-gizmo-*": 2,
"g-*": 2,
"gpt-4-all": 2,
"gpt-4o-all": 2,
}
func ModelPrice2JSONString() string {
if modelPrice == nil {
modelPrice = DefaultModelPrice
modelPrice = defaultModelPrice
}
jsonBytes, err := json.Marshal(modelPrice)
if err != nil {
@@ -201,7 +202,7 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1false
func GetModelPrice(name string, printErr bool) (float64, bool) {
if modelPrice == nil {
modelPrice = DefaultModelPrice
modelPrice = defaultModelPrice
}
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
@@ -218,16 +219,16 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
return price, true
}
func GetModelPrices() map[string]float64 {
func GetModelPriceMap() map[string]float64 {
if modelPrice == nil {
modelPrice = DefaultModelPrice
modelPrice = defaultModelPrice
}
return modelPrice
}
func ModelRatio2JSONString() string {
if modelRatio == nil {
modelRatio = DefaultModelRatio
modelRatio = defaultModelRatio
}
jsonBytes, err := json.Marshal(modelRatio)
if err != nil {
@@ -243,7 +244,7 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
func GetModelRatio(name string) float64 {
if modelRatio == nil {
modelRatio = DefaultModelRatio
modelRatio = defaultModelRatio
}
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
@@ -258,16 +259,21 @@ func GetModelRatio(name string) float64 {
return ratio
}
func GetModelRatios() map[string]float64 {
if modelRatio == nil {
modelRatio = DefaultModelRatio
func DefaultModelRatio2JSONString() string {
jsonBytes, err := json.Marshal(defaultModelRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
}
return modelRatio
return string(jsonBytes)
}
func GetDefaultModelRatioMap() map[string]float64 {
return defaultModelRatio
}
func CompletionRatio2JSONString() string {
if CompletionRatio == nil {
CompletionRatio = DefaultCompletionRatio
CompletionRatio = defaultCompletionRatio
}
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
@@ -298,7 +304,7 @@ func GetCompletionRatio(name string) float64 {
return 3
}
return 1.333333
return 4.0 / 3.0
}
if strings.HasPrefix(name, "gpt-4") && name != "gpt-4-all" && name != "gpt-4-gizmo-*" {
if strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4-turbo") || strings.HasPrefix(name, "gpt-4o") {
@@ -355,9 +361,9 @@ func GetCompletionRatio(name string) float64 {
return 1
}
func GetCompletionRatios() map[string]float64 {
func GetCompletionRatioMap() map[string]float64 {
if CompletionRatio == nil {
CompletionRatio = DefaultCompletionRatio
CompletionRatio = defaultCompletionRatio
}
return CompletionRatio
}

View File

@@ -1,5 +1,13 @@
package common
import (
"bytes"
"fmt"
goahocorasick "github.com/anknown/ahocorasick"
"one-api/constant"
"strings"
)
func SundaySearch(text string, pattern string) bool {
// 计算偏移表
offset := make(map[rune]int)
@@ -48,3 +56,25 @@ func RemoveDuplicate(s []string) []string {
}
return result
}
func InitAc() *goahocorasick.Machine {
m := new(goahocorasick.Machine)
dict := readRunes()
if err := m.Build(dict); err != nil {
fmt.Println(err)
return nil
}
return m
}
func readRunes() [][]rune {
var dict [][]rune
for _, word := range constant.SensitiveWords {
word = strings.ToLower(word)
l := bytes.TrimSpace([]byte(word))
dict = append(dict, bytes.Runes(l))
}
return dict
}

View File

@@ -16,7 +16,7 @@ var StreamCacheQueueLength = 0
// SensitiveWords 敏感词
// var SensitiveWords []string
var SensitiveWords = []string{
"test",
"test_sensitive",
}
func SensitiveWordsToString() string {

View File

@@ -1,6 +1,8 @@
package controller
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
@@ -9,6 +11,34 @@ import (
"strings"
)
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group string `json:"group"`
IsBlocking bool `json:"is_blocking"`
} `json:"permission"`
Root string `json:"root"`
Parent string `json:"parent"`
}
type OpenAIModelsResponse struct {
Data []OpenAIModel `json:"data"`
Success bool `json:"success"`
}
func GetAllChannels(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
@@ -35,6 +65,65 @@ func GetAllChannels(c *gin.Context) {
return
}
func FetchUpstreamModels(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if channel.Type != common.ChannelTypeOpenAI {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "仅支持 OpenAI 类型渠道",
})
return
}
url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
}
result := OpenAIModelsResponse{}
err = json.Unmarshal(body, &result)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
}
if !result.Success {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "上游返回错误",
})
}
var ids []string
for _, model := range result.Data {
ids = append(ids, model.ID)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": ids,
})
}
func FixChannelsAbilities(c *gin.Context) {
count, err := model.FixAbility()
if err != nil {

View File

@@ -200,18 +200,3 @@ func RetrieveModel(c *gin.Context) {
})
}
}
func GetPricing(c *gin.Context) {
userId := c.GetInt("id")
group, err := model.CacheGetUserGroup(userId)
groupRatio := common.GetGroupRatio("default")
if err != nil {
groupRatio = common.GetGroupRatio(group)
}
pricing := model.GetPricing(group)
c.JSON(200, gin.H{
"success": true,
"data": pricing,
"group_ratio": groupRatio,
})
}

47
controller/pricing.go Normal file
View File

@@ -0,0 +1,47 @@
package controller
import (
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/model"
)
func GetPricing(c *gin.Context) {
userId := c.GetInt("id")
// if no login, get default group ratio
groupRatio := common.GetGroupRatio("default")
group, err := model.CacheGetUserGroup(userId)
if err == nil {
groupRatio = common.GetGroupRatio(group)
}
pricing := model.GetPricing(group)
c.JSON(200, gin.H{
"success": true,
"data": pricing,
"group_ratio": groupRatio,
})
}
func ResetModelRatio(c *gin.Context) {
defaultStr := common.DefaultModelRatio2JSONString()
err := model.UpdateOption("ModelRatio", defaultStr)
if err != nil {
c.JSON(200, gin.H{
"success": false,
"message": err.Error(),
})
return
}
err = common.UpdateModelRatioByJSONString(defaultStr)
if err != nil {
c.JSON(200, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(200, gin.H{
"success": true,
"message": "重置模型倍率成功",
})
}

View File

@@ -36,7 +36,7 @@ const (
)
func GetLogByKey(key string) (logs []*Log, err error) {
err = DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.Split(key, "-")[1]).Find(&logs).Error
err = DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error
return logs, err
}

View File

@@ -370,7 +370,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
}, nil
}
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false)
completionTokens, err := service.CountTokenText(claudeResponse.Completion, model)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
}

View File

@@ -256,7 +256,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, false)
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model)
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,

View File

@@ -190,7 +190,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if simpleResponse.Usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
ctkm, _, _ := service.CountTokenText(string(choice.Message.Content), model, false)
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model)
completionTokens += ctkm
}
simpleResponse.Usage = dto.Usage{

View File

@@ -156,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, false)
completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model)
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,

View File

@@ -55,7 +55,13 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
promptTokens := 0
preConsumedTokens := common.PreConsumedQuota
if strings.HasPrefix(audioRequest.Model, "tts-1") {
promptTokens, err, _ = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive())
if constant.ShouldCheckPromptSensitive() {
err = service.CheckSensitiveInput(audioRequest.Input)
if err != nil {
return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
}
}
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
}
@@ -178,7 +184,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
if strings.HasPrefix(audioRequest.Model, "tts-1") {
quota = promptTokens
} else {
quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, false)
quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model)
}
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {

View File

@@ -10,6 +10,7 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
@@ -47,6 +48,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
return service.OpenAIErrorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
}
if constant.ShouldCheckPromptSensitive() {
err = service.CheckSensitiveInput(imageRequest.Prompt)
if err != nil {
return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
}
}
if strings.Contains(imageRequest.Size, "×") {
return service.OpenAIErrorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
}

View File

@@ -158,7 +158,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
modelPrice, success := common.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if !success {
defaultPrice, ok := common.DefaultModelPrice[modelName]
defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
@@ -457,7 +457,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
modelPrice, success := common.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if !success {
defaultPrice, ok := common.DefaultModelPrice[modelName]
defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {

View File

@@ -98,13 +98,17 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
var ratio float64
var modelRatio float64
//err := service.SensitiveWordsCheck(textRequest)
promptTokens, err, sensitiveTrigger := getPromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
if sensitiveTrigger {
if constant.ShouldCheckPromptSensitive() {
err = checkRequestSensitive(textRequest, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
}
}
promptTokens, err := getPromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
}
@@ -128,7 +132,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapper(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(relayInfo, *textRequest)
var requestBody io.Reader
@@ -136,7 +140,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
@@ -145,11 +149,11 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
} else {
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
if err != nil {
return service.OpenAIErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
}
@@ -182,15 +186,14 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return nil
}
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error, bool) {
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
var promptTokens int
var err error
var sensitiveTrigger bool
checkSensitive := constant.ShouldCheckPromptSensitive()
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
promptTokens, err, sensitiveTrigger = service.CountTokenChatRequest(*textRequest, textRequest.Model, checkSensitive)
promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model)
case relayconstant.RelayModeCompletions:
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
prompts := textRequest.Prompt
switch v := prompts.(type) {
case string:
@@ -199,17 +202,32 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
prompts = append(v, textRequest.Suffix)
}
promptTokens, err, sensitiveTrigger = service.CountTokenInput(prompts, textRequest.Model, checkSensitive)
promptTokens, err = service.CountTokenInput(prompts, textRequest.Model)
case relayconstant.RelayModeModerations:
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
case relayconstant.RelayModeEmbeddings:
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
default:
err = errors.New("unknown relay mode")
promptTokens = 0
}
info.PromptTokens = promptTokens
return promptTokens, err, sensitiveTrigger
return promptTokens, err
}
func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
var err error
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
err = service.CheckSensitiveMessages(textRequest.Messages)
case relayconstant.RelayModeCompletions:
err = service.CheckSensitiveInput(textRequest.Prompt)
case relayconstant.RelayModeModerations:
err = service.CheckSensitiveInput(textRequest.Input)
case relayconstant.RelayModeEmbeddings:
err = service.CheckSensitiveInput(textRequest.Input)
}
return err
}
// 预扣费并返回用户剩余配额

View File

@@ -74,6 +74,7 @@ func SetApiRouter(router *gin.Engine) {
{
optionRoute.GET("/", controller.GetOptions)
optionRoute.PUT("/", controller.UpdateOption)
optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio)
}
channelRoute := apiRouter.Group("/channel")
channelRoute.Use(middleware.AdminAuth())
@@ -92,6 +93,8 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.DELETE("/:id", controller.DeleteChannel)
channelRoute.POST("/batch", controller.DeleteChannelBatch)
channelRoute.POST("/fix", controller.FixChannelsAbilities)
channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
}
tokenRoute := apiRouter.Group("/token")
tokenRoute.Use(middleware.UserAuth())

View File

@@ -1,13 +1,60 @@
package service
import (
"bytes"
"errors"
"fmt"
"github.com/anknown/ahocorasick"
"one-api/common"
"one-api/constant"
"one-api/dto"
"strings"
)
func CheckSensitiveMessages(messages []dto.Message) error {
for _, message := range messages {
if len(message.Content) > 0 {
if message.IsStringContent() {
stringContent := message.StringContent()
if ok, words := SensitiveWordContains(stringContent); ok {
return errors.New("sensitive words: " + strings.Join(words, ","))
}
}
} else {
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == "image_url" {
// TODO: check image url
} else {
if ok, words := SensitiveWordContains(m.Text); ok {
return errors.New("sensitive words: " + strings.Join(words, ","))
}
}
}
}
}
return nil
}
func CheckSensitiveText(text string) error {
if ok, words := SensitiveWordContains(text); ok {
return errors.New("sensitive words: " + strings.Join(words, ","))
}
return nil
}
func CheckSensitiveInput(input any) error {
switch v := input.(type) {
case string:
return CheckSensitiveText(v)
case []string:
text := ""
for _, s := range v {
text += s
}
return CheckSensitiveText(text)
}
return CheckSensitiveText(fmt.Sprintf("%v", input))
}
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
func SensitiveWordContains(text string) (bool, []string) {
if len(constant.SensitiveWords) == 0 {
@@ -15,7 +62,7 @@ func SensitiveWordContains(text string) (bool, []string) {
}
checkText := strings.ToLower(text)
// 构建一个AC自动机
m := initAc()
m := common.InitAc()
hits := m.MultiPatternSearch([]rune(checkText), false)
if len(hits) > 0 {
words := make([]string, 0)
@@ -33,7 +80,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
return false, nil, text
}
checkText := strings.ToLower(text)
m := initAc()
m := common.InitAc()
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
if len(hits) > 0 {
words := make([]string, 0)
@@ -47,25 +94,3 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
}
return false, nil, text
}
func initAc() *goahocorasick.Machine {
m := new(goahocorasick.Machine)
dict := readRunes()
if err := m.Build(dict); err != nil {
fmt.Println(err)
return nil
}
return m
}
func readRunes() [][]rune {
var dict [][]rune
for _, word := range constant.SensitiveWords {
word = strings.ToLower(word)
l := bytes.TrimSpace([]byte(word))
dict = append(dict, bytes.Runes(l))
}
return dict
}

View File

@@ -34,7 +34,7 @@ func InitTokenEncoders() {
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
}
for model, _ := range common.DefaultModelRatio {
for model, _ := range common.GetDefaultModelRatioMap() {
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4o") {
@@ -127,11 +127,11 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
return tiles*170 + 85, nil
}
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) {
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, error) {
tkm := 0
msgTokens, err, b := CountTokenMessages(request.Messages, model, request.Stream, checkSensitive)
msgTokens, err := CountTokenMessages(request.Messages, model, request.Stream)
if err != nil {
return 0, err, b
return 0, err
}
tkm += msgTokens
if request.Tools != nil {
@@ -139,7 +139,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check
var openaiTools []dto.OpenAITools
err := json.Unmarshal(toolsData, &openaiTools)
if err != nil {
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())), false
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
}
countStr := ""
for _, tool := range openaiTools {
@@ -151,18 +151,18 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
}
}
toolTokens, err, _ := CountTokenInput(countStr, model, false)
toolTokens, err := CountTokenInput(countStr, model)
if err != nil {
return 0, err, false
return 0, err
}
tkm += 8
tkm += toolTokens
}
return tkm, nil, false
return tkm, nil
}
func CountTokenMessages(messages []dto.Message, model string, stream bool, checkSensitive bool) (int, error, bool) {
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
//recover when panic
tokenEncoder := getTokenEncoder(model)
// Reference:
@@ -186,13 +186,6 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool, check
if len(message.Content) > 0 {
if message.IsStringContent() {
stringContent := message.StringContent()
if checkSensitive {
contains, words := SensitiveWordContains(stringContent)
if contains {
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
return 0, err, true
}
}
tokenNum += getTokenNum(tokenEncoder, stringContent)
if message.Name != nil {
tokenNum += tokensPerName
@@ -205,7 +198,7 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool, check
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err := getImageToken(&imageUrl, model, stream)
if err != nil {
return 0, err, false
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
@@ -217,33 +210,33 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool, check
}
}
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
return tokenNum, nil, false
return tokenNum, nil
}
func CountTokenInput(input any, model string, check bool) (int, error, bool) {
func CountTokenInput(input any, model string) (int, error) {
switch v := input.(type) {
case string:
return CountTokenText(v, model, check)
return CountTokenText(v, model)
case []string:
text := ""
for _, s := range v {
text += s
}
return CountTokenText(text, model, check)
return CountTokenText(text, model)
}
return CountTokenInput(fmt.Sprintf("%v", input), model, check)
return CountTokenInput(fmt.Sprintf("%v", input), model)
}
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
tokens := 0
for _, message := range messages {
tkm, _, _ := CountTokenInput(message.Delta.GetContentString(), model, false)
tkm, _ := CountTokenInput(message.Delta.GetContentString(), model)
tokens += tkm
if message.Delta.ToolCalls != nil {
for _, tool := range message.Delta.ToolCalls {
tkm, _, _ := CountTokenInput(tool.Function.Name, model, false)
tkm, _ := CountTokenInput(tool.Function.Name, model)
tokens += tkm
tkm, _, _ = CountTokenInput(tool.Function.Arguments, model, false)
tkm, _ = CountTokenInput(tool.Function.Arguments, model)
tokens += tkm
}
}
@@ -251,29 +244,17 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
return tokens
}
func CountAudioToken(text string, model string, check bool) (int, error, bool) {
func CountAudioToken(text string, model string) (int, error) {
if strings.HasPrefix(model, "tts") {
contains, words := SensitiveWordContains(text)
if contains {
return utf8.RuneCountInString(text), fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")), true
}
return utf8.RuneCountInString(text), nil, false
return utf8.RuneCountInString(text), nil
} else {
return CountTokenText(text, model, check)
return CountTokenText(text, model)
}
}
// CountTokenText 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量
func CountTokenText(text string, model string, check bool) (int, error, bool) {
func CountTokenText(text string, model string) (int, error) {
var err error
var trigger bool
if check {
contains, words := SensitiveWordContains(text)
if contains {
err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ","))
trigger = true
}
}
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text), err, trigger
return getTokenNum(tokenEncoder, text), err
}

View File

@@ -19,7 +19,7 @@ import (
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
usage := &dto.Usage{}
usage.PromptTokens = promptTokens
ctkm, err, _ := CountTokenText(responseText, modeName, false)
ctkm, err := CountTokenText(responseText, modeName)
usage.CompletionTokens = ctkm
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage, err

3510
web/pnpm-lock.yaml generated Normal file

File diff suppressed because it is too large Load Diff

BIN
web/public/ratio.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 140 KiB

View File

@@ -1,7 +1,7 @@
import React, { useEffect, useState } from 'react';
import { getFooterHTML, getSystemName } from '../helpers';
import { Layout } from '@douyinfe/semi-ui';
import { Layout, Tooltip } from '@douyinfe/semi-ui';
const Footer = () => {
const systemName = getSystemName();
@@ -15,6 +15,30 @@ const Footer = () => {
}
};
const defaultFooter = (
<div className='custom-footer'>
<a
href='https://github.com/Calcium-Ion/new-api'
target='_blank'
rel='noreferrer'
>
New API {import.meta.env.VITE_REACT_APP_VERSION}{' '}
</a>
{' '}
<a href='https://github.com/Calcium-Ion' target='_blank' rel='noreferrer'>
Calcium-Ion
</a>{' '}
开发基于{' '}
<a
href='https://github.com/songquanpeng/one-api'
target='_blank'
rel='noreferrer'
>
One API
</a>
</div>
);
useEffect(() => {
const timer = setInterval(() => {
if (remainCheckTimes <= 0) {
@@ -31,41 +55,14 @@ const Footer = () => {
<Layout>
<Layout.Content style={{ textAlign: 'center' }}>
{footer ? (
<div
className='custom-footer'
dangerouslySetInnerHTML={{ __html: footer }}
></div>
<Tooltip content={defaultFooter}>
<div
className='custom-footer'
dangerouslySetInnerHTML={{ __html: footer }}
></div>
</Tooltip>
) : (
<div className='custom-footer'>
<a
href='https://github.com/Calcium-Ion/new-api'
target='_blank'
rel='noreferrer'
>
New API {import.meta.env.VITE_REACT_APP_VERSION}{' '}
</a>
{' '}
<a
href='https://github.com/Calcium-Ion'
target='_blank'
rel='noreferrer'
>
Calcium-Ion
</a>{' '}
开发基于{' '}
<a
href='https://github.com/songquanpeng/one-api'
target='_blank'
rel='noreferrer'
>
One API v0.5.4
</a>{' '}
本项目根据{' '}
<a href='https://opensource.org/licenses/mit-license.php'>
MIT 许可证
</a>{' '}
授权
</div>
defaultFooter
)}
</Layout.Content>
</Layout>

View File

@@ -1,4 +1,4 @@
import React, { useContext, useEffect, useRef, useState } from 'react';
import React, { useContext, useEffect, useRef, useMemo, useState } from 'react';
import { API, copy, showError, showSuccess } from '../helpers';
import {
@@ -10,8 +10,16 @@ import {
Table,
Tag,
Tooltip,
Popover,
ImagePreview,
Button,
} from '@douyinfe/semi-ui';
import { stringToColor } from '../helpers/render.js';
import {
IconMore,
IconVerify,
IconUploadError,
IconHelpCircle,
} from '@douyinfe/semi-icons';
import { UserContext } from '../context/User/index.js';
import Text from '@douyinfe/semi-ui/lib/es/typography/text';
@@ -20,42 +28,70 @@ function renderQuotaType(type) {
switch (type) {
case 1:
return (
<Tag color='green' size='large'>
<Tag color='teal' size='large'>
按次计费
</Tag>
);
case 0:
return (
<Tag color='blue' size='large'>
<Tag color='violet' size='large'>
按量计费
</Tag>
);
default:
return (
<Tag color='white' size='large'>
未知
</Tag>
);
return '未知';
}
}
function renderAvailable(available) {
return available ? (
<Tag color='green' size='large'>
可用
</Tag>
<Popover
content={<div style={{ padding: 8 }}>您的分组可以使用该模型</div>}
position='top'
key={available}
style={{
backgroundColor: 'rgba(var(--semi-blue-4),1)',
borderColor: 'rgba(var(--semi-blue-4),1)',
color: 'var(--semi-color-white)',
borderWidth: 1,
borderStyle: 'solid',
}}
>
<IconVerify style={{ color: 'green' }} size='large' />
</Popover>
) : (
<Tooltip content='您所在的分组不可用'>
<Tag color='red' size='large'>
不可用
</Tag>
</Tooltip>
<Popover
content={<div style={{ padding: 8 }}>您的分组无权使用该模型</div>}
position='top'
key={available}
style={{
backgroundColor: 'rgba(var(--semi-blue-4),1)',
borderColor: 'rgba(var(--semi-blue-4),1)',
color: 'var(--semi-color-white)',
borderWidth: 1,
borderStyle: 'solid',
}}
>
<IconUploadError style={{ color: '#FFA54F' }} size='large' />
</Popover>
);
}
const ModelPricing = () => {
const [filteredValue, setFilteredValue] = useState([]);
const compositionRef = useRef({ isComposition: false });
const [selectedRowKeys, setSelectedRowKeys] = useState([]);
const [modalImageUrl, setModalImageUrl] = useState('');
const [isModalOpenurl, setIsModalOpenurl] = useState(false);
const rowSelection = useMemo(
() => ({
onChange: (selectedRowKeys, selectedRows) => {
setSelectedRowKeys(selectedRowKeys);
},
}),
[],
);
const handleChange = (value) => {
if (compositionRef.current.isComposition) {
@@ -103,7 +139,7 @@ const ModelPricing = () => {
return (
<>
<Tag
color={stringToColor(text)}
color='green'
size='large'
onClick={() => {
copyText(text);
@@ -114,7 +150,8 @@ const ModelPricing = () => {
</>
);
},
onFilter: (value, record) => record.model_name.includes(value),
onFilter: (value, record) =>
record.model_name.toLowerCase().includes(value.toLowerCase()),
filteredValue,
},
{
@@ -126,18 +163,49 @@ const ModelPricing = () => {
sorter: (a, b) => a.quota_type - b.quota_type,
},
{
title: '模型倍率',
title: () => (
<span style={{ display: 'flex', alignItems: 'center' }}>
倍率
<Popover
content={
<div style={{ padding: 8 }}>
倍率是为了方便换算不同价格的模型
<br />
点击查看倍率说明
</div>
}
position='top'
style={{
backgroundColor: 'rgba(var(--semi-blue-4),1)',
borderColor: 'rgba(var(--semi-blue-4),1)',
color: 'var(--semi-color-white)',
borderWidth: 1,
borderStyle: 'solid',
}}
>
<IconHelpCircle
onClick={() => {
setModalImageUrl('/ratio.png');
setIsModalOpenurl(true);
}}
/>
</Popover>
</span>
),
dataIndex: 'model_ratio',
render: (text, record, index) => {
return <div>{record.quota_type === 0 ? text : 'N/A'}</div>;
},
},
{
title: '补全倍率',
dataIndex: 'completion_ratio',
render: (text, record, index) => {
let ratio = parseFloat(text.toFixed(3));
return <div>{record.quota_type === 0 ? ratio : 'N/A'}</div>;
let content = text;
let completionRatio = parseFloat(record.completion_ratio.toFixed(3));
content = (
<>
<Text>模型{record.quota_type === 0 ? text : '无'}</Text>
<br />
<Text>
补全{record.quota_type === 0 ? completionRatio : '无'}
</Text>
</>
);
return <div>{content}</div>;
},
},
{
@@ -176,7 +244,7 @@ const ModelPricing = () => {
const setModelsFormat = (models, groupRatio) => {
for (let i = 0; i < models.length; i++) {
models[i].key = i;
models[i].key = models[i].model_name;
models[i].group_ratio = groupRatio;
}
// sort by quota_type
@@ -239,15 +307,43 @@ const ModelPricing = () => {
<Layout>
{userState.user ? (
<Banner
type='info'
type='success'
fullMode={false}
closeIcon='null'
description={`您的分组为:${userState.user.group},分组倍率为:${groupRatio}`}
/>
) : (
<Banner
type='warning'
fullMode={false}
closeIcon='null'
description={`您还未登陆,显示的价格为默认分组倍率: ${groupRatio}`}
/>
)}
<br />
<Banner
type='info'
fullMode={false}
description={
<div>
按量计费费用 = 分组倍率 × 模型倍率 × 提示token数 + 补全token数 ×
补全倍率/ 500000 单位美元
</div>
}
closeIcon='null'
/>
<br />
<Button
theme='light'
type='tertiary'
style={{ width: 150 }}
onClick={() => {
copyText(selectedRowKeys);
}}
disabled={selectedRowKeys == ''}
>
复制选中模型
</Button>
<Table
style={{ marginTop: 5 }}
columns={columns}
@@ -257,6 +353,12 @@ const ModelPricing = () => {
pageSize: models.length,
showSizeChanger: false,
}}
rowSelection={rowSelection}
/>
<ImagePreview
src={modalImageUrl}
visible={isModalOpenurl}
onVisibleChange={(visible) => setIsModalOpenurl(visible)}
/>
</Layout>
</>

View File

@@ -216,6 +216,15 @@ export const verifyJSON = (str) => {
return true;
};
export function verifyJSONPromise(value) {
try {
JSON.parse(value);
return Promise.resolve();
} catch (e) {
return Promise.reject('不是合法的 JSON 字符串');
}
}
export function shouldShowPrompt(id) {
let prompt = localStorage.getItem(`prompt-${id}`);
return !prompt;

View File

@@ -39,13 +39,16 @@ const About = () => {
</Layout.Header>
<Layout.Content>
<p>可在设置页面设置关于内容支持 HTML & Markdown</p>
new-api项目仓库地址
New-API项目仓库地址
<a href='https://github.com/Calcium-Ion/new-api'>
https://github.com/Calcium-Ion/new-api
</a>
<p>
NewAPI © 2023 CalciumIon | 基于 One API v0.5.4 © 2023
JustSong本项目根据MIT许可证授权
JustSong
</p>
<p>
本项目根据MIT许可证授权需在遵守Apache-2.0协议的前提下使用
</p>
</Layout.Content>
</Layout>

View File

@@ -15,6 +15,7 @@ import {
Space,
Spin,
Button,
Tooltip,
Input,
Typography,
Select,
@@ -24,6 +25,7 @@ import {
} from '@douyinfe/semi-ui';
import { Divider } from 'semantic-ui-react';
import { getChannelModels, loadChannelModels } from '../../components/utils.js';
import axios from 'axios';
const MODEL_MAPPING_EXAMPLE = {
'gpt-3.5-turbo-0301': 'gpt-3.5-turbo',
@@ -35,6 +37,9 @@ const STATUS_CODE_MAPPING_EXAMPLE = {
400: '500',
};
const fetchButtonTips =
'1. 新建渠道时请求通过当前浏览器发出2. 编辑已有渠道,请求通过后端服务器发出';
function type2secretPrompt(type) {
// inputs.type === 15 ? '按照如下格式输入APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
switch (type) {
@@ -173,12 +178,58 @@ const EditChannel = (props) => {
setLoading(false);
};
const fetchUpstreamModelList = async (name) => {
if (inputs['type'] !== 1) {
showError('仅支持 OpenAI 接口格式');
return;
}
setLoading(true);
const models = inputs['models'] || [];
let err = false;
if (isEdit) {
const res = await API.get('/api/channel/fetch_models/' + channelId);
if (res.data && res.data?.success) {
models.push(...res.data.data);
} else {
err = true;
}
} else {
if (!inputs?.['key']) {
showError('请填写密钥');
err = true;
} else {
try {
const host = new URL(inputs['base_url'] || 'https://api.openai.com');
const url = `https://${host.hostname}/v1/models`;
const key = inputs['key'];
const res = await axios.get(url, {
headers: {
Authorization: `Bearer ${key}`,
},
});
if (res.data && res.data?.success) {
models.push(...es.data.data.map((model) => model.id));
} else {
err = true;
}
} catch (error) {
err = true;
}
}
}
if (!err) {
handleInputChange(name, Array.from(new Set(models)));
showSuccess('获取模型列表成功');
} else {
showError('获取模型列表失败');
}
setLoading(false);
};
const fetchModels = async () => {
try {
let res = await API.get(`/api/channel/models`);
if (res === undefined) {
return;
}
let localModelOptions = res.data.data.map((model) => ({
label: model.id,
value: model.id,
@@ -432,6 +483,17 @@ const EditChannel = (props) => {
)}
{inputs.type === 8 && (
<>
<div style={{ marginTop: 10 }}>
<Banner
type={'warning'}
description={
<>
如果你对接的是上游One API或者New
API等转发项目请使用OpenAI类型不要使用此类型除非你知道你在做什么
</>
}
></Banner>
</div>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>
完整的 Base URL支持变量{'{model}'}
@@ -550,6 +612,16 @@ const EditChannel = (props) => {
>
填入所有模型
</Button>
<Tooltip content={fetchButtonTips}>
<Button
type='tertiary'
onClick={() => {
fetchUpstreamModelList('models');
}}
>
获取模型列表
</Button>
</Tooltip>
<Button
type='warning'
onClick={() => {

View File

@@ -86,11 +86,21 @@ const Home = () => {
<p>
源码
<a
href='https://github.com/songquanpeng/one-api'
href='https://github.com/Calcium-Ion/new-api'
target='_blank'
rel='noreferrer'
>
https://github.com/songquanpeng/one-api
https://github.com/Calcium-Ion/new-api
</a>
</p>
<p>
协议
<a
href='https://www.apache.org/licenses/LICENSE-2.0'
target='_blank'
rel='noreferrer'
>
Apache-2.0 License
</a>
</p>
<p>启动时间{getStartTimeString()}</p>

View File

@@ -1,5 +1,13 @@
import React, { useEffect, useState, useRef } from 'react';
import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui';
import {
Button,
Col,
Form,
Popconfirm,
Row,
Space,
Spin,
} from '@douyinfe/semi-ui';
import {
compareObjects,
API,
@@ -7,6 +15,7 @@ import {
showSuccess,
showWarning,
verifyJSON,
verifyJSONPromise,
} from '../../../helpers';
export default function SettingsMagnification(props) {
@@ -22,43 +31,66 @@ export default function SettingsMagnification(props) {
async function onSubmit() {
try {
await refForm.current.validate();
const updateArray = compareObjects(inputs, inputsRow);
if (!updateArray.length) return showWarning('你似乎并没有修改什么');
const requestQueue = updateArray.map((item) => {
let value = '';
if (typeof inputs[item.key] === 'boolean') {
value = String(inputs[item.key]);
} else {
value = inputs[item.key];
}
return API.put('/api/option/', {
key: item.key,
value,
});
});
setLoading(true);
Promise.all(requestQueue)
.then((res) => {
if (requestQueue.length === 1) {
if (res.includes(undefined)) return;
} else if (requestQueue.length > 1) {
if (res.includes(undefined))
return showError('部分保存失败,请重试');
}
showSuccess('保存成功');
props.refresh();
console.log('Starting validation...');
await refForm.current
.validate()
.then(() => {
console.log('Validation passed');
const updateArray = compareObjects(inputs, inputsRow);
if (!updateArray.length) return showWarning('你似乎并没有修改什么');
const requestQueue = updateArray.map((item) => {
let value = '';
if (typeof inputs[item.key] === 'boolean') {
value = String(inputs[item.key]);
} else {
value = inputs[item.key];
}
return API.put('/api/option/', {
key: item.key,
value,
});
});
setLoading(true);
Promise.all(requestQueue)
.then((res) => {
if (requestQueue.length === 1) {
if (res.includes(undefined)) return;
} else if (requestQueue.length > 1) {
if (res.includes(undefined))
return showError('部分保存失败,请重试');
}
showSuccess('保存成功');
props.refresh();
})
.catch(() => {
showError('保存失败,请重试');
})
.finally(() => {
setLoading(false);
});
})
.catch(() => {
showError('保存失败,请重试');
})
.finally(() => {
setLoading(false);
.catch((error) => {
console.error('Validation failed:', error);
showError('请检查输入');
});
} catch (error) {
showError('请检查输入');
console.error(error);
} finally {
}
}
async function resetModelRatio() {
try {
let res = await API.post(`/api/option/rest_model_ratio`);
// return {success, message}
if (res.data.success) {
showSuccess(res.data.message);
props.refresh();
} else {
showError(res.data.message);
}
} catch (error) {
showError(error);
}
}
@@ -73,122 +105,141 @@ export default function SettingsMagnification(props) {
setInputsRow(structuredClone(currentInputs));
refForm.current.setValues(currentInputs);
}, [props.options]);
return (
<>
<Spin spinning={loading}>
<Form
values={inputs}
getFormApi={(formAPI) => (refForm.current = formAPI)}
style={{ marginBottom: 15 }}
>
<Form.Section text={'倍率设置'}>
<Row gutter={16}>
<Col span={16}>
<Form.TextArea
label={'模型固定价格'}
extraText={'一次调用消耗多少刀,优先级大于模型倍率'}
placeholder={
'为一个 JSON 文本,键为模型名称,值为一次调用消耗多少刀,比如 "gpt-4-gizmo-*": 0.1一次消耗0.1刀'
}
field={'ModelPrice'}
autosize={{ minRows: 6, maxRows: 12 }}
trigger='blur'
rules={[
{
validator: (rule, value) => verifyJSON(value),
message: '不是合法的 JSON 字符串',
},
]}
onChange={(value) =>
setInputs({
...inputs,
ModelPrice: value,
})
}
/>
</Col>
</Row>
<Row gutter={16}>
<Col span={16}>
<Form.TextArea
label={'模型倍率'}
extraText={''}
placeholder={'为一个 JSON 文本,键为模型名称,值为倍率'}
field={'ModelRatio'}
autosize={{ minRows: 6, maxRows: 12 }}
trigger='blur'
rules={[
{
validator: (rule, value) => verifyJSON(value),
message: '不是合法的 JSON 字符串',
},
]}
onChange={(value) =>
setInputs({
...inputs,
ModelRatio: value,
})
}
/>
</Col>
</Row>
<Row gutter={16}>
<Col span={16}>
<Form.TextArea
label={'模型补全倍率(仅对自定义模型有效)'}
extraText={'仅对自定义模型有效'}
placeholder={'为一个 JSON 文本,键为模型名称,值为倍率'}
field={'CompletionRatio'}
autosize={{ minRows: 6, maxRows: 12 }}
trigger='blur'
rules={[
{
validator: (rule, value) => verifyJSON(value),
message: '不是合法的 JSON 字符串',
},
]}
onChange={(value) =>
setInputs({
...inputs,
CompletionRatio: value,
})
}
/>
</Col>
</Row>
<Row gutter={16}>
<Col span={16}>
<Form.TextArea
label={'分组倍率'}
extraText={''}
placeholder={'为一个 JSON 文本,键为分组名称,值为倍率'}
field={'GroupRatio'}
autosize={{ minRows: 6, maxRows: 12 }}
trigger='blur'
rules={[
{
validator: (rule, value) => verifyJSON(value),
message: '不是合法的 JSON 字符串',
},
]}
onChange={(value) =>
setInputs({
...inputs,
GroupRatio: value,
})
}
/>
</Col>
</Row>
<Row>
<Button size='large' onClick={onSubmit}>
保存倍率设置
</Button>
</Row>
</Form.Section>
</Form>
</Spin>
</>
return (
<Spin spinning={loading}>
<Form
values={inputs}
getFormApi={(formAPI) => (refForm.current = formAPI)}
style={{ marginBottom: 15 }}
>
<Form.Section text={'倍率设置'}>
<Row gutter={16}>
<Col span={16}>
<Form.TextArea
label={'模型固定价格'}
extraText={'一次调用消耗多少刀,优先级大于模型倍率'}
placeholder={
'为一个 JSON 文本,键为模型名称,值为一次调用消耗多少刀,比如 "gpt-4-gizmo-*": 0.1一次消耗0.1刀'
}
field={'ModelPrice'}
autosize={{ minRows: 6, maxRows: 12 }}
trigger='blur'
stopValidateWithError
rules={[
{
validator: (rule, value) => {
return verifyJSON(value);
},
message: '不是合法的 JSON 字符串',
},
]}
onChange={(value) =>
setInputs({
...inputs,
ModelPrice: value,
})
}
/>
</Col>
</Row>
<Row gutter={16}>
<Col span={16}>
<Form.TextArea
label={'模型倍率'}
extraText={''}
placeholder={'为一个 JSON 文本,键为模型名称,值为倍率'}
field={'ModelRatio'}
autosize={{ minRows: 6, maxRows: 12 }}
trigger='blur'
stopValidateWithError
rules={[
{
validator: (rule, value) => {
return verifyJSON(value);
},
message: '不是合法的 JSON 字符串',
},
]}
onChange={(value) =>
setInputs({
...inputs,
ModelRatio: value,
})
}
/>
</Col>
</Row>
<Row gutter={16}>
<Col span={16}>
<Form.TextArea
label={'模型补全倍率(仅对自定义模型有效)'}
extraText={'仅对自定义模型有效'}
placeholder={'为一个 JSON 文本,键为模型名称,值为倍率'}
field={'CompletionRatio'}
autosize={{ minRows: 6, maxRows: 12 }}
trigger='blur'
stopValidateWithError
rules={[
{
validator: (rule, value) => {
return verifyJSON(value);
},
message: '不是合法的 JSON 字符串',
},
]}
onChange={(value) =>
setInputs({
...inputs,
CompletionRatio: value,
})
}
/>
</Col>
</Row>
<Row gutter={16}>
<Col span={16}>
<Form.TextArea
label={'分组倍率'}
extraText={''}
placeholder={'为一个 JSON 文本,键为分组名称,值为倍率'}
field={'GroupRatio'}
autosize={{ minRows: 6, maxRows: 12 }}
trigger='blur'
stopValidateWithError
rules={[
{
validator: (rule, value) => {
return verifyJSON(value);
},
message: '不是合法的 JSON 字符串',
},
]}
onChange={(value) =>
setInputs({
...inputs,
GroupRatio: value,
})
}
/>
</Col>
</Row>
</Form.Section>
</Form>
<Space>
<Button onClick={onSubmit}>保存倍率设置</Button>
<Popconfirm
title='确定重置模型倍率吗?'
content='此修改将不可逆'
okType={'danger'}
position={'top'}
onConfirm={() => {
resetModelRatio();
}}
>
<Button type={'danger'}>重置模型倍率</Button>
</Popconfirm>
</Space>
</Spin>
);
}