Compare commits

...

10 Commits

Author SHA1 Message Date
JustSong
7201bd1c97 fix: update api2d's base url (#83) 2023-05-17 18:47:25 +08:00
JustSong
73d5e0f283 feat: support dummy sk- prefix for token (#82) 2023-05-17 17:04:06 +08:00
JustSong
efc744ca35 feat: API /models & /models/:model implemented (close #68) 2023-05-17 10:42:52 +08:00
JustSong
e8da98139f fix: limit the shown text's length (close #80) 2023-05-16 21:33:59 +08:00
JustSong
519cb030f7 chore: update input label 2023-05-16 16:23:07 +08:00
JustSong
58fe923c85 perf: use max_tokens to reduce token consuming 2023-05-16 16:22:25 +08:00
JustSong
c9ac5e391f feat: support max_tokens now (#52) 2023-05-16 16:18:35 +08:00
JustSong
69cf1de7bd feat: disable operations for root user (close #76) 2023-05-16 15:38:03 +08:00
JustSong
4d6172a242 feat: able to set pre consumed quota now 2023-05-16 13:57:01 +08:00
JustSong
8afdc56b11 fix: fix quota not consuming 2023-05-16 13:29:22 +08:00
12 changed files with 292 additions and 37 deletions

View File

@@ -54,6 +54,7 @@ var QuotaForNewUser = 0
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
var QuotaRemindThreshold = 1000
var PreConsumedQuota = 500
var RootUserEmail = ""
@@ -131,7 +132,7 @@ const (
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://openai.api2d.net", // 2
"https://oa.api2d.net", // 2
"", // 3
"https://api.openai-asia.com", // 4
"https://api.openai-sb.com", // 5

View File

@@ -210,11 +210,12 @@ func testChannel(channel *model.Channel, request *ChatRequest) error {
func buildTestRequest(c *gin.Context) *ChatRequest {
model_ := c.Query("model")
testRequest := &ChatRequest{
Model: model_,
Model: model_,
MaxTokens: 1,
}
testMessage := Message{
Role: "user",
Content: "echo hi",
Content: "hi",
}
testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest

153
controller/model.go Normal file
View File

@@ -0,0 +1,153 @@
package controller
import (
"fmt"
"github.com/gin-gonic/gin"
)
// https://platform.openai.com/docs/api-reference/models/list
type OpenAIModelPermission struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group *string `json:"group"`
IsBlocking bool `json:"is_blocking"`
}
type OpenAIModels struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission OpenAIModelPermission `json:"permission"`
Root string `json:"root"`
Parent *string `json:"parent"`
}
var openAIModels []OpenAIModels
var openAIModelsMap map[string]OpenAIModels
func init() {
permission := OpenAIModelPermission{
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
Object: "model_permission",
Created: 1626777600,
AllowCreateEngine: true,
AllowSampling: true,
AllowLogprobs: true,
AllowSearchIndices: false,
AllowView: true,
AllowFineTuning: false,
Organization: "*",
Group: nil,
IsBlocking: false,
}
// https://platform.openai.com/docs/models/model-endpoint-compatibility
openAIModels = []OpenAIModels{
{
Id: "gpt-3.5-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-0301",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0301",
Parent: nil,
},
{
Id: "gpt-4",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4",
Parent: nil,
},
{
Id: "gpt-4-0314",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0314",
Parent: nil,
},
{
Id: "gpt-4-32k",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k",
Parent: nil,
},
{
Id: "gpt-4-32k-0314",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k-0314",
Parent: nil,
},
{
Id: "gpt-3.5-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo",
Parent: nil,
},
{
Id: "text-embedding-ada-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-embedding-ada-002",
Parent: nil,
},
}
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
openAIModelsMap[model.Id] = model
}
}
func ListModels(c *gin.Context) {
c.JSON(200, openAIModels)
}
func RetrieveModel(c *gin.Context) {
modelId := c.Param("model")
if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, model)
} else {
openAIError := OpenAIError{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Type: "invalid_request_error",
Param: "model",
Code: "model_not_found",
}
c.JSON(200, gin.H{
"error": openAIError,
})
}
}

View File

@@ -21,14 +21,16 @@ type Message struct {
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
}
type TextRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt string `json:"prompt"`
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
//Stream bool `json:"stream"`
}
@@ -128,6 +130,23 @@ func relayHelper(c *gin.Context) error {
model_ = strings.TrimSuffix(model_, "-0314")
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
}
var promptText string
for _, message := range textRequest.Messages {
promptText += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
}
promptTokens := countToken(promptText) + 3
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
}
ratio := common.GetModelRatio(textRequest.Model)
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
if consumeQuota {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return err
}
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
if err != nil {
return err
@@ -168,18 +187,14 @@ func relayHelper(c *gin.Context) error {
completionRatio = 2
}
if isStream {
var promptText string
for _, message := range textRequest.Messages {
promptText += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
}
completionText := fmt.Sprintf("%s: %s\n", "assistant", streamResponseText)
quota = countToken(promptText) + countToken(completionText)*completionRatio + 3
quota = promptTokens + countToken(completionText)*completionRatio
} else {
quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
}
ratio := common.GetModelRatio(textRequest.Model)
quota = int(float64(quota) * ratio)
err := model.DecreaseTokenQuota(tokenId, quota)
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("Error consuming token remain quota: " + err.Error())
}

View File

@@ -467,6 +467,13 @@ func CreateUser(c *gin.Context) {
})
return
}
if err := common.Validate.Struct(&user); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "输入不合法 " + err.Error(),
})
return
}
if user.DisplayName == "" {
user.DisplayName = user.Username
}

View File

@@ -85,6 +85,8 @@ func RootAuth() func(c *gin.Context) {
func TokenAuth() func(c *gin.Context) {
return func(c *gin.Context) {
key := c.Request.Header.Get("Authorization")
key = strings.TrimPrefix(key, "Bearer ")
key = strings.TrimPrefix(key, "sk-")
parts := strings.Split(key, "-")
key = parts[0]
token, err := model.ValidateUserToken(key)
@@ -111,7 +113,7 @@ func TokenAuth() func(c *gin.Context) {
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
requestURL := c.Request.URL.String()
consumeQuota := !token.UnlimitedQuota
consumeQuota := true
if strings.HasPrefix(requestURL, "/v1/models") {
consumeQuota = false
}

View File

@@ -55,6 +55,7 @@ func InitOptionMap() {
common.OptionMap["TurnstileSecretKey"] = ""
common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMapRWMutex.Unlock()
@@ -159,6 +160,8 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaForNewUser, _ = strconv.Atoi(value)
case "QuotaRemindThreshold":
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "PreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value)
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
case "TopUpLink":

View File

@@ -6,7 +6,6 @@ import (
_ "gorm.io/driver/sqlite"
"gorm.io/gorm"
"one-api/common"
"strings"
)
type Token struct {
@@ -38,7 +37,6 @@ func ValidateUserToken(key string) (token *Token, err error) {
if key == "" {
return nil, errors.New("未提供 token")
}
key = strings.Replace(key, "Bearer ", "", 1)
token = &Token{}
err = DB.Where("`key` = ?", key).First(token).Error
if err == nil {
@@ -130,7 +128,23 @@ func DeleteTokenById(id int, userId int) (err error) {
return token.Delete()
}
func DecreaseTokenQuota(tokenId int, quota int) (err error) {
func IncreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota + ?", quota)).Error
return err
}
func DecreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error
return err
}
func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
@@ -138,7 +152,7 @@ func DecreaseTokenQuota(tokenId int, quota int) (err error) {
if err != nil {
return err
}
if token.RemainQuota < quota {
if !token.UnlimitedQuota && token.RemainQuota < quota {
return errors.New("令牌额度不足")
}
userQuota, err := GetUserQuota(token.UserId)
@@ -163,17 +177,42 @@ func DecreaseTokenQuota(tokenId int, quota int) (err error) {
if email != "" {
topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
err = common.SendEmail(prompt, email,
fmt.Sprintf("%s剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota-quota, topUpLink, topUpLink))
fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
if err != nil {
common.SysError("发送邮件失败:" + err.Error())
}
}
}()
}
err = DB.Model(&Token{}).Where("id = ?", tokenId).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error
if err != nil {
return err
if !token.UnlimitedQuota {
err = DecreaseTokenQuota(tokenId, quota)
if err != nil {
return err
}
}
err = DecreaseUserQuota(token.UserId, quota)
return err
}
func PostConsumeTokenQuota(tokenId int, quota int) (err error) {
token, err := GetTokenById(tokenId)
if quota > 0 {
err = DecreaseUserQuota(token.UserId, quota)
} else {
err = IncreaseUserQuota(token.UserId, -quota)
}
if err != nil {
return err
}
if !token.UnlimitedQuota {
if quota > 0 {
err = DecreaseTokenQuota(tokenId, quota)
} else {
err = IncreaseTokenQuota(tokenId, -quota)
}
if err != nil {
return err
}
}
return nil
}

View File

@@ -11,8 +11,8 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
{
relayV1Router.GET("/models", controller.Relay)
relayV1Router.GET("/models/:model", controller.Relay)
relayV1Router.GET("/models", controller.ListModels)
relayV1Router.GET("/models/:model", controller.RetrieveModel)
relayV1Router.POST("/completions", controller.RelayNotImplemented)
relayV1Router.POST("/chat/completions", controller.Relay)
relayV1Router.POST("/edits", controller.RelayNotImplemented)

View File

@@ -28,6 +28,7 @@ const SystemSetting = () => {
RegisterEnabled: '',
QuotaForNewUser: 0,
QuotaRemindThreshold: 0,
PreConsumedQuota: 0,
ModelRatio: '',
TopUpLink: '',
AutomaticDisableChannelEnabled: '',
@@ -98,6 +99,7 @@ const SystemSetting = () => {
name === 'TurnstileSecretKey' ||
name === 'QuotaForNewUser' ||
name === 'QuotaRemindThreshold' ||
name === 'PreConsumedQuota' ||
name === 'ModelRatio' ||
name === 'TopUpLink'
) {
@@ -119,6 +121,9 @@ const SystemSetting = () => {
if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) {
await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
}
if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) {
await updateOption('PreConsumedQuota', inputs.PreConsumedQuota);
}
if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
if (!verifyJSON(inputs.ModelRatio)) {
showError('模型倍率不是合法的 JSON 字符串');
@@ -272,7 +277,7 @@ const SystemSetting = () => {
<Header as='h3'>
运营设置
</Header>
<Form.Group widths={3}>
<Form.Group widths={4}>
<Form.Input
label='新用户初始配额'
name='QuotaForNewUser'
@@ -302,6 +307,16 @@ const SystemSetting = () => {
min='0'
placeholder='低于此额度时将发送邮件提醒用户'
/>
<Form.Input
label='请求预扣费额度'
name='PreConsumedQuota'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.PreConsumedQuota}
type='number'
min='0'
placeholder='请求结束后多退少补'
/>
</Form.Group>
<Form.Group widths='equal'>
<Form.TextArea
@@ -321,7 +336,7 @@ const SystemSetting = () => {
</Header>
<Form.Group widths={3}>
<Form.Input
label='最长应时间'
label='最长应时间'
name='ChannelDisableThreshold'
onChange={handleInputChange}
autoComplete='new-password'

View File

@@ -4,6 +4,7 @@ import { Link } from 'react-router-dom';
import { API, showError, showSuccess } from '../helpers';
import { ITEMS_PER_PAGE } from '../constants';
import { renderText } from '../helpers/render';
function renderRole(role) {
switch (role) {
@@ -64,7 +65,7 @@ const UsersTable = () => {
(async () => {
const res = await API.post('/api/user/manage', {
username,
action,
action
});
const { success, message } = res.data;
if (success) {
@@ -161,18 +162,18 @@ const UsersTable = () => {
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortUser('username');
sortUser('id');
}}
>
用户名
ID
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortUser('display_name');
sortUser('username');
}}
>
显示名称
用户名
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
@@ -220,9 +221,17 @@ const UsersTable = () => {
if (user.deleted) return <></>;
return (
<Table.Row key={user.id}>
<Table.Cell>{user.username}</Table.Cell>
<Table.Cell>{user.display_name}</Table.Cell>
<Table.Cell>{user.email ? user.email : '无'}</Table.Cell>
<Table.Cell>{user.id}</Table.Cell>
<Table.Cell>
<Popup
content={user.email ? user.email : '未绑定邮箱地址'}
key={user.display_name}
header={user.display_name ? user.display_name : user.username}
trigger={<span>{renderText(user.username, 10)}</span>}
hoverable
/>
</Table.Cell>
<Table.Cell>{user.email ? renderText(user.email, 30) : '无'}</Table.Cell>
<Table.Cell>{user.quota}</Table.Cell>
<Table.Cell>{renderRole(user.role)}</Table.Cell>
<Table.Cell>{renderStatus(user.status)}</Table.Cell>
@@ -234,6 +243,7 @@ const UsersTable = () => {
onClick={() => {
manageUser(user.username, 'promote', idx);
}}
disabled={user.role === 100}
>
提升
</Button>
@@ -243,12 +253,13 @@ const UsersTable = () => {
onClick={() => {
manageUser(user.username, 'demote', idx);
}}
disabled={user.role === 100}
>
降级
</Button>
<Popup
trigger={
<Button size='small' negative>
<Button size='small' negative disabled={user.role === 100}>
删除
</Button>
}
@@ -274,6 +285,7 @@ const UsersTable = () => {
idx
);
}}
disabled={user.role === 100}
>
{user.status === 1 ? '禁用' : '启用'}
</Button>
@@ -281,6 +293,7 @@ const UsersTable = () => {
size={'small'}
as={Link}
to={'/user/edit/' + user.id}
disabled={user.role === 100}
>
编辑
</Button>

View File

@@ -0,0 +1,6 @@
export function renderText(text, limit) {
if (text.length > limit) {
return text.slice(0, limit - 3) + '...';
}
return text;
}