feat: support configuration file (#117)

* ♻️ refactor: move file directory

* ♻️ refactor: move file directory

* ♻️ refactor: support multiple config methods

* 🔥 del: remove unused code

* 💩 refactor: Refactor channel management and synchronization

* 💄 improve: add channel website

*  feat: allow recording 0 consumption
This commit is contained in:
Buer
2024-03-20 14:12:47 +08:00
committed by GitHub
parent 0409de0ea9
commit 71171c63f5
50 changed files with 581 additions and 481 deletions

View File

@@ -147,6 +147,10 @@ func UpdateAllChannelsBalance(c *gin.Context) {
}
func AutomaticallyUpdateChannels(frequency int) {
if frequency <= 0 {
return
}
for {
time.Sleep(time.Duration(frequency) * time.Minute)
common.SysLog("updating all channels")

View File

@@ -217,6 +217,10 @@ func TestAllChannels(c *gin.Context) {
}
func AutomaticallyTestChannels(frequency int) {
if frequency <= 0 {
return
}
for {
time.Sleep(time.Duration(frequency) * time.Minute)
common.SysLog("testing all channels")

View File

@@ -1,180 +0,0 @@
package controller
import (
"fmt"
"net/http"
"one-api/common"
"one-api/model"
"one-api/types"
"sort"
"github.com/gin-gonic/gin"
)
// https://platform.openai.com/docs/api-reference/models/list
var unknownOwnedBy = "未知"
type OpenAIModelPermission struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group *string `json:"group"`
IsBlocking bool `json:"is_blocking"`
}
type OpenAIModels struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy *string `json:"owned_by"`
Permission *[]OpenAIModelPermission `json:"permission"`
Root *string `json:"root"`
Parent *string `json:"parent"`
}
var modelOwnedBy map[int]string
func init() {
modelOwnedBy = map[int]string{
common.ChannelTypeOpenAI: "OpenAI",
common.ChannelTypeAnthropic: "Anthropic",
common.ChannelTypeBaidu: "Baidu",
common.ChannelTypePaLM: "Google PaLM",
common.ChannelTypeGemini: "Google Gemini",
common.ChannelTypeZhipu: "Zhipu",
common.ChannelTypeAli: "Ali",
common.ChannelTypeXunfei: "Xunfei",
common.ChannelType360: "360",
common.ChannelTypeTencent: "Tencent",
common.ChannelTypeBaichuan: "Baichuan",
common.ChannelTypeMiniMax: "MiniMax",
common.ChannelTypeDeepseek: "Deepseek",
common.ChannelTypeMoonshot: "Moonshot",
common.ChannelTypeMistral: "Mistral",
common.ChannelTypeGroq: "Groq",
common.ChannelTypeLingyi: "Lingyiwanwu",
}
}
func ListModels(c *gin.Context) {
groupName := c.GetString("group")
if groupName == "" {
id := c.GetInt("id")
user, err := model.GetUserById(id, false)
if err != nil {
common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error())
return
}
groupName = user.Group
}
models, err := model.ChannelGroup.GetGroupModels(groupName)
if err != nil {
common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error())
return
}
sort.Strings(models)
groupOpenAIModels := make([]OpenAIModels, 0, len(models))
for _, modelId := range models {
groupOpenAIModels = append(groupOpenAIModels, OpenAIModels{
Id: modelId,
Object: "model",
Created: 1677649963,
OwnedBy: getModelOwnedBy(modelId),
Permission: nil,
Root: nil,
Parent: nil,
})
}
// 根据 OwnedBy 排序
sort.Slice(groupOpenAIModels, func(i, j int) bool {
if groupOpenAIModels[i].OwnedBy == nil {
return true // 假设 nil 值小于任何非 nil 值
}
if groupOpenAIModels[j].OwnedBy == nil {
return false // 假设任何非 nil 值大于 nil 值
}
return *groupOpenAIModels[i].OwnedBy < *groupOpenAIModels[j].OwnedBy
})
c.JSON(200, gin.H{
"object": "list",
"data": groupOpenAIModels,
})
}
func ListModelsForAdmin(c *gin.Context) {
openAIModels := make([]OpenAIModels, 0, len(common.ModelRatio))
for modelId := range common.ModelRatio {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelId,
Object: "model",
Created: 1677649963,
OwnedBy: getModelOwnedBy(modelId),
Permission: nil,
Root: nil,
Parent: nil,
})
}
// 根据 OwnedBy 排序
sort.Slice(openAIModels, func(i, j int) bool {
if openAIModels[i].OwnedBy == nil {
return true // 假设 nil 值小于任何非 nil 值
}
if openAIModels[j].OwnedBy == nil {
return false // 假设任何非 nil 值大于 nil 值
}
return *openAIModels[i].OwnedBy < *openAIModels[j].OwnedBy
})
c.JSON(200, gin.H{
"object": "list",
"data": openAIModels,
})
}
func RetrieveModel(c *gin.Context) {
modelId := c.Param("model")
ownedByName := getModelOwnedBy(modelId)
if *ownedByName != unknownOwnedBy {
c.JSON(200, OpenAIModels{
Id: modelId,
Object: "model",
Created: 1677649963,
OwnedBy: ownedByName,
Permission: nil,
Root: nil,
Parent: nil,
})
} else {
openAIError := types.OpenAIError{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Type: "invalid_request_error",
Param: "model",
Code: "model_not_found",
}
c.JSON(200, gin.H{
"error": openAIError,
})
}
}
func getModelOwnedBy(modelId string) (ownedBy *string) {
if modelType, ok := common.ModelTypes[modelId]; ok {
if ownedByName, ok := modelOwnedBy[modelType.Type]; ok {
return &ownedByName
}
}
return &unknownOwnedBy
}

View File

@@ -1,53 +0,0 @@
package relay
import (
"one-api/types"
providersBase "one-api/providers/base"
"github.com/gin-gonic/gin"
)
type relayBase struct {
c *gin.Context
provider providersBase.ProviderInterface
originalModel string
modelName string
}
type RelayBaseInterface interface {
send() (err *types.OpenAIErrorWithStatusCode, done bool)
getPromptTokens() (int, error)
setRequest() error
setProvider(modelName string) error
getProvider() providersBase.ProviderInterface
getOriginalModel() string
getModelName() string
getContext() *gin.Context
}
func (r *relayBase) setProvider(modelName string) error {
provider, modelName, fail := getProvider(r.c, modelName)
if fail != nil {
return fail
}
r.provider = provider
r.modelName = modelName
return nil
}
func (r *relayBase) getContext() *gin.Context {
return r.c
}
func (r *relayBase) getProvider() providersBase.ProviderInterface {
return r.provider
}
func (r *relayBase) getOriginalModel() string {
return r.originalModel
}
func (r *relayBase) getModelName() string {
return r.modelName
}

View File

@@ -1,76 +0,0 @@
package relay
import (
"errors"
"math"
"net/http"
"one-api/common"
"one-api/common/requester"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
type relayChat struct {
relayBase
chatRequest types.ChatCompletionRequest
}
func NewRelayChat(c *gin.Context) *relayChat {
relay := &relayChat{}
relay.c = c
return relay
}
func (r *relayChat) setRequest() error {
if err := common.UnmarshalBodyReusable(r.c, &r.chatRequest); err != nil {
return err
}
if r.chatRequest.MaxTokens < 0 || r.chatRequest.MaxTokens > math.MaxInt32/2 {
return errors.New("max_tokens is invalid")
}
r.originalModel = r.chatRequest.Model
return nil
}
func (r *relayChat) getPromptTokens() (int, error) {
return common.CountTokenMessages(r.chatRequest.Messages, r.modelName), nil
}
func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
chatProvider, ok := r.provider.(providersBase.ChatInterface)
if !ok {
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
done = true
return
}
r.chatRequest.Model = r.modelName
if r.chatRequest.Stream {
var response requester.StreamReaderInterface[string]
response, err = chatProvider.CreateChatCompletionStream(&r.chatRequest)
if err != nil {
return
}
err = responseStreamClient(r.c, response)
} else {
var response *types.ChatCompletionResponse
response, err = chatProvider.CreateChatCompletion(&r.chatRequest)
if err != nil {
return
}
err = responseJsonClient(r.c, response)
}
if err != nil {
done = true
}
return
}

View File

@@ -1,76 +0,0 @@
package relay
import (
"errors"
"math"
"net/http"
"one-api/common"
"one-api/common/requester"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
type relayCompletions struct {
relayBase
request types.CompletionRequest
}
func NewRelayCompletions(c *gin.Context) *relayCompletions {
relay := &relayCompletions{}
relay.c = c
return relay
}
func (r *relayCompletions) setRequest() error {
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
return err
}
if r.request.MaxTokens < 0 || r.request.MaxTokens > math.MaxInt32/2 {
return errors.New("max_tokens is invalid")
}
r.originalModel = r.request.Model
return nil
}
func (r *relayCompletions) getPromptTokens() (int, error) {
return common.CountTokenInput(r.request.Prompt, r.modelName), nil
}
func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
provider, ok := r.provider.(providersBase.CompletionInterface)
if !ok {
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
done = true
return
}
r.request.Model = r.modelName
if r.request.Stream {
var response requester.StreamReaderInterface[string]
response, err = provider.CreateCompletionStream(&r.request)
if err != nil {
return
}
err = responseStreamClient(r.c, response)
} else {
var response *types.CompletionResponse
response, err = provider.CreateCompletion(&r.request)
if err != nil {
return
}
err = responseJsonClient(r.c, response)
}
if err != nil {
done = true
}
return
}

View File

@@ -1,63 +0,0 @@
package relay
import (
"net/http"
"one-api/common"
providersBase "one-api/providers/base"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
type relayEmbeddings struct {
relayBase
request types.EmbeddingRequest
}
func NewRelayEmbeddings(c *gin.Context) *relayEmbeddings {
relay := &relayEmbeddings{}
relay.c = c
return relay
}
func (r *relayEmbeddings) setRequest() error {
if strings.HasSuffix(r.c.Request.URL.Path, "embeddings") {
r.request.Model = r.c.Param("model")
}
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
return err
}
r.originalModel = r.request.Model
return nil
}
func (r *relayEmbeddings) getPromptTokens() (int, error) {
return common.CountTokenInput(r.request.Input, r.modelName), nil
}
func (r *relayEmbeddings) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
provider, ok := r.provider.(providersBase.EmbeddingsInterface)
if !ok {
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
done = true
return
}
r.request.Model = r.modelName
response, err := provider.CreateEmbeddings(&r.request)
if err != nil {
return
}
err = responseJsonClient(r.c, response)
if err != nil {
done = true
}
return
}

View File

@@ -1,71 +0,0 @@
package relay
import (
"errors"
"net/http"
"one-api/common"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
type relayImageEdits struct {
relayBase
request types.ImageEditRequest
}
func NewRelayImageEdits(c *gin.Context) *relayImageEdits {
relay := &relayImageEdits{}
relay.c = c
return relay
}
func (r *relayImageEdits) setRequest() error {
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
return err
}
if r.request.Prompt == "" {
return errors.New("field prompt is required")
}
if r.request.Model == "" {
r.request.Model = "dall-e-2"
}
if r.request.Size == "" {
r.request.Size = "1024x1024"
}
r.originalModel = r.request.Model
return nil
}
func (r *relayImageEdits) getPromptTokens() (int, error) {
return common.CountTokenImage(r.request)
}
func (r *relayImageEdits) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
provider, ok := r.provider.(providersBase.ImageEditsInterface)
if !ok {
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
done = true
return
}
r.request.Model = r.modelName
response, err := provider.CreateImageEdits(&r.request)
if err != nil {
return
}
err = responseJsonClient(r.c, response)
if err != nil {
done = true
}
return
}

View File

@@ -1,74 +0,0 @@
package relay
import (
"net/http"
"one-api/common"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
type relayImageGenerations struct {
relayBase
request types.ImageRequest
}
func NewRelayImageGenerations(c *gin.Context) *relayImageGenerations {
relay := &relayImageGenerations{}
relay.c = c
return relay
}
func (r *relayImageGenerations) setRequest() error {
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
return err
}
if r.request.Model == "" {
r.request.Model = "dall-e-2"
}
if r.request.N == 0 {
r.request.N = 1
}
if r.request.Size == "" {
r.request.Size = "1024x1024"
}
if r.request.Quality == "" {
r.request.Quality = "standard"
}
r.originalModel = r.request.Model
return nil
}
func (r *relayImageGenerations) getPromptTokens() (int, error) {
return common.CountTokenImage(r.request)
}
func (r *relayImageGenerations) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
provider, ok := r.provider.(providersBase.ImageGenerationsInterface)
if !ok {
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
done = true
return
}
r.request.Model = r.modelName
response, err := provider.CreateImageGenerations(&r.request)
if err != nil {
return
}
err = responseJsonClient(r.c, response)
if err != nil {
done = true
}
return
}

View File

@@ -1,66 +0,0 @@
package relay
import (
"net/http"
"one-api/common"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
type relayImageVariations struct {
relayBase
request types.ImageEditRequest
}
func NewRelayImageVariations(c *gin.Context) *relayImageVariations {
relay := &relayImageVariations{}
relay.c = c
return relay
}
func (r *relayImageVariations) setRequest() error {
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
return err
}
if r.request.Model == "" {
r.request.Model = "dall-e-2"
}
if r.request.Size == "" {
r.request.Size = "1024x1024"
}
r.originalModel = r.request.Model
return nil
}
func (r *relayImageVariations) getPromptTokens() (int, error) {
return common.CountTokenImage(r.request)
}
func (r *relayImageVariations) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
provider, ok := r.provider.(providersBase.ImageVariationsInterface)
if !ok {
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
done = true
return
}
r.request.Model = r.modelName
response, err := provider.CreateImageVariations(&r.request)
if err != nil {
return
}
err = responseJsonClient(r.c, response)
if err != nil {
done = true
}
return
}

View File

@@ -1,106 +0,0 @@
package relay
import (
"fmt"
"net/http"
"one-api/common"
"one-api/model"
"one-api/types"
"github.com/gin-gonic/gin"
)
func Relay(c *gin.Context) {
relay := Path2Relay(c, c.Request.URL.Path)
if relay == nil {
common.AbortWithMessage(c, http.StatusNotFound, "Not Found")
return
}
if err := relay.setRequest(); err != nil {
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
return
}
if err := relay.setProvider(relay.getOriginalModel()); err != nil {
common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error())
return
}
apiErr, done := RelayHandler(relay)
if apiErr == nil {
return
}
channel := relay.getProvider().GetChannel()
go processChannelRelayError(c.Request.Context(), channel.Id, channel.Name, apiErr)
retryTimes := common.RetryTimes
if done || !shouldRetry(c, apiErr.StatusCode) {
common.LogError(c.Request.Context(), fmt.Sprintf("relay error happen, status code is %d, won't retry in this case", apiErr.StatusCode))
retryTimes = 0
}
for i := retryTimes; i > 0; i-- {
// 冻结通道
model.ChannelGroup.Cooldowns(channel.Id)
if err := relay.setProvider(relay.getOriginalModel()); err != nil {
continue
}
channel = relay.getProvider().GetChannel()
common.LogError(c.Request.Context(), fmt.Sprintf("using channel #%d(%s) to retry (remain times %d)", channel.Id, channel.Name, i))
apiErr, done = RelayHandler(relay)
if apiErr == nil {
return
}
go processChannelRelayError(c.Request.Context(), channel.Id, channel.Name, apiErr)
if done || !shouldRetry(c, apiErr.StatusCode) {
break
}
}
if apiErr != nil {
requestId := c.GetString(common.RequestIdKey)
if apiErr.StatusCode == http.StatusTooManyRequests {
apiErr.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
}
apiErr.OpenAIError.Message = common.MessageWithRequestId(apiErr.OpenAIError.Message, requestId)
c.JSON(apiErr.StatusCode, gin.H{
"error": apiErr.OpenAIError,
})
}
}
func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCode, done bool) {
promptTokens, tonkeErr := relay.getPromptTokens()
if tonkeErr != nil {
err = common.ErrorWrapper(tonkeErr, "token_error", http.StatusBadRequest)
done = true
return
}
usage := &types.Usage{
PromptTokens: promptTokens,
}
relay.getProvider().SetUsage(usage)
var quotaInfo *QuotaInfo
quotaInfo, err = generateQuotaInfo(relay.getContext(), relay.getModelName(), promptTokens)
if err != nil {
done = true
return
}
err, done = relay.send()
if err != nil {
quotaInfo.undo(relay.getContext())
return
}
quotaInfo.consume(relay.getContext(), usage)
return
}

View File

@@ -1,62 +0,0 @@
package relay
import (
"net/http"
"one-api/common"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
type relayModerations struct {
relayBase
request types.ModerationRequest
}
func NewRelayModerations(c *gin.Context) *relayModerations {
relay := &relayModerations{}
relay.c = c
return relay
}
func (r *relayModerations) setRequest() error {
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
return err
}
if r.request.Model == "" {
r.request.Model = "text-moderation-stable"
}
r.originalModel = r.request.Model
return nil
}
func (r *relayModerations) getPromptTokens() (int, error) {
return common.CountTokenInput(r.request.Input, r.modelName), nil
}
func (r *relayModerations) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
provider, ok := r.provider.(providersBase.ModerationInterface)
if !ok {
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
done = true
return
}
r.request.Model = r.modelName
response, err := provider.CreateModeration(&r.request)
if err != nil {
return
}
err = responseJsonClient(r.c, response)
if err != nil {
done = true
}
return
}

View File

@@ -1,169 +0,0 @@
package relay
import (
"context"
"errors"
"fmt"
"math"
"net/http"
"one-api/common"
"one-api/model"
"one-api/types"
"time"
"github.com/gin-gonic/gin"
)
type QuotaInfo struct {
modelName string
promptTokens int
preConsumedTokens int
modelRatio []float64
groupRatio float64
ratio float64
preConsumedQuota int
userId int
channelId int
tokenId int
HandelStatus bool
}
func generateQuotaInfo(c *gin.Context, modelName string, promptTokens int) (*QuotaInfo, *types.OpenAIErrorWithStatusCode) {
quotaInfo := &QuotaInfo{
modelName: modelName,
promptTokens: promptTokens,
userId: c.GetInt("id"),
channelId: c.GetInt("channel_id"),
tokenId: c.GetInt("token_id"),
HandelStatus: false,
}
quotaInfo.initQuotaInfo(c.GetString("group"))
errWithCode := quotaInfo.preQuotaConsumption()
if errWithCode != nil {
return nil, errWithCode
}
return quotaInfo, nil
}
func (q *QuotaInfo) initQuotaInfo(groupName string) {
modelRatio := common.GetModelRatio(q.modelName)
groupRatio := common.GetGroupRatio(groupName)
preConsumedTokens := common.PreConsumedQuota
ratio := modelRatio[0] * groupRatio
preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio)
q.preConsumedTokens = preConsumedTokens
q.modelRatio = modelRatio
q.groupRatio = groupRatio
q.ratio = ratio
q.preConsumedQuota = preConsumedQuota
}
func (q *QuotaInfo) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
userQuota, err := model.CacheGetUserQuota(q.userId)
if err != nil {
return common.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota < q.preConsumedQuota {
return common.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(q.userId, q.preConsumedQuota)
if err != nil {
return common.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*q.preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
q.preConsumedQuota = 0
// common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
if q.preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(q.tokenId, q.preConsumedQuota)
if err != nil {
return common.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
q.HandelStatus = true
}
return nil
}
func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error {
quota := 0
completionRatio := q.modelRatio[1] * q.groupRatio
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
quota = int(math.Ceil(((float64(promptTokens) * q.ratio) + (float64(completionTokens) * completionRatio))))
if q.ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - q.preConsumedQuota
err := model.PostConsumeTokenQuota(q.tokenId, quotaDelta)
if err != nil {
return errors.New("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(q.userId)
if err != nil {
return errors.New("error consuming token remain quota: " + err.Error())
}
if quota != 0 {
requestTime := 0
requestStartTimeValue := ctx.Value("requestStartTime")
if requestStartTimeValue != nil {
requestStartTime, ok := requestStartTimeValue.(time.Time)
if ok {
requestTime = int(time.Since(requestStartTime).Milliseconds())
}
}
var modelRatioStr string
if q.modelRatio[0] == q.modelRatio[1] {
modelRatioStr = fmt.Sprintf("%.2f", q.modelRatio[0])
} else {
modelRatioStr = fmt.Sprintf("%.2f (输入)/%.2f (输出)", q.modelRatio[0], q.modelRatio[1])
}
logContent := fmt.Sprintf("模型倍率 %s分组倍率 %.2f", modelRatioStr, q.groupRatio)
model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent, requestTime)
model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota)
model.UpdateChannelUsedQuota(q.channelId, quota)
}
return nil
}
func (q *QuotaInfo) undo(c *gin.Context) {
tokenId := c.GetInt("token_id")
if q.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -q.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
}
func (q *QuotaInfo) consume(c *gin.Context, usage *types.Usage) {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err := q.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}

View File

@@ -1,58 +0,0 @@
package relay
import (
"net/http"
"one-api/common"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
type relaySpeech struct {
relayBase
request types.SpeechAudioRequest
}
func NewRelaySpeech(c *gin.Context) *relaySpeech {
relay := &relaySpeech{}
relay.c = c
return relay
}
func (r *relaySpeech) setRequest() error {
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
return err
}
r.originalModel = r.request.Model
return nil
}
func (r *relaySpeech) getPromptTokens() (int, error) {
return len(r.request.Input), nil
}
func (r *relaySpeech) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
provider, ok := r.provider.(providersBase.SpeechInterface)
if !ok {
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
done = true
return
}
r.request.Model = r.modelName
response, err := provider.CreateSpeech(&r.request)
if err != nil {
return
}
err = responseMultipart(r.c, response)
if err != nil {
done = true
}
return
}

View File

@@ -1,58 +0,0 @@
package relay
import (
"net/http"
"one-api/common"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
type relayTranscriptions struct {
relayBase
request types.AudioRequest
}
func NewRelayTranscriptions(c *gin.Context) *relayTranscriptions {
relay := &relayTranscriptions{}
relay.c = c
return relay
}
func (r *relayTranscriptions) setRequest() error {
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
return err
}
r.originalModel = r.request.Model
return nil
}
func (r *relayTranscriptions) getPromptTokens() (int, error) {
return 0, nil
}
func (r *relayTranscriptions) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
provider, ok := r.provider.(providersBase.TranscriptionsInterface)
if !ok {
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
done = true
return
}
r.request.Model = r.modelName
response, err := provider.CreateTranscriptions(&r.request)
if err != nil {
return
}
err = responseCustom(r.c, response)
if err != nil {
done = true
}
return
}

View File

@@ -1,58 +0,0 @@
package relay
import (
"net/http"
"one-api/common"
providersBase "one-api/providers/base"
"one-api/types"
"github.com/gin-gonic/gin"
)
type relayTranslations struct {
relayBase
request types.AudioRequest
}
func NewRelayTranslations(c *gin.Context) *relayTranslations {
relay := &relayTranslations{}
relay.c = c
return relay
}
func (r *relayTranslations) setRequest() error {
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
return err
}
r.originalModel = r.request.Model
return nil
}
func (r *relayTranslations) getPromptTokens() (int, error) {
return 0, nil
}
func (r *relayTranslations) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
provider, ok := r.provider.(providersBase.TranslationInterface)
if !ok {
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
done = true
return
}
r.request.Model = r.modelName
response, err := provider.CreateTranslation(&r.request)
if err != nil {
return
}
err = responseCustom(r.c, response)
if err != nil {
done = true
}
return
}

View File

@@ -1,202 +0,0 @@
package relay
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/controller"
"one-api/model"
"one-api/providers"
providersBase "one-api/providers/base"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
func Path2Relay(c *gin.Context, path string) RelayBaseInterface {
if strings.HasPrefix(path, "/v1/chat/completions") {
return NewRelayChat(c)
} else if strings.HasPrefix(path, "/v1/completions") {
return NewRelayCompletions(c)
} else if strings.HasPrefix(path, "/v1/embeddings") {
return NewRelayEmbeddings(c)
} else if strings.HasPrefix(path, "/v1/moderations") {
return NewRelayModerations(c)
} else if strings.HasPrefix(path, "/v1/images/generations") {
return NewRelayImageGenerations(c)
} else if strings.HasPrefix(path, "/v1/images/edits") {
return NewRelayImageEdits(c)
} else if strings.HasPrefix(path, "/v1/images/variations") {
return NewRelayImageVariations(c)
} else if strings.HasPrefix(path, "/v1/audio/speech") {
return NewRelaySpeech(c)
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
return NewRelayTranscriptions(c)
} else if strings.HasPrefix(path, "/v1/audio/translations") {
return NewRelayTranslations(c)
}
return nil
}
func getProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) {
channel, fail := fetchChannel(c, modeName)
if fail != nil {
return
}
c.Set("channel_id", channel.Id)
provider = providers.GetProvider(channel, c)
if provider == nil {
fail = errors.New("channel not found")
return
}
provider.SetOriginalModel(modeName)
newModelName, fail = provider.ModelMappingHandler(modeName)
if fail != nil {
return
}
return
}
func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fail error) {
channelId := c.GetInt("specific_channel_id")
if channelId > 0 {
return fetchChannelById(channelId)
}
return fetchChannelByModel(c, modelName)
}
func fetchChannelById(channelId int) (*model.Channel, error) {
channel, err := model.GetChannelById(channelId, true)
if err != nil {
return nil, errors.New("无效的渠道 Id")
}
if channel.Status != common.ChannelStatusEnabled {
return nil, errors.New("该渠道已被禁用")
}
return channel, nil
}
func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, error) {
group := c.GetString("group")
channel, err := model.ChannelGroup.Next(group, modelName)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, modelName)
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
return nil, errors.New(message)
}
return channel, nil
}
func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWithStatusCode {
// 将data转换为 JSON
responseBody, err := json.Marshal(data)
if err != nil {
return common.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(http.StatusOK)
_, err = c.Writer.Write(responseBody)
if err != nil {
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
}
return nil
}
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string]) *types.OpenAIErrorWithStatusCode {
requester.SetEventStreamHeaders(c)
dataChan, errChan := stream.Recv()
defer stream.Close()
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
fmt.Fprintln(w, "data: "+data+"\n")
return true
case err := <-errChan:
if !errors.Is(err, io.EOF) {
fmt.Fprintln(w, "data: "+err.Error()+"\n")
}
fmt.Fprintln(w, "data: [DONE]")
return false
}
})
return nil
}
func responseMultipart(c *gin.Context, resp *http.Response) *types.OpenAIErrorWithStatusCode {
defer resp.Body.Close()
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err := io.Copy(c.Writer, resp.Body)
if err != nil {
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
}
return nil
}
func responseCustom(c *gin.Context, response *types.AudioResponseWrapper) *types.OpenAIErrorWithStatusCode {
for k, v := range response.Headers {
c.Writer.Header().Set(k, v)
}
c.Writer.WriteHeader(http.StatusOK)
_, err := c.Writer.Write(response.Body)
if err != nil {
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
}
return nil
}
func shouldRetry(c *gin.Context, statusCode int) bool {
channelId := c.GetInt("specific_channel_id")
if channelId > 0 {
return false
}
if statusCode == http.StatusTooManyRequests {
return true
}
if statusCode/100 == 5 {
return true
}
if statusCode == http.StatusBadRequest {
return false
}
if statusCode/100 == 2 {
return false
}
return true
}
func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *types.OpenAIErrorWithStatusCode) {
common.LogError(ctx, fmt.Sprintf("relay error (channel #%d(%s)): %s", channelId, channelName, err.Message))
if controller.ShouldDisableChannel(&err.OpenAIError, err.StatusCode) {
controller.DisableChannel(channelId, channelName, err.Message)
}
}