Compare commits

...

9 Commits

Author SHA1 Message Date
JustSong
fa71daa8a7 fix: fix wrong implementation for /v1/models (close #128) 2023-05-31 14:43:29 +08:00
JustSong
54215dc303 chore: make channel test related code separated 2023-05-23 10:01:09 +08:00
JustSong
f9f42997b2 chore: only check OpenAI channel & custom channel 2023-05-23 10:00:36 +08:00
JustSong
25eab0b224 style: fix UI related problems 2023-05-22 22:41:39 +08:00
JustSong
34bce5b464 style: add positive attribute to submit buttons (close #113) 2023-05-22 22:30:11 +08:00
JustSong
d4794fc051 feat: return user's quota with billing api (close #92) 2023-05-22 17:10:31 +08:00
JustSong
8b43e0dd3f fix: add no-cache for index.html 2023-05-22 00:54:53 +08:00
JustSong
92c88fa273 fix: remove no-store for index.html 2023-05-22 00:44:27 +08:00
JustSong
38191d55be fix: do not cache index.html 2023-05-22 00:39:24 +08:00
14 changed files with 321 additions and 247 deletions

41
controller/billing.go Normal file
View File

@@ -0,0 +1,41 @@
package controller
import (
"github.com/gin-gonic/gin"
"one-api/model"
)
func GetSubscription(c *gin.Context) {
userId := c.GetInt("id")
quota, err := model.GetUserQuota(userId)
if err != nil {
openAIError := OpenAIError{
Message: err.Error(),
Type: "one_api_error",
}
c.JSON(200, gin.H{
"error": openAIError,
})
return
}
subscription := OpenAISubscriptionResponse{
Object: "billing_subscription",
HasPaymentMethod: true,
SoftLimitUSD: float64(quota),
HardLimitUSD: float64(quota),
SystemHardLimitUSD: float64(quota),
}
c.JSON(200, subscription)
return
}
func GetUsage(c *gin.Context) {
//userId := c.GetInt("id")
// TODO: get usage from database
usage := OpenAIUsageResponse{
Object: "list",
TotalUsage: 0,
}
c.JSON(200, usage)
return
}

View File

@@ -13,12 +13,27 @@ import (
"time"
)
// https://github.com/songquanpeng/one-api/issues/79
type OpenAISubscriptionResponse struct {
HasPaymentMethod bool `json:"has_payment_method"`
HardLimitUSD float64 `json:"hard_limit_usd"`
Object string `json:"object"`
HasPaymentMethod bool `json:"has_payment_method"`
SoftLimitUSD float64 `json:"soft_limit_usd"`
HardLimitUSD float64 `json:"hard_limit_usd"`
SystemHardLimitUSD float64 `json:"system_hard_limit_usd"`
}
type OpenAIUsageDailyCost struct {
Timestamp float64 `json:"timestamp"`
LineItems []struct {
Name string `json:"name"`
Cost float64 `json:"cost"`
}
}
type OpenAIUsageResponse struct {
Object string `json:"object"`
//DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
}
@@ -129,6 +144,10 @@ func updateAllChannelsBalance() error {
if channel.Status != common.ChannelStatusEnabled {
continue
}
// TODO: support Azure
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
continue
}
balance, err := updateChannelBalance(channel)
if err != nil {
continue

199
controller/channel-test.go Normal file
View File

@@ -0,0 +1,199 @@
package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"sync"
"time"
)
func testChannel(channel *model.Channel, request *ChatRequest) error {
if request.Model == "" {
request.Model = "gpt-3.5-turbo"
if channel.Type == common.ChannelTypeAzure {
request.Model = "gpt-35-turbo"
}
}
requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
} else {
if channel.Type == common.ChannelTypeCustom {
requestURL = channel.BaseURL
}
requestURL += "/v1/chat/completions"
}
jsonData, err := json.Marshal(request)
if err != nil {
return err
}
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return err
}
if channel.Type == common.ChannelTypeAzure {
req.Header.Set("api-key", channel.Key)
} else {
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return err
}
if response.Error.Message != "" || response.Error.Code != "" {
return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
}
return nil
}
func buildTestRequest(c *gin.Context) *ChatRequest {
model_ := c.Query("model")
testRequest := &ChatRequest{
Model: model_,
MaxTokens: 1,
}
testMessage := Message{
Role: "user",
Content: "hi",
}
testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest
}
func TestChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
testRequest := buildTestRequest(c)
tik := time.Now()
err = testChannel(channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
"time": consumedTime,
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"time": consumedTime,
})
return
}
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
// disable & notify
func disableChannel(channelId int, channelName string, reason string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
}
func testAllChannels(c *gin.Context) error {
testAllChannelsLock.Lock()
if testAllChannelsRunning {
testAllChannelsLock.Unlock()
return errors.New("测试已在运行中")
}
testAllChannelsRunning = true
testAllChannelsLock.Unlock()
channels, err := model.GetAllChannels(0, 0, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return err
}
testRequest := buildTestRequest(c)
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
go func() {
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
continue
}
tik := time.Now()
err := testChannel(channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if err != nil || milliseconds > disableThreshold {
if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
}
disableChannel(channel.Id, channel.Name, err.Error())
}
channel.UpdateResponseTime(milliseconds)
}
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
testAllChannelsLock.Lock()
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
}()
return nil
}
func TestAllChannels(c *gin.Context) {
err := testAllChannels(c)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}

View File

@@ -1,18 +1,12 @@
package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"sync"
"time"
)
func GetAllChannels(c *gin.Context) {
@@ -158,187 +152,3 @@ func UpdateChannel(c *gin.Context) {
})
return
}
func testChannel(channel *model.Channel, request *ChatRequest) error {
if request.Model == "" {
request.Model = "gpt-3.5-turbo"
if channel.Type == common.ChannelTypeAzure {
request.Model = "gpt-35-turbo"
}
}
requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
} else {
if channel.Type == common.ChannelTypeCustom {
requestURL = channel.BaseURL
}
requestURL += "/v1/chat/completions"
}
jsonData, err := json.Marshal(request)
if err != nil {
return err
}
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return err
}
if channel.Type == common.ChannelTypeAzure {
req.Header.Set("api-key", channel.Key)
} else {
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return err
}
if response.Error.Message != "" || response.Error.Code != "" {
return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
}
return nil
}
func buildTestRequest(c *gin.Context) *ChatRequest {
model_ := c.Query("model")
testRequest := &ChatRequest{
Model: model_,
MaxTokens: 1,
}
testMessage := Message{
Role: "user",
Content: "hi",
}
testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest
}
func TestChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
testRequest := buildTestRequest(c)
tik := time.Now()
err = testChannel(channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
"time": consumedTime,
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"time": consumedTime,
})
return
}
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
// disable & notify
func disableChannel(channelId int, channelName string, reason string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
}
func testAllChannels(c *gin.Context) error {
testAllChannelsLock.Lock()
if testAllChannelsRunning {
testAllChannelsLock.Unlock()
return errors.New("测试已在运行中")
}
testAllChannelsRunning = true
testAllChannelsLock.Unlock()
channels, err := model.GetAllChannels(0, 0, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return err
}
testRequest := buildTestRequest(c)
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
go func() {
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
continue
}
tik := time.Now()
err := testChannel(channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if err != nil || milliseconds > disableThreshold {
if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
}
disableChannel(channel.Id, channel.Name, err.Error())
}
channel.UpdateResponseTime(milliseconds)
}
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
testAllChannelsLock.Lock()
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
}()
return nil
}
func TestAllChannels(c *gin.Context) {
err := testAllChannels(c)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}

View File

@@ -23,20 +23,21 @@ type OpenAIModelPermission struct {
}
type OpenAIModels struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission OpenAIModelPermission `json:"permission"`
Root string `json:"root"`
Parent *string `json:"parent"`
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []OpenAIModelPermission `json:"permission"`
Root string `json:"root"`
Parent *string `json:"parent"`
}
var openAIModels []OpenAIModels
var openAIModelsMap map[string]OpenAIModels
func init() {
permission := OpenAIModelPermission{
var permission []OpenAIModelPermission
permission = append(permission, OpenAIModelPermission{
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
Object: "model_permission",
Created: 1626777600,
@@ -49,7 +50,7 @@ func init() {
Organization: "*",
Group: nil,
IsBlocking: false,
}
})
// https://platform.openai.com/docs/models/model-endpoint-compatibility
openAIModels = []OpenAIModels{
{
@@ -106,15 +107,6 @@ func init() {
Root: "gpt-4-32k-0314",
Parent: nil,
},
{
Id: "gpt-3.5-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo",
Parent: nil,
},
{
Id: "text-embedding-ada-002",
Object: "model",
@@ -132,7 +124,10 @@ func init() {
}
func ListModels(c *gin.Context) {
c.JSON(200, openAIModels)
c.JSON(200, gin.H{
"object": "list",
"data": openAIModels,
})
}
func RetrieveModel(c *gin.Context) {

View File

@@ -6,7 +6,11 @@ import (
func Cache() func(c *gin.Context) {
return func(c *gin.Context) {
c.Header("Cache-Control", "max-age=604800") // one week
if c.Request.RequestURI == "/" {
c.Header("Cache-Control", "no-cache")
} else {
c.Header("Cache-Control", "max-age=604800") // one week
}
c.Next()
}
}

View File

@@ -8,11 +8,14 @@ import (
)
func SetDashboardRouter(router *gin.Engine) {
apiRouter := router.Group("/dashboard")
apiRouter := router.Group("/")
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
apiRouter.Use(middleware.GlobalAPIRateLimit())
apiRouter.Use(middleware.TokenAuth())
{
apiRouter.GET("/billing/credit_grants", controller.GetTokenStatus)
apiRouter.GET("/dashboard/billing/subscription", controller.GetSubscription)
apiRouter.GET("/v1/dashboard/billing/subscription", controller.GetSubscription)
apiRouter.GET("/dashboard/billing/usage", controller.GetUsage)
apiRouter.GET("/v1/dashboard/billing/usage", controller.GetUsage)
}
}

View File

@@ -16,6 +16,7 @@ func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
router.Use(middleware.Cache())
router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/build")))
router.NoRoute(func(c *gin.Context) {
c.Header("Cache-Control", "no-cache")
c.Data(http.StatusOK, "text/html; charset=utf-8", indexPage)
})
}

View File

@@ -405,7 +405,7 @@ const ChannelsTable = () => {
<Button size='small' loading={loading} onClick={testAllChannels}>
测试所有已启用通道
</Button>
<Button size='small' onClick={updateAllChannelsBalance} loading={updatingBalance}>更新所有已启用通道余额</Button>
<Button size='small' onClick={updateAllChannelsBalance} loading={loading || updatingBalance}>更新所有已启用通道余额</Button>
<Pagination
floated='right'
activePage={activePage}

View File

@@ -167,7 +167,7 @@ const EditChannel = () => {
/>
)
}
<Button onClick={submit}>提交</Button>
<Button positive onClick={submit}>提交</Button>
</Form>
</Segment>
</>

View File

@@ -111,7 +111,7 @@ const EditRedemption = () => {
</Form.Field>
</>
}
<Button onClick={submit}>提交</Button>
<Button positive onClick={submit}>提交</Button>
</Form>
</Segment>
</>

View File

@@ -106,6 +106,34 @@ const EditToken = () => {
required={!isEdit}
/>
</Form.Field>
<Form.Field>
<Form.Input
label='过期时间'
name='expired_time'
placeholder={'请输入过期时间,格式为 yyyy-MM-dd HH:mm:ss-1 表示无限制'}
onChange={handleInputChange}
value={expired_time}
autoComplete='new-password'
type='datetime-local'
/>
</Form.Field>
<div style={{ lineHeight: '40px' }}>
<Button type={'button'} onClick={() => {
setExpiredTime(0, 0, 0, 0);
}}>永不过期</Button>
<Button type={'button'} onClick={() => {
setExpiredTime(1, 0, 0, 0);
}}>一个月后过期</Button>
<Button type={'button'} onClick={() => {
setExpiredTime(0, 1, 0, 0);
}}>一天后过期</Button>
<Button type={'button'} onClick={() => {
setExpiredTime(0, 0, 1, 0);
}}>一小时后过期</Button>
<Button type={'button'} onClick={() => {
setExpiredTime(0, 0, 0, 1);
}}>一分钟后过期</Button>
</div>
<Message>注意令牌的额度仅用于限制令牌本身的最大额度使用量实际的使用受到账户的剩余额度限制</Message>
<Form.Field>
<Form.Input
@@ -119,36 +147,10 @@ const EditToken = () => {
disabled={unlimited_quota}
/>
</Form.Field>
<Button type={'button'} style={{ marginBottom: '14px' }} onClick={() => {
<Button type={'button'} onClick={() => {
setUnlimitedQuota();
}}>{unlimited_quota ? '取消无限额度' : '设置为无限额度'}</Button>
<Form.Field>
<Form.Input
label='过期时间'
name='expired_time'
placeholder={'请输入过期时间,格式为 yyyy-MM-dd HH:mm:ss-1 表示无限制'}
onChange={handleInputChange}
value={expired_time}
autoComplete='new-password'
type='datetime-local'
/>
</Form.Field>
<Button type={'button'} onClick={() => {
setExpiredTime(0, 0, 0, 0);
}}>永不过期</Button>
<Button type={'button'} onClick={() => {
setExpiredTime(1, 0, 0, 0);
}}>一个月后过期</Button>
<Button type={'button'} onClick={() => {
setExpiredTime(0, 1, 0, 0);
}}>一天后过期</Button>
<Button type={'button'} onClick={() => {
setExpiredTime(0, 0, 1, 0);
}}>一小时后过期</Button>
<Button type={'button'} onClick={() => {
setExpiredTime(0, 0, 0, 1);
}}>一分钟后过期</Button>
<Button onClick={submit}>提交</Button>
<Button positive onClick={submit}>提交</Button>
</Form>
</Segment>
</>

View File

@@ -65,7 +65,7 @@ const AddUser = () => {
required
/>
</Form.Field>
<Button type={'submit'} onClick={submit}>
<Button positive type={'submit'} onClick={submit}>
提交
</Button>
</Form>

View File

@@ -142,7 +142,7 @@ const EditUser = () => {
readOnly
/>
</Form.Field>
<Button onClick={submit}>提交</Button>
<Button positive onClick={submit}>提交</Button>
</Form>
</Segment>
</>