Compare commits

...

8 Commits

Author SHA1 Message Date
JustSong
6ebc99460e fix: add user to blacklist when it's banned or deleted, and make deletion soft (close #473, close #791) 2024-03-10 15:56:19 +08:00
JustSong
27ad8bfb98 feat: able to search channel type now 2024-03-10 15:00:33 +08:00
JustSong
8388aa537f chore: able to search channel now 2024-03-10 14:59:57 +08:00
JustSong
2346bf70af fix: check response type when expect stream response 2024-03-10 14:59:40 +08:00
JustSong
f05b403ca5 feat: use real system prompt now (close #1079) 2024-03-10 14:32:30 +08:00
JustSong
b33616df44 feat: support groq now (close #1087) 2024-03-10 14:09:44 +08:00
JustSong
cf16f44970 feat: load channel models from server 2024-03-09 02:28:23 +08:00
JustSong
bf2e26a48f feat: support claude-3 (close #1080, close #1094) 2024-03-09 01:12:47 +08:00
28 changed files with 491 additions and 327 deletions

View File

@@ -78,6 +78,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [百川大模型](https://platform.baichuan-ai.com)
+ [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP)
+ [x] [MINIMAX](https://api.minimax.chat/)
+ [x] [Groq](https://wow.groq.com/)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。

29
common/blacklist/main.go Normal file
View File

@@ -0,0 +1,29 @@
package blacklist
import (
"fmt"
"sync"
)
var blackList sync.Map
func init() {
blackList = sync.Map{}
}
func userId2Key(id int) string {
return fmt.Sprintf("userid_%d", id)
}
func BanUser(id int) {
blackList.Store(userId2Key(id), true)
}
func UnbanUser(id int) {
blackList.Delete(userId2Key(id))
}
func IsUserBanned(id int) bool {
_, ok := blackList.Load(userId2Key(id))
return ok
}

View File

@@ -52,6 +52,7 @@ var EmailDomainWhitelist = []string{
}
var DebugEnabled = os.Getenv("DEBUG") == "true"
var DebugSQLEnabled = os.Getenv("DEBUG_SQL") == "true"
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
var LogConsumeEnabled = true

View File

@@ -15,6 +15,7 @@ const (
const (
UserStatusEnabled = 1 // don't use 0, 0 is the default value!
UserStatusDisabled = 2 // also don't use 0
UserStatusDeleted = 3
)
const (
@@ -38,35 +39,38 @@ const (
)
const (
ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1
ChannelTypeAPI2D = 2
ChannelTypeAzure = 3
ChannelTypeCloseAI = 4
ChannelTypeOpenAISB = 5
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
ChannelTypeAILS = 9
ChannelTypeAIProxy = 10
ChannelTypePaLM = 11
ChannelTypeAPI2GPT = 12
ChannelTypeAIGC2D = 13
ChannelTypeAnthropic = 14
ChannelTypeBaidu = 15
ChannelTypeZhipu = 16
ChannelTypeAli = 17
ChannelTypeXunfei = 18
ChannelType360 = 19
ChannelTypeOpenRouter = 20
ChannelTypeAIProxyLibrary = 21
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeGemini = 24
ChannelTypeMoonshot = 25
ChannelTypeBaichuan = 26
ChannelTypeMinimax = 27
ChannelTypeMistral = 28
ChannelTypeUnknown = iota
ChannelTypeOpenAI
ChannelTypeAPI2D
ChannelTypeAzure
ChannelTypeCloseAI
ChannelTypeOpenAISB
ChannelTypeOpenAIMax
ChannelTypeOhMyGPT
ChannelTypeCustom
ChannelTypeAILS
ChannelTypeAIProxy
ChannelTypePaLM
ChannelTypeAPI2GPT
ChannelTypeAIGC2D
ChannelTypeAnthropic
ChannelTypeBaidu
ChannelTypeZhipu
ChannelTypeAli
ChannelTypeXunfei
ChannelType360
ChannelTypeOpenRouter
ChannelTypeAIProxyLibrary
ChannelTypeFastGPT
ChannelTypeTencent
ChannelTypeGemini
ChannelTypeMoonshot
ChannelTypeBaichuan
ChannelTypeMinimax
ChannelTypeMistral
ChannelTypeGroq
ChannelTypeDummy
)
var ChannelBaseURLs = []string{
@@ -99,6 +103,7 @@ var ChannelBaseURLs = []string{
"https://api.baichuan-ai.com", // 26
"https://api.minimax.chat", // 27
"https://api.mistral.ai", // 28
"https://api.groq.com/openai", // 29
}
const (

View File

@@ -63,12 +63,15 @@ var ModelRatio = map[string]float64{
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"dall-e-2": 8, // $0.016 - $0.020 / image
"dall-e-3": 20, // $0.040 - $0.120 / image
"claude-instant-1": 0.815, // $1.63 / 1M tokens
"claude-2": 5.51, // $11.02 / 1M tokens
"claude-2.0": 5.51, // $11.02 / 1M tokens
"claude-2.1": 5.51, // $11.02 / 1M tokens
"dall-e-2": 8, // $0.016 - $0.020 / image
"dall-e-3": 20, // $0.040 - $0.120 / image
// https://www.anthropic.com/api#pricing
"claude-instant-1.2": 0.8 / 1000 * USD,
"claude-2.0": 8.0 / 1000 * USD,
"claude-2.1": 8.0 / 1000 * USD,
"claude-3-haiku-20240229": 0.25 / 1000 * USD,
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
"claude-3-opus-20240229": 15.0 / 1000 * USD,
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
@@ -122,6 +125,11 @@ var ModelRatio = map[string]float64{
"mistral-medium-latest": 2.7 / 1000 * USD,
"mistral-large-latest": 8.0 / 1000 * USD,
"mistral-embed": 0.1 / 1000 * USD,
// https://wow.groq.com/
"llama2-70b-4096": 0.7 / 1000 * USD,
"llama2-7b-2048": 0.1 / 1000 * USD,
"mixtral-8x7b-32768": 0.27 / 1000 * USD,
"gemma-7b-it": 0.1 / 1000 * USD,
}
var CompletionRatio = map[string]float64{}
@@ -206,7 +214,7 @@ func GetCompletionRatio(name string) float64 {
return 2
}
}
return 1.333333
return 4.0 / 3.0
}
if strings.HasPrefix(name, "gpt-4") {
if strings.HasSuffix(name, "preview") {
@@ -214,14 +222,18 @@ func GetCompletionRatio(name string) float64 {
}
return 2
}
if strings.HasPrefix(name, "claude-instant-1") {
return 3.38
if strings.HasPrefix(name, "claude-3") {
return 5
}
if strings.HasPrefix(name, "claude-2") {
return 2.965517
if strings.HasPrefix(name, "claude-") {
return 3
}
if strings.HasPrefix(name, "mistral-") {
return 3
}
switch name {
case "llama2-70b-4096":
return 0.8 / 0.7
}
return 1
}

View File

@@ -3,14 +3,13 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/ai360"
"github.com/songquanpeng/one-api/relay/channel/baichuan"
"github.com/songquanpeng/one-api/relay/channel/minimax"
"github.com/songquanpeng/one-api/relay/channel/mistral"
"github.com/songquanpeng/one-api/relay/channel/moonshot"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"net/http"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -42,6 +41,7 @@ type OpenAIModels struct {
var openAIModels []OpenAIModels
var openAIModelsMap map[string]OpenAIModels
var channelId2Models map[int][]string
func init() {
var permission []OpenAIModelPermission
@@ -79,65 +79,44 @@ func init() {
})
}
}
for _, modelName := range ai360.ModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "360",
Permission: permission,
Root: modelName,
Parent: nil,
})
}
for _, modelName := range moonshot.ModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "moonshot",
Permission: permission,
Root: modelName,
Parent: nil,
})
}
for _, modelName := range baichuan.ModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "baichuan",
Permission: permission,
Root: modelName,
Parent: nil,
})
}
for _, modelName := range minimax.ModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "minimax",
Permission: permission,
Root: modelName,
Parent: nil,
})
}
for _, modelName := range mistral.ModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "mistralai",
Permission: permission,
Root: modelName,
Parent: nil,
})
for _, channelType := range openai.CompatibleChannels {
if channelType == common.ChannelTypeAzure {
continue
}
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
for _, modelName := range channelModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: channelName,
Permission: permission,
Root: modelName,
Parent: nil,
})
}
}
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
openAIModelsMap[model.Id] = model
}
channelId2Models = make(map[int][]string)
for i := 1; i < common.ChannelTypeDummy; i++ {
adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i))
meta := &util.RelayMeta{
ChannelType: i,
}
adaptor.Init(meta)
channelId2Models[i] = adaptor.GetModelList()
}
}
func DashboardListModels(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channelId2Models,
})
}
func ListModels(c *gin.Context) {

View File

@@ -4,6 +4,7 @@ import (
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/model"
"net/http"
"strings"
@@ -42,11 +43,14 @@ func authHelper(c *gin.Context, minRole int) {
return
}
}
if status.(int) == common.UserStatusDisabled {
if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已被封禁",
})
session := sessions.Default(c)
session.Clear()
_ = session.Save()
c.Abort()
return
}
@@ -99,7 +103,7 @@ func TokenAuth() func(c *gin.Context) {
abortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !userEnabled {
if !userEnabled || blacklist.IsUserBanned(token.UserId) {
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
return
}

View File

@@ -72,7 +72,7 @@ func chooseDB() (*gorm.DB, error) {
func InitDB() (err error) {
db, err := chooseDB()
if err == nil {
if config.DebugEnabled {
if config.DebugSQLEnabled {
db = db.Debug()
}
DB = db

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
@@ -40,7 +41,7 @@ func GetMaxUserId() int {
}
func GetAllUsers(startIdx int, num int) (users []*User, err error) {
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted).Find(&users).Error
return users, err
}
@@ -123,6 +124,11 @@ func (user *User) Update(updatePassword bool) error {
return err
}
}
if user.Status == common.UserStatusDisabled {
blacklist.BanUser(user.Id)
} else if user.Status == common.UserStatusEnabled {
blacklist.UnbanUser(user.Id)
}
err = DB.Model(user).Updates(user).Error
return err
}
@@ -131,7 +137,10 @@ func (user *User) Delete() error {
if user.Id == 0 {
return errors.New("id 为空!")
}
err := DB.Delete(user).Error
blacklist.BanUser(user.Id)
user.Username = fmt.Sprintf("deleted_%s", helper.GetUUID())
user.Status = common.UserStatusDeleted
err := DB.Model(user).Updates(user).Error
return err
}

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
@@ -20,7 +19,7 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/v1/complete", meta.BaseURL), nil
return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
@@ -31,6 +30,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *ut
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
req.Header.Set("anthropic-beta", "messages-2023-12-15")
return nil
}
@@ -47,9 +47,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText = StreamHandler(c, resp)
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}

View File

@@ -1,5 +1,8 @@
package anthropic
var ModelList = []string{
"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1",
"claude-instant-1.2", "claude-2.0", "claude-2.1",
"claude-3-haiku-20240229",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
@@ -15,73 +16,135 @@ import (
"strings"
)
func stopReasonClaude2OpenAI(reason string) string {
switch reason {
func stopReasonClaude2OpenAI(reason *string) string {
if reason == nil {
return ""
}
switch *reason {
case "end_turn":
return "stop"
case "stop_sequence":
return "stop"
case "max_tokens":
return "length"
default:
return reason
return *reason
}
}
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
claudeRequest := Request{
Model: textRequest.Model,
Prompt: "",
MaxTokensToSample: textRequest.MaxTokens,
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
Stream: textRequest.Stream,
Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokensToSample == 0 {
claudeRequest.MaxTokensToSample = 1000000
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
}
// legacy model name mapping
if claudeRequest.Model == "claude-instant-1" {
claudeRequest.Model = "claude-instant-1.1"
} else if claudeRequest.Model == "claude-2" {
claudeRequest.Model = "claude-2.1"
}
prompt := ""
for _, message := range textRequest.Messages {
if message.Role == "user" {
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else if message.Role == "system" {
if prompt == "" {
prompt = message.StringContent()
}
if message.Role == "system" && claudeRequest.System == "" {
claudeRequest.System = message.StringContent()
continue
}
claudeMessage := Message{
Role: message.Role,
}
var content Content
if message.IsStringContent() {
content.Type = "text"
content.Text = message.StringContent()
claudeMessage.Content = append(claudeMessage.Content, content)
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
continue
}
var contents []Content
openaiContent := message.ParseContent()
for _, part := range openaiContent {
var content Content
if part.Type == model.ContentTypeText {
content.Type = "text"
content.Text = part.Text
} else if part.Type == model.ContentTypeImageURL {
content.Type = "image"
content.Source = &ImageSource{
Type: "base64",
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
content.Source.MediaType = mimeType
content.Source.Data = data
}
contents = append(contents, content)
}
claudeMessage.Content = contents
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
}
prompt += "\n\nAssistant:"
claudeRequest.Prompt = prompt
return &claudeRequest
}
func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletionsStreamResponse {
// https://docs.anthropic.com/claude/reference/messages-streaming
func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
var response *Response
var responseText string
var stopReason string
switch claudeResponse.Type {
case "message_start":
return nil, claudeResponse.Message
case "content_block_start":
if claudeResponse.ContentBlock != nil {
responseText = claudeResponse.ContentBlock.Text
}
case "content_block_delta":
if claudeResponse.Delta != nil {
responseText = claudeResponse.Delta.Text
}
case "message_delta":
if claudeResponse.Usage != nil {
response = &Response{
Usage: *claudeResponse.Usage,
}
}
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
stopReason = *claudeResponse.Delta.StopReason
}
}
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
choice.Delta.Content = responseText
choice.Delta.Role = "assistant"
finishReason := stopReasonClaude2OpenAI(&stopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var response openai.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &response
var openaiResponse openai.ChatCompletionsStreamResponse
openaiResponse.Object = "chat.completion.chunk"
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &openaiResponse, response
}
func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
var responseText string
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
}
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Content: responseText,
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Id: fmt.Sprintf("chatcmpl-%s", claudeResponse.Id),
Model: claudeResponse.Model,
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
@@ -89,17 +152,15 @@ func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
return &fullTextResponse
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
return i + 4, data[0:i], nil
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
@@ -111,29 +172,45 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
go func() {
for scanner.Scan() {
data := scanner.Text()
if !strings.HasPrefix(data, "event: completion") {
if len(data) < 6 {
continue
}
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
var usage model.Usage
var modelName string
var id string
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse Response
var claudeResponse StreamResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
responseText += claudeResponse.Completion
response := streamResponseClaude2OpenAI(&claudeResponse)
response.Id = responseId
response, meta := streamResponseClaude2OpenAI(&claudeResponse)
if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens
modelName = meta.Model
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true
}
if response == nil {
return true
}
response.Id = id
response.Model = modelName
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
@@ -147,11 +224,8 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
return false
}
})
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
_ = resp.Body.Close()
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
@@ -181,11 +255,10 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
fullTextResponse.Model = modelName
completionTokens := openai.CountTokenText(claudeResponse.Completion, modelName)
usage := model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
PromptTokens: claudeResponse.Usage.InputTokens,
CompletionTokens: claudeResponse.Usage.OutputTokens,
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)

View File

@@ -1,19 +1,44 @@
package anthropic
// https://docs.anthropic.com/claude/reference/messages_post
type Metadata struct {
UserId string `json:"user_id"`
}
type ImageSource struct {
Type string `json:"type"`
MediaType string `json:"media_type"`
Data string `json:"data"`
}
type Content struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Source *ImageSource `json:"source,omitempty"`
}
type Message struct {
Role string `json:"role"`
Content []Content `json:"content"`
}
type Request struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokensToSample int `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Model string `json:"model"`
Messages []Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//Metadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type Error struct {
@@ -22,8 +47,29 @@ type Error struct {
}
type Response struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error Error `json:"error"`
Id string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Content []Content `json:"content"`
Model string `json:"model"`
StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"`
Usage Usage `json:"usage"`
Error Error `json:"error"`
}
type Delta struct {
Type string `json:"type"`
Text string `json:"text"`
StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"`
}
type StreamResponse struct {
Type string `json:"type"`
Message *Response `json:"message"`
Index int `json:"index"`
ContentBlock *Content `json:"content_block"`
Delta *Delta `json:"delta"`
Usage *Usage `json:"usage"`
}

View File

@@ -2,6 +2,7 @@ package baidu
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
@@ -9,6 +10,7 @@ import (
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strings"
)
type Adaptor struct {
@@ -20,25 +22,33 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
var fullRequestURL string
switch meta.ActualModelName {
case "ERNIE-Bot-4":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "ERNIE-Bot-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
case "ERNIE-Bot":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Speed":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
default:
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + meta.ActualModelName
suffix := "chat/"
if strings.HasPrefix("Embedding", meta.ActualModelName) {
suffix = "embeddings/"
}
switch meta.ActualModelName {
case "ERNIE-4.0":
suffix += "completions_pro"
case "ERNIE-Bot-4":
suffix += "completions_pro"
case "ERNIE-3.5-8K":
suffix += "completions"
case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k"
case "ERNIE-Bot":
suffix += "completions"
case "ERNIE-Speed":
suffix += "ernie_speed"
case "ERNIE-Bot-turbo":
suffix += "eb-instant"
case "BLOOMZ-7B":
suffix += "bloomz_7b1"
case "Embedding-V1":
suffix += "embedding-v1"
default:
suffix += meta.ActualModelName
}
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix)
var accessToken string
var err error
if accessToken, err = GetAccessToken(meta.APIKey); err != nil {

View File

@@ -0,0 +1,10 @@
package groq
// https://console.groq.com/docs/models
var ModelList = []string{
"gemma-7b-it",
"llama2-7b-2048",
"llama2-70b-4096",
"mixtral-8x7b-32768",
}

View File

@@ -6,11 +6,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/ai360"
"github.com/songquanpeng/one-api/relay/channel/baichuan"
"github.com/songquanpeng/one-api/relay/channel/minimax"
"github.com/songquanpeng/one-api/relay/channel/mistral"
"github.com/songquanpeng/one-api/relay/channel/moonshot"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
@@ -86,37 +82,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
}
func (a *Adaptor) GetModelList() []string {
switch a.ChannelType {
case common.ChannelType360:
return ai360.ModelList
case common.ChannelTypeMoonshot:
return moonshot.ModelList
case common.ChannelTypeBaichuan:
return baichuan.ModelList
case common.ChannelTypeMinimax:
return minimax.ModelList
case common.ChannelTypeMistral:
return mistral.ModelList
default:
return ModelList
}
_, modelList := GetCompatibleChannelMeta(a.ChannelType)
return modelList
}
func (a *Adaptor) GetChannelName() string {
switch a.ChannelType {
case common.ChannelTypeAzure:
return "azure"
case common.ChannelType360:
return "360"
case common.ChannelTypeMoonshot:
return "moonshot"
case common.ChannelTypeBaichuan:
return "baichuan"
case common.ChannelTypeMinimax:
return "minimax"
case common.ChannelTypeMistral:
return "mistralai"
default:
return "openai"
}
channelName, _ := GetCompatibleChannelMeta(a.ChannelType)
return channelName
}

View File

@@ -0,0 +1,42 @@
package openai
import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel/ai360"
"github.com/songquanpeng/one-api/relay/channel/baichuan"
"github.com/songquanpeng/one-api/relay/channel/groq"
"github.com/songquanpeng/one-api/relay/channel/minimax"
"github.com/songquanpeng/one-api/relay/channel/mistral"
"github.com/songquanpeng/one-api/relay/channel/moonshot"
)
var CompatibleChannels = []int{
common.ChannelTypeAzure,
common.ChannelType360,
common.ChannelTypeMoonshot,
common.ChannelTypeBaichuan,
common.ChannelTypeMinimax,
common.ChannelTypeMistral,
common.ChannelTypeGroq,
}
func GetCompatibleChannelMeta(channelType int) (string, []string) {
switch channelType {
case common.ChannelTypeAzure:
return "azure", ModelList
case common.ChannelType360:
return "360", ai360.ModelList
case common.ChannelTypeMoonshot:
return "moonshot", moonshot.ModelList
case common.ChannelTypeBaichuan:
return "baichuan", baichuan.ModelList
case common.ChannelTypeMinimax:
return "minimax", minimax.ModelList
case common.ChannelTypeMistral:
return "mistralai", mistral.ModelList
case common.ChannelTypeGroq:
return "groq", groq.ModelList
default:
return "openai", ModelList
}
}

View File

@@ -28,17 +28,6 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, Message{
Role: "user",
Content: message.StringContent(),
})
messages = append(messages, Message{
Role: "assistant",
Content: "Okay",
})
continue
}
messages = append(messages, Message{
Content: message.StringContent(),
Role: message.Role,

View File

@@ -27,21 +27,10 @@ import (
func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, Message{
Role: "user",
Content: message.StringContent(),
})
messages = append(messages, Message{
Role: "assistant",
Content: "Okay",
})
} else {
messages = append(messages, Message{
Role: message.Role,
Content: message.StringContent(),
})
}
messages = append(messages, Message{
Role: message.Role,
Content: message.StringContent(),
})
}
xunfeiRequest := ChatRequest{}
xunfeiRequest.Header.AppId = xunfeiAppId

View File

@@ -76,21 +76,10 @@ func GetToken(apikey string) string {
func ConvertRequest(request model.GeneralOpenAIRequest) *Request {
messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, Message{
Role: "system",
Content: message.StringContent(),
})
messages = append(messages, Message{
Role: "user",
Content: "Okay",
})
} else {
messages = append(messages, Message{
Role: message.Role,
Content: message.StringContent(),
})
}
messages = append(messages, Message{
Role: message.Role,
Content: message.StringContent(),
})
}
return &Request{
Prompt: messages,

View File

@@ -83,11 +83,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK {
errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json")
if errorHappened {
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return util.RelayErrorHandler(resp)
}
meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
// do response
usage, respErr := adaptor.DoResponse(c, resp, meta)

View File

@@ -14,6 +14,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.Use(middleware.GlobalAPIRateLimit())
{
apiRouter.GET("/status", controller.GetStatus)
apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
apiRouter.GET("/notice", controller.GetNotice)
apiRouter.GET("/about", controller.GetAbout)
apiRouter.GET("/home_page_content", controller.GetHomePageContent)

View File

@@ -15,7 +15,7 @@ export const CHANNEL_OPTIONS = {
key: 3,
text: 'Azure OpenAI',
value: 3,
color: 'orange'
color: 'secondary'
},
11: {
key: 11,
@@ -89,6 +89,12 @@ export const CHANNEL_OPTIONS = {
value: 27,
color: 'default'
},
29: {
key: 29,
text: 'Groq',
value: 29,
color: 'default'
},
8: {
key: 8,
text: '自定义渠道',

View File

@@ -163,6 +163,9 @@ const typeConfig = {
},
modelGroup: "minimax",
},
29: {
modelGroup: "groq",
},
};
export { defaultConfig, typeConfig };

View File

@@ -1,7 +1,16 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react';
import { Link } from 'react-router-dom';
import { API, setPromptShown, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
import {
API,
loadChannelModels,
setPromptShown,
shouldShowPrompt,
showError,
showInfo,
showSuccess,
timestamp2string
} from '../helpers';
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
import { renderGroup, renderNumber } from '../helpers/render';
@@ -95,6 +104,7 @@ const ChannelsTable = () => {
.catch((reason) => {
showError(reason);
});
loadChannelModels().then();
}, []);
const manageChannel = async (id, action, idx, value) => {

View File

@@ -14,6 +14,7 @@ export const CHANNEL_OPTIONS = [
{ key: 23, text: '腾讯混元', value: 23, color: 'teal' },
{ key: 26, text: '百川大模型', value: 26, color: 'orange' },
{ key: 27, text: 'MiniMax', value: 27, color: 'red' },
{ key: 29, text: 'Groq', value: 29, color: 'orange' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 22, text: '知识库FastGPT', value: 22, color: 'blue' },
{ key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' },

View File

@@ -1,11 +1,13 @@
import { toast } from 'react-toastify';
import { toastConstants } from '../constants';
import React from 'react';
import { API } from './api';
const HTMLToastContent = ({ htmlContent }) => {
return <div dangerouslySetInnerHTML={{ __html: htmlContent }} />;
};
export default HTMLToastContent;
export function isAdmin() {
let user = localStorage.getItem('user');
if (!user) return false;
@@ -29,7 +31,7 @@ export function getSystemName() {
export function getLogo() {
let logo = localStorage.getItem('logo');
if (!logo) return '/logo.png';
return logo
return logo;
}
export function getFooterHTML() {
@@ -196,4 +198,30 @@ export function shouldShowPrompt(id) {
export function setPromptShown(id) {
localStorage.setItem(`prompt-${id}`, 'true');
}
let channelModels = undefined;
export async function loadChannelModels() {
const res = await API.get('/api/models');
const { success, data } = res.data;
if (!success) {
return;
}
channelModels = data;
localStorage.setItem('channel_models', JSON.stringify(data));
}
export function getChannelModels(type) {
if (channelModels !== undefined && type in channelModels) {
return channelModels[type];
}
let models = localStorage.getItem('channel_models');
if (!models) {
return [];
}
channelModels = JSON.parse(models);
if (type in channelModels) {
return channelModels[type];
}
return [];
}

View File

@@ -1,7 +1,7 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react';
import { useNavigate, useParams } from 'react-router-dom';
import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers';
import { API, copy, getChannelModels, showError, showInfo, showSuccess, verifyJSON } from '../../helpers';
import { CHANNEL_OPTIONS } from '../../constants';
const MODEL_MAPPING_EXAMPLE = {
@@ -56,60 +56,12 @@ const EditChannel = () => {
const [customModel, setCustomModel] = useState('');
const handleInputChange = (e, { name, value }) => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
if (name === 'type' && inputs.models.length === 0) {
let localModels = [];
switch (value) {
case 14:
localModels = ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1'];
break;
case 11:
localModels = ['PaLM-2'];
break;
case 15:
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1'];
break;
case 17:
localModels = ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1'];
let withInternetVersion = [];
for (let i = 0; i < localModels.length; i++) {
if (localModels[i].startsWith('qwen-')) {
withInternetVersion.push(localModels[i] + '-internet');
}
}
localModels = [...localModels, ...withInternetVersion];
break;
case 16:
localModels = ["glm-4", "glm-4v", "glm-3-turbo",'chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite'];
break;
case 18:
localModels = [
'SparkDesk',
'SparkDesk-v1.1',
'SparkDesk-v2.1',
'SparkDesk-v3.1',
'SparkDesk-v3.5'
];
break;
case 19:
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];
break;
case 23:
localModels = ['hunyuan'];
break;
case 24:
localModels = ['gemini-pro', 'gemini-pro-vision'];
break;
case 25:
localModels = ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k'];
break;
case 26:
localModels = ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan-Text-Embedding'];
break;
case 27:
localModels = ['abab5.5s-chat', 'abab5.5-chat', 'abab6-chat'];
break;
if (name === 'type') {
let localModels = getChannelModels(value);
if (inputs.models.length === 0) {
setInputs((inputs) => ({ ...inputs, models: localModels }));
}
setInputs((inputs) => ({ ...inputs, models: localModels }));
setBasicModels(localModels);
}
};
@@ -262,6 +214,7 @@ const EditChannel = () => {
label='类型'
name='type'
required
search
options={CHANNEL_OPTIONS}
value={inputs.type}
onChange={handleInputChange}
@@ -390,6 +343,8 @@ const EditChannel = () => {
required
fluid
multiple
search
onLabelClick={(e, { value }) => {copy(value).then()}}
selection
onChange={handleInputChange}
value={inputs.models}