Compare commits

..

18 Commits

Author SHA1 Message Date
JustSong
3af4649b52 fix: only check model when request path in whitelist 2024-04-06 20:42:35 +08:00
GAI Group
52c32c0b4a chore: resolve the issue of onclick event scope for custom Lark button (#1281)
chore: Resolve the issue of onclick event scope for custom Lark button
2024-04-06 20:08:05 +08:00
Buer
3fe2863ff7 feat: berry theme update & bug fix (#1282)
* ️ improve: delete google fonts

* ️ improve: Optimized priority input handling in TableRow component.

* 🔖 chore: channel batch add

*  feat: add dark mod

*  feat: support token limit ip range and models

*  feat: add MessagePusher

*  feat: add lark login
2024-04-06 19:44:23 +08:00
JustSong
acf8cb6248 chore: update default nextweb link 2024-04-06 11:47:31 +08:00
JustSong
572fc9ffb8 fix: fix stepfun model ratio & id 2024-04-06 10:43:54 +08:00
GAI Group
569c04acb0 fix: fix Lark icon button style (#1279) 2024-04-06 10:18:59 +08:00
JustSong
961b4108e6 chore: fix refactor caused typo 2024-04-06 02:12:50 +08:00
JustSong
0b8ccb94eb chore: reorganize common package 2024-04-06 02:03:59 +08:00
JustSong
f586ae0ad8 chore: remove helper & util subpackage for relay 2024-04-06 01:50:12 +08:00
JustSong
24ed170e7b chore: reorganize adaptor related package 2024-04-06 01:36:48 +08:00
JustSong
f70506eac1 chore: reorganize relay related package 2024-04-06 01:31:44 +08:00
JustSong
8f4d78e24d chore: reorganize billing related package 2024-04-06 01:26:48 +08:00
JustSong
cd2707692f chore: reorganize billing related package 2024-04-06 01:09:23 +08:00
JustSong
2ab7d25a80 chore: reorganize helper related package 2024-04-06 01:02:35 +08:00
JustSong
f9d914873f chore: reorganize constant related package 2024-04-06 00:44:33 +08:00
JustSong
880e12c855 feat: support cogview-3 2024-04-06 00:30:08 +08:00
JustSong
0cb224e62e chore: fix typo 2024-04-05 23:55:25 +08:00
JustSong
a44fb5d482 fix: fix channel model list is empty 2024-04-05 23:44:57 +08:00
171 changed files with 2279 additions and 1510 deletions

9
common/config/key.go Normal file
View File

@@ -0,0 +1,9 @@
package config
const (
KeyPrefix = "cfg_"
KeyAPIVersion = KeyPrefix + "api_version"
KeyLibraryID = KeyPrefix + "library_id"
KeyPlugin = KeyPrefix + "plugin"
)

View File

@@ -4,118 +4,3 @@ import "time"
var StartTime = time.Now().Unix() // unit: second var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
const (
RoleGuestUser = 0
RoleCommonUser = 1
RoleAdminUser = 10
RoleRootUser = 100
)
const (
UserStatusEnabled = 1 // don't use 0, 0 is the default value!
UserStatusDisabled = 2 // also don't use 0
UserStatusDeleted = 3
)
const (
TokenStatusEnabled = 1 // don't use 0, 0 is the default value!
TokenStatusDisabled = 2 // also don't use 0
TokenStatusExpired = 3
TokenStatusExhausted = 4
)
const (
RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value!
RedemptionCodeStatusDisabled = 2 // also don't use 0
RedemptionCodeStatusUsed = 3 // also don't use 0
)
const (
ChannelStatusUnknown = 0
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
ChannelStatusManuallyDisabled = 2 // also don't use 0
ChannelStatusAutoDisabled = 3
)
const (
ChannelTypeUnknown = iota
ChannelTypeOpenAI
ChannelTypeAPI2D
ChannelTypeAzure
ChannelTypeCloseAI
ChannelTypeOpenAISB
ChannelTypeOpenAIMax
ChannelTypeOhMyGPT
ChannelTypeCustom
ChannelTypeAILS
ChannelTypeAIProxy
ChannelTypePaLM
ChannelTypeAPI2GPT
ChannelTypeAIGC2D
ChannelTypeAnthropic
ChannelTypeBaidu
ChannelTypeZhipu
ChannelTypeAli
ChannelTypeXunfei
ChannelType360
ChannelTypeOpenRouter
ChannelTypeAIProxyLibrary
ChannelTypeFastGPT
ChannelTypeTencent
ChannelTypeGemini
ChannelTypeMoonshot
ChannelTypeBaichuan
ChannelTypeMinimax
ChannelTypeMistral
ChannelTypeGroq
ChannelTypeOllama
ChannelTypeLingYiWanWu
ChannelTypeStepFun
ChannelTypeDummy
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"https://generativelanguage.googleapis.com", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", // 23
"https://generativelanguage.googleapis.com", // 24
"https://api.moonshot.cn", // 25
"https://api.baichuan-ai.com", // 26
"https://api.minimax.chat", // 27
"https://api.mistral.ai", // 28
"https://api.groq.com/openai", // 29
"http://localhost:11434", // 30
"https://api.lingyiwanwu.com", // 31
"https://api.stepfun.com", // 32
}
const (
ConfigKeyPrefix = "cfg_"
ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version"
ConfigKeyLibraryID = ConfigKeyPrefix + "library_id"
ConfigKeyPlugin = ConfigKeyPrefix + "plugin"
)

View File

@@ -2,16 +2,14 @@ package helper
import ( import (
"fmt" "fmt"
"github.com/google/uuid" "github.com/songquanpeng/one-api/common/random"
"html/template" "html/template"
"log" "log"
"math/rand"
"net" "net"
"os/exec" "os/exec"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"time"
) )
func OpenBrowser(url string) { func OpenBrowser(url string) {
@@ -79,31 +77,6 @@ func Bytes2Size(num int64) string {
return numStr + " " + unit return numStr + " " + unit
} }
func Seconds2Time(num int) (time string) {
if num/31104000 > 0 {
time += strconv.Itoa(num/31104000) + " 年 "
num %= 31104000
}
if num/2592000 > 0 {
time += strconv.Itoa(num/2592000) + " 个月 "
num %= 2592000
}
if num/86400 > 0 {
time += strconv.Itoa(num/86400) + " 天 "
num %= 86400
}
if num/3600 > 0 {
time += strconv.Itoa(num/3600) + " 小时 "
num %= 3600
}
if num/60 > 0 {
time += strconv.Itoa(num/60) + " 分钟 "
num %= 60
}
time += strconv.Itoa(num) + " 秒"
return
}
func Interface2String(inter interface{}) string { func Interface2String(inter interface{}) string {
switch inter := inter.(type) { switch inter := inter.(type) {
case string: case string:
@@ -128,65 +101,8 @@ func IntMax(a int, b int) int {
} }
} }
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
return code
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const keyNumbers = "0123456789"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetRandomNumberString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyNumbers[rand.Intn(len(keyNumbers))]
}
return string(key)
}
func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
func GenRequestID() string { func GenRequestID() string {
return GetTimeString() + GetRandomNumberString(8) return GetTimeString() + random.GetRandomNumberString(8)
} }
func Max(a int, b int) int { func Max(a int, b int) int {

15
common/helper/time.go Normal file
View File

@@ -0,0 +1,15 @@
package helper
import (
"fmt"
"time"
)
func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}

View File

@@ -1,8 +0,0 @@
package common
import "math/rand"
// RandRange returns a random number between min and max (max is not included)
func RandRange(min, max int) int {
return min + rand.Intn(max-min)
}

61
common/random/main.go Normal file
View File

@@ -0,0 +1,61 @@
package random
import (
"github.com/google/uuid"
"math/rand"
"strings"
"time"
)
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
return code
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const keyNumbers = "0123456789"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetRandomNumberString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyNumbers[rand.Intn(len(keyNumbers))]
}
return string(key)
}
// RandRange returns a random number between min and max (max is not included)
func RandRange(min, max int) int {
return min + rand.Intn(max-min)
}

View File

@@ -7,10 +7,9 @@ import (
"fmt" "fmt"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http" "net/http"
@@ -134,8 +133,8 @@ func GitHubOAuth(c *gin.Context) {
user.DisplayName = "GitHub User" user.DisplayName = "GitHub User"
} }
user.Email = githubUser.Email user.Email = githubUser.Email
user.Role = common.RoleCommonUser user.Role = model.RoleCommonUser
user.Status = common.UserStatusEnabled user.Status = model.UserStatusEnabled
if err := user.Insert(0); err != nil { if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -153,7 +152,7 @@ func GitHubOAuth(c *gin.Context) {
} }
} }
if user.Status != common.UserStatusEnabled { if user.Status != model.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁", "message": "用户已被封禁",
"success": false, "success": false,
@@ -220,7 +219,7 @@ func GitHubBind(c *gin.Context) {
func GenerateOAuthCode(c *gin.Context) { func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c) session := sessions.Default(c)
state := helper.GetRandomString(12) state := random.GetRandomString(12)
session.Set("oauth_state", state) session.Set("oauth_state", state)
err := session.Save() err := session.Save()
if err != nil { if err != nil {

View File

@@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
@@ -123,8 +122,8 @@ func LarkOAuth(c *gin.Context) {
} else { } else {
user.DisplayName = "Lark User" user.DisplayName = "Lark User"
} }
user.Role = common.RoleCommonUser user.Role = model.RoleCommonUser
user.Status = common.UserStatusEnabled user.Status = model.UserStatusEnabled
if err := user.Insert(0); err != nil { if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -142,7 +141,7 @@ func LarkOAuth(c *gin.Context) {
} }
} }
if user.Status != common.UserStatusEnabled { if user.Status != model.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁", "message": "用户已被封禁",
"success": false, "success": false,

View File

@@ -5,7 +5,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
@@ -84,8 +83,8 @@ func WeChatAuth(c *gin.Context) {
if config.RegisterEnabled { if config.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
user.DisplayName = "WeChat User" user.DisplayName = "WeChat User"
user.Role = common.RoleCommonUser user.Role = model.RoleCommonUser
user.Status = common.UserStatusEnabled user.Status = model.UserStatusEnabled
if err := user.Insert(0); err != nil { if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -103,7 +102,7 @@ func WeChatAuth(c *gin.Context) {
} }
} }
if user.Status != common.UserStatusEnabled { if user.Status != model.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁", "message": "用户已被封禁",
"success": false, "success": false,

View File

@@ -4,12 +4,12 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/client"
"io" "io"
"net/http" "net/http"
"strconv" "strconv"
@@ -96,7 +96,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
for k := range headers { for k := range headers {
req.Header.Add(k, headers.Get(k)) req.Header.Add(k, headers.Get(k))
} }
res, err := util.HTTPClient.Do(req) res, err := client.HTTPClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -204,28 +204,28 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
} }
func updateChannelBalance(channel *model.Channel) (float64, error) { func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type] baseURL := channeltype.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" { if channel.GetBaseURL() == "" {
channel.BaseURL = &baseURL channel.BaseURL = &baseURL
} }
switch channel.Type { switch channel.Type {
case common.ChannelTypeOpenAI: case channeltype.OpenAI:
if channel.GetBaseURL() != "" { if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL() baseURL = channel.GetBaseURL()
} }
case common.ChannelTypeAzure: case channeltype.Azure:
return 0, errors.New("尚未实现") return 0, errors.New("尚未实现")
case common.ChannelTypeCustom: case channeltype.Custom:
baseURL = channel.GetBaseURL() baseURL = channel.GetBaseURL()
case common.ChannelTypeCloseAI: case channeltype.CloseAI:
return updateChannelCloseAIBalance(channel) return updateChannelCloseAIBalance(channel)
case common.ChannelTypeOpenAISB: case channeltype.OpenAISB:
return updateChannelOpenAISBBalance(channel) return updateChannelOpenAISBBalance(channel)
case common.ChannelTypeAIProxy: case channeltype.AIProxy:
return updateChannelAIProxyBalance(channel) return updateChannelAIProxyBalance(channel)
case common.ChannelTypeAPI2GPT: case channeltype.API2GPT:
return updateChannelAPI2GPTBalance(channel) return updateChannelAPI2GPTBalance(channel)
case common.ChannelTypeAIGC2D: case channeltype.AIGC2D:
return updateChannelAIGC2DBalance(channel) return updateChannelAIGC2DBalance(channel)
default: default:
return 0, errors.New("尚未实现") return 0, errors.New("尚未实现")
@@ -301,11 +301,11 @@ func updateAllChannelsBalance() error {
return err return err
} }
for _, channel := range channels { for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled { if channel.Status != model.ChannelStatusEnabled {
continue continue
} }
// TODO: support Azure // TODO: support Azure
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { if channel.Type != channeltype.OpenAI && channel.Type != channeltype.Custom {
continue continue
} }
balance, err := updateChannelBalance(channel) balance, err := updateChannelBalance(channel)

View File

@@ -5,17 +5,18 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message" "github.com/songquanpeng/one-api/common/message"
"github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/constant" relay "github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/helper" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -56,9 +57,9 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
c.Set("channel", channel.Type) c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
middleware.SetupContextForSelectedChannel(c, channel, "") middleware.SetupContextForSelectedChannel(c, channel, "")
meta := util.GetRelayMeta(c) meta := meta.GetByContext(c)
apiType := constant.ChannelType2APIType(channel.Type) apiType := channeltype.ToAPIType(channel.Type)
adaptor := helper.GetAdaptor(apiType) adaptor := relay.GetAdaptor(apiType)
if adaptor == nil { if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
} }
@@ -73,7 +74,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
request := buildTestRequest() request := buildTestRequest()
request.Model = modelName request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName meta.OriginModelName, meta.ActualModelName = modelName, modelName
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request) convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
if err != nil { if err != nil {
return err, nil return err, nil
} }
@@ -88,7 +89,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
return err, nil return err, nil
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
err := util.RelayErrorHandler(resp) err := controller.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
} }
usage, respErr := adaptor.DoResponse(c, resp, meta) usage, respErr := adaptor.DoResponse(c, resp, meta)
@@ -171,7 +172,7 @@ func testChannels(notify bool, scope string) error {
} }
go func() { go func() {
for _, channel := range channels { for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled isChannelEnabled := channel.Status == model.ChannelStatusEnabled
tik := time.Now() tik := time.Now()
err, openaiErr := testChannel(channel) err, openaiErr := testChannel(channel)
tok := time.Now() tok := time.Now()
@@ -184,10 +185,10 @@ func testChannels(notify bool, scope string) error {
_ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s %d测试超时", channel.Name, channel.Id), "", err.Error()) _ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s %d测试超时", channel.Name, channel.Id), "", err.Error())
} }
} }
if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { if isChannelEnabled && monitor.ShouldDisableChannel(openaiErr, -1) {
monitor.DisableChannel(channel.Id, channel.Name, err.Error()) monitor.DisableChannel(channel.Id, channel.Name, err.Error())
} }
if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { if !isChannelEnabled && monitor.ShouldEnableChannel(err, openaiErr) {
monitor.EnableChannel(channel.Id, channel.Name) monitor.EnableChannel(channel.Id, channel.Name)
} }
channel.UpdateResponseTime(milliseconds) channel.UpdateResponseTime(milliseconds)

View File

@@ -2,13 +2,13 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"net/http" "net/http"
) )
func GetGroups(c *gin.Context) { func GetGroups(c *gin.Context) {
groupNames := make([]string, 0) groupNames := make([]string, 0)
for groupName := range common.GroupRatio { for groupName := range billingratio.GroupRatio {
groupNames = append(groupNames, groupName) groupNames = append(groupNames, groupName)
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -3,13 +3,13 @@ package controller
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai" relay "github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/helper" "github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"net/http" "net/http"
"strings" "strings"
) )
@@ -41,8 +41,8 @@ type OpenAIModels struct {
Parent *string `json:"parent"` Parent *string `json:"parent"`
} }
var openAIModels []OpenAIModels var models []OpenAIModels
var openAIModelsMap map[string]OpenAIModels var modelsMap map[string]OpenAIModels
var channelId2Models map[int][]string var channelId2Models map[int][]string
func init() { func init() {
@@ -62,15 +62,15 @@ func init() {
IsBlocking: false, IsBlocking: false,
}) })
// https://platform.openai.com/docs/models/model-endpoint-compatibility // https://platform.openai.com/docs/models/model-endpoint-compatibility
for i := 0; i < constant.APITypeDummy; i++ { for i := 0; i < apitype.Dummy; i++ {
if i == constant.APITypeAIProxyLibrary { if i == apitype.AIProxyLibrary {
continue continue
} }
adaptor := helper.GetAdaptor(i) adaptor := relay.GetAdaptor(i)
channelName := adaptor.GetChannelName() channelName := adaptor.GetChannelName()
modelNames := adaptor.GetModelList() modelNames := adaptor.GetModelList()
for _, modelName := range modelNames { for _, modelName := range modelNames {
openAIModels = append(openAIModels, OpenAIModels{ models = append(models, OpenAIModels{
Id: modelName, Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
@@ -82,12 +82,12 @@ func init() {
} }
} }
for _, channelType := range openai.CompatibleChannels { for _, channelType := range openai.CompatibleChannels {
if channelType == common.ChannelTypeAzure { if channelType == channeltype.Azure {
continue continue
} }
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
for _, modelName := range channelModelList { for _, modelName := range channelModelList {
openAIModels = append(openAIModels, OpenAIModels{ models = append(models, OpenAIModels{
Id: modelName, Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
@@ -98,14 +98,14 @@ func init() {
}) })
} }
} }
openAIModelsMap = make(map[string]OpenAIModels) modelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels { for _, model := range models {
openAIModelsMap[model.Id] = model modelsMap[model.Id] = model
} }
channelId2Models = make(map[int][]string) channelId2Models = make(map[int][]string)
for i := 1; i < common.ChannelTypeDummy; i++ { for i := 1; i < channeltype.Dummy; i++ {
adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i)) adaptor := relay.GetAdaptor(channeltype.ToAPIType(i))
meta := &util.RelayMeta{ meta := &meta.Meta{
ChannelType: i, ChannelType: i,
} }
adaptor.Init(meta) adaptor.Init(meta)
@@ -121,6 +121,13 @@ func DashboardListModels(c *gin.Context) {
}) })
} }
func ListAllModels(c *gin.Context) {
c.JSON(200, gin.H{
"object": "list",
"data": models,
})
}
func ListModels(c *gin.Context) { func ListModels(c *gin.Context) {
ctx := c.Request.Context() ctx := c.Request.Context()
var availableModels []string var availableModels []string
@@ -136,7 +143,7 @@ func ListModels(c *gin.Context) {
modelSet[availableModel] = true modelSet[availableModel] = true
} }
availableOpenAIModels := make([]OpenAIModels, 0) availableOpenAIModels := make([]OpenAIModels, 0)
for _, model := range openAIModels { for _, model := range models {
if _, ok := modelSet[model.Id]; ok { if _, ok := modelSet[model.Id]; ok {
modelSet[model.Id] = false modelSet[model.Id] = false
availableOpenAIModels = append(availableOpenAIModels, model) availableOpenAIModels = append(availableOpenAIModels, model)
@@ -162,7 +169,7 @@ func ListModels(c *gin.Context) {
func RetrieveModel(c *gin.Context) { func RetrieveModel(c *gin.Context) {
modelId := c.Param("model") modelId := c.Param("model")
if model, ok := openAIModelsMap[modelId]; ok { if model, ok := modelsMap[modelId]; ok {
c.JSON(200, model) c.JSON(200, model)
} else { } else {
Error := relaymodel.Error{ Error := relaymodel.Error{

View File

@@ -4,6 +4,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"strconv" "strconv"
@@ -106,7 +107,7 @@ func AddRedemption(c *gin.Context) {
} }
var keys []string var keys []string
for i := 0; i < redemption.Count; i++ { for i := 0; i < redemption.Count; i++ {
key := helper.GetUUID() key := random.GetUUID()
cleanRedemption := model.Redemption{ cleanRedemption := model.Redemption{
UserId: c.GetInt("id"), UserId: c.GetInt("id"),
Name: redemption.Name, Name: redemption.Name,

View File

@@ -12,26 +12,25 @@ import (
"github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/middleware"
dbmodel "github.com/songquanpeng/one-api/model" dbmodel "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"net/http" "net/http"
) )
// https://platform.openai.com/docs/api-reference/chat // https://platform.openai.com/docs/api-reference/chat
func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
var err *model.ErrorWithStatusCode var err *model.ErrorWithStatusCode
switch relayMode { switch relayMode {
case constant.RelayModeImagesGenerations: case relaymode.ImagesGenerations:
err = controller.RelayImageHelper(c, relayMode) err = controller.RelayImageHelper(c, relayMode)
case constant.RelayModeAudioSpeech: case relaymode.AudioSpeech:
fallthrough fallthrough
case constant.RelayModeAudioTranslation: case relaymode.AudioTranslation:
fallthrough fallthrough
case constant.RelayModeAudioTranscription: case relaymode.AudioTranscription:
err = controller.RelayAudioHelper(c, relayMode) err = controller.RelayAudioHelper(c, relayMode)
default: default:
err = controller.RelayTextHelper(c) err = controller.RelayTextHelper(c)
@@ -41,13 +40,13 @@ func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
func Relay(c *gin.Context) { func Relay(c *gin.Context) {
ctx := c.Request.Context() ctx := c.Request.Context()
relayMode := constant.Path2RelayMode(c.Request.URL.Path) relayMode := relaymode.GetByPath(c.Request.URL.Path)
if config.DebugEnabled { if config.DebugEnabled {
requestBody, _ := common.GetRequestBody(c) requestBody, _ := common.GetRequestBody(c)
logger.Debugf(ctx, "request body: %s", string(requestBody)) logger.Debugf(ctx, "request body: %s", string(requestBody))
} }
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
bizErr := relay(c, relayMode) bizErr := relayHelper(c, relayMode)
if bizErr == nil { if bizErr == nil {
monitor.Emit(channelId, true) monitor.Emit(channelId, true)
return return
@@ -76,7 +75,7 @@ func Relay(c *gin.Context) {
middleware.SetupContextForSelectedChannel(c, channel, originalModel) middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c) requestBody, err := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
bizErr = relay(c, relayMode) bizErr = relayHelper(c, relayMode)
if bizErr == nil { if bizErr == nil {
return return
} }
@@ -118,7 +117,7 @@ func shouldRetry(c *gin.Context, statusCode int) bool {
func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) { func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) {
logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors // https://platform.openai.com/docs/guides/error-codes/api-errors
if util.ShouldDisableChannel(&err.Error, err.StatusCode) { if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
monitor.DisableChannel(channelId, channelName, err.Message) monitor.DisableChannel(channelId, channelName, err.Message)
} else { } else {
monitor.Emit(channelId, false) monitor.Emit(channelId, false)

View File

@@ -3,10 +3,10 @@ package controller
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/network" "github.com/songquanpeng/one-api/common/network"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"strconv" "strconv"
@@ -141,7 +141,7 @@ func AddToken(c *gin.Context) {
cleanToken := model.Token{ cleanToken := model.Token{
UserId: c.GetInt("id"), UserId: c.GetInt("id"),
Name: token.Name, Name: token.Name,
Key: helper.GenerateKey(), Key: random.GenerateKey(),
CreatedTime: helper.GetTimestamp(), CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(), AccessedTime: helper.GetTimestamp(),
ExpiredTime: token.ExpiredTime, ExpiredTime: token.ExpiredTime,
@@ -212,15 +212,15 @@ func UpdateToken(c *gin.Context) {
}) })
return return
} }
if token.Status == common.TokenStatusEnabled { if token.Status == model.TokenStatusEnabled {
if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { if cleanToken.Status == model.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
}) })
return return
} }
if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { if cleanToken.Status == model.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度", "message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度",

View File

@@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"strconv" "strconv"
@@ -239,7 +239,7 @@ func GetUser(c *gin.Context) {
return return
} }
myRole := c.GetInt("role") myRole := c.GetInt("role")
if myRole <= user.Role && myRole != common.RoleRootUser { if myRole <= user.Role && myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无权获取同级或更高等级用户的信息", "message": "无权获取同级或更高等级用户的信息",
@@ -287,7 +287,7 @@ func GenerateAccessToken(c *gin.Context) {
}) })
return return
} }
user.AccessToken = helper.GetUUID() user.AccessToken = random.GetUUID()
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -324,7 +324,7 @@ func GetAffCode(c *gin.Context) {
return return
} }
if user.AffCode == "" { if user.AffCode == "" {
user.AffCode = helper.GetRandomString(4) user.AffCode = random.GetRandomString(4)
if err := user.Update(false); err != nil { if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -388,14 +388,14 @@ func UpdateUser(c *gin.Context) {
return return
} }
myRole := c.GetInt("role") myRole := c.GetInt("role")
if myRole <= originUser.Role && myRole != common.RoleRootUser { if myRole <= originUser.Role && myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无权更新同权限等级或更高权限等级的用户信息", "message": "无权更新同权限等级或更高权限等级的用户信息",
}) })
return return
} }
if myRole <= updatedUser.Role && myRole != common.RoleRootUser { if myRole <= updatedUser.Role && myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无权将其他用户权限等级提升到大于等于自己的权限等级", "message": "无权将其他用户权限等级提升到大于等于自己的权限等级",
@@ -509,7 +509,7 @@ func DeleteSelf(c *gin.Context) {
id := c.GetInt("id") id := c.GetInt("id")
user, _ := model.GetUserById(id, false) user, _ := model.GetUserById(id, false)
if user.Role == common.RoleRootUser { if user.Role == model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "不能删除超级管理员账户", "message": "不能删除超级管理员账户",
@@ -611,7 +611,7 @@ func ManageUser(c *gin.Context) {
return return
} }
myRole := c.GetInt("role") myRole := c.GetInt("role")
if myRole <= user.Role && myRole != common.RoleRootUser { if myRole <= user.Role && myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无权更新同权限等级或更高权限等级的用户信息", "message": "无权更新同权限等级或更高权限等级的用户信息",
@@ -620,8 +620,8 @@ func ManageUser(c *gin.Context) {
} }
switch req.Action { switch req.Action {
case "disable": case "disable":
user.Status = common.UserStatusDisabled user.Status = model.UserStatusDisabled
if user.Role == common.RoleRootUser { if user.Role == model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法禁用超级管理员用户", "message": "无法禁用超级管理员用户",
@@ -629,9 +629,9 @@ func ManageUser(c *gin.Context) {
return return
} }
case "enable": case "enable":
user.Status = common.UserStatusEnabled user.Status = model.UserStatusEnabled
case "delete": case "delete":
if user.Role == common.RoleRootUser { if user.Role == model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法删除超级管理员用户", "message": "无法删除超级管理员用户",
@@ -646,37 +646,37 @@ func ManageUser(c *gin.Context) {
return return
} }
case "promote": case "promote":
if myRole != common.RoleRootUser { if myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "普通管理员用户无法提升其他用户为管理员", "message": "普通管理员用户无法提升其他用户为管理员",
}) })
return return
} }
if user.Role >= common.RoleAdminUser { if user.Role >= model.RoleAdminUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "该用户已经是管理员", "message": "该用户已经是管理员",
}) })
return return
} }
user.Role = common.RoleAdminUser user.Role = model.RoleAdminUser
case "demote": case "demote":
if user.Role == common.RoleRootUser { if user.Role == model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法降级超级管理员用户", "message": "无法降级超级管理员用户",
}) })
return return
} }
if user.Role == common.RoleCommonUser { if user.Role == model.RoleCommonUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "该用户已经是普通用户", "message": "该用户已经是普通用户",
}) })
return return
} }
user.Role = common.RoleCommonUser user.Role = model.RoleCommonUser
} }
if err := user.Update(false); err != nil { if err := user.Update(false); err != nil {
@@ -730,7 +730,7 @@ func EmailBind(c *gin.Context) {
}) })
return return
} }
if user.Role == common.RoleRootUser { if user.Role == model.RoleRootUser {
config.RootUserEmail = email config.RootUserEmail = email
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -12,7 +12,7 @@ import (
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/router" "github.com/songquanpeng/one-api/router"
"os" "os"
"strconv" "strconv"

View File

@@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/network" "github.com/songquanpeng/one-api/common/network"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
@@ -45,7 +44,7 @@ func authHelper(c *gin.Context, minRole int) {
return return
} }
} }
if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) { if status.(int) == model.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "用户已被封禁", "message": "用户已被封禁",
@@ -72,19 +71,19 @@ func authHelper(c *gin.Context, minRole int) {
func UserAuth() func(c *gin.Context) { func UserAuth() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
authHelper(c, common.RoleCommonUser) authHelper(c, model.RoleCommonUser)
} }
} }
func AdminAuth() func(c *gin.Context) { func AdminAuth() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
authHelper(c, common.RoleAdminUser) authHelper(c, model.RoleAdminUser)
} }
} }
func RootAuth() func(c *gin.Context) { func RootAuth() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
authHelper(c, common.RoleRootUser) authHelper(c, model.RoleRootUser)
} }
} }
@@ -117,7 +116,7 @@ func TokenAuth() func(c *gin.Context) {
return return
} }
requestModel, err := getRequestModel(c) requestModel, err := getRequestModel(c)
if err != nil && !strings.HasPrefix(c.Request.URL.Path, "/v1/models") { if err != nil && shouldCheckModel(c) {
abortWithMessage(c, http.StatusBadRequest, err.Error()) abortWithMessage(c, http.StatusBadRequest, err.Error())
return return
} }
@@ -143,3 +142,19 @@ func TokenAuth() func(c *gin.Context) {
c.Next() c.Next()
} }
} }
func shouldCheckModel(c *gin.Context) bool {
if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
return true
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
return true
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images") {
return true
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
return true
}
return false
}

View File

@@ -3,9 +3,10 @@ package middleware
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channeltype"
"net/http" "net/http"
"strconv" "strconv"
) )
@@ -33,7 +34,7 @@ func Distribute() func(c *gin.Context) {
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return return
} }
if channel.Status != common.ChannelStatusEnabled { if channel.Status != model.ChannelStatusEnabled {
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
return return
} }
@@ -66,19 +67,19 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
// this is for backward compatibility // this is for backward compatibility
switch channel.Type { switch channel.Type {
case common.ChannelTypeAzure: case channeltype.Azure:
c.Set(common.ConfigKeyAPIVersion, channel.Other) c.Set(config.KeyAPIVersion, channel.Other)
case common.ChannelTypeXunfei: case channeltype.Xunfei:
c.Set(common.ConfigKeyAPIVersion, channel.Other) c.Set(config.KeyAPIVersion, channel.Other)
case common.ChannelTypeGemini: case channeltype.Gemini:
c.Set(common.ConfigKeyAPIVersion, channel.Other) c.Set(config.KeyAPIVersion, channel.Other)
case common.ChannelTypeAIProxyLibrary: case channeltype.AIProxyLibrary:
c.Set(common.ConfigKeyLibraryID, channel.Other) c.Set(config.KeyLibraryID, channel.Other)
case common.ChannelTypeAli: case channeltype.Ali:
c.Set(common.ConfigKeyPlugin, channel.Other) c.Set(config.KeyPlugin, channel.Other)
} }
cfg, _ := channel.LoadConfig() cfg, _ := channel.LoadConfig()
for k, v := range cfg { for k, v := range cfg {
c.Set(common.ConfigKeyPrefix+k, v) c.Set(config.KeyPrefix+k, v)
} }
} }

View File

@@ -57,7 +57,7 @@ func (channel *Channel) AddAbilities() error {
Group: group, Group: group,
Model: model, Model: model,
ChannelId: channel.Id, ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled, Enabled: channel.Status == ChannelStatusEnabled,
Priority: channel.Priority, Priority: channel.Priority,
} }
abilities = append(abilities, ability) abilities = append(abilities, ability)

View File

@@ -8,6 +8,7 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"math/rand" "math/rand"
"sort" "sort"
"strconv" "strconv"
@@ -172,7 +173,7 @@ var channelSyncLock sync.RWMutex
func InitChannelCache() { func InitChannelCache() {
newChannelId2channel := make(map[int]*Channel) newChannelId2channel := make(map[int]*Channel)
var channels []*Channel var channels []*Channel
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) DB.Where("status = ?", ChannelStatusEnabled).Find(&channels)
for _, channel := range channels { for _, channel := range channels {
newChannelId2channel[channel.Id] = channel newChannelId2channel[channel.Id] = channel
} }
@@ -247,7 +248,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPrior
idx := rand.Intn(endIdx) idx := rand.Intn(endIdx)
if ignoreFirstPriority { if ignoreFirstPriority {
if endIdx < len(channels) { // which means there are more than one priority if endIdx < len(channels) { // which means there are more than one priority
idx = common.RandRange(endIdx, len(channels)) idx = random.RandRange(endIdx, len(channels))
} }
} }
return channels[idx], nil return channels[idx], nil

View File

@@ -3,13 +3,19 @@ package model
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm" "gorm.io/gorm"
) )
const (
ChannelStatusUnknown = 0
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
ChannelStatusManuallyDisabled = 2 // also don't use 0
ChannelStatusAutoDisabled = 3
)
type Channel struct { type Channel struct {
Id int `json:"id"` Id int `json:"id"`
Type int `json:"type" gorm:"default:0"` Type int `json:"type" gorm:"default:0"`
@@ -39,7 +45,7 @@ func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) {
case "all": case "all":
err = DB.Order("id desc").Find(&channels).Error err = DB.Order("id desc").Find(&channels).Error
case "disabled": case "disabled":
err = DB.Order("id desc").Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Find(&channels).Error err = DB.Order("id desc").Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Find(&channels).Error
default: default:
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
} }
@@ -168,7 +174,7 @@ func (channel *Channel) LoadConfig() (map[string]string, error) {
} }
func UpdateChannelStatusById(id int, status int) { func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) err := UpdateAbilityStatus(id, status == ChannelStatusEnabled)
if err != nil { if err != nil {
logger.SysError("failed to update ability status: " + err.Error()) logger.SysError("failed to update ability status: " + err.Error())
} }
@@ -199,6 +205,6 @@ func DeleteChannelByStatus(status int64) (int64, error) {
} }
func DeleteDisabledChannel() (int64, error) { func DeleteDisabledChannel() (int64, error) {
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) result := DB.Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Delete(&Channel{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }

View File

@@ -7,7 +7,6 @@ import (
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm" "gorm.io/gorm"
) )

View File

@@ -7,6 +7,7 @@ import (
"github.com/songquanpeng/one-api/common/env" "github.com/songquanpeng/one-api/common/env"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"gorm.io/driver/mysql" "gorm.io/driver/mysql"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
@@ -31,10 +32,10 @@ func CreateRootAccountIfNeed() error {
rootUser := User{ rootUser := User{
Username: "root", Username: "root",
Password: hashedPassword, Password: hashedPassword,
Role: common.RoleRootUser, Role: RoleRootUser,
Status: common.UserStatusEnabled, Status: UserStatusEnabled,
DisplayName: "Root User", DisplayName: "Root User",
AccessToken: helper.GetUUID(), AccessToken: random.GetUUID(),
Quota: 500000000000000, Quota: 500000000000000,
} }
DB.Create(&rootUser) DB.Create(&rootUser)
@@ -44,7 +45,7 @@ func CreateRootAccountIfNeed() error {
Id: 1, Id: 1,
UserId: rootUser.Id, UserId: rootUser.Id,
Key: config.InitialRootToken, Key: config.InitialRootToken,
Status: common.TokenStatusEnabled, Status: TokenStatusEnabled,
Name: "Initial Root Token", Name: "Initial Root Token",
CreatedTime: helper.GetTimestamp(), CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(), AccessedTime: helper.GetTimestamp(),

View File

@@ -1,9 +1,9 @@
package model package model
import ( import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -66,9 +66,9 @@ func InitOptionMap() {
config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10) config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10)
config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10) config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10)
config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10) config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10)
config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() config.OptionMap["ModelRatio"] = billingratio.ModelRatio2JSONString()
config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() config.OptionMap["GroupRatio"] = billingratio.GroupRatio2JSONString()
config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() config.OptionMap["CompletionRatio"] = billingratio.CompletionRatio2JSONString()
config.OptionMap["TopUpLink"] = config.TopUpLink config.OptionMap["TopUpLink"] = config.TopUpLink
config.OptionMap["ChatLink"] = config.ChatLink config.OptionMap["ChatLink"] = config.ChatLink
config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64) config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64)
@@ -82,7 +82,7 @@ func loadOptionsFromDatabase() {
options, _ := AllOption() options, _ := AllOption()
for _, option := range options { for _, option := range options {
if option.Key == "ModelRatio" { if option.Key == "ModelRatio" {
option.Value = common.AddNewMissingRatio(option.Value) option.Value = billingratio.AddNewMissingRatio(option.Value)
} }
err := updateOptionMap(option.Key, option.Value) err := updateOptionMap(option.Key, option.Value)
if err != nil { if err != nil {
@@ -209,11 +209,11 @@ func updateOptionMap(key string, value string) (err error) {
case "RetryTimes": case "RetryTimes":
config.RetryTimes, _ = strconv.Atoi(value) config.RetryTimes, _ = strconv.Atoi(value)
case "ModelRatio": case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value) err = billingratio.UpdateModelRatioByJSONString(value)
case "GroupRatio": case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value) err = billingratio.UpdateGroupRatioByJSONString(value)
case "CompletionRatio": case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value) err = billingratio.UpdateCompletionRatioByJSONString(value)
case "TopUpLink": case "TopUpLink":
config.TopUpLink = value config.TopUpLink = value
case "ChatLink": case "ChatLink":

View File

@@ -8,6 +8,12 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
const (
RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value!
RedemptionCodeStatusDisabled = 2 // also don't use 0
RedemptionCodeStatusUsed = 3 // also don't use 0
)
type Redemption struct { type Redemption struct {
Id int `json:"id"` Id int `json:"id"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
@@ -61,7 +67,7 @@ func Redeem(key string, userId int) (quota int64, err error) {
if err != nil { if err != nil {
return errors.New("无效的兑换码") return errors.New("无效的兑换码")
} }
if redemption.Status != common.RedemptionCodeStatusEnabled { if redemption.Status != RedemptionCodeStatusEnabled {
return errors.New("该兑换码已被使用") return errors.New("该兑换码已被使用")
} }
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
@@ -69,7 +75,7 @@ func Redeem(key string, userId int) (quota int64, err error) {
return err return err
} }
redemption.RedeemedTime = helper.GetTimestamp() redemption.RedeemedTime = helper.GetTimestamp()
redemption.Status = common.RedemptionCodeStatusUsed redemption.Status = RedemptionCodeStatusUsed
err = tx.Save(redemption).Error err = tx.Save(redemption).Error
return err return err
}) })

View File

@@ -11,6 +11,13 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
const (
TokenStatusEnabled = 1 // don't use 0, 0 is the default value!
TokenStatusDisabled = 2 // also don't use 0
TokenStatusExpired = 3
TokenStatusExhausted = 4
)
type Token struct { type Token struct {
Id int `json:"id"` Id int `json:"id"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
@@ -62,17 +69,17 @@ func ValidateUserToken(key string) (token *Token, err error) {
} }
return nil, errors.New("令牌验证失败") return nil, errors.New("令牌验证失败")
} }
if token.Status == common.TokenStatusExhausted { if token.Status == TokenStatusExhausted {
return nil, fmt.Errorf("令牌 %s#%d额度已用尽", token.Name, token.Id) return nil, fmt.Errorf("令牌 %s#%d额度已用尽", token.Name, token.Id)
} else if token.Status == common.TokenStatusExpired { } else if token.Status == TokenStatusExpired {
return nil, errors.New("该令牌已过期") return nil, errors.New("该令牌已过期")
} }
if token.Status != common.TokenStatusEnabled { if token.Status != TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用") return nil, errors.New("该令牌状态不可用")
} }
if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() { if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() {
if !common.RedisEnabled { if !common.RedisEnabled {
token.Status = common.TokenStatusExpired token.Status = TokenStatusExpired
err := token.SelectUpdate() err := token.SelectUpdate()
if err != nil { if err != nil {
logger.SysError("failed to update token status" + err.Error()) logger.SysError("failed to update token status" + err.Error())
@@ -83,7 +90,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
if !token.UnlimitedQuota && token.RemainQuota <= 0 { if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled { if !common.RedisEnabled {
// in this case, we can make sure the token is exhausted // in this case, we can make sure the token is exhausted
token.Status = common.TokenStatusExhausted token.Status = TokenStatusExhausted
err := token.SelectUpdate() err := token.SelectUpdate()
if err != nil { if err != nil {
logger.SysError("failed to update token status" + err.Error()) logger.SysError("failed to update token status" + err.Error())

View File

@@ -6,12 +6,25 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"gorm.io/gorm" "gorm.io/gorm"
"strings" "strings"
) )
const (
RoleGuestUser = 0
RoleCommonUser = 1
RoleAdminUser = 10
RoleRootUser = 100
)
const (
UserStatusEnabled = 1 // don't use 0, 0 is the default value!
UserStatusDisabled = 2 // also don't use 0
UserStatusDeleted = 3
)
// User if you add sensitive fields, don't forget to clean them in setupLogin function. // User if you add sensitive fields, don't forget to clean them in setupLogin function.
// Otherwise, the sensitive information will be saved on local storage in plain text! // Otherwise, the sensitive information will be saved on local storage in plain text!
type User struct { type User struct {
@@ -42,7 +55,7 @@ func GetMaxUserId() int {
} }
func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) { func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) {
query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted) query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", UserStatusDeleted)
switch order { switch order {
case "quota": case "quota":
@@ -108,8 +121,8 @@ func (user *User) Insert(inviterId int) error {
} }
} }
user.Quota = config.QuotaForNewUser user.Quota = config.QuotaForNewUser
user.AccessToken = helper.GetUUID() user.AccessToken = random.GetUUID()
user.AffCode = helper.GetRandomString(4) user.AffCode = random.GetRandomString(4)
result := DB.Create(user) result := DB.Create(user)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
@@ -138,9 +151,9 @@ func (user *User) Update(updatePassword bool) error {
return err return err
} }
} }
if user.Status == common.UserStatusDisabled { if user.Status == UserStatusDisabled {
blacklist.BanUser(user.Id) blacklist.BanUser(user.Id)
} else if user.Status == common.UserStatusEnabled { } else if user.Status == UserStatusEnabled {
blacklist.UnbanUser(user.Id) blacklist.UnbanUser(user.Id)
} }
err = DB.Model(user).Updates(user).Error err = DB.Model(user).Updates(user).Error
@@ -152,8 +165,8 @@ func (user *User) Delete() error {
return errors.New("id 为空!") return errors.New("id 为空!")
} }
blacklist.BanUser(user.Id) blacklist.BanUser(user.Id)
user.Username = fmt.Sprintf("deleted_%s", helper.GetUUID()) user.Username = fmt.Sprintf("deleted_%s", random.GetUUID())
user.Status = common.UserStatusDeleted user.Status = UserStatusDeleted
err := DB.Model(user).Updates(user).Error err := DB.Model(user).Updates(user).Error
return err return err
} }
@@ -177,7 +190,7 @@ func (user *User) ValidateAndFill() (err error) {
} }
} }
okay := common.ValidatePasswordAndHash(password, user.Password) okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled { if !okay || user.Status != UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁") return errors.New("用户名或密码错误,或用户已被封禁")
} }
return nil return nil
@@ -273,7 +286,7 @@ func IsAdmin(userId int) bool {
logger.SysError("no such user " + err.Error()) logger.SysError("no such user " + err.Error())
return false return false
} }
return user.Role >= common.RoleAdminUser return user.Role >= RoleAdminUser
} }
func IsUserEnabled(userId int) (bool, error) { func IsUserEnabled(userId int) (bool, error) {
@@ -285,7 +298,7 @@ func IsUserEnabled(userId int) (bool, error) {
if err != nil { if err != nil {
return false, err return false, err
} }
return user.Status == common.UserStatusEnabled, nil return user.Status == UserStatusEnabled, nil
} }
func ValidateAccessToken(token string) (user *User) { func ValidateAccessToken(token string) (user *User) {
@@ -358,7 +371,7 @@ func decreaseUserQuota(id int, quota int64) (err error) {
} }
func GetRootUserEmail() (email string) { func GetRootUserEmail() (email string) {
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) DB.Model(&User{}).Where("role = ?", RoleRootUser).Select("email").Find(&email)
return email return email
} }

View File

@@ -2,7 +2,6 @@ package monitor
import ( import (
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message" "github.com/songquanpeng/one-api/common/message"
@@ -29,7 +28,7 @@ func notifyRootUser(subject string, content string) {
// DisableChannel disable & notify // DisableChannel disable & notify
func DisableChannel(channelId int, channelName string, reason string) { func DisableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason)) logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason))
subject := fmt.Sprintf("渠道「%s」#%d已被禁用", channelName, channelId) subject := fmt.Sprintf("渠道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("渠道「%s」#%d已被禁用原因%s", channelName, channelId, reason) content := fmt.Sprintf("渠道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
@@ -37,7 +36,7 @@ func DisableChannel(channelId int, channelName string, reason string) {
} }
func MetricDisableChannel(channelId int, successRate float64) { func MetricDisableChannel(channelId int, successRate float64) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100)) logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100))
subject := fmt.Sprintf("渠道 #%d 已被禁用", channelId) subject := fmt.Sprintf("渠道 #%d 已被禁用", channelId)
content := fmt.Sprintf("该渠道(#%d在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。", content := fmt.Sprintf("该渠道(#%d在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。",
@@ -47,7 +46,7 @@ func MetricDisableChannel(channelId int, successRate float64) {
// EnableChannel enable & notify // EnableChannel enable & notify
func EnableChannel(channelId int, channelName string) { func EnableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) model.UpdateChannelStatusById(channelId, model.ChannelStatusEnabled)
logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId)) logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId))
subject := fmt.Sprintf("渠道「%s」#%d已被启用", channelName, channelId) subject := fmt.Sprintf("渠道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("渠道「%s」#%d已被启用", channelName, channelId) content := fmt.Sprintf("渠道「%s」#%d已被启用", channelName, channelId)

62
monitor/manage.go Normal file
View File

@@ -0,0 +1,62 @@
package monitor
import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/model"
"net/http"
"strings"
)
func ShouldDisableChannel(err *model.Error, statusCode int) bool {
if !config.AutomaticDisableChannelEnabled {
return false
}
if err == nil {
return false
}
if statusCode == http.StatusUnauthorized {
return true
}
switch err.Type {
case "insufficient_quota":
return true
// https://docs.anthropic.com/claude/reference/errors
case "authentication_error":
return true
case "permission_error":
return true
case "forbidden":
return true
}
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
return true
}
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
return true
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
return true
}
//if strings.Contains(err.Message, "quota") {
// return true
//}
if strings.Contains(err.Message, "credit") {
return true
}
if strings.Contains(err.Message, "balance") {
return true
}
return false
}
func ShouldEnableChannel(err error, openAIErr *model.Error) bool {
if !config.AutomaticEnableChannelEnabled {
return false
}
if err != nil {
return false
}
if openAIErr != nil {
return false
}
return true
}

45
relay/adaptor.go Normal file
View File

@@ -0,0 +1,45 @@
package relay
import (
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/aiproxy"
"github.com/songquanpeng/one-api/relay/adaptor/ali"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/baidu"
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
"github.com/songquanpeng/one-api/relay/adaptor/ollama"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/adaptor/palm"
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
"github.com/songquanpeng/one-api/relay/adaptor/zhipu"
"github.com/songquanpeng/one-api/relay/apitype"
)
func GetAdaptor(apiType int) adaptor.Adaptor {
switch apiType {
case apitype.AIProxyLibrary:
return &aiproxy.Adaptor{}
case apitype.Ali:
return &ali.Adaptor{}
case apitype.Anthropic:
return &anthropic.Adaptor{}
case apitype.Baidu:
return &baidu.Adaptor{}
case apitype.Gemini:
return &gemini.Adaptor{}
case apitype.OpenAI:
return &openai.Adaptor{}
case apitype.PaLM:
return &palm.Adaptor{}
case apitype.Tencent:
return &tencent.Adaptor{}
case apitype.Xunfei:
return &xunfei.Adaptor{}
case apitype.Zhipu:
return &zhipu.Adaptor{}
case apitype.Ollama:
return &ollama.Adaptor{}
}
return nil
}

View File

@@ -4,10 +4,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
) )
@@ -15,16 +15,16 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey) req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil return nil
} }
@@ -34,7 +34,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
aiProxyLibraryRequest := ConvertRequest(*request) aiProxyLibraryRequest := ConvertRequest(*request)
aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID) aiProxyLibraryRequest.LibraryId = c.GetString(config.KeyLibraryID)
return aiProxyLibraryRequest, nil return aiProxyLibraryRequest, nil
} }
@@ -45,11 +45,11 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
return request, nil return request, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody) return adaptor.DoRequestHelper(a, c, meta, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
err, usage = StreamHandler(c, resp) err, usage = StreamHandler(c, resp)
} else { } else {

View File

@@ -1,6 +1,6 @@
package aiproxy package aiproxy
import "github.com/songquanpeng/one-api/relay/channel/openai" import "github.com/songquanpeng/one-api/relay/adaptor/openai"
var ModelList = []string{""} var ModelList = []string{""}

View File

@@ -8,7 +8,8 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "io"
@@ -53,7 +54,7 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon
FinishReason: "stop", FinishReason: "stop",
} }
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice}, Choices: []openai.TextResponseChoice{choice},
@@ -66,7 +67,7 @@ func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletion
choice.Delta.Content = aiProxyDocuments2Markdown(documents) choice.Delta.Content = aiProxyDocuments2Markdown(documents)
choice.FinishReason = &constant.StopFinishReason choice.FinishReason = &constant.StopFinishReason
return &openai.ChatCompletionsStreamResponse{ return &openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Model: "", Model: "",
@@ -78,7 +79,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
var choice openai.ChatCompletionsStreamResponseChoice var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content choice.Delta.Content = response.Content
return &openai.ChatCompletionsStreamResponse{ return &openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Model: response.Model, Model: response.Model,

View File

@@ -4,11 +4,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"net/http" "net/http"
) )
@@ -18,16 +18,16 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
fullRequestURL := "" fullRequestURL := ""
switch meta.Mode { switch meta.Mode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL) fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
case constant.RelayModeImagesGenerations: case relaymode.ImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL) fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL)
default: default:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL) fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
@@ -36,19 +36,19 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fullRequestURL, nil return fullRequestURL, nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
if meta.IsStream { if meta.IsStream {
req.Header.Set("Accept", "text/event-stream") req.Header.Set("Accept", "text/event-stream")
req.Header.Set("X-DashScope-SSE", "enable") req.Header.Set("X-DashScope-SSE", "enable")
} }
req.Header.Set("Authorization", "Bearer "+meta.APIKey) req.Header.Set("Authorization", "Bearer "+meta.APIKey)
if meta.Mode == constant.RelayModeImagesGenerations { if meta.Mode == relaymode.ImagesGenerations {
req.Header.Set("X-DashScope-Async", "enable") req.Header.Set("X-DashScope-Async", "enable")
} }
if c.GetString(common.ConfigKeyPlugin) != "" { if c.GetString(config.KeyPlugin) != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin)) req.Header.Set("X-DashScope-Plugin", c.GetString(config.KeyPlugin))
} }
return nil return nil
} }
@@ -58,7 +58,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch relayMode { switch relayMode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
aliEmbeddingRequest := ConvertEmbeddingRequest(*request) aliEmbeddingRequest := ConvertEmbeddingRequest(*request)
return aliEmbeddingRequest, nil return aliEmbeddingRequest, nil
default: default:
@@ -76,18 +76,18 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
return aliRequest, nil return aliRequest, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody) return adaptor.DoRequestHelper(a, c, meta, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
err, usage = StreamHandler(c, resp) err, usage = StreamHandler(c, resp)
} else { } else {
switch meta.Mode { switch meta.Mode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp) err, usage = EmbeddingHandler(c, resp)
case constant.RelayModeImagesGenerations: case relaymode.ImagesGenerations:
err, usage = ImageHandler(c, resp) err, usage = ImageHandler(c, resp)
default: default:
err, usage = Handler(c, resp) err, usage = Handler(c, resp)

View File

@@ -8,7 +8,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "io"
"net/http" "net/http"

View File

@@ -7,7 +7,7 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "io"
"net/http" "net/http"

View File

@@ -1,7 +1,7 @@
package ali package ali
import ( import (
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
) )

View File

@@ -4,9 +4,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
) )
@@ -14,16 +14,16 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-api-key", meta.APIKey) req.Header.Set("x-api-key", meta.APIKey)
anthropicVersion := c.Request.Header.Get("anthropic-version") anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" { if anthropicVersion == "" {
@@ -48,11 +48,11 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
return request, nil return request, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody) return adaptor.DoRequestHelper(a, c, meta, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
err, usage = StreamHandler(c, resp) err, usage = StreamHandler(c, resp)
} else { } else {

View File

@@ -9,7 +9,7 @@ import (
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "io"
"net/http" "net/http"

View File

@@ -0,0 +1,15 @@
package azure
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
)
func GetAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString(config.KeyAPIVersion)
}
return apiVersion
}

View File

@@ -3,25 +3,25 @@ package baidu
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
) )
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix := "chat/" suffix := "chat/"
if strings.HasPrefix(meta.ActualModelName, "Embedding") { if strings.HasPrefix(meta.ActualModelName, "Embedding") {
@@ -89,8 +89,8 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fullRequestURL, nil return fullRequestURL, nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey) req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil return nil
} }
@@ -100,7 +100,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch relayMode { switch relayMode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil return baiduEmbeddingRequest, nil
default: default:
@@ -116,16 +116,16 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
return request, nil return request, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody) return adaptor.DoRequestHelper(a, c, meta, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
err, usage = StreamHandler(c, resp) err, usage = StreamHandler(c, resp)
} else { } else {
switch meta.Mode { switch meta.Mode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp) err, usage = EmbeddingHandler(c, resp)
default: default:
err, usage = Handler(c, resp) err, usage = Handler(c, resp)

View File

@@ -8,10 +8,10 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/client"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@@ -305,7 +305,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) {
} }
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json") req.Header.Add("Accept", "application/json")
res, err := util.ImpatientHTTPClient.Do(req) res, err := client.ImpatientHTTPClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1,15 +1,16 @@
package channel package adaptor
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/client"
"github.com/songquanpeng/one-api/relay/meta"
"io" "io"
"net/http" "net/http"
) )
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) { func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if meta.IsStream && c.Request.Header.Get("Accept") == "" { if meta.IsStream && c.Request.Header.Get("Accept") == "" {
@@ -17,7 +18,7 @@ func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.Rela
} }
} }
func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(meta) fullRequestURL, err := a.GetRequestURL(meta)
if err != nil { if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err) return nil, fmt.Errorf("get request url failed: %w", err)
@@ -38,7 +39,7 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBod
} }
func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) { func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := util.HTTPClient.Do(req) resp, err := client.HTTPClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -5,10 +5,10 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
channelhelper "github.com/songquanpeng/one-api/relay/channel" channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
) )
@@ -16,11 +16,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
version := helper.AssignOrDefault(meta.APIVersion, "v1") version := helper.AssignOrDefault(meta.APIVersion, "v1")
action := "generateContent" action := "generateContent"
if meta.IsStream { if meta.IsStream {
@@ -29,7 +29,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channelhelper.SetupCommonRequestHeader(c, req, meta) channelhelper.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-goog-api-key", meta.APIKey) req.Header.Set("x-goog-api-key", meta.APIKey)
return nil return nil
@@ -49,11 +49,11 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
return request, nil return request, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channelhelper.DoRequestHelper(a, c, meta, requestBody) return channelhelper.DoRequestHelper(a, c, meta, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
var responseText string var responseText string
err, responseText = StreamHandler(c, resp) err, responseText = StreamHandler(c, resp)

View File

@@ -9,7 +9,8 @@ import (
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "io"
@@ -155,7 +156,7 @@ type ChatPromptFeedback struct {
func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
@@ -233,7 +234,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
var choice openai.ChatCompletionsStreamResponseChoice var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content choice.Delta.Content = dummy.Content
response := openai.ChatCompletionsStreamResponse{ response := openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Model: "gemini-pro", Model: "gemini-pro",

View File

@@ -0,0 +1,21 @@
package adaptor
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
type Adaptor interface {
Init(meta *meta.Meta)
GetRequestURL(meta *meta.Meta) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
ConvertImageRequest(request *model.ImageRequest) (any, error)
DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
GetModelList() []string
GetChannelName() string
}

View File

@@ -0,0 +1,14 @@
package minimax
import (
"fmt"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/relaymode"
)
func GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Mode == relaymode.ChatCompletions {
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil
}
return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode)
}

View File

@@ -3,34 +3,34 @@ package ollama
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
) )
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
// https://github.com/ollama/ollama/blob/main/docs/api.md // https://github.com/ollama/ollama/blob/main/docs/api.md
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
if meta.Mode == constant.RelayModeEmbeddings { if meta.Mode == relaymode.Embeddings {
fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL) fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL)
} }
return fullRequestURL, nil return fullRequestURL, nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey) req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil return nil
} }
@@ -40,7 +40,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch relayMode { switch relayMode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request) ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request)
return ollamaEmbeddingRequest, nil return ollamaEmbeddingRequest, nil
default: default:
@@ -55,16 +55,16 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
return request, nil return request, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody) return adaptor.DoRequestHelper(a, c, meta, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
err, usage = StreamHandler(c, resp) err, usage = StreamHandler(c, resp)
} else { } else {
switch meta.Mode { switch meta.Mode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp) err, usage = EmbeddingHandler(c, resp)
default: default:
err, usage = Handler(c, resp) err, usage = Handler(c, resp)

View File

@@ -5,15 +5,16 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/random"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
) )
@@ -51,7 +52,7 @@ func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse {
choice.FinishReason = "stop" choice.FinishReason = "stop"
} }
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice}, Choices: []openai.TextResponseChoice{choice},
@@ -72,7 +73,7 @@ func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompl
choice.FinishReason = &constant.StopFinishReason choice.FinishReason = &constant.StopFinishReason
} }
response := openai.ChatCompletionsStreamResponse{ response := openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Model: ollamaResponse.Model, Model: ollamaResponse.Model,

View File

@@ -4,12 +4,12 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor/minimax"
"github.com/songquanpeng/one-api/relay/channel/minimax" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@@ -19,14 +19,14 @@ type Adaptor struct {
ChannelType int ChannelType int
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
a.ChannelType = meta.ChannelType a.ChannelType = meta.ChannelType
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
switch meta.ChannelType { switch meta.ChannelType {
case common.ChannelTypeAzure: case channeltype.Azure:
if meta.Mode == constant.RelayModeImagesGenerations { if meta.Mode == relaymode.ImagesGenerations {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.APIVersion) fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.APIVersion)
@@ -42,22 +42,22 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
//https://github.com/songquanpeng/one-api/issues/1191 //https://github.com/songquanpeng/one-api/issues/1191
// {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version} // {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version}
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil return GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
case common.ChannelTypeMinimax: case channeltype.Minimax:
return minimax.GetRequestURL(meta) return minimax.GetRequestURL(meta)
default: default:
return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
} }
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
if meta.ChannelType == common.ChannelTypeAzure { if meta.ChannelType == channeltype.Azure {
req.Header.Set("api-key", meta.APIKey) req.Header.Set("api-key", meta.APIKey)
return nil return nil
} }
req.Header.Set("Authorization", "Bearer "+meta.APIKey) req.Header.Set("Authorization", "Bearer "+meta.APIKey)
if meta.ChannelType == common.ChannelTypeOpenRouter { if meta.ChannelType == channeltype.OpenRouter {
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
req.Header.Set("X-Title", "One API") req.Header.Set("X-Title", "One API")
} }
@@ -78,11 +78,11 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
return request, nil return request, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody) return adaptor.DoRequestHelper(a, c, meta, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
var responseText string var responseText string
err, responseText, usage = StreamHandler(c, resp, meta.Mode) err, responseText, usage = StreamHandler(c, resp, meta.Mode)
@@ -91,7 +91,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
} }
} else { } else {
switch meta.Mode { switch meta.Mode {
case constant.RelayModeImagesGenerations: case relaymode.ImagesGenerations:
err, _ = ImageHandler(c, resp) err, _ = ImageHandler(c, resp)
default: default:
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)

View File

@@ -0,0 +1,50 @@
package openai
import (
"github.com/songquanpeng/one-api/relay/adaptor/ai360"
"github.com/songquanpeng/one-api/relay/adaptor/baichuan"
"github.com/songquanpeng/one-api/relay/adaptor/groq"
"github.com/songquanpeng/one-api/relay/adaptor/lingyiwanwu"
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
"github.com/songquanpeng/one-api/relay/adaptor/mistral"
"github.com/songquanpeng/one-api/relay/adaptor/moonshot"
"github.com/songquanpeng/one-api/relay/adaptor/stepfun"
"github.com/songquanpeng/one-api/relay/channeltype"
)
var CompatibleChannels = []int{
channeltype.Azure,
channeltype.AI360,
channeltype.Moonshot,
channeltype.Baichuan,
channeltype.Minimax,
channeltype.Mistral,
channeltype.Groq,
channeltype.LingYiWanWu,
channeltype.StepFun,
}
func GetCompatibleChannelMeta(channelType int) (string, []string) {
switch channelType {
case channeltype.Azure:
return "azure", ModelList
case channeltype.AI360:
return "360", ai360.ModelList
case channeltype.Moonshot:
return "moonshot", moonshot.ModelList
case channeltype.Baichuan:
return "baichuan", baichuan.ModelList
case channeltype.Minimax:
return "minimax", minimax.ModelList
case channeltype.Mistral:
return "mistralai", mistral.ModelList
case channeltype.Groq:
return "groq", groq.ModelList
case channeltype.LingYiWanWu:
return "lingyiwanwu", lingyiwanwu.ModelList
case channeltype.StepFun:
return "stepfun", stepfun.ModelList
default:
return "openai", ModelList
}
}

View File

@@ -0,0 +1,30 @@
package openai
import (
"fmt"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/model"
"strings"
)
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
usage := &model.Usage{}
usage.PromptTokens = promptTokens
usage.CompletionTokens = CountTokenText(responseText, modeName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage
}
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case channeltype.OpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case channeltype.Azure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}
return fullRequestURL
}

View File

@@ -0,0 +1,44 @@
package openai
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var imageResponse ImageResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &imageResponse)
if err != nil {
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, nil
}

View File

@@ -8,8 +8,8 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@@ -46,7 +46,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
data = data[6:] data = data[6:]
if !strings.HasPrefix(data, "[DONE]") { if !strings.HasPrefix(data, "[DONE]") {
switch relayMode { switch relayMode {
case constant.RelayModeChatCompletions: case relaymode.ChatCompletions:
var streamResponse ChatCompletionsStreamResponse var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse) err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil { if err != nil {
@@ -59,7 +59,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
if streamResponse.Usage != nil { if streamResponse.Usage != nil {
usage = streamResponse.Usage usage = streamResponse.Usage
} }
case constant.RelayModeCompletions: case relaymode.Completions:
var streamResponse CompletionsStreamResponse var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse) err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil { if err != nil {
@@ -149,37 +149,3 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
} }
return nil, &textResponse.Usage return nil, &textResponse.Usage
} }
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var imageResponse ImageResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &imageResponse)
if err != nil {
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, nil
}

View File

@@ -4,10 +4,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/pkoukk/tiktoken-go" "github.com/pkoukk/tiktoken-go"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"math" "math"
"strings" "strings"
@@ -28,7 +28,7 @@ func InitTokenEncoders() {
if err != nil { if err != nil {
logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
} }
for model := range common.ModelRatio { for model := range billingratio.ModelRatio {
if strings.HasPrefix(model, "gpt-3.5") { if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") { } else if strings.HasPrefix(model, "gpt-4") {

View File

@@ -4,10 +4,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
) )
@@ -15,16 +15,16 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-goog-api-key", meta.APIKey) req.Header.Set("x-goog-api-key", meta.APIKey)
return nil return nil
} }
@@ -43,11 +43,11 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
return request, nil return request, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody) return adaptor.DoRequestHelper(a, c, meta, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
var responseText string var responseText string
err, responseText = StreamHandler(c, resp) err, responseText = StreamHandler(c, resp)

View File

@@ -7,7 +7,8 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "io"
@@ -74,7 +75,7 @@ func streamResponsePaLM2OpenAI(palmResponse *ChatResponse) *openai.ChatCompletio
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := "" responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID())
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
dataChan := make(chan string) dataChan := make(chan string)
stopChan := make(chan bool) stopChan := make(chan bool)

View File

@@ -4,10 +4,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@@ -19,16 +19,16 @@ type Adaptor struct {
Sign string Sign string
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/hyllm/v1/chat/completions", meta.BaseURL), nil return fmt.Sprintf("%s/hyllm/v1/chat/completions", meta.BaseURL), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", a.Sign) req.Header.Set("Authorization", a.Sign)
req.Header.Set("X-TC-Action", meta.ActualModelName) req.Header.Set("X-TC-Action", meta.ActualModelName)
return nil return nil
@@ -59,11 +59,11 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
return request, nil return request, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody) return adaptor.DoRequestHelper(a, c, meta, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
var responseText string var responseText string
err, responseText = StreamHandler(c, resp) err, responseText = StreamHandler(c, resp)

View File

@@ -13,7 +13,8 @@ import (
"github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "io"
@@ -41,7 +42,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
return &ChatRequest{ return &ChatRequest{
Timestamp: helper.GetTimestamp(), Timestamp: helper.GetTimestamp(),
Expired: helper.GetTimestamp() + 24*60*60, Expired: helper.GetTimestamp() + 24*60*60,
QueryID: helper.GetUUID(), QueryID: random.GetUUID(),
Temperature: request.Temperature, Temperature: request.Temperature,
TopP: request.TopP, TopP: request.TopP,
Stream: stream, Stream: stream,
@@ -71,7 +72,7 @@ func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
response := openai.ChatCompletionsStreamResponse{ response := openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Model: "tencent-hunyuan", Model: "tencent-hunyuan",

View File

@@ -3,10 +3,10 @@ package xunfei
import ( import (
"errors" "errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@@ -16,16 +16,16 @@ type Adaptor struct {
request *model.GeneralOpenAIRequest request *model.GeneralOpenAIRequest
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return "", nil return "", nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
// check DoResponse for auth part // check DoResponse for auth part
return nil return nil
} }
@@ -45,14 +45,14 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
return request, nil return request, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
// xunfei's request is not http request, so we don't need to do anything here // xunfei's request is not http request, so we don't need to do anything here
dummyResp := &http.Response{} dummyResp := &http.Response{}
dummyResp.StatusCode = http.StatusOK dummyResp.StatusCode = http.StatusOK
return dummyResp, nil return dummyResp, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
splits := strings.Split(meta.APIKey, "|") splits := strings.Split(meta.APIKey, "|")
if len(splits) != 3 { if len(splits) != 3 {
return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)

View File

@@ -9,9 +9,11 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "io"
@@ -62,7 +64,7 @@ func getToolCalls(response *ChatResponse) []model.Tool {
return toolCalls return toolCalls
} }
toolCall := model.Tool{ toolCall := model.Tool{
Id: fmt.Sprintf("call_%s", helper.GetUUID()), Id: fmt.Sprintf("call_%s", random.GetUUID()),
Type: "function", Type: "function",
Function: *item.FunctionCall, Function: *item.FunctionCall,
} }
@@ -88,7 +90,7 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
FinishReason: constant.StopFinishReason, FinishReason: constant.StopFinishReason,
} }
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice}, Choices: []openai.TextResponseChoice{choice},
@@ -112,7 +114,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl
choice.FinishReason = &constant.StopFinishReason choice.FinishReason = &constant.StopFinishReason
} }
response := openai.ChatCompletionsStreamResponse{ response := openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Model: "SparkDesk", Model: "SparkDesk",
@@ -278,7 +280,7 @@ func getAPIVersion(c *gin.Context, modelName string) string {
return apiVersion return apiVersion
} }
apiVersion = c.GetString(common.ConfigKeyAPIVersion) apiVersion = c.GetString(config.KeyAPIVersion)
if apiVersion != "" { if apiVersion != "" {
return apiVersion return apiVersion
} }

View File

@@ -4,11 +4,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"math" "math"
"net/http" "net/http"
@@ -19,7 +19,7 @@ type Adaptor struct {
APIVersion string APIVersion string
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
@@ -31,14 +31,17 @@ func (a *Adaptor) SetVersionByModeName(modelName string) {
} }
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
switch meta.Mode {
case relaymode.ImagesGenerations:
return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil
case relaymode.Embeddings:
return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
}
a.SetVersionByModeName(meta.ActualModelName) a.SetVersionByModeName(meta.ActualModelName)
if a.APIVersion == "v4" { if a.APIVersion == "v4" {
return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil
} }
if meta.Mode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
}
method := "invoke" method := "invoke"
if meta.IsStream { if meta.IsStream {
method = "sse-invoke" method = "sse-invoke"
@@ -46,8 +49,8 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
token := GetToken(meta.APIKey) token := GetToken(meta.APIKey)
req.Header.Set("Authorization", token) req.Header.Set("Authorization", token)
return nil return nil
@@ -58,7 +61,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch relayMode { switch relayMode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil return baiduEmbeddingRequest, nil
default: default:
@@ -81,14 +84,19 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
return request, nil newRequest := ImageRequest{
Model: request.Model,
Prompt: request.Prompt,
UserId: request.User,
}
return newRequest, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody) return adaptor.DoRequestHelper(a, c, meta, requestBody)
} }
func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
err, _, usage = openai.StreamHandler(c, resp, meta.Mode) err, _, usage = openai.StreamHandler(c, resp, meta.Mode)
} else { } else {
@@ -97,15 +105,22 @@ func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *util.R
return return
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
switch meta.Mode {
case relaymode.Embeddings:
err, usage = EmbeddingsHandler(c, resp)
return
case relaymode.ImagesGenerations:
err, usage = openai.ImageHandler(c, resp)
return
}
if a.APIVersion == "v4" { if a.APIVersion == "v4" {
return a.DoResponseV4(c, resp, meta) return a.DoResponseV4(c, resp, meta)
} }
if meta.IsStream { if meta.IsStream {
err, usage = StreamHandler(c, resp) err, usage = StreamHandler(c, resp)
} else { } else {
if meta.Mode == constant.RelayModeEmbeddings { if meta.Mode == relaymode.Embeddings {
err, usage = EmbeddingsHandler(c, resp) err, usage = EmbeddingsHandler(c, resp)
} else { } else {
err, usage = Handler(c, resp) err, usage = Handler(c, resp)

View File

@@ -3,4 +3,5 @@ package zhipu
var ModelList = []string{ var ModelList = []string{
"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
"glm-4", "glm-4v", "glm-3-turbo", "embedding-2", "glm-4", "glm-4v", "glm-3-turbo", "embedding-2",
"cogview-3",
} }

View File

@@ -8,7 +8,7 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "io"
@@ -256,7 +256,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
} }
func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var zhipuResponse EmbeddingRespone var zhipuResponse EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -280,7 +280,7 @@ func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithSta
return nil, &fullTextResponse.Usage return nil, &fullTextResponse.Usage
} }
func embeddingResponseZhipu2OpenAI(response *EmbeddingRespone) *openai.EmbeddingResponse { func embeddingResponseZhipu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{ openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list", Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)), Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),

View File

@@ -50,7 +50,7 @@ type EmbeddingRequest struct {
Input string `json:"input"` Input string `json:"input"`
} }
type EmbeddingRespone struct { type EmbeddingResponse struct {
Model string `json:"model"` Model string `json:"model"`
Object string `json:"object"` Object string `json:"object"`
Embeddings []EmbeddingData `json:"data"` Embeddings []EmbeddingData `json:"data"`
@@ -62,3 +62,9 @@ type EmbeddingData struct {
Object string `json:"object"` Object string `json:"object"`
Embedding []float64 `json:"embedding"` Embedding []float64 `json:"embedding"`
} }
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
UserId string `json:"user_id,omitempty"`
}

17
relay/apitype/define.go Normal file
View File

@@ -0,0 +1,17 @@
package apitype
const (
OpenAI = iota
Anthropic
PaLM
Baidu
Zhipu
Ali
Xunfei
AIProxyLibrary
Tencent
Gemini
Ollama
Dummy // this one is only for count, do not add any channel after this
)

42
relay/billing/billing.go Normal file
View File

@@ -0,0 +1,42 @@
package billing
import (
"context"
"fmt"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
)
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
}
}(ctx)
}
}
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
// quotaDelta is remaining quota to be consumed
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
logger.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(ctx, userId)
if err != nil {
logger.SysError("error update user quota cache: " + err.Error())
}
// totalQuota is total quota consumed
if totalQuota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
model.UpdateChannelUsedQuota(channelId, totalQuota)
}
if totalQuota <= 0 {
logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
}
}

View File

@@ -1,4 +1,4 @@
package common package ratio
import ( import (
"encoding/json" "encoding/json"

View File

@@ -0,0 +1,51 @@
package ratio
var ImageSizeRatios = map[string]map[string]float64{
"dall-e-2": {
"256x256": 1,
"512x512": 1.125,
"1024x1024": 1.25,
},
"dall-e-3": {
"1024x1024": 1,
"1024x1792": 2,
"1792x1024": 2,
},
"ali-stable-diffusion-xl": {
"512x1024": 1,
"1024x768": 1,
"1024x1024": 1,
"576x1024": 1,
"1024x576": 1,
},
"ali-stable-diffusion-v1.5": {
"512x1024": 1,
"1024x768": 1,
"1024x1024": 1,
"576x1024": 1,
"1024x576": 1,
},
"wanx-v1": {
"1024x1024": 1,
"720x1280": 1,
"1280x720": 1,
},
}
var ImageGenerationAmounts = map[string][2]int{
"dall-e-2": {1, 10},
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
"ali-stable-diffusion-xl": {1, 4}, // Ali
"ali-stable-diffusion-v1.5": {1, 4}, // Ali
"wanx-v1": {1, 4}, // Ali
"cogview-3": {1, 1},
}
var ImagePromptLengthLimitations = map[string]int{
"dall-e-2": 1000,
"dall-e-3": 4000,
"ali-stable-diffusion-xl": 4000,
"ali-stable-diffusion-v1.5": 4000,
"wanx-v1": 4000,
"cogview-3": 833,
}

View File

@@ -1,4 +1,4 @@
package common package ratio
import ( import (
"encoding/json" "encoding/json"
@@ -62,8 +62,8 @@ var ModelRatio = map[string]float64{
"text-search-ada-doc-001": 10, "text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1, "text-moderation-stable": 0.1,
"text-moderation-latest": 0.1, "text-moderation-latest": 0.1,
"dall-e-2": 8, // $0.016 - $0.020 / image "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image
"dall-e-3": 20, // $0.040 - $0.120 / image "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image
// https://www.anthropic.com/api#pricing // https://www.anthropic.com/api#pricing
"claude-instant-1.2": 0.8 / 1000 * USD, "claude-instant-1.2": 0.8 / 1000 * USD,
"claude-2.0": 8.0 / 1000 * USD, "claude-2.0": 8.0 / 1000 * USD,
@@ -104,6 +104,7 @@ var ModelRatio = map[string]float64{
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"cogview-3": 0.25 * RMB,
// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens "qwen-plus": 1.4286, // ¥0.02 / 1k tokens
@@ -153,6 +154,10 @@ var ModelRatio = map[string]float64{
"yi-34b-chat-0205": 2.5 / 1000 * RMB, "yi-34b-chat-0205": 2.5 / 1000 * RMB,
"yi-34b-chat-200k": 12.0 / 1000 * RMB, "yi-34b-chat-200k": 12.0 / 1000 * RMB,
"yi-vl-plus": 6.0 / 1000 * RMB, "yi-vl-plus": 6.0 / 1000 * RMB,
// stepfun todo
"step-1v-32k": 0.024 * RMB,
"step-1-32k": 0.024 * RMB,
"step-1-200k": 0.15 * RMB,
} }
var CompletionRatio = map[string]float64{} var CompletionRatio = map[string]float64{}

View File

@@ -1,21 +0,0 @@
package channel
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor interface {
Init(meta *util.RelayMeta)
GetRequestURL(meta *util.RelayMeta) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
ConvertImageRequest(request *model.ImageRequest) (any, error)
DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode)
GetModelList() []string
GetChannelName() string
}

Some files were not shown because too many files have changed in this diff Show More