mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-24 10:23:41 +08:00
Compare commits
20 Commits
v0.4.9
...
v0.5.0-alp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e6938bd236 | ||
|
|
8f721d67a5 | ||
|
|
fcc1e2d568 | ||
|
|
9a1db61675 | ||
|
|
3c940113ab | ||
|
|
0495b9a0d7 | ||
|
|
12a0e7105e | ||
|
|
e628b643cd | ||
|
|
675847bf98 | ||
|
|
2ff15baf66 | ||
|
|
4139a7036f | ||
|
|
02da0b51f8 | ||
|
|
35cfebee12 | ||
|
|
0e088f7c3e | ||
|
|
f61d326721 | ||
|
|
74b06b643a | ||
|
|
ccf7709e23 | ||
|
|
d592e2c8b8 | ||
|
|
b520b54625 | ||
|
|
81c5901123 |
16
README.md
16
README.md
@@ -61,6 +61,9 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
|
|||||||
1. 支持多种 API 访问渠道:
|
1. 支持多种 API 访问渠道:
|
||||||
+ [x] OpenAI 官方通道(支持配置镜像)
|
+ [x] OpenAI 官方通道(支持配置镜像)
|
||||||
+ [x] **Azure OpenAI API**
|
+ [x] **Azure OpenAI API**
|
||||||
|
+ [x] [Anthropic Claude 系列模型](https://anthropic.com)
|
||||||
|
+ [x] [Google PaLM2 系列模型](https://developers.generativeai.google)
|
||||||
|
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
|
||||||
+ [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)
|
+ [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)
|
||||||
+ [x] [OpenAI-SB](https://openai-sb.com)
|
+ [x] [OpenAI-SB](https://openai-sb.com)
|
||||||
+ [x] [API2D](https://api2d.com/r/197971)
|
+ [x] [API2D](https://api2d.com/r/197971)
|
||||||
@@ -81,16 +84,19 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
|
|||||||
12. 支持以美元为单位显示额度。
|
12. 支持以美元为单位显示额度。
|
||||||
13. 支持发布公告,设置充值链接,设置新用户初始额度。
|
13. 支持发布公告,设置充值链接,设置新用户初始额度。
|
||||||
14. 支持模型映射,重定向用户的请求模型。
|
14. 支持模型映射,重定向用户的请求模型。
|
||||||
15. 支持丰富的**自定义**设置,
|
15. 支持失败自动重试。
|
||||||
|
16. 支持绘图接口。
|
||||||
|
17. 支持丰富的**自定义**设置,
|
||||||
1. 支持自定义系统名称,logo 以及页脚。
|
1. 支持自定义系统名称,logo 以及页脚。
|
||||||
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
|
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
|
||||||
16. 支持通过系统访问令牌访问管理 API。
|
18. 支持通过系统访问令牌访问管理 API。
|
||||||
17. 支持 Cloudflare Turnstile 用户校验。
|
19. 支持 Cloudflare Turnstile 用户校验。
|
||||||
18. 支持用户管理,支持**多种用户登录注册方式**:
|
20. 支持用户管理,支持**多种用户登录注册方式**:
|
||||||
+ 邮箱登录注册以及通过邮箱进行密码重置。
|
+ 邮箱登录注册以及通过邮箱进行密码重置。
|
||||||
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
|
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
|
||||||
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
|
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
|
||||||
19. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
|
21. 支持 [ChatGLM](https://github.com/THUDM/ChatGLM2-6B)。
|
||||||
|
22. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
|
||||||
|
|
||||||
## 部署
|
## 部署
|
||||||
### 基于 Docker 进行部署
|
### 基于 Docker 进行部署
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ var AutomaticDisableChannelEnabled = false
|
|||||||
var QuotaRemindThreshold = 1000
|
var QuotaRemindThreshold = 1000
|
||||||
var PreConsumedQuota = 500
|
var PreConsumedQuota = 500
|
||||||
var ApproximateTokenEnabled = false
|
var ApproximateTokenEnabled = false
|
||||||
|
var RetryTimes = 0
|
||||||
|
|
||||||
var RootUserEmail = ""
|
var RootUserEmail = ""
|
||||||
|
|
||||||
@@ -150,6 +151,8 @@ const (
|
|||||||
ChannelTypePaLM = 11
|
ChannelTypePaLM = 11
|
||||||
ChannelTypeAPI2GPT = 12
|
ChannelTypeAPI2GPT = 12
|
||||||
ChannelTypeAIGC2D = 13
|
ChannelTypeAIGC2D = 13
|
||||||
|
ChannelTypeAnthropic = 14
|
||||||
|
ChannelTypeBaidu = 15
|
||||||
)
|
)
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
var ChannelBaseURLs = []string{
|
||||||
@@ -167,4 +170,6 @@ var ChannelBaseURLs = []string{
|
|||||||
"", // 11
|
"", // 11
|
||||||
"https://api.api2gpt.com", // 12
|
"https://api.api2gpt.com", // 12
|
||||||
"https://api.aigc2d.com", // 13
|
"https://api.aigc2d.com", // 13
|
||||||
|
"https://api.anthropic.com", // 14
|
||||||
|
"https://aip.baidubce.com", // 15
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import "encoding/json"
|
|||||||
|
|
||||||
// ModelRatio
|
// ModelRatio
|
||||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||||
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
|
||||||
// https://openai.com/pricing
|
// https://openai.com/pricing
|
||||||
// TODO: when a new api is enabled, check the pricing here
|
// TODO: when a new api is enabled, check the pricing here
|
||||||
// 1 === $0.002 / 1K tokens
|
// 1 === $0.002 / 1K tokens
|
||||||
@@ -35,6 +36,12 @@ var ModelRatio = map[string]float64{
|
|||||||
"text-search-ada-doc-001": 10,
|
"text-search-ada-doc-001": 10,
|
||||||
"text-moderation-stable": 0.1,
|
"text-moderation-stable": 0.1,
|
||||||
"text-moderation-latest": 0.1,
|
"text-moderation-latest": 0.1,
|
||||||
|
"dall-e": 8,
|
||||||
|
"claude-instant-1": 0.75,
|
||||||
|
"claude-2": 30,
|
||||||
|
"ERNIE-Bot": 1, // 0.012元/千tokens
|
||||||
|
"ERNIE-Bot-turbo": 0.67, // 0.008元/千tokens
|
||||||
|
"PaLM-2": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModelRatio2JSONString() string {
|
func ModelRatio2JSONString() string {
|
||||||
|
|||||||
@@ -7,16 +7,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func GetSubscription(c *gin.Context) {
|
func GetSubscription(c *gin.Context) {
|
||||||
var quota int
|
var remainQuota int
|
||||||
|
var usedQuota int
|
||||||
var err error
|
var err error
|
||||||
var token *model.Token
|
var token *model.Token
|
||||||
if common.DisplayTokenStatEnabled {
|
if common.DisplayTokenStatEnabled {
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
token, err = model.GetTokenById(tokenId)
|
token, err = model.GetTokenById(tokenId)
|
||||||
quota = token.RemainQuota
|
remainQuota = token.RemainQuota
|
||||||
|
usedQuota = token.UsedQuota
|
||||||
} else {
|
} else {
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
quota, err = model.GetUserQuota(userId)
|
remainQuota, err = model.GetUserQuota(userId)
|
||||||
|
usedQuota, err = model.GetUserUsedQuota(userId)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openAIError := OpenAIError{
|
openAIError := OpenAIError{
|
||||||
@@ -28,6 +31,7 @@ func GetSubscription(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
quota := remainQuota + usedQuota
|
||||||
amount := float64(quota)
|
amount := float64(quota)
|
||||||
if common.DisplayInCurrencyEnabled {
|
if common.DisplayInCurrencyEnabled {
|
||||||
amount /= common.QuotaPerUnit
|
amount /= common.QuotaPerUnit
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, request ChatRequest) error {
|
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypeAzure:
|
case common.ChannelTypeAzure:
|
||||||
request.Model = "gpt-35-turbo"
|
request.Model = "gpt-35-turbo"
|
||||||
@@ -33,11 +33,11 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
|
|||||||
|
|
||||||
jsonData, err := json.Marshal(request)
|
jsonData, err := json.Marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err, nil
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
|
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err, nil
|
||||||
}
|
}
|
||||||
if channel.Type == common.ChannelTypeAzure {
|
if channel.Type == common.ChannelTypeAzure {
|
||||||
req.Header.Set("api-key", channel.Key)
|
req.Header.Set("api-key", channel.Key)
|
||||||
@@ -48,18 +48,18 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
|
|||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err, nil
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
var response TextResponse
|
var response TextResponse
|
||||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err, nil
|
||||||
}
|
}
|
||||||
if response.Usage.CompletionTokens == 0 {
|
if response.Usage.CompletionTokens == 0 {
|
||||||
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
|
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
|
||||||
}
|
}
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildTestRequest() *ChatRequest {
|
func buildTestRequest() *ChatRequest {
|
||||||
@@ -94,7 +94,7 @@ func TestChannel(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
testRequest := buildTestRequest()
|
testRequest := buildTestRequest()
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
err = testChannel(channel, *testRequest)
|
err, _ = testChannel(channel, *testRequest)
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
go channel.UpdateResponseTime(milliseconds)
|
go channel.UpdateResponseTime(milliseconds)
|
||||||
@@ -158,13 +158,14 @@ func testAllChannels(notify bool) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
err := testChannel(channel, *testRequest)
|
err, openaiErr := testChannel(channel, *testRequest)
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
if err != nil || milliseconds > disableThreshold {
|
if milliseconds > disableThreshold {
|
||||||
if milliseconds > disableThreshold {
|
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||||
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
disableChannel(channel.Id, channel.Name, err.Error())
|
||||||
}
|
}
|
||||||
|
if shouldDisableChannel(openaiErr) {
|
||||||
disableChannel(channel.Id, channel.Name, err.Error())
|
disableChannel(channel.Id, channel.Name, err.Error())
|
||||||
}
|
}
|
||||||
channel.UpdateResponseTime(milliseconds)
|
channel.UpdateResponseTime(milliseconds)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -53,6 +54,15 @@ func init() {
|
|||||||
})
|
})
|
||||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||||
openAIModels = []OpenAIModels{
|
openAIModels = []OpenAIModels{
|
||||||
|
{
|
||||||
|
Id: "dall-e",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "dall-e",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Id: "gpt-3.5-turbo",
|
Id: "gpt-3.5-turbo",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -242,6 +252,69 @@ func init() {
|
|||||||
Root: "code-davinci-edit-001",
|
Root: "code-davinci-edit-001",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Id: "ChatGLM",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "thudm",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "ChatGLM",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "ChatGLM2",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "thudm",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "ChatGLM2",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "claude-instant-1",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "anturopic",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "claude-instant-1",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "claude-2",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "anturopic",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "claude-2",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "ERNIE-Bot",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "baidu",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "ERNIE-Bot",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "ERNIE-Bot-turbo",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "baidu",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "ERNIE-Bot-turbo",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "PaLM-2",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "PaLM-2",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
openAIModelsMap = make(map[string]OpenAIModels)
|
openAIModelsMap = make(map[string]OpenAIModels)
|
||||||
for _, model := range openAIModels {
|
for _, model := range openAIModels {
|
||||||
|
|||||||
203
controller/relay-baidu.go
Normal file
203
controller/relay-baidu.go
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
||||||
|
|
||||||
|
type BaiduTokenResponse struct {
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
SessionKey string `json:"session_key"`
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
SessionSecret string `json:"session_secret"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduChatRequest struct {
|
||||||
|
Messages []BaiduMessage `json:"messages"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
UserId string `json:"user_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduError struct {
|
||||||
|
ErrorCode int `json:"error_code"`
|
||||||
|
ErrorMsg string `json:"error_msg"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduChatResponse struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Result string `json:"result"`
|
||||||
|
IsTruncated bool `json:"is_truncated"`
|
||||||
|
NeedClearHistory bool `json:"need_clear_history"`
|
||||||
|
Usage Usage `json:"usage"`
|
||||||
|
BaiduError
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduChatStreamResponse struct {
|
||||||
|
BaiduChatResponse
|
||||||
|
SentenceId int `json:"sentence_id"`
|
||||||
|
IsEnd bool `json:"is_end"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
||||||
|
messages := make([]BaiduMessage, 0, len(request.Messages))
|
||||||
|
for _, message := range request.Messages {
|
||||||
|
messages = append(messages, BaiduMessage{
|
||||||
|
Role: message.Role,
|
||||||
|
Content: message.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return &BaiduChatRequest{
|
||||||
|
Messages: messages,
|
||||||
|
Stream: request.Stream,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
|
||||||
|
choice := OpenAITextResponseChoice{
|
||||||
|
Index: 0,
|
||||||
|
Message: Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: response.Result,
|
||||||
|
},
|
||||||
|
FinishReason: "stop",
|
||||||
|
}
|
||||||
|
fullTextResponse := OpenAITextResponse{
|
||||||
|
Id: response.Id,
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: response.Created,
|
||||||
|
Choices: []OpenAITextResponseChoice{choice},
|
||||||
|
Usage: response.Usage,
|
||||||
|
}
|
||||||
|
return &fullTextResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
|
||||||
|
var choice ChatCompletionsStreamResponseChoice
|
||||||
|
choice.Delta.Content = baiduResponse.Result
|
||||||
|
choice.FinishReason = "stop"
|
||||||
|
response := ChatCompletionsStreamResponse{
|
||||||
|
Id: baiduResponse.Id,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: baiduResponse.Created,
|
||||||
|
Model: "ernie-bot",
|
||||||
|
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||||
|
}
|
||||||
|
return &response
|
||||||
|
}
|
||||||
|
|
||||||
|
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
var usage Usage
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||||
|
return i + 1, data[0:i], nil
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
dataChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if len(data) < 6 { // ignore blank line or wrong format
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = data[6:]
|
||||||
|
dataChan <- data
|
||||||
|
}
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
var baiduResponse BaiduChatStreamResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &baiduResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
usage.PromptTokens += baiduResponse.Usage.PromptTokens
|
||||||
|
usage.CompletionTokens += baiduResponse.Usage.CompletionTokens
|
||||||
|
usage.TotalTokens += baiduResponse.Usage.TotalTokens
|
||||||
|
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
var baiduResponse BaiduChatResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
if baiduResponse.ErrorMsg != "" {
|
||||||
|
return &OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: OpenAIError{
|
||||||
|
Message: baiduResponse.ErrorMsg,
|
||||||
|
Type: "baidu_error",
|
||||||
|
Param: "",
|
||||||
|
Code: baiduResponse.ErrorCode,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &fullTextResponse.Usage
|
||||||
|
}
|
||||||
221
controller/relay-claude.go
Normal file
221
controller/relay-claude.go
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ClaudeMetadata struct {
|
||||||
|
UserId string `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
MaxTokensToSample int `json:"max_tokens_to_sample"`
|
||||||
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeError struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeResponse struct {
|
||||||
|
Completion string `json:"completion"`
|
||||||
|
StopReason string `json:"stop_reason"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Error ClaudeError `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func stopReasonClaude2OpenAI(reason string) string {
|
||||||
|
switch reason {
|
||||||
|
case "stop_sequence":
|
||||||
|
return "stop"
|
||||||
|
case "max_tokens":
|
||||||
|
return "length"
|
||||||
|
default:
|
||||||
|
return reason
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
|
||||||
|
claudeRequest := ClaudeRequest{
|
||||||
|
Model: textRequest.Model,
|
||||||
|
Prompt: "",
|
||||||
|
MaxTokensToSample: textRequest.MaxTokens,
|
||||||
|
StopSequences: nil,
|
||||||
|
Temperature: textRequest.Temperature,
|
||||||
|
TopP: textRequest.TopP,
|
||||||
|
Stream: textRequest.Stream,
|
||||||
|
}
|
||||||
|
if claudeRequest.MaxTokensToSample == 0 {
|
||||||
|
claudeRequest.MaxTokensToSample = 1000000
|
||||||
|
}
|
||||||
|
prompt := ""
|
||||||
|
for _, message := range textRequest.Messages {
|
||||||
|
if message.Role == "user" {
|
||||||
|
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
|
||||||
|
} else if message.Role == "assistant" {
|
||||||
|
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
|
||||||
|
} else {
|
||||||
|
// ignore other roles
|
||||||
|
}
|
||||||
|
prompt += "\n\nAssistant:"
|
||||||
|
}
|
||||||
|
claudeRequest.Prompt = prompt
|
||||||
|
return &claudeRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
|
||||||
|
var choice ChatCompletionsStreamResponseChoice
|
||||||
|
choice.Delta.Content = claudeResponse.Completion
|
||||||
|
choice.FinishReason = stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
||||||
|
var response ChatCompletionsStreamResponse
|
||||||
|
response.Object = "chat.completion.chunk"
|
||||||
|
response.Model = claudeResponse.Model
|
||||||
|
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
|
||||||
|
return &response
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
|
||||||
|
choice := OpenAITextResponseChoice{
|
||||||
|
Index: 0,
|
||||||
|
Message: Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
|
||||||
|
Name: nil,
|
||||||
|
},
|
||||||
|
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||||
|
}
|
||||||
|
fullTextResponse := OpenAITextResponse{
|
||||||
|
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Choices: []OpenAITextResponseChoice{choice},
|
||||||
|
}
|
||||||
|
return &fullTextResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
||||||
|
responseText := ""
|
||||||
|
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||||
|
createdTime := common.GetTimestamp()
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
|
||||||
|
return i + 4, data[0:i], nil
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
dataChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if !strings.HasPrefix(data, "event: completion") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
|
||||||
|
dataChan <- data
|
||||||
|
}
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
// some implementations may add \r at the end of data
|
||||||
|
data = strings.TrimSuffix(data, "\r")
|
||||||
|
var claudeResponse ClaudeResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &claudeResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
responseText += claudeResponse.Completion
|
||||||
|
response := streamResponseClaude2OpenAI(&claudeResponse)
|
||||||
|
response.Id = responseId
|
||||||
|
response.Created = createdTime
|
||||||
|
jsonStr, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||||
|
}
|
||||||
|
return nil, responseText
|
||||||
|
}
|
||||||
|
|
||||||
|
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
var claudeResponse ClaudeResponse
|
||||||
|
err = json.Unmarshal(responseBody, &claudeResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
if claudeResponse.Error.Type != "" {
|
||||||
|
return &OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: OpenAIError{
|
||||||
|
Message: claudeResponse.Error.Message,
|
||||||
|
Type: claudeResponse.Error.Type,
|
||||||
|
Param: "",
|
||||||
|
Code: claudeResponse.Error.Type,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
|
||||||
|
completionTokens := countTokenText(claudeResponse.Completion, model)
|
||||||
|
usage := Usage{
|
||||||
|
PromptTokens: promptTokens,
|
||||||
|
CompletionTokens: completionTokens,
|
||||||
|
TotalTokens: promptTokens + completionTokens,
|
||||||
|
}
|
||||||
|
fullTextResponse.Usage = usage
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
@@ -1,34 +1,181 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||||
// TODO: this part is not finished
|
imageModel := "dall-e"
|
||||||
req, err := http.NewRequest(c.Request.Method, c.Request.RequestURI, c.Request.Body)
|
|
||||||
|
tokenId := c.GetInt("token_id")
|
||||||
|
channelType := c.GetInt("channel")
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
consumeQuota := c.GetBool("consume_quota")
|
||||||
|
group := c.GetString("group")
|
||||||
|
|
||||||
|
var imageRequest ImageRequest
|
||||||
|
if consumeQuota {
|
||||||
|
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prompt validation
|
||||||
|
if imageRequest.Prompt == "" {
|
||||||
|
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not "256x256", "512x512", or "1024x1024"
|
||||||
|
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
|
||||||
|
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// N should between 1 and 10
|
||||||
|
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
|
||||||
|
return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// map model name
|
||||||
|
modelMapping := c.GetString("model_mapping")
|
||||||
|
isModelMapped := false
|
||||||
|
if modelMapping != "" {
|
||||||
|
modelMap := make(map[string]string)
|
||||||
|
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if modelMap[imageModel] != "" {
|
||||||
|
imageModel = modelMap[imageModel]
|
||||||
|
isModelMapped = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := common.ChannelBaseURLs[channelType]
|
||||||
|
requestURL := c.Request.URL.String()
|
||||||
|
|
||||||
|
if c.GetString("base_url") != "" {
|
||||||
|
baseURL = c.GetString("base_url")
|
||||||
|
}
|
||||||
|
|
||||||
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
|
|
||||||
|
var requestBody io.Reader
|
||||||
|
if isModelMapped {
|
||||||
|
jsonStr, err := json.Marshal(imageRequest)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
|
} else {
|
||||||
|
requestBody = c.Request.Body
|
||||||
|
}
|
||||||
|
|
||||||
|
modelRatio := common.GetModelRatio(imageModel)
|
||||||
|
groupRatio := common.GetGroupRatio(group)
|
||||||
|
ratio := modelRatio * groupRatio
|
||||||
|
userQuota, err := model.CacheGetUserQuota(userId)
|
||||||
|
|
||||||
|
sizeRatio := 1.0
|
||||||
|
// Size
|
||||||
|
if imageRequest.Size == "256x256" {
|
||||||
|
sizeRatio = 1
|
||||||
|
} else if imageRequest.Size == "512x512" {
|
||||||
|
sizeRatio = 1.125
|
||||||
|
} else if imageRequest.Size == "1024x1024" {
|
||||||
|
sizeRatio = 1.25
|
||||||
|
}
|
||||||
|
quota := int(ratio*sizeRatio*1000) * imageRequest.N
|
||||||
|
|
||||||
|
if consumeQuota && userQuota-quota < 0 {
|
||||||
|
return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "do_request_failed", http.StatusOK)
|
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = req.Body.Close()
|
err = req.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusOK)
|
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
err = c.Request.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
var textResponse ImageResponse
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if consumeQuota {
|
||||||
|
err := model.PostConsumeTokenQuota(tokenId, quota)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error consuming token remain quota: " + err.Error())
|
||||||
|
}
|
||||||
|
err = model.CacheUpdateUserQuota(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error update user quota cache: " + err.Error())
|
||||||
|
}
|
||||||
|
if quota != 0 {
|
||||||
|
tokenName := c.GetString("token_name")
|
||||||
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
|
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
|
||||||
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if consumeQuota {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &textResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
}
|
||||||
|
|
||||||
for k, v := range resp.Header {
|
for k, v := range resp.Header {
|
||||||
c.Writer.Header().Set(k, v[0])
|
c.Writer.Header().Set(k, v[0])
|
||||||
}
|
}
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
_, err = io.Copy(c.Writer, resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "copy_response_body_failed", http.StatusOK)
|
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusOK)
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
133
controller/relay-openai.go
Normal file
133
controller/relay-openai.go
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
|
||||||
|
responseText := ""
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||||
|
return i + 1, data[0:i], nil
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
dataChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if len(data) < 6 { // ignore blank line or wrong format
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dataChan <- data
|
||||||
|
data = data[6:]
|
||||||
|
if !strings.HasPrefix(data, "[DONE]") {
|
||||||
|
switch relayMode {
|
||||||
|
case RelayModeChatCompletions:
|
||||||
|
var streamResponse ChatCompletionsStreamResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &streamResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, choice := range streamResponse.Choices {
|
||||||
|
responseText += choice.Delta.Content
|
||||||
|
}
|
||||||
|
case RelayModeCompletions:
|
||||||
|
var streamResponse CompletionsStreamResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &streamResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, choice := range streamResponse.Choices {
|
||||||
|
responseText += choice.Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
if strings.HasPrefix(data, "data: [DONE]") {
|
||||||
|
data = data[:12]
|
||||||
|
}
|
||||||
|
// some implementations may add \r at the end of data
|
||||||
|
data = strings.TrimSuffix(data, "\r")
|
||||||
|
c.Render(-1, common.CustomEvent{Data: data})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||||
|
}
|
||||||
|
return nil, responseText
|
||||||
|
}
|
||||||
|
|
||||||
|
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
var textResponse TextResponse
|
||||||
|
if consumeQuota {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &textResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
if textResponse.Error.Type != "" {
|
||||||
|
return &OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: textResponse.Error,
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
// Reset response body
|
||||||
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
}
|
||||||
|
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||||
|
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||||
|
// So the client will be confused by the response.
|
||||||
|
// For example, Postman will report error, and we cannot check the response at all.
|
||||||
|
for k, v := range resp.Header {
|
||||||
|
c.Writer.Header().Set(k, v[0])
|
||||||
|
}
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err := io.Copy(c.Writer, resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
return nil, &textResponse.Usage
|
||||||
|
}
|
||||||
@@ -1,10 +1,17 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
||||||
|
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
||||||
|
|
||||||
type PaLMChatMessage struct {
|
type PaLMChatMessage struct {
|
||||||
Author string `json:"author"`
|
Author string `json:"author"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
@@ -15,45 +22,188 @@ type PaLMFilter struct {
|
|||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
type PaLMPrompt struct {
|
||||||
|
Messages []PaLMChatMessage `json:"messages"`
|
||||||
|
}
|
||||||
|
|
||||||
type PaLMChatRequest struct {
|
type PaLMChatRequest struct {
|
||||||
Prompt []Message `json:"prompt"`
|
Prompt PaLMPrompt `json:"prompt"`
|
||||||
Temperature float64 `json:"temperature"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
CandidateCount int `json:"candidateCount"`
|
CandidateCount int `json:"candidateCount,omitempty"`
|
||||||
TopP float64 `json:"topP"`
|
TopP float64 `json:"topP,omitempty"`
|
||||||
TopK int `json:"topK"`
|
TopK int `json:"topK,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMError struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Status string `json:"status"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
|
||||||
type PaLMChatResponse struct {
|
type PaLMChatResponse struct {
|
||||||
Candidates []Message `json:"candidates"`
|
Candidates []PaLMChatMessage `json:"candidates"`
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
Filters []PaLMFilter `json:"filters"`
|
Filters []PaLMFilter `json:"filters"`
|
||||||
|
Error PaLMError `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func relayPaLM(openAIRequest GeneralOpenAIRequest, c *gin.Context) *OpenAIErrorWithStatusCode {
|
func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
|
||||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage
|
palmRequest := PaLMChatRequest{
|
||||||
messages := make([]PaLMChatMessage, 0, len(openAIRequest.Messages))
|
Prompt: PaLMPrompt{
|
||||||
for _, message := range openAIRequest.Messages {
|
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
|
||||||
var author string
|
},
|
||||||
if message.Role == "user" {
|
Temperature: textRequest.Temperature,
|
||||||
author = "0"
|
CandidateCount: textRequest.N,
|
||||||
} else {
|
TopP: textRequest.TopP,
|
||||||
author = "1"
|
TopK: textRequest.MaxTokens,
|
||||||
}
|
}
|
||||||
messages = append(messages, PaLMChatMessage{
|
for _, message := range textRequest.Messages {
|
||||||
Author: author,
|
palmMessage := PaLMChatMessage{
|
||||||
Content: message.Content,
|
Content: message.Content,
|
||||||
})
|
}
|
||||||
|
if message.Role == "user" {
|
||||||
|
palmMessage.Author = "0"
|
||||||
|
} else {
|
||||||
|
palmMessage.Author = "1"
|
||||||
|
}
|
||||||
|
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
|
||||||
}
|
}
|
||||||
request := PaLMChatRequest{
|
return &palmRequest
|
||||||
Prompt: nil,
|
}
|
||||||
Temperature: openAIRequest.Temperature,
|
|
||||||
CandidateCount: openAIRequest.N,
|
func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
|
||||||
TopP: openAIRequest.TopP,
|
fullTextResponse := OpenAITextResponse{
|
||||||
TopK: openAIRequest.MaxTokens,
|
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||||
}
|
}
|
||||||
// TODO: forward request to PaLM & convert response
|
for i, candidate := range response.Candidates {
|
||||||
fmt.Print(request)
|
choice := OpenAITextResponseChoice{
|
||||||
return nil
|
Index: i,
|
||||||
|
Message: Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: candidate.Content,
|
||||||
|
},
|
||||||
|
FinishReason: "stop",
|
||||||
|
}
|
||||||
|
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||||
|
}
|
||||||
|
return &fullTextResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse {
|
||||||
|
var choice ChatCompletionsStreamResponseChoice
|
||||||
|
if len(palmResponse.Candidates) > 0 {
|
||||||
|
choice.Delta.Content = palmResponse.Candidates[0].Content
|
||||||
|
}
|
||||||
|
choice.FinishReason = "stop"
|
||||||
|
var response ChatCompletionsStreamResponse
|
||||||
|
response.Object = "chat.completion.chunk"
|
||||||
|
response.Model = "palm2"
|
||||||
|
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
|
||||||
|
return &response
|
||||||
|
}
|
||||||
|
|
||||||
|
func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
||||||
|
responseText := ""
|
||||||
|
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||||
|
createdTime := common.GetTimestamp()
|
||||||
|
dataChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error reading stream response: " + err.Error())
|
||||||
|
stopChan <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error closing stream response: " + err.Error())
|
||||||
|
stopChan <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var palmResponse PaLMChatResponse
|
||||||
|
err = json.Unmarshal(responseBody, &palmResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
stopChan <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
|
||||||
|
fullTextResponse.Id = responseId
|
||||||
|
fullTextResponse.Created = createdTime
|
||||||
|
if len(palmResponse.Candidates) > 0 {
|
||||||
|
responseText = palmResponse.Candidates[0].Content
|
||||||
|
}
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
stopChan <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dataChan <- string(jsonResponse)
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + data})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||||
|
}
|
||||||
|
return nil, responseText
|
||||||
|
}
|
||||||
|
|
||||||
|
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
var palmResponse PaLMChatResponse
|
||||||
|
err = json.Unmarshal(responseBody, &palmResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||||
|
return &OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: OpenAIError{
|
||||||
|
Message: palmResponse.Error.Message,
|
||||||
|
Type: palmResponse.Error.Status,
|
||||||
|
Param: "",
|
||||||
|
Code: palmResponse.Error.Code,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||||
|
completionTokens := countTokenText(palmResponse.Candidates[0].Content, model)
|
||||||
|
usage := Usage{
|
||||||
|
PromptTokens: promptTokens,
|
||||||
|
CompletionTokens: completionTokens,
|
||||||
|
TotalTokens: promptTokens + completionTokens,
|
||||||
|
}
|
||||||
|
fullTextResponse.Usage = usage
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,17 +1,24 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
APITypeOpenAI = iota
|
||||||
|
APITypeClaude
|
||||||
|
APITypePaLM
|
||||||
|
APITypeBaidu
|
||||||
)
|
)
|
||||||
|
|
||||||
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||||
@@ -30,6 +37,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
if relayMode == RelayModeModerations && textRequest.Model == "" {
|
if relayMode == RelayModeModerations && textRequest.Model == "" {
|
||||||
textRequest.Model = "text-moderation-latest"
|
textRequest.Model = "text-moderation-latest"
|
||||||
}
|
}
|
||||||
|
if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
|
||||||
|
textRequest.Model = c.Param("model")
|
||||||
|
}
|
||||||
// request validation
|
// request validation
|
||||||
if textRequest.Model == "" {
|
if textRequest.Model == "" {
|
||||||
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
|
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
|
||||||
@@ -67,33 +77,63 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
isModelMapped = true
|
isModelMapped = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
apiType := APITypeOpenAI
|
||||||
|
if strings.HasPrefix(textRequest.Model, "claude") {
|
||||||
|
apiType = APITypeClaude
|
||||||
|
} else if strings.HasPrefix(textRequest.Model, "ERNIE") {
|
||||||
|
apiType = APITypeBaidu
|
||||||
|
} else if strings.HasPrefix(textRequest.Model, "PaLM") {
|
||||||
|
apiType = APITypePaLM
|
||||||
|
}
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
baseURL := common.ChannelBaseURLs[channelType]
|
||||||
requestURL := c.Request.URL.String()
|
requestURL := c.Request.URL.String()
|
||||||
if c.GetString("base_url") != "" {
|
if c.GetString("base_url") != "" {
|
||||||
baseURL = c.GetString("base_url")
|
baseURL = c.GetString("base_url")
|
||||||
}
|
}
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
if channelType == common.ChannelTypeAzure {
|
switch apiType {
|
||||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
case APITypeOpenAI:
|
||||||
query := c.Request.URL.Query()
|
if channelType == common.ChannelTypeAzure {
|
||||||
apiVersion := query.Get("api-version")
|
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||||
if apiVersion == "" {
|
query := c.Request.URL.Query()
|
||||||
apiVersion = c.GetString("api_version")
|
apiVersion := query.Get("api-version")
|
||||||
|
if apiVersion == "" {
|
||||||
|
apiVersion = c.GetString("api_version")
|
||||||
|
}
|
||||||
|
requestURL := strings.Split(requestURL, "?")[0]
|
||||||
|
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
||||||
|
baseURL = c.GetString("base_url")
|
||||||
|
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||||
|
model_ := textRequest.Model
|
||||||
|
model_ = strings.Replace(model_, ".", "", -1)
|
||||||
|
// https://github.com/songquanpeng/one-api/issues/67
|
||||||
|
model_ = strings.TrimSuffix(model_, "-0301")
|
||||||
|
model_ = strings.TrimSuffix(model_, "-0314")
|
||||||
|
model_ = strings.TrimSuffix(model_, "-0613")
|
||||||
|
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
|
||||||
}
|
}
|
||||||
requestURL := strings.Split(requestURL, "?")[0]
|
case APITypeClaude:
|
||||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
fullRequestURL = "https://api.anthropic.com/v1/complete"
|
||||||
baseURL = c.GetString("base_url")
|
if baseURL != "" {
|
||||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
|
||||||
model_ := textRequest.Model
|
}
|
||||||
model_ = strings.Replace(model_, ".", "", -1)
|
case APITypeBaidu:
|
||||||
// https://github.com/songquanpeng/one-api/issues/67
|
switch textRequest.Model {
|
||||||
model_ = strings.TrimSuffix(model_, "-0301")
|
case "ERNIE-Bot":
|
||||||
model_ = strings.TrimSuffix(model_, "-0314")
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
||||||
model_ = strings.TrimSuffix(model_, "-0613")
|
case "ERNIE-Bot-turbo":
|
||||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||||
} else if channelType == common.ChannelTypePaLM {
|
case "BLOOMZ-7B":
|
||||||
err := relayPaLM(textRequest, c)
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
||||||
return err
|
}
|
||||||
|
apiKey := c.Request.Header.Get("Authorization")
|
||||||
|
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||||
|
fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days
|
||||||
|
case APITypePaLM:
|
||||||
|
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
|
||||||
|
apiKey := c.Request.Header.Get("Authorization")
|
||||||
|
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||||
|
fullRequestURL += "?key=" + apiKey
|
||||||
}
|
}
|
||||||
var promptTokens int
|
var promptTokens int
|
||||||
var completionTokens int
|
var completionTokens int
|
||||||
@@ -138,16 +178,49 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
} else {
|
} else {
|
||||||
requestBody = c.Request.Body
|
requestBody = c.Request.Body
|
||||||
}
|
}
|
||||||
|
switch apiType {
|
||||||
|
case APITypeClaude:
|
||||||
|
claudeRequest := requestOpenAI2Claude(textRequest)
|
||||||
|
jsonStr, err := json.Marshal(claudeRequest)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
|
case APITypeBaidu:
|
||||||
|
baiduRequest := requestOpenAI2Baidu(textRequest)
|
||||||
|
jsonStr, err := json.Marshal(baiduRequest)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
|
case APITypePaLM:
|
||||||
|
palmRequest := requestOpenAI2PaLM(textRequest)
|
||||||
|
jsonStr, err := json.Marshal(palmRequest)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
|
}
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if channelType == common.ChannelTypeAzure {
|
apiKey := c.Request.Header.Get("Authorization")
|
||||||
key := c.Request.Header.Get("Authorization")
|
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||||
key = strings.TrimPrefix(key, "Bearer ")
|
switch apiType {
|
||||||
req.Header.Set("api-key", key)
|
case APITypeOpenAI:
|
||||||
} else {
|
if channelType == common.ChannelTypeAzure {
|
||||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
req.Header.Set("api-key", apiKey)
|
||||||
|
} else {
|
||||||
|
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||||
|
}
|
||||||
|
case APITypeClaude:
|
||||||
|
req.Header.Set("x-api-key", apiKey)
|
||||||
|
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||||
|
if anthropicVersion == "" {
|
||||||
|
anthropicVersion = "2023-06-01"
|
||||||
|
}
|
||||||
|
req.Header.Set("anthropic-version", anthropicVersion)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
@@ -179,7 +252,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
if strings.HasPrefix(textRequest.Model, "gpt-4") {
|
if strings.HasPrefix(textRequest.Model, "gpt-4") {
|
||||||
completionRatio = 2
|
completionRatio = 2
|
||||||
}
|
}
|
||||||
if isStream {
|
if isStream && apiType != APITypeBaidu {
|
||||||
completionTokens = countTokenText(streamResponseText, textRequest.Model)
|
completionTokens = countTokenText(streamResponseText, textRequest.Model)
|
||||||
} else {
|
} else {
|
||||||
promptTokens = textResponse.Usage.PromptTokens
|
promptTokens = textResponse.Usage.PromptTokens
|
||||||
@@ -215,123 +288,72 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
switch apiType {
|
||||||
if isStream {
|
case APITypeOpenAI:
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
if isStream {
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
err, responseText := openaiStreamHandler(c, resp, relayMode)
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if i := strings.Index(string(data), "\n\n"); i >= 0 {
|
|
||||||
return i + 2, data[0:i], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if len(data) < 6 { // must be something wrong!
|
|
||||||
common.SysError("invalid stream response: " + data)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
dataChan <- data
|
|
||||||
data = data[6:]
|
|
||||||
if !strings.HasPrefix(data, "[DONE]") {
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeChatCompletions:
|
|
||||||
var streamResponse ChatCompletionsStreamResponse
|
|
||||||
err = json.Unmarshal([]byte(data), &streamResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, choice := range streamResponse.Choices {
|
|
||||||
streamResponseText += choice.Delta.Content
|
|
||||||
}
|
|
||||||
case RelayModeCompletions:
|
|
||||||
var streamResponse CompletionsStreamResponse
|
|
||||||
err = json.Unmarshal([]byte(data), &streamResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, choice := range streamResponse.Choices {
|
|
||||||
streamResponseText += choice.Text
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
|
||||||
c.Writer.Header().Set("Connection", "keep-alive")
|
|
||||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
|
||||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
if strings.HasPrefix(data, "data: [DONE]") {
|
|
||||||
data = data[:12]
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: data})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
if consumeQuota {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
return err
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
streamResponseText = responseText
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
err, usage := openaiHandler(c, resp, consumeQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
return err
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
textResponse.Usage = *usage
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case APITypeClaude:
|
||||||
|
if isStream {
|
||||||
|
err, responseText := claudeStreamHandler(c, resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
return err
|
||||||
}
|
}
|
||||||
if textResponse.Error.Type != "" {
|
streamResponseText = responseText
|
||||||
return &OpenAIErrorWithStatusCode{
|
return nil
|
||||||
OpenAIError: textResponse.Error,
|
} else {
|
||||||
StatusCode: resp.StatusCode,
|
err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
|
||||||
}
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
// Reset response body
|
textResponse.Usage = *usage
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
return nil
|
||||||
}
|
}
|
||||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
case APITypeBaidu:
|
||||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
if isStream {
|
||||||
// So the client will be confused by the response.
|
err, usage := baiduStreamHandler(c, resp)
|
||||||
// For example, Postman will report error, and we cannot check the response at all.
|
if err != nil {
|
||||||
for k, v := range resp.Header {
|
return err
|
||||||
c.Writer.Header().Set(k, v[0])
|
}
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
err, usage := baiduHandler(c, resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
case APITypePaLM:
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
if textRequest.Stream { // PaLM2 API does not support stream
|
||||||
if err != nil {
|
err, responseText := palmStreamHandler(c, resp)
|
||||||
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
streamResponseText = responseText
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
default:
|
||||||
if err != nil {
|
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -91,3 +91,16 @@ func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatus
|
|||||||
StatusCode: statusCode,
|
StatusCode: statusCode,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldDisableChannel(err *OpenAIError) bool {
|
||||||
|
if !common.AutomaticDisableChannelEnabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,10 +2,12 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
@@ -37,6 +39,7 @@ type GeneralOpenAIRequest struct {
|
|||||||
N int `json:"n,omitempty"`
|
N int `json:"n,omitempty"`
|
||||||
Input any `json:"input,omitempty"`
|
Input any `json:"input,omitempty"`
|
||||||
Instruction string `json:"instruction,omitempty"`
|
Instruction string `json:"instruction,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatRequest struct {
|
type ChatRequest struct {
|
||||||
@@ -53,6 +56,12 @@ type TextRequest struct {
|
|||||||
//Stream bool `json:"stream"`
|
//Stream bool `json:"stream"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ImageRequest struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
N int `json:"n"`
|
||||||
|
Size string `json:"size"`
|
||||||
|
}
|
||||||
|
|
||||||
type Usage struct {
|
type Usage struct {
|
||||||
PromptTokens int `json:"prompt_tokens"`
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
CompletionTokens int `json:"completion_tokens"`
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
@@ -76,13 +85,40 @@ type TextResponse struct {
|
|||||||
Error OpenAIError `json:"error"`
|
Error OpenAIError `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OpenAITextResponseChoice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Message `json:"message"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAITextResponse struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||||
|
Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageResponse struct {
|
||||||
|
Created int `json:"created"`
|
||||||
|
Data []struct {
|
||||||
|
Url string `json:"url"`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionsStreamResponseChoice struct {
|
||||||
|
Delta struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"delta"`
|
||||||
|
FinishReason string `json:"finish_reason,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type ChatCompletionsStreamResponse struct {
|
type ChatCompletionsStreamResponse struct {
|
||||||
Choices []struct {
|
Id string `json:"id"`
|
||||||
Delta struct {
|
Object string `json:"object"`
|
||||||
Content string `json:"content"`
|
Created int64 `json:"created"`
|
||||||
} `json:"delta"`
|
Model string `json:"model"`
|
||||||
FinishReason string `json:"finish_reason"`
|
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||||
} `json:"choices"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompletionsStreamResponse struct {
|
type CompletionsStreamResponse struct {
|
||||||
@@ -100,6 +136,8 @@ func Relay(c *gin.Context) {
|
|||||||
relayMode = RelayModeCompletions
|
relayMode = RelayModeCompletions
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
|
||||||
relayMode = RelayModeEmbeddings
|
relayMode = RelayModeEmbeddings
|
||||||
|
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||||
|
relayMode = RelayModeEmbeddings
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||||
relayMode = RelayModeModerations
|
relayMode = RelayModeModerations
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||||
@@ -115,16 +153,25 @@ func Relay(c *gin.Context) {
|
|||||||
err = relayTextHelper(c, relayMode)
|
err = relayTextHelper(c, relayMode)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.StatusCode == http.StatusTooManyRequests {
|
retryTimesStr := c.Query("retry")
|
||||||
err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
||||||
|
if retryTimesStr == "" {
|
||||||
|
retryTimes = common.RetryTimes
|
||||||
|
}
|
||||||
|
if retryTimes > 0 {
|
||||||
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
|
||||||
|
} else {
|
||||||
|
if err.StatusCode == http.StatusTooManyRequests {
|
||||||
|
err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
||||||
|
}
|
||||||
|
c.JSON(err.StatusCode, gin.H{
|
||||||
|
"error": err.OpenAIError,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
c.JSON(err.StatusCode, gin.H{
|
|
||||||
"error": err.OpenAIError,
|
|
||||||
})
|
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
||||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||||
if common.AutomaticDisableChannelEnabled && (err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated") {
|
if shouldDisableChannel(&err.OpenAIError) {
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
channelName := c.GetString("channel_name")
|
channelName := c.GetString("channel_name")
|
||||||
disableChannel(channelId, channelName, err.Message)
|
disableChannel(channelId, channelName, err.Message)
|
||||||
|
|||||||
@@ -2,12 +2,13 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ModelRequest struct {
|
type ModelRequest struct {
|
||||||
@@ -73,6 +74,16 @@ func Distribute() func(c *gin.Context) {
|
|||||||
modelRequest.Model = "text-moderation-stable"
|
modelRequest.Model = "text-moderation-stable"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||||
|
if modelRequest.Model == "" {
|
||||||
|
modelRequest.Model = c.Param("model")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||||
|
if modelRequest.Model == "" {
|
||||||
|
modelRequest.Model = "dall-e"
|
||||||
|
}
|
||||||
|
}
|
||||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
message := "无可用渠道"
|
message := "无可用渠道"
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||||
common.OptionMap["ChatLink"] = common.ChatLink
|
common.OptionMap["ChatLink"] = common.ChatLink
|
||||||
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
|
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
|
||||||
|
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
|
||||||
common.OptionMapRWMutex.Unlock()
|
common.OptionMapRWMutex.Unlock()
|
||||||
loadOptionsFromDatabase()
|
loadOptionsFromDatabase()
|
||||||
}
|
}
|
||||||
@@ -196,6 +197,8 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
|
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
|
||||||
case "PreConsumedQuota":
|
case "PreConsumedQuota":
|
||||||
common.PreConsumedQuota, _ = strconv.Atoi(value)
|
common.PreConsumedQuota, _ = strconv.Atoi(value)
|
||||||
|
case "RetryTimes":
|
||||||
|
common.RetryTimes, _ = strconv.Atoi(value)
|
||||||
case "ModelRatio":
|
case "ModelRatio":
|
||||||
err = common.UpdateModelRatioByJSONString(value)
|
err = common.UpdateModelRatioByJSONString(value)
|
||||||
case "GroupRatio":
|
case "GroupRatio":
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
package router
|
package router
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-contrib/gzip"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"one-api/controller"
|
"one-api/controller"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
|
|
||||||
|
"github.com/gin-contrib/gzip"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetApiRouter(router *gin.Engine) {
|
func SetApiRouter(router *gin.Engine) {
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
package router
|
package router
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"one-api/controller"
|
"one-api/controller"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetRelayRouter(router *gin.Engine) {
|
func SetRelayRouter(router *gin.Engine) {
|
||||||
@@ -20,10 +21,11 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
relayV1Router.POST("/completions", controller.Relay)
|
relayV1Router.POST("/completions", controller.Relay)
|
||||||
relayV1Router.POST("/chat/completions", controller.Relay)
|
relayV1Router.POST("/chat/completions", controller.Relay)
|
||||||
relayV1Router.POST("/edits", controller.Relay)
|
relayV1Router.POST("/edits", controller.Relay)
|
||||||
relayV1Router.POST("/images/generations", controller.RelayNotImplemented)
|
relayV1Router.POST("/images/generations", controller.Relay)
|
||||||
relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
|
relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
|
||||||
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
|
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
|
||||||
relayV1Router.POST("/embeddings", controller.Relay)
|
relayV1Router.POST("/embeddings", controller.Relay)
|
||||||
|
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
|
||||||
relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented)
|
relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented)
|
||||||
relayV1Router.POST("/audio/translations", controller.RelayNotImplemented)
|
relayV1Router.POST("/audio/translations", controller.RelayNotImplemented)
|
||||||
relayV1Router.GET("/files", controller.RelayNotImplemented)
|
relayV1Router.GET("/files", controller.RelayNotImplemented)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import {
|
|||||||
} from 'semantic-ui-react';
|
} from 'semantic-ui-react';
|
||||||
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
|
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
|
||||||
import { UserContext } from '../context/User';
|
import { UserContext } from '../context/User';
|
||||||
import { API, getLogo, showError, showSuccess } from '../helpers';
|
import { API, getLogo, showError, showSuccess, showInfo } from '../helpers';
|
||||||
|
|
||||||
const LoginForm = () => {
|
const LoginForm = () => {
|
||||||
const [inputs, setInputs] = useState({
|
const [inputs, setInputs] = useState({
|
||||||
@@ -76,7 +76,7 @@ const LoginForm = () => {
|
|||||||
async function handleSubmit(e) {
|
async function handleSubmit(e) {
|
||||||
setSubmitted(true);
|
setSubmitted(true);
|
||||||
if (username && password) {
|
if (username && password) {
|
||||||
const res = await API.post('/api/user/login', {
|
const res = await API.post(`/api/user/login`, {
|
||||||
username,
|
username,
|
||||||
password,
|
password,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ const OperationSetting = () => {
|
|||||||
DisplayInCurrencyEnabled: '',
|
DisplayInCurrencyEnabled: '',
|
||||||
DisplayTokenStatEnabled: '',
|
DisplayTokenStatEnabled: '',
|
||||||
ApproximateTokenEnabled: '',
|
ApproximateTokenEnabled: '',
|
||||||
|
RetryTimes: 0,
|
||||||
});
|
});
|
||||||
const [originInputs, setOriginInputs] = useState({});
|
const [originInputs, setOriginInputs] = useState({});
|
||||||
let [loading, setLoading] = useState(false);
|
let [loading, setLoading] = useState(false);
|
||||||
@@ -122,6 +123,9 @@ const OperationSetting = () => {
|
|||||||
if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) {
|
if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) {
|
||||||
await updateOption('QuotaPerUnit', inputs.QuotaPerUnit);
|
await updateOption('QuotaPerUnit', inputs.QuotaPerUnit);
|
||||||
}
|
}
|
||||||
|
if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
|
||||||
|
await updateOption('RetryTimes', inputs.RetryTimes);
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -133,7 +137,7 @@ const OperationSetting = () => {
|
|||||||
<Header as='h3'>
|
<Header as='h3'>
|
||||||
通用设置
|
通用设置
|
||||||
</Header>
|
</Header>
|
||||||
<Form.Group widths={3}>
|
<Form.Group widths={4}>
|
||||||
<Form.Input
|
<Form.Input
|
||||||
label='充值链接'
|
label='充值链接'
|
||||||
name='TopUpLink'
|
name='TopUpLink'
|
||||||
@@ -162,6 +166,17 @@ const OperationSetting = () => {
|
|||||||
step='0.01'
|
step='0.01'
|
||||||
placeholder='一单位货币能兑换的额度'
|
placeholder='一单位货币能兑换的额度'
|
||||||
/>
|
/>
|
||||||
|
<Form.Input
|
||||||
|
label='失败重试次数'
|
||||||
|
name='RetryTimes'
|
||||||
|
type={'number'}
|
||||||
|
step='1'
|
||||||
|
min='0'
|
||||||
|
onChange={handleInputChange}
|
||||||
|
autoComplete='new-password'
|
||||||
|
value={inputs.RetryTimes}
|
||||||
|
placeholder='失败重试次数'
|
||||||
|
/>
|
||||||
</Form.Group>
|
</Form.Group>
|
||||||
<Form.Group inline>
|
<Form.Group inline>
|
||||||
<Form.Checkbox
|
<Form.Checkbox
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
export const CHANNEL_OPTIONS = [
|
export const CHANNEL_OPTIONS = [
|
||||||
{ key: 1, text: 'OpenAI', value: 1, color: 'green' },
|
{ key: 1, text: 'OpenAI', value: 1, color: 'green' },
|
||||||
|
{ key: 14, text: 'Anthropic', value: 14, color: 'black' },
|
||||||
{ key: 8, text: '自定义', value: 8, color: 'pink' },
|
{ key: 8, text: '自定义', value: 8, color: 'pink' },
|
||||||
{ key: 3, text: 'Azure', value: 3, color: 'olive' },
|
{ key: 3, text: 'Azure', value: 3, color: 'olive' },
|
||||||
|
{ key: 11, text: 'PaLM', value: 11, color: 'orange' },
|
||||||
|
{ key: 15, text: 'Baidu', value: 15, color: 'blue' },
|
||||||
{ key: 2, text: 'API2D', value: 2, color: 'blue' },
|
{ key: 2, text: 'API2D', value: 2, color: 'blue' },
|
||||||
{ key: 4, text: 'CloseAI', value: 4, color: 'teal' },
|
{ key: 4, text: 'CloseAI', value: 4, color: 'teal' },
|
||||||
{ key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' },
|
{ key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' },
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import React, { useEffect, useState } from 'react';
|
import React, { useEffect, useState } from 'react';
|
||||||
import { Button, Form, Header, Message, Segment } from 'semantic-ui-react';
|
import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react';
|
||||||
import { useParams } from 'react-router-dom';
|
import { useParams } from 'react-router-dom';
|
||||||
import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers';
|
import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers';
|
||||||
import { CHANNEL_OPTIONS } from '../../constants';
|
import { CHANNEL_OPTIONS } from '../../constants';
|
||||||
@@ -27,10 +27,12 @@ const EditChannel = () => {
|
|||||||
};
|
};
|
||||||
const [batch, setBatch] = useState(false);
|
const [batch, setBatch] = useState(false);
|
||||||
const [inputs, setInputs] = useState(originInputs);
|
const [inputs, setInputs] = useState(originInputs);
|
||||||
|
const [originModelOptions, setOriginModelOptions] = useState([]);
|
||||||
const [modelOptions, setModelOptions] = useState([]);
|
const [modelOptions, setModelOptions] = useState([]);
|
||||||
const [groupOptions, setGroupOptions] = useState([]);
|
const [groupOptions, setGroupOptions] = useState([]);
|
||||||
const [basicModels, setBasicModels] = useState([]);
|
const [basicModels, setBasicModels] = useState([]);
|
||||||
const [fullModels, setFullModels] = useState([]);
|
const [fullModels, setFullModels] = useState([]);
|
||||||
|
const [customModel, setCustomModel] = useState('');
|
||||||
const handleInputChange = (e, { name, value }) => {
|
const handleInputChange = (e, { name, value }) => {
|
||||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||||
};
|
};
|
||||||
@@ -62,13 +64,16 @@ const EditChannel = () => {
|
|||||||
const fetchModels = async () => {
|
const fetchModels = async () => {
|
||||||
try {
|
try {
|
||||||
let res = await API.get(`/api/channel/models`);
|
let res = await API.get(`/api/channel/models`);
|
||||||
setModelOptions(res.data.data.map((model) => ({
|
let localModelOptions = res.data.data.map((model) => ({
|
||||||
key: model.id,
|
key: model.id,
|
||||||
text: model.id,
|
text: model.id,
|
||||||
value: model.id
|
value: model.id
|
||||||
})));
|
}));
|
||||||
|
setOriginModelOptions(localModelOptions);
|
||||||
setFullModels(res.data.data.map((model) => model.id));
|
setFullModels(res.data.data.map((model) => model.id));
|
||||||
setBasicModels(res.data.data.filter((model) => !model.id.startsWith('gpt-4')).map((model) => model.id));
|
setBasicModels(res.data.data.filter((model) => {
|
||||||
|
return model.id.startsWith('gpt-3') || model.id.startsWith('text-');
|
||||||
|
}).map((model) => model.id));
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
showError(error.message);
|
showError(error.message);
|
||||||
}
|
}
|
||||||
@@ -87,6 +92,20 @@ const EditChannel = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
let localModelOptions = [...originModelOptions];
|
||||||
|
inputs.models.forEach((model) => {
|
||||||
|
if (!localModelOptions.find((option) => option.key === model)) {
|
||||||
|
localModelOptions.push({
|
||||||
|
key: model,
|
||||||
|
text: model,
|
||||||
|
value: model
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
setModelOptions(localModelOptions);
|
||||||
|
}, [originModelOptions, inputs.models]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (isEdit) {
|
if (isEdit) {
|
||||||
loadChannel().then();
|
loadChannel().then();
|
||||||
@@ -263,6 +282,32 @@ const EditChannel = () => {
|
|||||||
<Button type={'button'} onClick={() => {
|
<Button type={'button'} onClick={() => {
|
||||||
handleInputChange(null, { name: 'models', value: [] });
|
handleInputChange(null, { name: 'models', value: [] });
|
||||||
}}>清除所有模型</Button>
|
}}>清除所有模型</Button>
|
||||||
|
<Input
|
||||||
|
action={
|
||||||
|
<Button type={'button'} onClick={()=>{
|
||||||
|
if (customModel.trim() === "") return;
|
||||||
|
if (inputs.models.includes(customModel)) return;
|
||||||
|
let localModels = [...inputs.models];
|
||||||
|
localModels.push(customModel);
|
||||||
|
let localModelOptions = [];
|
||||||
|
localModelOptions.push({
|
||||||
|
key: customModel,
|
||||||
|
text: customModel,
|
||||||
|
value: customModel,
|
||||||
|
});
|
||||||
|
setModelOptions(modelOptions=>{
|
||||||
|
return [...modelOptions, ...localModelOptions];
|
||||||
|
});
|
||||||
|
setCustomModel('');
|
||||||
|
handleInputChange(null, { name: 'models', value: localModels });
|
||||||
|
}}>填入</Button>
|
||||||
|
}
|
||||||
|
placeholder='输入自定义模型名称'
|
||||||
|
value={customModel}
|
||||||
|
onChange={(e, { value }) => {
|
||||||
|
setCustomModel(value);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.TextArea
|
<Form.TextArea
|
||||||
@@ -292,7 +337,7 @@ const EditChannel = () => {
|
|||||||
label='密钥'
|
label='密钥'
|
||||||
name='key'
|
name='key'
|
||||||
required
|
required
|
||||||
placeholder={'请输入密钥'}
|
placeholder={inputs.type === 15 ? "请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次" : '请输入密钥'}
|
||||||
onChange={handleInputChange}
|
onChange={handleInputChange}
|
||||||
value={inputs.key}
|
value={inputs.key}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
@@ -309,7 +354,7 @@ const EditChannel = () => {
|
|||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
<Button positive onClick={submit}>提交</Button>
|
<Button type={isEdit ? "button" : "submit"} positive onClick={submit}>提交</Button>
|
||||||
</Form>
|
</Form>
|
||||||
</Segment>
|
</Segment>
|
||||||
</>
|
</>
|
||||||
|
|||||||
Reference in New Issue
Block a user