mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 16:06:38 +08:00
feat: 初步重构
This commit is contained in:
parent
9b421478c1
commit
5b18cd6b0a
@ -230,7 +230,7 @@ func StringsContains(strs []string, str string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// []byte only read, panic on append
|
// StringToByteSlice []byte only read, panic on append
|
||||||
func StringToByteSlice(s string) []byte {
|
func StringToByteSlice(s string) []byte {
|
||||||
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
|
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
|
||||||
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
|
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
"one-api/service"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -92,7 +93,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
|
|||||||
for k := range headers {
|
for k := range headers {
|
||||||
req.Header.Add(k, headers.Get(k))
|
req.Header.Add(k, headers.Get(k))
|
||||||
}
|
}
|
||||||
res, err := httpClient.Do(req)
|
res, err := service.GetHttpClient().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -310,7 +311,7 @@ func updateAllChannelsBalance() error {
|
|||||||
} else {
|
} else {
|
||||||
// err is nil & balance <= 0 means quota is used up
|
// err is nil & balance <= 0 means quota is used up
|
||||||
if balance <= 0 {
|
if balance <= 0 {
|
||||||
disableChannel(channel.Id, channel.Name, "余额不足")
|
service.DisableChannel(channel.Id, channel.Name, "余额不足")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
time.Sleep(common.RequestInterval)
|
time.Sleep(common.RequestInterval)
|
||||||
|
@ -5,9 +5,17 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
relaychannel "one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/constant"
|
||||||
|
"one-api/service"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -15,89 +23,77 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
|
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
|
||||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, request.Model))
|
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
|
||||||
switch channel.Type {
|
w := httptest.NewRecorder()
|
||||||
case common.ChannelTypePaLM:
|
c, _ := gin.CreateTestContext(w)
|
||||||
fallthrough
|
c.Request = &http.Request{
|
||||||
case common.ChannelTypeAnthropic:
|
Method: "POST",
|
||||||
fallthrough
|
URL: &url.URL{Path: "/v1/chat/completions"},
|
||||||
case common.ChannelTypeBaidu:
|
Body: nil,
|
||||||
fallthrough
|
Header: make(http.Header),
|
||||||
case common.ChannelTypeZhipu:
|
|
||||||
fallthrough
|
|
||||||
case common.ChannelTypeAli:
|
|
||||||
fallthrough
|
|
||||||
case common.ChannelType360:
|
|
||||||
fallthrough
|
|
||||||
case common.ChannelTypeGemini:
|
|
||||||
fallthrough
|
|
||||||
case common.ChannelTypeXunfei:
|
|
||||||
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
|
|
||||||
case common.ChannelTypeAzure:
|
|
||||||
if request.Model == "" {
|
|
||||||
request.Model = "gpt-35-turbo"
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
default:
|
|
||||||
if request.Model == "" {
|
|
||||||
request.Model = "gpt-3.5-turbo"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
baseUrl := common.ChannelBaseURLs[channel.Type]
|
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||||
if channel.GetBaseURL() != "" {
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
baseUrl = channel.GetBaseURL()
|
c.Set("channel", channel.Type)
|
||||||
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
|
meta := relaycommon.GenRelayInfo(c)
|
||||||
|
apiType := constant.ChannelType2APIType(channel.Type)
|
||||||
|
adaptor := relaychannel.GetAdaptor(apiType)
|
||||||
|
if adaptor == nil {
|
||||||
|
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||||
}
|
}
|
||||||
requestURL := getFullRequestURL(baseUrl, "/v1/chat/completions", channel.Type)
|
if testModel == "" {
|
||||||
|
testModel = adaptor.GetModelList()[0]
|
||||||
|
}
|
||||||
|
request := buildTestRequest()
|
||||||
|
|
||||||
if channel.Type == common.ChannelTypeAzure {
|
adaptor.Init(meta, *request)
|
||||||
requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(request)
|
request.Model = testModel
|
||||||
|
meta.UpstreamModelName = testModel
|
||||||
|
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
if channel.Type == common.ChannelTypeAzure {
|
requestBody := bytes.NewBuffer(jsonData)
|
||||||
req.Header.Set("api-key", channel.Key)
|
c.Request.Body = io.NopCloser(requestBody)
|
||||||
} else {
|
resp, err := adaptor.DoRequest(c, meta, requestBody)
|
||||||
req.Header.Set("Authorization", "Bearer "+channel.Key)
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
resp, err := httpClient.Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
if resp.StatusCode != http.StatusOK {
|
||||||
var response TextResponse
|
err := relaycommon.RelayErrorHandler(resp)
|
||||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.OpenAIError.Message), &err.OpenAIError
|
||||||
|
}
|
||||||
|
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||||
|
if respErr != nil {
|
||||||
|
return fmt.Errorf("%s", respErr.OpenAIError.Message), &respErr.OpenAIError
|
||||||
|
}
|
||||||
|
if usage == nil {
|
||||||
|
return errors.New("usage is nil"), nil
|
||||||
|
}
|
||||||
|
result := w.Result()
|
||||||
|
// print result.Body
|
||||||
|
respBody, err := io.ReadAll(result.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
if response.Usage.CompletionTokens == 0 {
|
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||||
if response.Error.Message == "" {
|
|
||||||
response.Error.Message = "补全 tokens 非预期返回 0"
|
|
||||||
}
|
|
||||||
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
|
|
||||||
}
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildTestRequest() *ChatRequest {
|
func buildTestRequest() *dto.GeneralOpenAIRequest {
|
||||||
testRequest := &ChatRequest{
|
testRequest := &dto.GeneralOpenAIRequest{
|
||||||
Model: "", // this will be set later
|
Model: "", // this will be set later
|
||||||
MaxTokens: 1,
|
MaxTokens: 1,
|
||||||
}
|
}
|
||||||
content, _ := json.Marshal("hi")
|
content, _ := json.Marshal("hi")
|
||||||
testMessage := Message{
|
testMessage := dto.Message{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: content,
|
Content: content,
|
||||||
}
|
}
|
||||||
@ -114,7 +110,6 @@ func TestChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
testModel := c.Query("model")
|
|
||||||
channel, err := model.GetChannelById(id, true)
|
channel, err := model.GetChannelById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@ -123,12 +118,9 @@ func TestChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
testRequest := buildTestRequest()
|
testModel := c.Query("model")
|
||||||
if testModel != "" {
|
|
||||||
testRequest.Model = testModel
|
|
||||||
}
|
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
err, _ = testChannel(channel, *testRequest)
|
err, _ = testChannel(channel, testModel)
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
go channel.UpdateResponseTime(milliseconds)
|
go channel.UpdateResponseTime(milliseconds)
|
||||||
@ -152,31 +144,6 @@ func TestChannel(c *gin.Context) {
|
|||||||
var testAllChannelsLock sync.Mutex
|
var testAllChannelsLock sync.Mutex
|
||||||
var testAllChannelsRunning bool = false
|
var testAllChannelsRunning bool = false
|
||||||
|
|
||||||
// disable & notify
|
|
||||||
func disableChannel(channelId int, channelName string, reason string) {
|
|
||||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
|
|
||||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
|
||||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
|
||||||
notifyRootUser(subject, content)
|
|
||||||
}
|
|
||||||
|
|
||||||
func enableChannel(channelId int, channelName string) {
|
|
||||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
|
|
||||||
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
|
||||||
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
|
||||||
notifyRootUser(subject, content)
|
|
||||||
}
|
|
||||||
|
|
||||||
func notifyRootUser(subject string, content string) {
|
|
||||||
if common.RootUserEmail == "" {
|
|
||||||
common.RootUserEmail = model.GetRootUserEmail()
|
|
||||||
}
|
|
||||||
err := common.SendEmail(subject, common.RootUserEmail, content)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testAllChannels(notify bool) error {
|
func testAllChannels(notify bool) error {
|
||||||
if common.RootUserEmail == "" {
|
if common.RootUserEmail == "" {
|
||||||
common.RootUserEmail = model.GetRootUserEmail()
|
common.RootUserEmail = model.GetRootUserEmail()
|
||||||
@ -192,7 +159,6 @@ func testAllChannels(notify bool) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
testRequest := buildTestRequest()
|
|
||||||
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
||||||
if disableThreshold == 0 {
|
if disableThreshold == 0 {
|
||||||
disableThreshold = 10000000 // a impossible value
|
disableThreshold = 10000000 // a impossible value
|
||||||
@ -201,7 +167,7 @@ func testAllChannels(notify bool) error {
|
|||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
err, openaiErr := testChannel(channel, *testRequest)
|
err, openaiErr := testChannel(channel, "")
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
|
|
||||||
@ -218,11 +184,11 @@ func testAllChannels(notify bool) error {
|
|||||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||||
ban = false
|
ban = false
|
||||||
}
|
}
|
||||||
if isChannelEnabled && shouldDisableChannel(openaiErr, -1) && ban {
|
if isChannelEnabled && service.ShouldDisableChannel(openaiErr, -1) && ban {
|
||||||
disableChannel(channel.Id, channel.Name, err.Error())
|
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||||
}
|
}
|
||||||
if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
|
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr) {
|
||||||
enableChannel(channel.Id, channel.Name)
|
service.EnableChannel(channel.Id, channel.Name)
|
||||||
}
|
}
|
||||||
channel.UpdateResponseTime(milliseconds)
|
channel.UpdateResponseTime(milliseconds)
|
||||||
time.Sleep(common.RequestInterval)
|
time.Sleep(common.RequestInterval)
|
||||||
|
@ -10,7 +10,9 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/controller/relay"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
relay2 "one-api/relay"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -63,7 +65,7 @@ import (
|
|||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
//req.Header.Set("Authorization", "Bearer midjourney-proxy")
|
//req.Header.Set("ApiKey", "Bearer midjourney-proxy")
|
||||||
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
||||||
resp, err := httpClient.Do(req)
|
resp, err := httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -221,7 +223,7 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
||||||
resp, err := httpClient.Do(req)
|
resp, err := relay.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
||||||
continue
|
continue
|
||||||
@ -231,7 +233,7 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var responseItems []Midjourney
|
var responseItems []relay2.Midjourney
|
||||||
err = json.Unmarshal(responseBody, &responseItems)
|
err = json.Unmarshal(responseBody, &responseItems)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||||
@ -284,7 +286,7 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask Midjourney) bool {
|
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask relay2.Midjourney) bool {
|
||||||
if oldTask.Code != 1 {
|
if oldTask.Code != 1 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,6 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,220 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
|
|
||||||
|
|
||||||
type AIProxyLibraryRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Query string `json:"query"`
|
|
||||||
LibraryId string `json:"libraryId"`
|
|
||||||
Stream bool `json:"stream"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AIProxyLibraryError struct {
|
|
||||||
ErrCode int `json:"errCode"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AIProxyLibraryDocument struct {
|
|
||||||
Title string `json:"title"`
|
|
||||||
URL string `json:"url"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AIProxyLibraryResponse struct {
|
|
||||||
Success bool `json:"success"`
|
|
||||||
Answer string `json:"answer"`
|
|
||||||
Documents []AIProxyLibraryDocument `json:"documents"`
|
|
||||||
AIProxyLibraryError
|
|
||||||
}
|
|
||||||
|
|
||||||
type AIProxyLibraryStreamResponse struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
Finish bool `json:"finish"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Documents []AIProxyLibraryDocument `json:"documents"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
|
|
||||||
query := ""
|
|
||||||
if len(request.Messages) != 0 {
|
|
||||||
query = string(request.Messages[len(request.Messages)-1].Content)
|
|
||||||
}
|
|
||||||
return &AIProxyLibraryRequest{
|
|
||||||
Model: request.Model,
|
|
||||||
Stream: request.Stream,
|
|
||||||
Query: query,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
|
|
||||||
if len(documents) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
content := "\n\n参考文档:\n"
|
|
||||||
for i, document := range documents {
|
|
||||||
content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
|
|
||||||
}
|
|
||||||
return content
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
|
|
||||||
content, _ := json.Marshal(response.Answer + aiProxyDocuments2Markdown(response.Documents))
|
|
||||||
choice := OpenAITextResponseChoice{
|
|
||||||
Index: 0,
|
|
||||||
Message: Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: content,
|
|
||||||
},
|
|
||||||
FinishReason: "stop",
|
|
||||||
}
|
|
||||||
fullTextResponse := OpenAITextResponse{
|
|
||||||
Id: common.GetUUID(),
|
|
||||||
Object: "chat.completion",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Choices: []OpenAITextResponseChoice{choice},
|
|
||||||
}
|
|
||||||
return &fullTextResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
|
|
||||||
choice.FinishReason = &stopFinishReason
|
|
||||||
return &ChatCompletionsStreamResponse{
|
|
||||||
Id: common.GetUUID(),
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Model: "",
|
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = response.Content
|
|
||||||
return &ChatCompletionsStreamResponse{
|
|
||||||
Id: common.GetUUID(),
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Model: response.Model,
|
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var usage Usage
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
||||||
return i + 1, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if len(data) < 5 { // ignore blank line or wrong format
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if data[:5] != "data:" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = data[5:]
|
|
||||||
dataChan <- data
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
var documents []AIProxyLibraryDocument
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
var AIProxyLibraryResponse AIProxyLibraryStreamResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if len(AIProxyLibraryResponse.Documents) != 0 {
|
|
||||||
documents = AIProxyLibraryResponse.Documents
|
|
||||||
}
|
|
||||||
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
response := documentsAIProxyLibrary(documents)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
return nil, &usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var AIProxyLibraryResponse AIProxyLibraryResponse
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if AIProxyLibraryResponse.ErrCode != 0 {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: AIProxyLibraryResponse.Message,
|
|
||||||
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
|
|
||||||
Code: AIProxyLibraryResponse.ErrCode,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &fullTextResponse.Usage
|
|
||||||
}
|
|
@ -1,752 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
APITypeOpenAI = iota
|
|
||||||
APITypeClaude
|
|
||||||
APITypePaLM
|
|
||||||
APITypeBaidu
|
|
||||||
APITypeZhipu
|
|
||||||
APITypeAli
|
|
||||||
APITypeXunfei
|
|
||||||
APITypeAIProxyLibrary
|
|
||||||
APITypeTencent
|
|
||||||
APITypeGemini
|
|
||||||
)
|
|
||||||
|
|
||||||
var httpClient *http.Client
|
|
||||||
var impatientHTTPClient *http.Client
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
if common.RelayTimeout == 0 {
|
|
||||||
httpClient = &http.Client{}
|
|
||||||
} else {
|
|
||||||
httpClient = &http.Client{
|
|
||||||
Timeout: time.Duration(common.RelayTimeout) * time.Second,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impatientHTTPClient = &http.Client{
|
|
||||||
Timeout: 5 * time.Second,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
||||||
channelType := c.GetInt("channel")
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
tokenId := c.GetInt("token_id")
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
group := c.GetString("group")
|
|
||||||
tokenUnlimited := c.GetBool("token_unlimited_quota")
|
|
||||||
startTime := time.Now()
|
|
||||||
var textRequest GeneralOpenAIRequest
|
|
||||||
|
|
||||||
err := common.UnmarshalBodyReusable(c, &textRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
if relayMode == RelayModeModerations && textRequest.Model == "" {
|
|
||||||
textRequest.Model = "text-moderation-latest"
|
|
||||||
}
|
|
||||||
if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
|
|
||||||
textRequest.Model = c.Param("model")
|
|
||||||
}
|
|
||||||
// request validation
|
|
||||||
if textRequest.Model == "" {
|
|
||||||
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeCompletions:
|
|
||||||
if textRequest.Prompt == "" {
|
|
||||||
return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
case RelayModeChatCompletions:
|
|
||||||
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
|
|
||||||
return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
case RelayModeEmbeddings:
|
|
||||||
case RelayModeModerations:
|
|
||||||
if textRequest.Input == "" {
|
|
||||||
return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
case RelayModeEdits:
|
|
||||||
if textRequest.Instruction == "" {
|
|
||||||
return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// map model name
|
|
||||||
modelMapping := c.GetString("model_mapping")
|
|
||||||
isModelMapped := false
|
|
||||||
if modelMapping != "" && modelMapping != "{}" {
|
|
||||||
modelMap := make(map[string]string)
|
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if modelMap[textRequest.Model] != "" {
|
|
||||||
textRequest.Model = modelMap[textRequest.Model]
|
|
||||||
isModelMapped = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
apiType := APITypeOpenAI
|
|
||||||
switch channelType {
|
|
||||||
case common.ChannelTypeAnthropic:
|
|
||||||
apiType = APITypeClaude
|
|
||||||
case common.ChannelTypeBaidu:
|
|
||||||
apiType = APITypeBaidu
|
|
||||||
case common.ChannelTypePaLM:
|
|
||||||
apiType = APITypePaLM
|
|
||||||
case common.ChannelTypeZhipu:
|
|
||||||
apiType = APITypeZhipu
|
|
||||||
case common.ChannelTypeAli:
|
|
||||||
apiType = APITypeAli
|
|
||||||
case common.ChannelTypeXunfei:
|
|
||||||
apiType = APITypeXunfei
|
|
||||||
case common.ChannelTypeAIProxyLibrary:
|
|
||||||
apiType = APITypeAIProxyLibrary
|
|
||||||
case common.ChannelTypeTencent:
|
|
||||||
apiType = APITypeTencent
|
|
||||||
case common.ChannelTypeGemini:
|
|
||||||
apiType = APITypeGemini
|
|
||||||
}
|
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
|
||||||
requestURL := c.Request.URL.String()
|
|
||||||
if c.GetString("base_url") != "" {
|
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
}
|
|
||||||
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
|
||||||
switch apiType {
|
|
||||||
case APITypeOpenAI:
|
|
||||||
if channelType == common.ChannelTypeAzure {
|
|
||||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
|
||||||
query := c.Request.URL.Query()
|
|
||||||
apiVersion := query.Get("api-version")
|
|
||||||
if apiVersion == "" {
|
|
||||||
apiVersion = c.GetString("api_version")
|
|
||||||
}
|
|
||||||
requestURL := strings.Split(requestURL, "?")[0]
|
|
||||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
|
||||||
model_ := textRequest.Model
|
|
||||||
model_ = strings.Replace(model_, ".", "", -1)
|
|
||||||
// https://github.com/songquanpeng/one-api/issues/67
|
|
||||||
model_ = strings.TrimSuffix(model_, "-0301")
|
|
||||||
model_ = strings.TrimSuffix(model_, "-0314")
|
|
||||||
model_ = strings.TrimSuffix(model_, "-0613")
|
|
||||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
|
||||||
fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
|
|
||||||
}
|
|
||||||
case APITypeClaude:
|
|
||||||
fullRequestURL = "https://api.anthropic.com/v1/complete"
|
|
||||||
if baseURL != "" {
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
|
|
||||||
}
|
|
||||||
case APITypeBaidu:
|
|
||||||
switch textRequest.Model {
|
|
||||||
case "ERNIE-Bot":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
|
||||||
case "ERNIE-Bot-turbo":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
|
||||||
case "ERNIE-Bot-4":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
|
||||||
case "BLOOMZ-7B":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
|
||||||
case "Embedding-V1":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
|
||||||
}
|
|
||||||
apiKey := c.Request.Header.Get("Authorization")
|
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
|
||||||
var err error
|
|
||||||
if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
|
|
||||||
return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
fullRequestURL += "?access_token=" + apiKey
|
|
||||||
case APITypePaLM:
|
|
||||||
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
|
|
||||||
if baseURL != "" {
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
|
|
||||||
}
|
|
||||||
apiKey := c.Request.Header.Get("Authorization")
|
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
|
||||||
fullRequestURL += "?key=" + apiKey
|
|
||||||
case APITypeGemini:
|
|
||||||
requestBaseURL := "https://generativelanguage.googleapis.com"
|
|
||||||
if baseURL != "" {
|
|
||||||
requestBaseURL = baseURL
|
|
||||||
}
|
|
||||||
version := "v1beta"
|
|
||||||
if c.GetString("api_version") != "" {
|
|
||||||
version = c.GetString("api_version")
|
|
||||||
}
|
|
||||||
action := "generateContent"
|
|
||||||
if textRequest.Stream {
|
|
||||||
action = "streamGenerateContent"
|
|
||||||
}
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
|
|
||||||
apiKey := c.Request.Header.Get("Authorization")
|
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
|
||||||
fullRequestURL += "?key=" + apiKey
|
|
||||||
//log.Println(fullRequestURL)
|
|
||||||
|
|
||||||
case APITypeZhipu:
|
|
||||||
method := "invoke"
|
|
||||||
if textRequest.Stream {
|
|
||||||
method = "sse-invoke"
|
|
||||||
}
|
|
||||||
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
|
||||||
case APITypeAli:
|
|
||||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
|
||||||
if relayMode == RelayModeEmbeddings {
|
|
||||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
|
||||||
}
|
|
||||||
case APITypeTencent:
|
|
||||||
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
|
|
||||||
case APITypeAIProxyLibrary:
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
|
|
||||||
}
|
|
||||||
var promptTokens int
|
|
||||||
var completionTokens int
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeChatCompletions:
|
|
||||||
promptTokens, err = countTokenMessages(textRequest.Messages, textRequest.Model)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
case RelayModeCompletions:
|
|
||||||
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
|
|
||||||
case RelayModeModerations:
|
|
||||||
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
|
|
||||||
}
|
|
||||||
modelPrice := common.GetModelPrice(textRequest.Model, false)
|
|
||||||
groupRatio := common.GetGroupRatio(group)
|
|
||||||
|
|
||||||
var preConsumedQuota int
|
|
||||||
var ratio float64
|
|
||||||
var modelRatio float64
|
|
||||||
if modelPrice == -1 {
|
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
|
||||||
if textRequest.MaxTokens != 0 {
|
|
||||||
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
|
|
||||||
}
|
|
||||||
modelRatio = common.GetModelRatio(textRequest.Model)
|
|
||||||
ratio = modelRatio * groupRatio
|
|
||||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
|
||||||
} else {
|
|
||||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
|
||||||
}
|
|
||||||
|
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if userQuota < 0 || userQuota-preConsumedQuota < 0 {
|
|
||||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
|
||||||
}
|
|
||||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if userQuota > 100*preConsumedQuota {
|
|
||||||
// 用户额度充足,判断令牌额度是否充足
|
|
||||||
if !tokenUnlimited {
|
|
||||||
// 非无限令牌,判断令牌额度是否充足
|
|
||||||
tokenQuota := c.GetInt("token_quota")
|
|
||||||
if tokenQuota > 100*preConsumedQuota {
|
|
||||||
// 令牌额度充足,信任令牌
|
|
||||||
preConsumedQuota = 0
|
|
||||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", userId, userQuota, tokenId, tokenQuota))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// in this case, we do not pre-consume quota
|
|
||||||
// because the user has enough quota
|
|
||||||
preConsumedQuota = 0
|
|
||||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if preConsumedQuota > 0 {
|
|
||||||
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var requestBody io.Reader
|
|
||||||
if isModelMapped {
|
|
||||||
jsonStr, err := json.Marshal(textRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
} else {
|
|
||||||
requestBody = c.Request.Body
|
|
||||||
}
|
|
||||||
switch apiType {
|
|
||||||
case APITypeClaude:
|
|
||||||
claudeRequest := requestOpenAI2Claude(textRequest)
|
|
||||||
jsonStr, err := json.Marshal(claudeRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
case APITypeBaidu:
|
|
||||||
var jsonData []byte
|
|
||||||
var err error
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeEmbeddings:
|
|
||||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
|
|
||||||
jsonData, err = json.Marshal(baiduEmbeddingRequest)
|
|
||||||
default:
|
|
||||||
baiduRequest := requestOpenAI2Baidu(textRequest)
|
|
||||||
jsonData, err = json.Marshal(baiduRequest)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonData)
|
|
||||||
case APITypePaLM:
|
|
||||||
palmRequest := requestOpenAI2PaLM(textRequest)
|
|
||||||
jsonStr, err := json.Marshal(palmRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
case APITypeGemini:
|
|
||||||
geminiChatRequest := requestOpenAI2Gemini(textRequest)
|
|
||||||
jsonStr, err := json.Marshal(geminiChatRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
case APITypeZhipu:
|
|
||||||
zhipuRequest := requestOpenAI2Zhipu(textRequest)
|
|
||||||
jsonStr, err := json.Marshal(zhipuRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
case APITypeAli:
|
|
||||||
var jsonStr []byte
|
|
||||||
var err error
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeEmbeddings:
|
|
||||||
aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
|
|
||||||
jsonStr, err = json.Marshal(aliEmbeddingRequest)
|
|
||||||
default:
|
|
||||||
aliRequest := requestOpenAI2Ali(textRequest)
|
|
||||||
jsonStr, err = json.Marshal(aliRequest)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
case APITypeTencent:
|
|
||||||
apiKey := c.Request.Header.Get("Authorization")
|
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
|
||||||
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
tencentRequest := requestOpenAI2Tencent(textRequest)
|
|
||||||
tencentRequest.AppId = appId
|
|
||||||
tencentRequest.SecretId = secretId
|
|
||||||
jsonStr, err := json.Marshal(tencentRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
sign := getTencentSign(*tencentRequest, secretKey)
|
|
||||||
c.Request.Header.Set("Authorization", sign)
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
case APITypeAIProxyLibrary:
|
|
||||||
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
|
|
||||||
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
|
|
||||||
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
var req *http.Request
|
|
||||||
var resp *http.Response
|
|
||||||
isStream := textRequest.Stream
|
|
||||||
|
|
||||||
if apiType != APITypeXunfei { // cause xunfei use websocket
|
|
||||||
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
|
||||||
// 设置GetBody函数,该函数返回一个新的io.ReadCloser,该io.ReadCloser返回与原始请求体相同的数据
|
|
||||||
req.GetBody = func() (io.ReadCloser, error) {
|
|
||||||
return io.NopCloser(requestBody), nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
apiKey := c.Request.Header.Get("Authorization")
|
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
|
||||||
switch apiType {
|
|
||||||
case APITypeOpenAI:
|
|
||||||
if channelType == common.ChannelTypeAzure {
|
|
||||||
req.Header.Set("api-key", apiKey)
|
|
||||||
} else {
|
|
||||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
|
||||||
if c.Request.Header.Get("OpenAI-Organization") != "" {
|
|
||||||
req.Header.Set("OpenAI-Organization", c.Request.Header.Get("OpenAI-Organization"))
|
|
||||||
}
|
|
||||||
if channelType == common.ChannelTypeOpenRouter {
|
|
||||||
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
|
||||||
req.Header.Set("X-Title", "One API")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case APITypeClaude:
|
|
||||||
req.Header.Set("x-api-key", apiKey)
|
|
||||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
|
||||||
if anthropicVersion == "" {
|
|
||||||
anthropicVersion = "2023-06-01"
|
|
||||||
}
|
|
||||||
req.Header.Set("anthropic-version", anthropicVersion)
|
|
||||||
case APITypeZhipu:
|
|
||||||
token := getZhipuToken(apiKey)
|
|
||||||
req.Header.Set("Authorization", token)
|
|
||||||
case APITypeAli:
|
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
||||||
if textRequest.Stream {
|
|
||||||
req.Header.Set("X-DashScope-SSE", "enable")
|
|
||||||
}
|
|
||||||
case APITypeTencent:
|
|
||||||
req.Header.Set("Authorization", apiKey)
|
|
||||||
case APITypeGemini:
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
default:
|
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
||||||
}
|
|
||||||
if apiType != APITypeGemini {
|
|
||||||
// 设置公共头部...
|
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
|
||||||
if isStream && c.Request.Header.Get("Accept") == "" {
|
|
||||||
req.Header.Set("Accept", "text/event-stream")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
//req.HeaderBar.Set("Connection", c.Request.HeaderBar.Get("Connection"))
|
|
||||||
resp, err = httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = req.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = c.Request.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if preConsumedQuota != 0 {
|
|
||||||
go func(ctx context.Context) {
|
|
||||||
// return pre-consumed quota
|
|
||||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
|
||||||
}
|
|
||||||
}(c.Request.Context())
|
|
||||||
}
|
|
||||||
return relayErrorHandler(resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var textResponse TextResponse
|
|
||||||
tokenName := c.GetString("token_name")
|
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
|
||||||
// c.Writer.Flush()
|
|
||||||
go func() {
|
|
||||||
useTimeSeconds := time.Now().Unix() - startTime.Unix()
|
|
||||||
promptTokens = textResponse.Usage.PromptTokens
|
|
||||||
completionTokens = textResponse.Usage.CompletionTokens
|
|
||||||
|
|
||||||
quota := 0
|
|
||||||
if modelPrice == -1 {
|
|
||||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
|
||||||
quota = promptTokens + int(float64(completionTokens)*completionRatio)
|
|
||||||
quota = int(float64(quota) * ratio)
|
|
||||||
if ratio != 0 && quota <= 0 {
|
|
||||||
quota = 1
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
|
||||||
}
|
|
||||||
totalTokens := promptTokens + completionTokens
|
|
||||||
var logContent string
|
|
||||||
if modelPrice == -1 {
|
|
||||||
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
|
||||||
} else {
|
|
||||||
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
|
|
||||||
}
|
|
||||||
|
|
||||||
// record all the consume log even if quota is 0
|
|
||||||
if totalTokens == 0 {
|
|
||||||
// in this case, must be some error happened
|
|
||||||
// we cannot just return, because we may have to return the pre-consumed quota
|
|
||||||
quota = 0
|
|
||||||
logContent += fmt.Sprintf("(有疑问请联系管理员)")
|
|
||||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", userId, channelId, tokenId, textRequest.Model, preConsumedQuota))
|
|
||||||
} else {
|
|
||||||
quotaDelta := quota - preConsumedQuota
|
|
||||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
|
||||||
}
|
|
||||||
err = model.CacheUpdateUserQuota(userId)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
|
||||||
}
|
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
|
||||||
}
|
|
||||||
|
|
||||||
logModel := textRequest.Model
|
|
||||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
|
||||||
logModel = "gpt-4-gizmo-*"
|
|
||||||
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
|
|
||||||
}
|
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), isStream)
|
|
||||||
|
|
||||||
//if quota != 0 {
|
|
||||||
//
|
|
||||||
//}
|
|
||||||
}()
|
|
||||||
}(c.Request.Context())
|
|
||||||
switch apiType {
|
|
||||||
case APITypeOpenAI:
|
|
||||||
if isStream {
|
|
||||||
err, responseText := openaiStreamHandler(c, resp, relayMode)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
textResponse.Usage.PromptTokens = promptTokens
|
|
||||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeClaude:
|
|
||||||
if isStream {
|
|
||||||
err, responseText := claudeStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
textResponse.Usage.PromptTokens = promptTokens
|
|
||||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeBaidu:
|
|
||||||
if isStream {
|
|
||||||
err, usage := baiduStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
var err *OpenAIErrorWithStatusCode
|
|
||||||
var usage *Usage
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeEmbeddings:
|
|
||||||
err, usage = baiduEmbeddingHandler(c, resp)
|
|
||||||
default:
|
|
||||||
err, usage = baiduHandler(c, resp)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypePaLM:
|
|
||||||
if textRequest.Stream { // PaLM2 API does not support stream
|
|
||||||
err, responseText := palmStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
textResponse.Usage.PromptTokens = promptTokens
|
|
||||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeGemini:
|
|
||||||
if textRequest.Stream {
|
|
||||||
err, responseText := geminiChatStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
textResponse.Usage.PromptTokens = promptTokens
|
|
||||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeZhipu:
|
|
||||||
if isStream {
|
|
||||||
err, usage := zhipuStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
// zhipu's API does not return prompt tokens & completion tokens
|
|
||||||
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := zhipuHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
// zhipu's API does not return prompt tokens & completion tokens
|
|
||||||
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeAli:
|
|
||||||
if isStream {
|
|
||||||
err, usage := aliStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
var err *OpenAIErrorWithStatusCode
|
|
||||||
var usage *Usage
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeEmbeddings:
|
|
||||||
err, usage = aliEmbeddingHandler(c, resp)
|
|
||||||
default:
|
|
||||||
err, usage = aliHandler(c, resp)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeXunfei:
|
|
||||||
auth := c.Request.Header.Get("Authorization")
|
|
||||||
auth = strings.TrimPrefix(auth, "Bearer ")
|
|
||||||
splits := strings.Split(auth, "|")
|
|
||||||
if len(splits) != 3 {
|
|
||||||
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
var err *OpenAIErrorWithStatusCode
|
|
||||||
var usage *Usage
|
|
||||||
if isStream {
|
|
||||||
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
|
|
||||||
} else {
|
|
||||||
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
case APITypeAIProxyLibrary:
|
|
||||||
if isStream {
|
|
||||||
err, usage := aiProxyLibraryStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := aiProxyLibraryHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeTencent:
|
|
||||||
if isStream {
|
|
||||||
err, responseText := tencentStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
textResponse.Usage.PromptTokens = promptTokens
|
|
||||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := tencentHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,340 +1,34 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay"
|
||||||
|
"one-api/relay/constant"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
|
"one-api/service"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content json.RawMessage `json:"content"`
|
|
||||||
Name *string `json:"name,omitempty"`
|
|
||||||
ToolCalls any `json:"tool_calls,omitempty"`
|
|
||||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type MediaMessage struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Text string `json:"text"`
|
|
||||||
ImageUrl any `json:"image_url,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type MessageImageUrl struct {
|
|
||||||
Url string `json:"url"`
|
|
||||||
Detail string `json:"detail"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
ContentTypeText = "text"
|
|
||||||
ContentTypeImageURL = "image_url"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (m Message) StringContent() string {
|
|
||||||
var stringContent string
|
|
||||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
|
||||||
return stringContent
|
|
||||||
}
|
|
||||||
return string(m.Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Message) ParseContent() []MediaMessage {
|
|
||||||
var contentList []MediaMessage
|
|
||||||
var stringContent string
|
|
||||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
|
||||||
contentList = append(contentList, MediaMessage{
|
|
||||||
Type: ContentTypeText,
|
|
||||||
Text: stringContent,
|
|
||||||
})
|
|
||||||
return contentList
|
|
||||||
}
|
|
||||||
var arrayContent []json.RawMessage
|
|
||||||
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
|
|
||||||
for _, contentItem := range arrayContent {
|
|
||||||
var contentMap map[string]any
|
|
||||||
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
switch contentMap["type"] {
|
|
||||||
case ContentTypeText:
|
|
||||||
if subStr, ok := contentMap["text"].(string); ok {
|
|
||||||
contentList = append(contentList, MediaMessage{
|
|
||||||
Type: ContentTypeText,
|
|
||||||
Text: subStr,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
case ContentTypeImageURL:
|
|
||||||
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
|
|
||||||
detail, ok := subObj["detail"]
|
|
||||||
if ok {
|
|
||||||
subObj["detail"] = detail.(string)
|
|
||||||
} else {
|
|
||||||
subObj["detail"] = "auto"
|
|
||||||
}
|
|
||||||
contentList = append(contentList, MediaMessage{
|
|
||||||
Type: ContentTypeImageURL,
|
|
||||||
ImageUrl: MessageImageUrl{
|
|
||||||
Url: subObj["url"].(string),
|
|
||||||
Detail: subObj["detail"].(string),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return contentList
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
RelayModeUnknown = iota
|
|
||||||
RelayModeChatCompletions
|
|
||||||
RelayModeCompletions
|
|
||||||
RelayModeEmbeddings
|
|
||||||
RelayModeModerations
|
|
||||||
RelayModeImagesGenerations
|
|
||||||
RelayModeEdits
|
|
||||||
RelayModeMidjourneyImagine
|
|
||||||
RelayModeMidjourneyDescribe
|
|
||||||
RelayModeMidjourneyBlend
|
|
||||||
RelayModeMidjourneyChange
|
|
||||||
RelayModeMidjourneySimpleChange
|
|
||||||
RelayModeMidjourneyNotify
|
|
||||||
RelayModeMidjourneyTaskFetch
|
|
||||||
RelayModeMidjourneyTaskFetchByCondition
|
|
||||||
RelayModeAudioSpeech
|
|
||||||
RelayModeAudioTranscription
|
|
||||||
RelayModeAudioTranslation
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/chat
|
|
||||||
|
|
||||||
type ResponseFormat struct {
|
|
||||||
Type string `json:"type,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeneralOpenAIRequest struct {
|
|
||||||
Model string `json:"model,omitempty"`
|
|
||||||
Messages []Message `json:"messages,omitempty"`
|
|
||||||
Prompt any `json:"prompt,omitempty"`
|
|
||||||
Stream bool `json:"stream,omitempty"`
|
|
||||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
|
||||||
N int `json:"n,omitempty"`
|
|
||||||
Input any `json:"input,omitempty"`
|
|
||||||
Instruction string `json:"instruction,omitempty"`
|
|
||||||
Size string `json:"size,omitempty"`
|
|
||||||
Functions any `json:"functions,omitempty"`
|
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
|
||||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
|
||||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
|
||||||
Seed float64 `json:"seed,omitempty"`
|
|
||||||
Tools any `json:"tools,omitempty"`
|
|
||||||
ToolChoice any `json:"tool_choice,omitempty"`
|
|
||||||
User string `json:"user,omitempty"`
|
|
||||||
LogProbs bool `json:"logprobs,omitempty"`
|
|
||||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
|
||||||
if r.Input == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var input []string
|
|
||||||
switch r.Input.(type) {
|
|
||||||
case string:
|
|
||||||
input = []string{r.Input.(string)}
|
|
||||||
case []any:
|
|
||||||
input = make([]string, 0, len(r.Input.([]any)))
|
|
||||||
for _, item := range r.Input.([]any) {
|
|
||||||
if str, ok := item.(string); ok {
|
|
||||||
input = append(input, str)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return input
|
|
||||||
}
|
|
||||||
|
|
||||||
type AudioRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Voice string `json:"voice"`
|
|
||||||
Input string `json:"input"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Messages []Message `json:"messages"`
|
|
||||||
MaxTokens uint `json:"max_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TextRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Messages []Message `json:"messages"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
MaxTokens uint `json:"max_tokens"`
|
|
||||||
//Stream bool `json:"stream"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ImageRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
N int `json:"n"`
|
|
||||||
Size string `json:"size"`
|
|
||||||
Quality string `json:"quality,omitempty"`
|
|
||||||
ResponseFormat string `json:"response_format,omitempty"`
|
|
||||||
Style string `json:"style,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AudioResponse struct {
|
|
||||||
Text string `json:"text,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Usage struct {
|
|
||||||
PromptTokens int `json:"prompt_tokens"`
|
|
||||||
CompletionTokens int `json:"completion_tokens"`
|
|
||||||
TotalTokens int `json:"total_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIError struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
Param string `json:"param"`
|
|
||||||
Code any `json:"code"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIErrorWithStatusCode struct {
|
|
||||||
OpenAIError
|
|
||||||
StatusCode int `json:"status_code"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TextResponse struct {
|
|
||||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
|
||||||
Usage `json:"usage"`
|
|
||||||
Error OpenAIError `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAITextResponseChoice struct {
|
|
||||||
Index int `json:"index"`
|
|
||||||
Message `json:"message"`
|
|
||||||
FinishReason string `json:"finish_reason"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAITextResponse struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
|
||||||
Usage `json:"usage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIEmbeddingResponseItem struct {
|
|
||||||
Object string `json:"object"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
Embedding []float64 `json:"embedding"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIEmbeddingResponse struct {
|
|
||||||
Object string `json:"object"`
|
|
||||||
Data []OpenAIEmbeddingResponseItem `json:"data"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Usage `json:"usage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ImageResponse struct {
|
|
||||||
Created int `json:"created"`
|
|
||||||
Data []struct {
|
|
||||||
Url string `json:"url"`
|
|
||||||
B64Json string `json:"b64_json"`
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionsStreamResponseChoice struct {
|
|
||||||
Delta struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
} `json:"delta"`
|
|
||||||
FinishReason *string `json:"finish_reason,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionsStreamResponse struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionsStreamResponseSimple struct {
|
|
||||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type CompletionsStreamResponse struct {
|
|
||||||
Choices []struct {
|
|
||||||
Text string `json:"text"`
|
|
||||||
FinishReason string `json:"finish_reason"`
|
|
||||||
} `json:"choices"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type MidjourneyRequest struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
NotifyHook string `json:"notifyHook"`
|
|
||||||
Action string `json:"action"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
State string `json:"state"`
|
|
||||||
TaskId string `json:"taskId"`
|
|
||||||
Base64Array []string `json:"base64Array"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type MidjourneyResponse struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Properties interface{} `json:"properties"`
|
|
||||||
Result string `json:"result"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func Relay(c *gin.Context) {
|
func Relay(c *gin.Context) {
|
||||||
relayMode := RelayModeUnknown
|
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
var err *dto.OpenAIErrorWithStatusCode
|
||||||
relayMode = RelayModeChatCompletions
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
|
|
||||||
relayMode = RelayModeCompletions
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
|
|
||||||
relayMode = RelayModeEmbeddings
|
|
||||||
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
|
||||||
relayMode = RelayModeEmbeddings
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
|
||||||
relayMode = RelayModeModerations
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
|
||||||
relayMode = RelayModeImagesGenerations
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
|
|
||||||
relayMode = RelayModeEdits
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
|
|
||||||
relayMode = RelayModeAudioSpeech
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
|
||||||
relayMode = RelayModeAudioTranscription
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
|
||||||
relayMode = RelayModeAudioTranslation
|
|
||||||
}
|
|
||||||
var err *OpenAIErrorWithStatusCode
|
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case RelayModeImagesGenerations:
|
case relayconstant.RelayModeImagesGenerations:
|
||||||
err = relayImageHelper(c, relayMode)
|
err = relay.RelayImageHelper(c, relayMode)
|
||||||
case RelayModeAudioSpeech:
|
case relayconstant.RelayModeAudioSpeech:
|
||||||
fallthrough
|
fallthrough
|
||||||
case RelayModeAudioTranslation:
|
case relayconstant.RelayModeAudioTranslation:
|
||||||
fallthrough
|
fallthrough
|
||||||
case RelayModeAudioTranscription:
|
case relayconstant.RelayModeAudioTranscription:
|
||||||
err = relayAudioHelper(c, relayMode)
|
err = relay.RelayAudioHelper(c, relayMode)
|
||||||
default:
|
default:
|
||||||
err = relayTextHelper(c, relayMode)
|
err = relay.TextHelper(c)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
@ -358,42 +52,42 @@ func Relay(c *gin.Context) {
|
|||||||
autoBan := c.GetBool("auto_ban")
|
autoBan := c.GetBool("auto_ban")
|
||||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
||||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||||
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan {
|
if service.ShouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan {
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
channelName := c.GetString("channel_name")
|
channelName := c.GetString("channel_name")
|
||||||
disableChannel(channelId, channelName, err.Message)
|
service.DisableChannel(channelId, channelName, err.Message)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RelayMidjourney(c *gin.Context) {
|
func RelayMidjourney(c *gin.Context) {
|
||||||
relayMode := RelayModeUnknown
|
relayMode := relayconstant.RelayModeUnknown
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
|
if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
|
||||||
relayMode = RelayModeMidjourneyImagine
|
relayMode = relayconstant.RelayModeMidjourneyImagine
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
|
||||||
relayMode = RelayModeMidjourneyBlend
|
relayMode = relayconstant.RelayModeMidjourneyBlend
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
|
||||||
relayMode = RelayModeMidjourneyDescribe
|
relayMode = relayconstant.RelayModeMidjourneyDescribe
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
|
||||||
relayMode = RelayModeMidjourneyNotify
|
relayMode = relayconstant.RelayModeMidjourneyNotify
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
|
||||||
relayMode = RelayModeMidjourneyChange
|
relayMode = relayconstant.RelayModeMidjourneyChange
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
|
||||||
relayMode = RelayModeMidjourneyChange
|
relayMode = relayconstant.RelayModeMidjourneyChange
|
||||||
} else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
|
} else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
|
||||||
relayMode = RelayModeMidjourneyTaskFetch
|
relayMode = relayconstant.RelayModeMidjourneyTaskFetch
|
||||||
} else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
|
} else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
|
||||||
relayMode = RelayModeMidjourneyTaskFetchByCondition
|
relayMode = relayconstant.RelayModeMidjourneyTaskFetchByCondition
|
||||||
}
|
}
|
||||||
|
|
||||||
var err *MidjourneyResponse
|
var err *dto.MidjourneyResponse
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case RelayModeMidjourneyNotify:
|
case relayconstant.RelayModeMidjourneyNotify:
|
||||||
err = relayMidjourneyNotify(c)
|
err = relay.RelayMidjourneyNotify(c)
|
||||||
case RelayModeMidjourneyTaskFetch, RelayModeMidjourneyTaskFetchByCondition:
|
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
||||||
err = relayMidjourneyTask(c, relayMode)
|
err = relay.RelayMidjourneyTask(c, relayMode)
|
||||||
default:
|
default:
|
||||||
err = relayMidjourneySubmit(c, relayMode)
|
err = relay.RelayMidjourneySubmit(c, relayMode)
|
||||||
}
|
}
|
||||||
//err = relayMidjourneySubmit(c, relayMode)
|
//err = relayMidjourneySubmit(c, relayMode)
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
@ -425,7 +119,7 @@ func RelayMidjourney(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RelayNotImplemented(c *gin.Context) {
|
func RelayNotImplemented(c *gin.Context) {
|
||||||
err := OpenAIError{
|
err := dto.OpenAIError{
|
||||||
Message: "API not implemented",
|
Message: "API not implemented",
|
||||||
Type: "new_api_error",
|
Type: "new_api_error",
|
||||||
Param: "",
|
Param: "",
|
||||||
@ -437,7 +131,7 @@ func RelayNotImplemented(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RelayNotFound(c *gin.Context) {
|
func RelayNotFound(c *gin.Context) {
|
||||||
err := OpenAIError{
|
err := dto.OpenAIError{
|
||||||
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
|
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
|
||||||
Type: "invalid_request_error",
|
Type: "invalid_request_error",
|
||||||
Param: "",
|
Param: "",
|
||||||
|
13
dto/error.go
Normal file
13
dto/error.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
type OpenAIError struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Param string `json:"param"`
|
||||||
|
Code any `json:"code"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIErrorWithStatusCode struct {
|
||||||
|
OpenAIError
|
||||||
|
StatusCode int `json:"status_code"`
|
||||||
|
}
|
137
dto/request.go
Normal file
137
dto/request.go
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import "encoding/json"
|
||||||
|
|
||||||
|
type ResponseFormat struct {
|
||||||
|
Type string `json:"type,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeneralOpenAIRequest struct {
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
Messages []Message `json:"messages,omitempty"`
|
||||||
|
Prompt any `json:"prompt,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
Input any `json:"input,omitempty"`
|
||||||
|
Instruction string `json:"instruction,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
Functions any `json:"functions,omitempty"`
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||||
|
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||||
|
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||||
|
Seed float64 `json:"seed,omitempty"`
|
||||||
|
Tools any `json:"tools,omitempty"`
|
||||||
|
ToolChoice any `json:"tool_choice,omitempty"`
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
|
LogProbs bool `json:"logprobs,omitempty"`
|
||||||
|
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||||
|
if r.Input == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var input []string
|
||||||
|
switch r.Input.(type) {
|
||||||
|
case string:
|
||||||
|
input = []string{r.Input.(string)}
|
||||||
|
case []any:
|
||||||
|
input = make([]string, 0, len(r.Input.([]any)))
|
||||||
|
for _, item := range r.Input.([]any) {
|
||||||
|
if str, ok := item.(string); ok {
|
||||||
|
input = append(input, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return input
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content json.RawMessage `json:"content"`
|
||||||
|
Name *string `json:"name,omitempty"`
|
||||||
|
ToolCalls any `json:"tool_calls,omitempty"`
|
||||||
|
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MediaMessage struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
ImageUrl any `json:"image_url,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageImageUrl struct {
|
||||||
|
Url string `json:"url"`
|
||||||
|
Detail string `json:"detail"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ContentTypeText = "text"
|
||||||
|
ContentTypeImageURL = "image_url"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m Message) StringContent() string {
|
||||||
|
var stringContent string
|
||||||
|
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||||
|
return stringContent
|
||||||
|
}
|
||||||
|
return string(m.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Message) ParseContent() []MediaMessage {
|
||||||
|
var contentList []MediaMessage
|
||||||
|
var stringContent string
|
||||||
|
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||||
|
contentList = append(contentList, MediaMessage{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
Text: stringContent,
|
||||||
|
})
|
||||||
|
return contentList
|
||||||
|
}
|
||||||
|
var arrayContent []json.RawMessage
|
||||||
|
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
|
||||||
|
for _, contentItem := range arrayContent {
|
||||||
|
var contentMap map[string]any
|
||||||
|
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch contentMap["type"] {
|
||||||
|
case ContentTypeText:
|
||||||
|
if subStr, ok := contentMap["text"].(string); ok {
|
||||||
|
contentList = append(contentList, MediaMessage{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
Text: subStr,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case ContentTypeImageURL:
|
||||||
|
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
|
||||||
|
detail, ok := subObj["detail"]
|
||||||
|
if ok {
|
||||||
|
subObj["detail"] = detail.(string)
|
||||||
|
} else {
|
||||||
|
subObj["detail"] = "auto"
|
||||||
|
}
|
||||||
|
contentList = append(contentList, MediaMessage{
|
||||||
|
Type: ContentTypeImageURL,
|
||||||
|
ImageUrl: MessageImageUrl{
|
||||||
|
Url: subObj["url"].(string),
|
||||||
|
Detail: subObj["detail"].(string),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return contentList
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Usage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
86
dto/response.go
Normal file
86
dto/response.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
type TextResponse struct {
|
||||||
|
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||||
|
Usage `json:"usage"`
|
||||||
|
Error OpenAIError `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAITextResponseChoice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Message `json:"message"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAITextResponse struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||||
|
Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIEmbeddingResponseItem struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIEmbeddingResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []OpenAIEmbeddingResponseItem `json:"data"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageResponse struct {
|
||||||
|
Created int `json:"created"`
|
||||||
|
Data []struct {
|
||||||
|
Url string `json:"url"`
|
||||||
|
B64Json string `json:"b64_json"`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionsStreamResponseChoice struct {
|
||||||
|
Delta struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"delta"`
|
||||||
|
FinishReason *string `json:"finish_reason,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionsStreamResponse struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionsStreamResponseSimple struct {
|
||||||
|
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompletionsStreamResponse struct {
|
||||||
|
Choices []struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
} `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MidjourneyRequest struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
NotifyHook string `json:"notifyHook"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
State string `json:"state"`
|
||||||
|
TaskId string `json:"taskId"`
|
||||||
|
Base64Array []string `json:"base64Array"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MidjourneyResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Properties interface{} `json:"properties"`
|
||||||
|
Result string `json:"result"`
|
||||||
|
}
|
3
main.go
3
main.go
@ -12,6 +12,7 @@ import (
|
|||||||
"one-api/controller"
|
"one-api/controller"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
"one-api/relay/common"
|
||||||
"one-api/router"
|
"one-api/router"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -105,7 +106,7 @@ func main() {
|
|||||||
common.SysLog("pprof enabled")
|
common.SysLog("pprof enabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.InitTokenEncoders()
|
common.InitTokenEncoders()
|
||||||
|
|
||||||
// Initialize HTTP server
|
// Initialize HTTP server
|
||||||
server := gin.New()
|
server := gin.New()
|
||||||
|
@ -129,15 +129,18 @@ func Distribute() func(c *gin.Context) {
|
|||||||
c.Set("model_mapping", channel.GetModelMapping())
|
c.Set("model_mapping", channel.GetModelMapping())
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
|
// TODO: api_version统一
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypeAzure:
|
case common.ChannelTypeAzure:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
case common.ChannelTypeXunfei:
|
case common.ChannelTypeXunfei:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
case common.ChannelTypeAIProxyLibrary:
|
//case common.ChannelTypeAIProxyLibrary:
|
||||||
c.Set("library_id", channel.Other)
|
// c.Set("library_id", channel.Other)
|
||||||
case common.ChannelTypeGemini:
|
case common.ChannelTypeGemini:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
|
case common.ChannelTypeAli:
|
||||||
|
c.Set("plugin", channel.Other)
|
||||||
}
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
57
relay/channel/adapter.go
Normal file
57
relay/channel/adapter.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package channel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel/ali"
|
||||||
|
"one-api/relay/channel/baidu"
|
||||||
|
"one-api/relay/channel/claude"
|
||||||
|
"one-api/relay/channel/gemini"
|
||||||
|
"one-api/relay/channel/openai"
|
||||||
|
"one-api/relay/channel/palm"
|
||||||
|
"one-api/relay/channel/tencent"
|
||||||
|
"one-api/relay/channel/xunfei"
|
||||||
|
"one-api/relay/channel/zhipu"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/constant"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor interface {
|
||||||
|
// Init IsStream bool
|
||||||
|
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
|
||||||
|
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||||
|
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||||
|
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
|
||||||
|
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||||
|
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
||||||
|
GetModelList() []string
|
||||||
|
GetChannelName() string
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAdaptor(apiType int) Adaptor {
|
||||||
|
switch apiType {
|
||||||
|
//case constant.APITypeAIProxyLibrary:
|
||||||
|
// return &aiproxy.Adaptor{}
|
||||||
|
case constant.APITypeAli:
|
||||||
|
return &ali.Adaptor{}
|
||||||
|
case constant.APITypeAnthropic:
|
||||||
|
return &claude.Adaptor{}
|
||||||
|
case constant.APITypeBaidu:
|
||||||
|
return &baidu.Adaptor{}
|
||||||
|
case constant.APITypeGemini:
|
||||||
|
return &gemini.Adaptor{}
|
||||||
|
case constant.APITypeOpenAI:
|
||||||
|
return &openai.Adaptor{}
|
||||||
|
case constant.APITypePaLM:
|
||||||
|
return &palm.Adaptor{}
|
||||||
|
case constant.APITypeTencent:
|
||||||
|
return &tencent.Adaptor{}
|
||||||
|
case constant.APITypeXunfei:
|
||||||
|
return &xunfei.Adaptor{}
|
||||||
|
case constant.APITypeZhipu:
|
||||||
|
return &zhipu.Adaptor{}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
80
relay/channel/ali/adaptor.go
Normal file
80
relay/channel/ali/adaptor.go
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
package ali
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
relaychannel "one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/constant"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", info.BaseUrl)
|
||||||
|
if info.RelayMode == constant.RelayModeEmbeddings {
|
||||||
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
|
||||||
|
}
|
||||||
|
return fullRequestURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
|
relaychannel.SetupApiRequestHeader(info, c, req)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
|
if info.IsStream {
|
||||||
|
req.Header.Set("X-DashScope-SSE", "enable")
|
||||||
|
}
|
||||||
|
if c.GetString("plugin") != "" {
|
||||||
|
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
switch relayMode {
|
||||||
|
case constant.RelayModeEmbeddings:
|
||||||
|
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
|
||||||
|
return baiduEmbeddingRequest, nil
|
||||||
|
default:
|
||||||
|
baiduRequest := requestOpenAI2Ali(*request)
|
||||||
|
return baiduRequest, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
if info.IsStream {
|
||||||
|
err, usage = aliStreamHandler(c, resp)
|
||||||
|
} else {
|
||||||
|
switch info.RelayMode {
|
||||||
|
case constant.RelayModeEmbeddings:
|
||||||
|
err, usage = aliEmbeddingHandler(c, resp)
|
||||||
|
default:
|
||||||
|
err, usage = aliHandler(c, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
8
relay/channel/ali/constants.go
Normal file
8
relay/channel/ali/constants.go
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
package ali
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
|
||||||
|
"text-embedding-v1",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelName = "ali"
|
70
relay/channel/ali/dto.go
Normal file
70
relay/channel/ali/dto.go
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
package ali
|
||||||
|
|
||||||
|
type AliMessage struct {
|
||||||
|
User string `json:"user"`
|
||||||
|
Bot string `json:"bot"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliInput struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
History []AliMessage `json:"history"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliParameters struct {
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
Seed uint64 `json:"seed,omitempty"`
|
||||||
|
EnableSearch bool `json:"enable_search,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliChatRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input AliInput `json:"input"`
|
||||||
|
Parameters AliParameters `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliEmbeddingRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input struct {
|
||||||
|
Texts []string `json:"texts"`
|
||||||
|
} `json:"input"`
|
||||||
|
Parameters *struct {
|
||||||
|
TextType string `json:"text_type,omitempty"`
|
||||||
|
} `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliEmbedding struct {
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
TextIndex int `json:"text_index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliEmbeddingResponse struct {
|
||||||
|
Output struct {
|
||||||
|
Embeddings []AliEmbedding `json:"embeddings"`
|
||||||
|
} `json:"output"`
|
||||||
|
Usage AliUsage `json:"usage"`
|
||||||
|
AliError
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliError struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
RequestId string `json:"request_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliUsage struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliOutput struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliChatResponse struct {
|
||||||
|
Output AliOutput `json:"output"`
|
||||||
|
Usage AliUsage `json:"usage"`
|
||||||
|
AliError
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package controller
|
package ali
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
@ -7,81 +7,14 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
|
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
|
||||||
|
|
||||||
type AliMessage struct {
|
func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest {
|
||||||
User string `json:"user"`
|
|
||||||
Bot string `json:"bot"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliInput struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
History []AliMessage `json:"history"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliParameters struct {
|
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
|
||||||
TopK int `json:"top_k,omitempty"`
|
|
||||||
Seed uint64 `json:"seed,omitempty"`
|
|
||||||
EnableSearch bool `json:"enable_search,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliChatRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Input AliInput `json:"input"`
|
|
||||||
Parameters AliParameters `json:"parameters,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliEmbeddingRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Input struct {
|
|
||||||
Texts []string `json:"texts"`
|
|
||||||
} `json:"input"`
|
|
||||||
Parameters *struct {
|
|
||||||
TextType string `json:"text_type,omitempty"`
|
|
||||||
} `json:"parameters,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliEmbedding struct {
|
|
||||||
Embedding []float64 `json:"embedding"`
|
|
||||||
TextIndex int `json:"text_index"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliEmbeddingResponse struct {
|
|
||||||
Output struct {
|
|
||||||
Embeddings []AliEmbedding `json:"embeddings"`
|
|
||||||
} `json:"output"`
|
|
||||||
Usage AliUsage `json:"usage"`
|
|
||||||
AliError
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliError struct {
|
|
||||||
Code string `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
RequestId string `json:"request_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliUsage struct {
|
|
||||||
InputTokens int `json:"input_tokens"`
|
|
||||||
OutputTokens int `json:"output_tokens"`
|
|
||||||
TotalTokens int `json:"total_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliOutput struct {
|
|
||||||
Text string `json:"text"`
|
|
||||||
FinishReason string `json:"finish_reason"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliChatResponse struct {
|
|
||||||
Output AliOutput `json:"output"`
|
|
||||||
Usage AliUsage `json:"usage"`
|
|
||||||
AliError
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
|
|
||||||
messages := make([]AliMessage, 0, len(request.Messages))
|
messages := make([]AliMessage, 0, len(request.Messages))
|
||||||
prompt := ""
|
prompt := ""
|
||||||
for i := 0; i < len(request.Messages); i++ {
|
for i := 0; i < len(request.Messages); i++ {
|
||||||
@ -119,7 +52,7 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
|
func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
|
||||||
return &AliEmbeddingRequest{
|
return &AliEmbeddingRequest{
|
||||||
Model: "text-embedding-v1",
|
Model: "text-embedding-v1",
|
||||||
Input: struct {
|
Input: struct {
|
||||||
@ -130,21 +63,21 @@ func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingReque
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
var aliResponse AliEmbeddingResponse
|
var aliResponse AliEmbeddingResponse
|
||||||
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if aliResponse.Code != "" {
|
if aliResponse.Code != "" {
|
||||||
return &OpenAIErrorWithStatusCode{
|
return &dto.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: OpenAIError{
|
OpenAIError: dto.OpenAIError{
|
||||||
Message: aliResponse.Message,
|
Message: aliResponse.Message,
|
||||||
Type: aliResponse.Code,
|
Type: aliResponse.Code,
|
||||||
Param: aliResponse.RequestId,
|
Param: aliResponse.RequestId,
|
||||||
@ -157,7 +90,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
|
|||||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
@ -165,16 +98,16 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
|
|||||||
return nil, &fullTextResponse.Usage
|
return nil, &fullTextResponse.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
|
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
|
||||||
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
|
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
|
||||||
Object: "list",
|
Object: "list",
|
||||||
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
||||||
Model: "text-embedding-v1",
|
Model: "text-embedding-v1",
|
||||||
Usage: Usage{TotalTokens: response.Usage.TotalTokens},
|
Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, item := range response.Output.Embeddings {
|
for _, item := range response.Output.Embeddings {
|
||||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
|
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
|
||||||
Object: `embedding`,
|
Object: `embedding`,
|
||||||
Index: item.TextIndex,
|
Index: item.TextIndex,
|
||||||
Embedding: item.Embedding,
|
Embedding: item.Embedding,
|
||||||
@ -183,22 +116,22 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddin
|
|||||||
return &openAIEmbeddingResponse
|
return &openAIEmbeddingResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
|
func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
|
||||||
content, _ := json.Marshal(response.Output.Text)
|
content, _ := json.Marshal(response.Output.Text)
|
||||||
choice := OpenAITextResponseChoice{
|
choice := dto.OpenAITextResponseChoice{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Message: Message{
|
Message: dto.Message{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: content,
|
Content: content,
|
||||||
},
|
},
|
||||||
FinishReason: response.Output.FinishReason,
|
FinishReason: response.Output.FinishReason,
|
||||||
}
|
}
|
||||||
fullTextResponse := OpenAITextResponse{
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
Id: response.RequestId,
|
Id: response.RequestId,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Choices: []OpenAITextResponseChoice{choice},
|
Choices: []dto.OpenAITextResponseChoice{choice},
|
||||||
Usage: Usage{
|
Usage: dto.Usage{
|
||||||
PromptTokens: response.Usage.InputTokens,
|
PromptTokens: response.Usage.InputTokens,
|
||||||
CompletionTokens: response.Usage.OutputTokens,
|
CompletionTokens: response.Usage.OutputTokens,
|
||||||
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
|
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
|
||||||
@ -207,25 +140,25 @@ func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
|
|||||||
return &fullTextResponse
|
return &fullTextResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
|
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.Content = aliResponse.Output.Text
|
choice.Delta.Content = aliResponse.Output.Text
|
||||||
if aliResponse.Output.FinishReason != "null" {
|
if aliResponse.Output.FinishReason != "null" {
|
||||||
finishReason := aliResponse.Output.FinishReason
|
finishReason := aliResponse.Output.FinishReason
|
||||||
choice.FinishReason = &finishReason
|
choice.FinishReason = &finishReason
|
||||||
}
|
}
|
||||||
response := ChatCompletionsStreamResponse{
|
response := dto.ChatCompletionsStreamResponse{
|
||||||
Id: aliResponse.RequestId,
|
Id: aliResponse.RequestId,
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Model: "ernie-bot",
|
Model: "ernie-bot",
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||||
}
|
}
|
||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
var usage Usage
|
var usage dto.Usage
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
if atEOF && len(data) == 0 {
|
if atEOF && len(data) == 0 {
|
||||||
@ -255,7 +188,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
lastResponseText := ""
|
lastResponseText := ""
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
@ -288,28 +221,28 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
|
|||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
var aliResponse AliChatResponse
|
var aliResponse AliChatResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(responseBody, &aliResponse)
|
err = json.Unmarshal(responseBody, &aliResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
if aliResponse.Code != "" {
|
if aliResponse.Code != "" {
|
||||||
return &OpenAIErrorWithStatusCode{
|
return &dto.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: OpenAIError{
|
OpenAIError: dto.OpenAIError{
|
||||||
Message: aliResponse.Message,
|
Message: aliResponse.Message,
|
||||||
Type: aliResponse.Code,
|
Type: aliResponse.Code,
|
||||||
Param: aliResponse.RequestId,
|
Param: aliResponse.RequestId,
|
||||||
@ -321,7 +254,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode
|
|||||||
fullTextResponse := responseAli2OpenAI(&aliResponse)
|
fullTextResponse := responseAli2OpenAI(&aliResponse)
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
52
relay/channel/api_request.go
Normal file
52
relay/channel/api_request.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package channel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetupApiRequestHeader(info *relaycommon.RelayInfo, c *gin.Context, req *http.Request) {
|
||||||
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
|
if info.IsStream && c.Request.Header.Get("Accept") == "" {
|
||||||
|
req.Header.Set("Accept", "text/event-stream")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DoApiRequest(a Adaptor, c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
fullRequestURL, err := a.GetRequestURL(info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get request url failed: %w", err)
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("new request failed: %w", err)
|
||||||
|
}
|
||||||
|
err = a.SetupRequestHeader(c, req, info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||||
|
}
|
||||||
|
resp, err := doRequest(c, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("do request failed: %w", err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
||||||
|
resp, err := service.GetHttpClient().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp == nil {
|
||||||
|
return nil, errors.New("resp is nil")
|
||||||
|
}
|
||||||
|
_ = req.Body.Close()
|
||||||
|
_ = c.Request.Body.Close()
|
||||||
|
return resp, nil
|
||||||
|
}
|
92
relay/channel/baidu/adaptor.go
Normal file
92
relay/channel/baidu/adaptor.go
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
package baidu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
relaychannel "one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/constant"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
var fullRequestURL string
|
||||||
|
switch info.UpstreamModelName {
|
||||||
|
case "ERNIE-Bot-4":
|
||||||
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
||||||
|
case "ERNIE-Bot-8K":
|
||||||
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
|
||||||
|
case "ERNIE-Bot":
|
||||||
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
||||||
|
case "ERNIE-Speed":
|
||||||
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
|
||||||
|
case "ERNIE-Bot-turbo":
|
||||||
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||||
|
case "BLOOMZ-7B":
|
||||||
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
||||||
|
case "Embedding-V1":
|
||||||
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
||||||
|
}
|
||||||
|
var accessToken string
|
||||||
|
var err error
|
||||||
|
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
fullRequestURL += "?access_token=" + accessToken
|
||||||
|
return fullRequestURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
|
relaychannel.SetupApiRequestHeader(info, c, req)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
switch relayMode {
|
||||||
|
case constant.RelayModeEmbeddings:
|
||||||
|
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
|
||||||
|
return baiduEmbeddingRequest, nil
|
||||||
|
default:
|
||||||
|
baiduRequest := requestOpenAI2Baidu(*request)
|
||||||
|
return baiduRequest, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
if info.IsStream {
|
||||||
|
err, usage = baiduStreamHandler(c, resp)
|
||||||
|
} else {
|
||||||
|
switch info.RelayMode {
|
||||||
|
case constant.RelayModeEmbeddings:
|
||||||
|
err, usage = baiduEmbeddingHandler(c, resp)
|
||||||
|
default:
|
||||||
|
err, usage = baiduHandler(c, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
12
relay/channel/baidu/constants.go
Normal file
12
relay/channel/baidu/constants.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package baidu
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"ERNIE-Bot-4",
|
||||||
|
"ERNIE-Bot-8K",
|
||||||
|
"ERNIE-Bot",
|
||||||
|
"ERNIE-Speed",
|
||||||
|
"ERNIE-Bot-turbo",
|
||||||
|
"Embedding-V1",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelName = "baidu"
|
71
relay/channel/baidu/dto.go
Normal file
71
relay/channel/baidu/dto.go
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
package baidu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/dto"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BaiduMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduChatRequest struct {
|
||||||
|
Messages []BaiduMessage `json:"messages"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
UserId string `json:"user_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Error struct {
|
||||||
|
ErrorCode int `json:"error_code"`
|
||||||
|
ErrorMsg string `json:"error_msg"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduChatResponse struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Result string `json:"result"`
|
||||||
|
IsTruncated bool `json:"is_truncated"`
|
||||||
|
NeedClearHistory bool `json:"need_clear_history"`
|
||||||
|
Usage dto.Usage `json:"usage"`
|
||||||
|
Error
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduChatStreamResponse struct {
|
||||||
|
BaiduChatResponse
|
||||||
|
SentenceId int `json:"sentence_id"`
|
||||||
|
IsEnd bool `json:"is_end"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduEmbeddingRequest struct {
|
||||||
|
Input []string `json:"input"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduEmbeddingData struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduEmbeddingResponse struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Data []BaiduEmbeddingData `json:"data"`
|
||||||
|
Usage dto.Usage `json:"usage"`
|
||||||
|
Error
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduAccessToken struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
ErrorDescription string `json:"error_description,omitempty"`
|
||||||
|
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||||
|
ExpiresAt time.Time `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduTokenResponse struct {
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package controller
|
package baidu
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
@ -9,6 +9,9 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -16,74 +19,9 @@ import (
|
|||||||
|
|
||||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
||||||
|
|
||||||
type BaiduTokenResponse struct {
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduChatRequest struct {
|
|
||||||
Messages []BaiduMessage `json:"messages"`
|
|
||||||
Stream bool `json:"stream"`
|
|
||||||
UserId string `json:"user_id,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduError struct {
|
|
||||||
ErrorCode int `json:"error_code"`
|
|
||||||
ErrorMsg string `json:"error_msg"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduChatResponse struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Result string `json:"result"`
|
|
||||||
IsTruncated bool `json:"is_truncated"`
|
|
||||||
NeedClearHistory bool `json:"need_clear_history"`
|
|
||||||
Usage Usage `json:"usage"`
|
|
||||||
BaiduError
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduChatStreamResponse struct {
|
|
||||||
BaiduChatResponse
|
|
||||||
SentenceId int `json:"sentence_id"`
|
|
||||||
IsEnd bool `json:"is_end"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduEmbeddingRequest struct {
|
|
||||||
Input []string `json:"input"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduEmbeddingData struct {
|
|
||||||
Object string `json:"object"`
|
|
||||||
Embedding []float64 `json:"embedding"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduEmbeddingResponse struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Data []BaiduEmbeddingData `json:"data"`
|
|
||||||
Usage Usage `json:"usage"`
|
|
||||||
BaiduError
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduAccessToken struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
Error string `json:"error,omitempty"`
|
|
||||||
ErrorDescription string `json:"error_description,omitempty"`
|
|
||||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
|
||||||
ExpiresAt time.Time `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var baiduTokenStore sync.Map
|
var baiduTokenStore sync.Map
|
||||||
|
|
||||||
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
||||||
messages := make([]BaiduMessage, 0, len(request.Messages))
|
messages := make([]BaiduMessage, 0, len(request.Messages))
|
||||||
for _, message := range request.Messages {
|
for _, message := range request.Messages {
|
||||||
if message.Role == "system" {
|
if message.Role == "system" {
|
||||||
@ -108,57 +46,57 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
|
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
|
||||||
content, _ := json.Marshal(response.Result)
|
content, _ := json.Marshal(response.Result)
|
||||||
choice := OpenAITextResponseChoice{
|
choice := dto.OpenAITextResponseChoice{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Message: Message{
|
Message: dto.Message{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: content,
|
Content: content,
|
||||||
},
|
},
|
||||||
FinishReason: "stop",
|
FinishReason: "stop",
|
||||||
}
|
}
|
||||||
fullTextResponse := OpenAITextResponse{
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
Id: response.Id,
|
Id: response.Id,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: response.Created,
|
Created: response.Created,
|
||||||
Choices: []OpenAITextResponseChoice{choice},
|
Choices: []dto.OpenAITextResponseChoice{choice},
|
||||||
Usage: response.Usage,
|
Usage: response.Usage,
|
||||||
}
|
}
|
||||||
return &fullTextResponse
|
return &fullTextResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
|
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.Content = baiduResponse.Result
|
choice.Delta.Content = baiduResponse.Result
|
||||||
if baiduResponse.IsEnd {
|
if baiduResponse.IsEnd {
|
||||||
choice.FinishReason = &stopFinishReason
|
choice.FinishReason = &relaycommon.StopFinishReason
|
||||||
}
|
}
|
||||||
response := ChatCompletionsStreamResponse{
|
response := dto.ChatCompletionsStreamResponse{
|
||||||
Id: baiduResponse.Id,
|
Id: baiduResponse.Id,
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: baiduResponse.Created,
|
Created: baiduResponse.Created,
|
||||||
Model: "ernie-bot",
|
Model: "ernie-bot",
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||||
}
|
}
|
||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
|
func embeddingRequestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduEmbeddingRequest {
|
||||||
return &BaiduEmbeddingRequest{
|
return &BaiduEmbeddingRequest{
|
||||||
Input: request.ParseInput(),
|
Input: request.ParseInput(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
|
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
|
||||||
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
|
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
|
||||||
Object: "list",
|
Object: "list",
|
||||||
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
|
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
|
||||||
Model: "baidu-embedding",
|
Model: "baidu-embedding",
|
||||||
Usage: response.Usage,
|
Usage: response.Usage,
|
||||||
}
|
}
|
||||||
for _, item := range response.Data {
|
for _, item := range response.Data {
|
||||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
|
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
|
||||||
Object: item.Object,
|
Object: item.Object,
|
||||||
Index: item.Index,
|
Index: item.Index,
|
||||||
Embedding: item.Embedding,
|
Embedding: item.Embedding,
|
||||||
@ -167,8 +105,8 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbe
|
|||||||
return &openAIEmbeddingResponse
|
return &openAIEmbeddingResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
var usage Usage
|
var usage dto.Usage
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
if atEOF && len(data) == 0 {
|
if atEOF && len(data) == 0 {
|
||||||
@ -195,7 +133,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
@ -225,28 +163,28 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
|
|||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
var baiduResponse BaiduChatResponse
|
var baiduResponse BaiduChatResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
if baiduResponse.ErrorMsg != "" {
|
if baiduResponse.ErrorMsg != "" {
|
||||||
return &OpenAIErrorWithStatusCode{
|
return &dto.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: OpenAIError{
|
OpenAIError: dto.OpenAIError{
|
||||||
Message: baiduResponse.ErrorMsg,
|
Message: baiduResponse.ErrorMsg,
|
||||||
Type: "baidu_error",
|
Type: "baidu_error",
|
||||||
Param: "",
|
Param: "",
|
||||||
@ -258,7 +196,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
|
|||||||
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
@ -266,23 +204,23 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
|
|||||||
return nil, &fullTextResponse.Usage
|
return nil, &fullTextResponse.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
var baiduResponse BaiduEmbeddingResponse
|
var baiduResponse BaiduEmbeddingResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
if baiduResponse.ErrorMsg != "" {
|
if baiduResponse.ErrorMsg != "" {
|
||||||
return &OpenAIErrorWithStatusCode{
|
return &dto.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: OpenAIError{
|
OpenAIError: dto.OpenAIError{
|
||||||
Message: baiduResponse.ErrorMsg,
|
Message: baiduResponse.ErrorMsg,
|
||||||
Type: "baidu_error",
|
Type: "baidu_error",
|
||||||
Param: "",
|
Param: "",
|
||||||
@ -294,7 +232,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
|
|||||||
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
@ -337,7 +275,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
|||||||
}
|
}
|
||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
req.Header.Add("Accept", "application/json")
|
req.Header.Add("Accept", "application/json")
|
||||||
res, err := impatientHTTPClient.Do(req)
|
res, err := service.GetImpatientHttpClient().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
65
relay/channel/claude/adaptor.go
Normal file
65
relay/channel/claude/adaptor.go
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
relaychannel "one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
|
relaychannel.SetupApiRequestHeader(info, c, req)
|
||||||
|
req.Header.Set("x-api-key", info.ApiKey)
|
||||||
|
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||||
|
if anthropicVersion == "" {
|
||||||
|
anthropicVersion = "2023-06-01"
|
||||||
|
}
|
||||||
|
req.Header.Set("anthropic-version", anthropicVersion)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
if info.IsStream {
|
||||||
|
var responseText string
|
||||||
|
err, responseText = claudeStreamHandler(c, resp)
|
||||||
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
} else {
|
||||||
|
err, usage = claudeHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
7
relay/channel/claude/constants.go
Normal file
7
relay/channel/claude/constants.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelName = "claude"
|
29
relay/channel/claude/dto.go
Normal file
29
relay/channel/claude/dto.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
type ClaudeMetadata struct {
|
||||||
|
UserId string `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
MaxTokensToSample uint `json:"max_tokens_to_sample"`
|
||||||
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeError struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeResponse struct {
|
||||||
|
Completion string `json:"completion"`
|
||||||
|
StopReason string `json:"stop_reason"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Error ClaudeError `json:"error"`
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package controller
|
package claude
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
@ -8,37 +8,11 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ClaudeMetadata struct {
|
|
||||||
UserId string `json:"user_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClaudeRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
MaxTokensToSample uint `json:"max_tokens_to_sample"`
|
|
||||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
|
||||||
TopK int `json:"top_k,omitempty"`
|
|
||||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
|
||||||
Stream bool `json:"stream,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClaudeError struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClaudeResponse struct {
|
|
||||||
Completion string `json:"completion"`
|
|
||||||
StopReason string `json:"stop_reason"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Error ClaudeError `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func stopReasonClaude2OpenAI(reason string) string {
|
func stopReasonClaude2OpenAI(reason string) string {
|
||||||
switch reason {
|
switch reason {
|
||||||
case "stop_sequence":
|
case "stop_sequence":
|
||||||
@ -50,7 +24,7 @@ func stopReasonClaude2OpenAI(reason string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
|
func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
|
||||||
claudeRequest := ClaudeRequest{
|
claudeRequest := ClaudeRequest{
|
||||||
Model: textRequest.Model,
|
Model: textRequest.Model,
|
||||||
Prompt: "",
|
Prompt: "",
|
||||||
@ -78,41 +52,41 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
|
|||||||
return &claudeRequest
|
return &claudeRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
|
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse {
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.Content = claudeResponse.Completion
|
choice.Delta.Content = claudeResponse.Completion
|
||||||
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
||||||
if finishReason != "null" {
|
if finishReason != "null" {
|
||||||
choice.FinishReason = &finishReason
|
choice.FinishReason = &finishReason
|
||||||
}
|
}
|
||||||
var response ChatCompletionsStreamResponse
|
var response dto.ChatCompletionsStreamResponse
|
||||||
response.Object = "chat.completion.chunk"
|
response.Object = "chat.completion.chunk"
|
||||||
response.Model = claudeResponse.Model
|
response.Model = claudeResponse.Model
|
||||||
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
|
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
|
||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
|
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
|
||||||
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
|
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
|
||||||
choice := OpenAITextResponseChoice{
|
choice := dto.OpenAITextResponseChoice{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Message: Message{
|
Message: dto.Message{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: content,
|
Content: content,
|
||||||
Name: nil,
|
Name: nil,
|
||||||
},
|
},
|
||||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||||
}
|
}
|
||||||
fullTextResponse := OpenAITextResponse{
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Choices: []OpenAITextResponseChoice{choice},
|
Choices: []dto.OpenAITextResponseChoice{choice},
|
||||||
}
|
}
|
||||||
return &fullTextResponse
|
return &fullTextResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||||
responseText := ""
|
responseText := ""
|
||||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||||
createdTime := common.GetTimestamp()
|
createdTime := common.GetTimestamp()
|
||||||
@ -142,7 +116,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
@ -172,28 +146,28 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
|
|||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||||
}
|
}
|
||||||
return nil, responseText
|
return nil, responseText
|
||||||
}
|
}
|
||||||
|
|
||||||
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
var claudeResponse ClaudeResponse
|
var claudeResponse ClaudeResponse
|
||||||
err = json.Unmarshal(responseBody, &claudeResponse)
|
err = json.Unmarshal(responseBody, &claudeResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
if claudeResponse.Error.Type != "" {
|
if claudeResponse.Error.Type != "" {
|
||||||
return &OpenAIErrorWithStatusCode{
|
return &dto.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: OpenAIError{
|
OpenAIError: dto.OpenAIError{
|
||||||
Message: claudeResponse.Error.Message,
|
Message: claudeResponse.Error.Message,
|
||||||
Type: claudeResponse.Error.Type,
|
Type: claudeResponse.Error.Type,
|
||||||
Param: "",
|
Param: "",
|
||||||
@ -203,8 +177,8 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
|
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
|
||||||
completionTokens := countTokenText(claudeResponse.Completion, model)
|
completionTokens := service.CountTokenText(claudeResponse.Completion, model)
|
||||||
usage := Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
TotalTokens: promptTokens + completionTokens,
|
TotalTokens: promptTokens + completionTokens,
|
||||||
@ -212,7 +186,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
|||||||
fullTextResponse.Usage = usage
|
fullTextResponse.Usage = usage
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
64
relay/channel/gemini/adaptor.go
Normal file
64
relay/channel/gemini/adaptor.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
relaychannel "one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
version := "v1"
|
||||||
|
action := "generateContent"
|
||||||
|
if info.IsStream {
|
||||||
|
action = "streamGenerateContent"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
|
relaychannel.SetupApiRequestHeader(info, c, req)
|
||||||
|
req.Header.Set("x-goog-api-key", info.ApiKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
return CovertGemini2OpenAI(*request), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
if info.IsStream {
|
||||||
|
var responseText string
|
||||||
|
err, responseText = geminiChatStreamHandler(c, resp)
|
||||||
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
} else {
|
||||||
|
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
12
relay/channel/gemini/constant.go
Normal file
12
relay/channel/gemini/constant.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
const (
|
||||||
|
GeminiVisionMaxImageNum = 16
|
||||||
|
)
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"gemini-pro",
|
||||||
|
"gemini-pro-vision",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelName = "google gemini"
|
62
relay/channel/gemini/dto.go
Normal file
62
relay/channel/gemini/dto.go
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
type GeminiChatRequest struct {
|
||||||
|
Contents []GeminiChatContent `json:"contents"`
|
||||||
|
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
|
||||||
|
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
|
||||||
|
Tools []GeminiChatTools `json:"tools,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiInlineData struct {
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
Data string `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiPart struct {
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiChatContent struct {
|
||||||
|
Role string `json:"role,omitempty"`
|
||||||
|
Parts []GeminiPart `json:"parts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiChatSafetySettings struct {
|
||||||
|
Category string `json:"category"`
|
||||||
|
Threshold string `json:"threshold"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiChatTools struct {
|
||||||
|
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiChatGenerationConfig struct {
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"topP,omitempty"`
|
||||||
|
TopK float64 `json:"topK,omitempty"`
|
||||||
|
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
|
||||||
|
CandidateCount int `json:"candidateCount,omitempty"`
|
||||||
|
StopSequences []string `json:"stopSequences,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiChatCandidate struct {
|
||||||
|
Content GeminiChatContent `json:"content"`
|
||||||
|
FinishReason string `json:"finishReason"`
|
||||||
|
Index int64 `json:"index"`
|
||||||
|
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiChatSafetyRating struct {
|
||||||
|
Category string `json:"category"`
|
||||||
|
Probability string `json:"probability"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiChatPromptFeedback struct {
|
||||||
|
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiChatResponse struct {
|
||||||
|
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||||
|
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package controller
|
package gemini
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
@ -7,57 +7,16 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
GeminiVisionMaxImageNum = 16
|
|
||||||
)
|
|
||||||
|
|
||||||
type GeminiChatRequest struct {
|
|
||||||
Contents []GeminiChatContent `json:"contents"`
|
|
||||||
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
|
|
||||||
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
|
|
||||||
Tools []GeminiChatTools `json:"tools,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiInlineData struct {
|
|
||||||
MimeType string `json:"mimeType"`
|
|
||||||
Data string `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiPart struct {
|
|
||||||
Text string `json:"text,omitempty"`
|
|
||||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiChatContent struct {
|
|
||||||
Role string `json:"role,omitempty"`
|
|
||||||
Parts []GeminiPart `json:"parts"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiChatSafetySettings struct {
|
|
||||||
Category string `json:"category"`
|
|
||||||
Threshold string `json:"threshold"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiChatTools struct {
|
|
||||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiChatGenerationConfig struct {
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopP float64 `json:"topP,omitempty"`
|
|
||||||
TopK float64 `json:"topK,omitempty"`
|
|
||||||
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
|
|
||||||
CandidateCount int `json:"candidateCount,omitempty"`
|
|
||||||
StopSequences []string `json:"stopSequences,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||||
func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
|
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatRequest {
|
||||||
geminiRequest := GeminiChatRequest{
|
geminiRequest := GeminiChatRequest{
|
||||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
||||||
SafetySettings: []GeminiChatSafetySettings{
|
SafetySettings: []GeminiChatSafetySettings{
|
||||||
@ -106,16 +65,16 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
|
|||||||
imageNum := 0
|
imageNum := 0
|
||||||
for _, part := range openaiContent {
|
for _, part := range openaiContent {
|
||||||
|
|
||||||
if part.Type == ContentTypeText {
|
if part.Type == dto.ContentTypeText {
|
||||||
parts = append(parts, GeminiPart{
|
parts = append(parts, GeminiPart{
|
||||||
Text: part.Text,
|
Text: part.Text,
|
||||||
})
|
})
|
||||||
} else if part.Type == ContentTypeImageURL {
|
} else if part.Type == dto.ContentTypeImageURL {
|
||||||
imageNum += 1
|
imageNum += 1
|
||||||
if imageNum > GeminiVisionMaxImageNum {
|
if imageNum > GeminiVisionMaxImageNum {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(MessageImageUrl).Url)
|
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||||
parts = append(parts, GeminiPart{
|
parts = append(parts, GeminiPart{
|
||||||
InlineData: &GeminiInlineData{
|
InlineData: &GeminiInlineData{
|
||||||
MimeType: mimeType,
|
MimeType: mimeType,
|
||||||
@ -154,11 +113,6 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
|
|||||||
return &geminiRequest
|
return &geminiRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeminiChatResponse struct {
|
|
||||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
|
||||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *GeminiChatResponse) GetResponseText() string {
|
func (g *GeminiChatResponse) GetResponseText() string {
|
||||||
if g == nil {
|
if g == nil {
|
||||||
return ""
|
return ""
|
||||||
@ -169,38 +123,22 @@ func (g *GeminiChatResponse) GetResponseText() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeminiChatCandidate struct {
|
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||||
Content GeminiChatContent `json:"content"`
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
FinishReason string `json:"finishReason"`
|
|
||||||
Index int64 `json:"index"`
|
|
||||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiChatSafetyRating struct {
|
|
||||||
Category string `json:"category"`
|
|
||||||
Probability string `json:"probability"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiChatPromptFeedback struct {
|
|
||||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
|
|
||||||
fullTextResponse := OpenAITextResponse{
|
|
||||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
|
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||||
}
|
}
|
||||||
content, _ := json.Marshal("")
|
content, _ := json.Marshal("")
|
||||||
for i, candidate := range response.Candidates {
|
for i, candidate := range response.Candidates {
|
||||||
choice := OpenAITextResponseChoice{
|
choice := dto.OpenAITextResponseChoice{
|
||||||
Index: i,
|
Index: i,
|
||||||
Message: Message{
|
Message: dto.Message{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: content,
|
Content: content,
|
||||||
},
|
},
|
||||||
FinishReason: stopFinishReason,
|
FinishReason: relaycommon.StopFinishReason,
|
||||||
}
|
}
|
||||||
content, _ = json.Marshal(candidate.Content.Parts[0].Text)
|
content, _ = json.Marshal(candidate.Content.Parts[0].Text)
|
||||||
if len(candidate.Content.Parts) > 0 {
|
if len(candidate.Content.Parts) > 0 {
|
||||||
@ -211,18 +149,18 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse
|
|||||||
return &fullTextResponse
|
return &fullTextResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
|
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.Content = geminiResponse.GetResponseText()
|
choice.Delta.Content = geminiResponse.GetResponseText()
|
||||||
choice.FinishReason = &stopFinishReason
|
choice.FinishReason = &relaycommon.StopFinishReason
|
||||||
var response ChatCompletionsStreamResponse
|
var response dto.ChatCompletionsStreamResponse
|
||||||
response.Object = "chat.completion.chunk"
|
response.Object = "chat.completion.chunk"
|
||||||
response.Model = "gemini"
|
response.Model = "gemini"
|
||||||
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
|
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
|
||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||||
responseText := ""
|
responseText := ""
|
||||||
dataChan := make(chan string)
|
dataChan := make(chan string)
|
||||||
stopChan := make(chan bool)
|
stopChan := make(chan bool)
|
||||||
@ -252,7 +190,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
@ -264,14 +202,14 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
|
|||||||
var dummy dummyStruct
|
var dummy dummyStruct
|
||||||
err := json.Unmarshal([]byte(data), &dummy)
|
err := json.Unmarshal([]byte(data), &dummy)
|
||||||
responseText += dummy.Content
|
responseText += dummy.Content
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.Content = dummy.Content
|
choice.Delta.Content = dummy.Content
|
||||||
response := ChatCompletionsStreamResponse{
|
response := dto.ChatCompletionsStreamResponse{
|
||||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Model: "gemini-pro",
|
Model: "gemini-pro",
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||||
}
|
}
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -287,28 +225,28 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
|
|||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||||
}
|
}
|
||||||
return nil, responseText
|
return nil, responseText
|
||||||
}
|
}
|
||||||
|
|
||||||
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
var geminiResponse GeminiChatResponse
|
var geminiResponse GeminiChatResponse
|
||||||
err = json.Unmarshal(responseBody, &geminiResponse)
|
err = json.Unmarshal(responseBody, &geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
if len(geminiResponse.Candidates) == 0 {
|
if len(geminiResponse.Candidates) == 0 {
|
||||||
return &OpenAIErrorWithStatusCode{
|
return &dto.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: OpenAIError{
|
OpenAIError: dto.OpenAIError{
|
||||||
Message: "No candidates returned",
|
Message: "No candidates returned",
|
||||||
Type: "server_error",
|
Type: "server_error",
|
||||||
Param: "",
|
Param: "",
|
||||||
@ -318,8 +256,8 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||||
completionTokens := countTokenText(geminiResponse.GetResponseText(), model)
|
completionTokens := service.CountTokenText(geminiResponse.GetResponseText(), model)
|
||||||
usage := Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
TotalTokens: promptTokens + completionTokens,
|
TotalTokens: promptTokens + completionTokens,
|
||||||
@ -327,7 +265,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
|
|||||||
fullTextResponse.Usage = usage
|
fullTextResponse.Usage = usage
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
7
relay/channel/moonshot/constants.go
Normal file
7
relay/channel/moonshot/constants.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package moonshot
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"moonshot-v1-8k",
|
||||||
|
"moonshot-v1-32k",
|
||||||
|
"moonshot-v1-128k",
|
||||||
|
}
|
84
relay/channel/openai/adaptor.go
Normal file
84
relay/channel/openai/adaptor.go
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
relaychannel "one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
if info.ChannelType == common.ChannelTypeAzure {
|
||||||
|
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||||
|
requestURL := strings.Split(info.RequestURLPath, "?")[0]
|
||||||
|
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, info.ApiVersion)
|
||||||
|
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||||
|
model_ := info.UpstreamModelName
|
||||||
|
model_ = strings.Replace(model_, ".", "", -1)
|
||||||
|
// https://github.com/songquanpeng/one-api/issues/67
|
||||||
|
model_ = strings.TrimSuffix(model_, "-0301")
|
||||||
|
model_ = strings.TrimSuffix(model_, "-0314")
|
||||||
|
model_ = strings.TrimSuffix(model_, "-0613")
|
||||||
|
|
||||||
|
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||||
|
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||||
|
}
|
||||||
|
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
|
relaychannel.SetupApiRequestHeader(info, c, req)
|
||||||
|
if info.ChannelType == common.ChannelTypeAzure {
|
||||||
|
req.Header.Set("api-key", info.ApiKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
|
if info.ChannelType == common.ChannelTypeOpenRouter {
|
||||||
|
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||||||
|
req.Header.Set("X-Title", "One API")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
if info.IsStream {
|
||||||
|
var responseText string
|
||||||
|
err, responseText = openaiStreamHandler(c, resp, info.RelayMode)
|
||||||
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
} else {
|
||||||
|
err, usage = openaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
21
relay/channel/openai/constant.go
Normal file
21
relay/channel/openai/constant.go
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
|
||||||
|
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
|
||||||
|
"gpt-3.5-turbo-instruct",
|
||||||
|
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
|
||||||
|
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||||
|
"gpt-4-turbo-preview",
|
||||||
|
"gpt-4-vision-preview",
|
||||||
|
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||||
|
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
|
||||||
|
"text-moderation-latest", "text-moderation-stable",
|
||||||
|
"text-davinci-edit-001",
|
||||||
|
"davinci-002", "babbage-002",
|
||||||
|
"dall-e-2", "dall-e-3",
|
||||||
|
"whisper-1",
|
||||||
|
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelName = "openai"
|
@ -1,4 +1,4 @@
|
|||||||
package controller
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
@ -8,12 +8,15 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
|
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||||
var responseTextBuilder strings.Builder
|
var responseTextBuilder strings.Builder
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
@ -54,8 +57,8 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
|||||||
}
|
}
|
||||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case RelayModeChatCompletions:
|
case relayconstant.RelayModeChatCompletions:
|
||||||
var streamResponses []ChatCompletionsStreamResponseSimple
|
var streamResponses []dto.ChatCompletionsStreamResponseSimple
|
||||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
@ -66,8 +69,8 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
|||||||
responseTextBuilder.WriteString(choice.Delta.Content)
|
responseTextBuilder.WriteString(choice.Delta.Content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case RelayModeCompletions:
|
case relayconstant.RelayModeCompletions:
|
||||||
var streamResponses []CompletionsStreamResponse
|
var streamResponses []dto.CompletionsStreamResponse
|
||||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
@ -85,7 +88,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
|||||||
}
|
}
|
||||||
common.SafeSend(stopChan, true)
|
common.SafeSend(stopChan, true)
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
@ -102,28 +105,28 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
|||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
return nil, responseTextBuilder.String()
|
return nil, responseTextBuilder.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
var textResponse TextResponse
|
var textResponse dto.TextResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
err = json.Unmarshal(responseBody, &textResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
if textResponse.Error.Type != "" {
|
if textResponse.Error.Type != "" {
|
||||||
return &OpenAIErrorWithStatusCode{
|
return &dto.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: textResponse.Error,
|
OpenAIError: textResponse.Error,
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
}, nil
|
}, nil
|
||||||
@ -140,19 +143,19 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
|||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
_, err = io.Copy(c.Writer, resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if textResponse.Usage.TotalTokens == 0 {
|
if textResponse.Usage.TotalTokens == 0 {
|
||||||
completionTokens := 0
|
completionTokens := 0
|
||||||
for _, choice := range textResponse.Choices {
|
for _, choice := range textResponse.Choices {
|
||||||
completionTokens += countTokenText(string(choice.Message.Content), model)
|
completionTokens += service.CountTokenText(string(choice.Message.Content), model)
|
||||||
}
|
}
|
||||||
textResponse.Usage = Usage{
|
textResponse.Usage = dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
TotalTokens: promptTokens + completionTokens,
|
TotalTokens: promptTokens + completionTokens,
|
59
relay/channel/palm/adaptor.go
Normal file
59
relay/channel/palm/adaptor.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package palm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
relaychannel "one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
|
relaychannel.SetupApiRequestHeader(info, c, req)
|
||||||
|
req.Header.Set("x-goog-api-key", info.ApiKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
if info.IsStream {
|
||||||
|
var responseText string
|
||||||
|
err, responseText = palmStreamHandler(c, resp)
|
||||||
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
} else {
|
||||||
|
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
7
relay/channel/palm/constants.go
Normal file
7
relay/channel/palm/constants.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package palm
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"PaLM-2",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelName = "google palm"
|
38
relay/channel/palm/dto.go
Normal file
38
relay/channel/palm/dto.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package palm
|
||||||
|
|
||||||
|
import "one-api/dto"
|
||||||
|
|
||||||
|
type PaLMChatMessage struct {
|
||||||
|
Author string `json:"author"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMFilter struct {
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMPrompt struct {
|
||||||
|
Messages []PaLMChatMessage `json:"messages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMChatRequest struct {
|
||||||
|
Prompt PaLMPrompt `json:"prompt"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
CandidateCount int `json:"candidateCount,omitempty"`
|
||||||
|
TopP float64 `json:"topP,omitempty"`
|
||||||
|
TopK uint `json:"topK,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMError struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMChatResponse struct {
|
||||||
|
Candidates []PaLMChatMessage `json:"candidates"`
|
||||||
|
Messages []dto.Message `json:"messages"`
|
||||||
|
Filters []PaLMFilter `json:"filters"`
|
||||||
|
Error PaLMError `json:"error"`
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package controller
|
package palm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -7,47 +7,15 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
||||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
||||||
|
|
||||||
type PaLMChatMessage struct {
|
func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
|
||||||
Author string `json:"author"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMFilter struct {
|
|
||||||
Reason string `json:"reason"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMPrompt struct {
|
|
||||||
Messages []PaLMChatMessage `json:"messages"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMChatRequest struct {
|
|
||||||
Prompt PaLMPrompt `json:"prompt"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
CandidateCount int `json:"candidateCount,omitempty"`
|
|
||||||
TopP float64 `json:"topP,omitempty"`
|
|
||||||
TopK uint `json:"topK,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMError struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMChatResponse struct {
|
|
||||||
Candidates []PaLMChatMessage `json:"candidates"`
|
|
||||||
Messages []Message `json:"messages"`
|
|
||||||
Filters []PaLMFilter `json:"filters"`
|
|
||||||
Error PaLMError `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
|
|
||||||
palmRequest := PaLMChatRequest{
|
palmRequest := PaLMChatRequest{
|
||||||
Prompt: PaLMPrompt{
|
Prompt: PaLMPrompt{
|
||||||
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
|
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
|
||||||
@ -71,15 +39,15 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
|
|||||||
return &palmRequest
|
return &palmRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
|
func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
|
||||||
fullTextResponse := OpenAITextResponse{
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
|
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||||
}
|
}
|
||||||
for i, candidate := range response.Candidates {
|
for i, candidate := range response.Candidates {
|
||||||
content, _ := json.Marshal(candidate.Content)
|
content, _ := json.Marshal(candidate.Content)
|
||||||
choice := OpenAITextResponseChoice{
|
choice := dto.OpenAITextResponseChoice{
|
||||||
Index: i,
|
Index: i,
|
||||||
Message: Message{
|
Message: dto.Message{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: content,
|
Content: content,
|
||||||
},
|
},
|
||||||
@ -90,20 +58,20 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
|
|||||||
return &fullTextResponse
|
return &fullTextResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse {
|
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
if len(palmResponse.Candidates) > 0 {
|
if len(palmResponse.Candidates) > 0 {
|
||||||
choice.Delta.Content = palmResponse.Candidates[0].Content
|
choice.Delta.Content = palmResponse.Candidates[0].Content
|
||||||
}
|
}
|
||||||
choice.FinishReason = &stopFinishReason
|
choice.FinishReason = &relaycommon.StopFinishReason
|
||||||
var response ChatCompletionsStreamResponse
|
var response dto.ChatCompletionsStreamResponse
|
||||||
response.Object = "chat.completion.chunk"
|
response.Object = "chat.completion.chunk"
|
||||||
response.Model = "palm2"
|
response.Model = "palm2"
|
||||||
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
|
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
|
||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||||
responseText := ""
|
responseText := ""
|
||||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||||
createdTime := common.GetTimestamp()
|
createdTime := common.GetTimestamp()
|
||||||
@ -144,7 +112,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta
|
|||||||
dataChan <- string(jsonResponse)
|
dataChan <- string(jsonResponse)
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
@ -157,28 +125,28 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta
|
|||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||||
}
|
}
|
||||||
return nil, responseText
|
return nil, responseText
|
||||||
}
|
}
|
||||||
|
|
||||||
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
var palmResponse PaLMChatResponse
|
var palmResponse PaLMChatResponse
|
||||||
err = json.Unmarshal(responseBody, &palmResponse)
|
err = json.Unmarshal(responseBody, &palmResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||||
return &OpenAIErrorWithStatusCode{
|
return &dto.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: OpenAIError{
|
OpenAIError: dto.OpenAIError{
|
||||||
Message: palmResponse.Error.Message,
|
Message: palmResponse.Error.Message,
|
||||||
Type: palmResponse.Error.Status,
|
Type: palmResponse.Error.Status,
|
||||||
Param: "",
|
Param: "",
|
||||||
@ -188,8 +156,8 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||||
completionTokens := countTokenText(palmResponse.Candidates[0].Content, model)
|
completionTokens := service.CountTokenText(palmResponse.Candidates[0].Content, model)
|
||||||
usage := Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
TotalTokens: promptTokens + completionTokens,
|
TotalTokens: promptTokens + completionTokens,
|
||||||
@ -197,7 +165,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
|||||||
fullTextResponse.Usage = usage
|
fullTextResponse.Usage = usage
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
73
relay/channel/tencent/adaptor.go
Normal file
73
relay/channel/tencent/adaptor.go
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
package tencent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
relaychannel "one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
Sign string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
return fmt.Sprintf("%s/hyllm/v1/chat/completions", info.BaseUrl), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
|
relaychannel.SetupApiRequestHeader(info, c, req)
|
||||||
|
req.Header.Set("Authorization", a.Sign)
|
||||||
|
req.Header.Set("X-TC-Action", info.UpstreamModelName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
apiKey := c.Request.Header.Get("Authorization")
|
||||||
|
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||||
|
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tencentRequest := requestOpenAI2Tencent(*request)
|
||||||
|
tencentRequest.AppId = appId
|
||||||
|
tencentRequest.SecretId = secretId
|
||||||
|
// we have to calculate the sign here
|
||||||
|
a.Sign = getTencentSign(*tencentRequest, secretKey)
|
||||||
|
return tencentRequest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
if info.IsStream {
|
||||||
|
var responseText string
|
||||||
|
err, responseText = tencentStreamHandler(c, resp)
|
||||||
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
} else {
|
||||||
|
err, usage = tencentHandler(c, resp)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
9
relay/channel/tencent/constants.go
Normal file
9
relay/channel/tencent/constants.go
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
package tencent
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"ChatPro",
|
||||||
|
"ChatStd",
|
||||||
|
"hunyuan",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelName = "tencent"
|
61
relay/channel/tencent/dto.go
Normal file
61
relay/channel/tencent/dto.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
package tencent
|
||||||
|
|
||||||
|
import "one-api/dto"
|
||||||
|
|
||||||
|
type TencentMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentChatRequest struct {
|
||||||
|
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
||||||
|
SecretId string `json:"secret_id"` // 官网 SecretId
|
||||||
|
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
||||||
|
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
||||||
|
Timestamp int64 `json:"timestamp"`
|
||||||
|
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
||||||
|
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
||||||
|
Expired int64 `json:"expired"`
|
||||||
|
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
||||||
|
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
||||||
|
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
||||||
|
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
||||||
|
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
||||||
|
// 建议该参数和 temperature 只设置1个,不要同时更改
|
||||||
|
TopP float64 `json:"top_p"`
|
||||||
|
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
||||||
|
// 同步请求超时:60s,如果内容较长建议使用流式
|
||||||
|
Stream int `json:"stream"`
|
||||||
|
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
||||||
|
// 输入 content 总数最大支持 3000 token。
|
||||||
|
Messages []TencentMessage `json:"messages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentError struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentUsage struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentResponseChoices struct {
|
||||||
|
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||||
|
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||||
|
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentChatResponse struct {
|
||||||
|
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
||||||
|
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||||
|
Id string `json:"id,omitempty"` // 会话 id
|
||||||
|
Usage dto.Usage `json:"usage,omitempty"` // token 数量
|
||||||
|
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||||
|
Note string `json:"note,omitempty"` // 注释
|
||||||
|
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package controller
|
package tencent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
@ -12,6 +12,9 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -19,65 +22,7 @@ import (
|
|||||||
|
|
||||||
// https://cloud.tencent.com/document/product/1729/97732
|
// https://cloud.tencent.com/document/product/1729/97732
|
||||||
|
|
||||||
type TencentMessage struct {
|
func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest {
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TencentChatRequest struct {
|
|
||||||
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
|
||||||
SecretId string `json:"secret_id"` // 官网 SecretId
|
|
||||||
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
|
||||||
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
|
||||||
Timestamp int64 `json:"timestamp"`
|
|
||||||
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
|
||||||
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
|
||||||
Expired int64 `json:"expired"`
|
|
||||||
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
|
||||||
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
|
||||||
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
|
||||||
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
|
||||||
Temperature float64 `json:"temperature"`
|
|
||||||
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
|
||||||
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
|
||||||
// 建议该参数和 temperature 只设置1个,不要同时更改
|
|
||||||
TopP float64 `json:"top_p"`
|
|
||||||
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
|
||||||
// 同步请求超时:60s,如果内容较长建议使用流式
|
|
||||||
Stream int `json:"stream"`
|
|
||||||
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
|
||||||
// 输入 content 总数最大支持 3000 token。
|
|
||||||
Messages []TencentMessage `json:"messages"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TencentError struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TencentUsage struct {
|
|
||||||
InputTokens int `json:"input_tokens"`
|
|
||||||
OutputTokens int `json:"output_tokens"`
|
|
||||||
TotalTokens int `json:"total_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TencentResponseChoices struct {
|
|
||||||
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
|
||||||
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
|
||||||
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
|
||||||
}
|
|
||||||
|
|
||||||
type TencentChatResponse struct {
|
|
||||||
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
|
||||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
|
||||||
Id string `json:"id,omitempty"` // 会话 id
|
|
||||||
Usage Usage `json:"usage,omitempty"` // token 数量
|
|
||||||
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
|
||||||
Note string `json:"note,omitempty"` // 注释
|
|
||||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
|
|
||||||
messages := make([]TencentMessage, 0, len(request.Messages))
|
messages := make([]TencentMessage, 0, len(request.Messages))
|
||||||
for i := 0; i < len(request.Messages); i++ {
|
for i := 0; i < len(request.Messages); i++ {
|
||||||
message := request.Messages[i]
|
message := request.Messages[i]
|
||||||
@ -112,17 +57,17 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
|
func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
|
||||||
fullTextResponse := OpenAITextResponse{
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Usage: response.Usage,
|
Usage: response.Usage,
|
||||||
}
|
}
|
||||||
if len(response.Choices) > 0 {
|
if len(response.Choices) > 0 {
|
||||||
content, _ := json.Marshal(response.Choices[0].Messages.Content)
|
content, _ := json.Marshal(response.Choices[0].Messages.Content)
|
||||||
choice := OpenAITextResponseChoice{
|
choice := dto.OpenAITextResponseChoice{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Message: Message{
|
Message: dto.Message{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: content,
|
Content: content,
|
||||||
},
|
},
|
||||||
@ -133,24 +78,24 @@ func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
|
|||||||
return &fullTextResponse
|
return &fullTextResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse {
|
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||||
response := ChatCompletionsStreamResponse{
|
response := dto.ChatCompletionsStreamResponse{
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Model: "tencent-hunyuan",
|
Model: "tencent-hunyuan",
|
||||||
}
|
}
|
||||||
if len(TencentResponse.Choices) > 0 {
|
if len(TencentResponse.Choices) > 0 {
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
||||||
if TencentResponse.Choices[0].FinishReason == "stop" {
|
if TencentResponse.Choices[0].FinishReason == "stop" {
|
||||||
choice.FinishReason = &stopFinishReason
|
choice.FinishReason = &relaycommon.StopFinishReason
|
||||||
}
|
}
|
||||||
response.Choices = append(response.Choices, choice)
|
response.Choices = append(response.Choices, choice)
|
||||||
}
|
}
|
||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||||
var responseText string
|
var responseText string
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
@ -181,7 +126,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
@ -209,28 +154,28 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith
|
|||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||||
}
|
}
|
||||||
return nil, responseText
|
return nil, responseText
|
||||||
}
|
}
|
||||||
|
|
||||||
func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
var TencentResponse TencentChatResponse
|
var TencentResponse TencentChatResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(responseBody, &TencentResponse)
|
err = json.Unmarshal(responseBody, &TencentResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
if TencentResponse.Error.Code != 0 {
|
if TencentResponse.Error.Code != 0 {
|
||||||
return &OpenAIErrorWithStatusCode{
|
return &dto.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: OpenAIError{
|
OpenAIError: dto.OpenAIError{
|
||||||
Message: TencentResponse.Error.Message,
|
Message: TencentResponse.Error.Message,
|
||||||
Code: TencentResponse.Error.Code,
|
Code: TencentResponse.Error.Code,
|
||||||
},
|
},
|
||||||
@ -240,7 +185,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatus
|
|||||||
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
68
relay/channel/xunfei/adaptor.go
Normal file
68
relay/channel/xunfei/adaptor.go
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
package xunfei
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
relaychannel "one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
request *dto.GeneralOpenAIRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
|
relaychannel.SetupApiRequestHeader(info, c, req)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
a.request = request
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
// xunfei's request is not http request, so we don't need to do anything here
|
||||||
|
dummyResp := &http.Response{}
|
||||||
|
dummyResp.StatusCode = http.StatusOK
|
||||||
|
return dummyResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
splits := strings.Split(info.ApiKey, "|")
|
||||||
|
if len(splits) != 3 {
|
||||||
|
return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
if a.request == nil {
|
||||||
|
return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
if info.IsStream {
|
||||||
|
err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
|
||||||
|
} else {
|
||||||
|
err, usage = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2])
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
11
relay/channel/xunfei/constants.go
Normal file
11
relay/channel/xunfei/constants.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package xunfei
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"SparkDesk",
|
||||||
|
"SparkDesk-v1.1",
|
||||||
|
"SparkDesk-v2.1",
|
||||||
|
"SparkDesk-v3.1",
|
||||||
|
"SparkDesk-v3.5",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelName = "xunfei"
|
59
relay/channel/xunfei/dto.go
Normal file
59
relay/channel/xunfei/dto.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package xunfei
|
||||||
|
|
||||||
|
import "one-api/dto"
|
||||||
|
|
||||||
|
type XunfeiMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XunfeiChatRequest struct {
|
||||||
|
Header struct {
|
||||||
|
AppId string `json:"app_id"`
|
||||||
|
} `json:"header"`
|
||||||
|
Parameter struct {
|
||||||
|
Chat struct {
|
||||||
|
Domain string `json:"domain,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||||
|
Auditing bool `json:"auditing,omitempty"`
|
||||||
|
} `json:"chat"`
|
||||||
|
} `json:"parameter"`
|
||||||
|
Payload struct {
|
||||||
|
Message struct {
|
||||||
|
Text []XunfeiMessage `json:"text"`
|
||||||
|
} `json:"message"`
|
||||||
|
} `json:"payload"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XunfeiChatResponseTextItem struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XunfeiChatResponse struct {
|
||||||
|
Header struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Sid string `json:"sid"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
} `json:"header"`
|
||||||
|
Payload struct {
|
||||||
|
Choices struct {
|
||||||
|
Status int `json:"status"`
|
||||||
|
Seq int `json:"seq"`
|
||||||
|
Text []XunfeiChatResponseTextItem `json:"text"`
|
||||||
|
} `json:"choices"`
|
||||||
|
Usage struct {
|
||||||
|
//Text struct {
|
||||||
|
// QuestionTokens string `json:"question_tokens"`
|
||||||
|
// PromptTokens string `json:"prompt_tokens"`
|
||||||
|
// CompletionTokens string `json:"completion_tokens"`
|
||||||
|
// TotalTokens string `json:"total_tokens"`
|
||||||
|
//} `json:"text"`
|
||||||
|
Text dto.Usage `json:"text"`
|
||||||
|
} `json:"usage"`
|
||||||
|
} `json:"payload"`
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package controller
|
package xunfei
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
@ -12,6 +12,9 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -19,63 +22,7 @@ import (
|
|||||||
// https://console.xfyun.cn/services/cbm
|
// https://console.xfyun.cn/services/cbm
|
||||||
// https://www.xfyun.cn/doc/spark/Web.html
|
// https://www.xfyun.cn/doc/spark/Web.html
|
||||||
|
|
||||||
type XunfeiMessage struct {
|
func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XunfeiChatRequest struct {
|
|
||||||
Header struct {
|
|
||||||
AppId string `json:"app_id"`
|
|
||||||
} `json:"header"`
|
|
||||||
Parameter struct {
|
|
||||||
Chat struct {
|
|
||||||
Domain string `json:"domain,omitempty"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopK int `json:"top_k,omitempty"`
|
|
||||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
|
||||||
Auditing bool `json:"auditing,omitempty"`
|
|
||||||
} `json:"chat"`
|
|
||||||
} `json:"parameter"`
|
|
||||||
Payload struct {
|
|
||||||
Message struct {
|
|
||||||
Text []XunfeiMessage `json:"text"`
|
|
||||||
} `json:"message"`
|
|
||||||
} `json:"payload"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XunfeiChatResponseTextItem struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
Role string `json:"role"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XunfeiChatResponse struct {
|
|
||||||
Header struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Sid string `json:"sid"`
|
|
||||||
Status int `json:"status"`
|
|
||||||
} `json:"header"`
|
|
||||||
Payload struct {
|
|
||||||
Choices struct {
|
|
||||||
Status int `json:"status"`
|
|
||||||
Seq int `json:"seq"`
|
|
||||||
Text []XunfeiChatResponseTextItem `json:"text"`
|
|
||||||
} `json:"choices"`
|
|
||||||
Usage struct {
|
|
||||||
//Text struct {
|
|
||||||
// QuestionTokens string `json:"question_tokens"`
|
|
||||||
// PromptTokens string `json:"prompt_tokens"`
|
|
||||||
// CompletionTokens string `json:"completion_tokens"`
|
|
||||||
// TotalTokens string `json:"total_tokens"`
|
|
||||||
//} `json:"text"`
|
|
||||||
Text Usage `json:"text"`
|
|
||||||
} `json:"usage"`
|
|
||||||
} `json:"payload"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
|
|
||||||
messages := make([]XunfeiMessage, 0, len(request.Messages))
|
messages := make([]XunfeiMessage, 0, len(request.Messages))
|
||||||
for _, message := range request.Messages {
|
for _, message := range request.Messages {
|
||||||
if message.Role == "system" {
|
if message.Role == "system" {
|
||||||
@ -104,7 +51,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
|
|||||||
return &xunfeiRequest
|
return &xunfeiRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
|
func responseXunfei2OpenAI(response *XunfeiChatResponse) *dto.OpenAITextResponse {
|
||||||
if len(response.Payload.Choices.Text) == 0 {
|
if len(response.Payload.Choices.Text) == 0 {
|
||||||
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
||||||
{
|
{
|
||||||
@ -113,24 +60,24 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
content, _ := json.Marshal(response.Payload.Choices.Text[0].Content)
|
content, _ := json.Marshal(response.Payload.Choices.Text[0].Content)
|
||||||
choice := OpenAITextResponseChoice{
|
choice := dto.OpenAITextResponseChoice{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Message: Message{
|
Message: dto.Message{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: content,
|
Content: content,
|
||||||
},
|
},
|
||||||
FinishReason: stopFinishReason,
|
FinishReason: relaycommon.StopFinishReason,
|
||||||
}
|
}
|
||||||
fullTextResponse := OpenAITextResponse{
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Choices: []OpenAITextResponseChoice{choice},
|
Choices: []dto.OpenAITextResponseChoice{choice},
|
||||||
Usage: response.Payload.Usage.Text,
|
Usage: response.Payload.Usage.Text,
|
||||||
}
|
}
|
||||||
return &fullTextResponse
|
return &fullTextResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse {
|
func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||||
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
||||||
{
|
{
|
||||||
@ -138,16 +85,16 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatComple
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
||||||
if xunfeiResponse.Payload.Choices.Status == 2 {
|
if xunfeiResponse.Payload.Choices.Status == 2 {
|
||||||
choice.FinishReason = &stopFinishReason
|
choice.FinishReason = &relaycommon.StopFinishReason
|
||||||
}
|
}
|
||||||
response := ChatCompletionsStreamResponse{
|
response := dto.ChatCompletionsStreamResponse{
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Model: "SparkDesk",
|
Model: "SparkDesk",
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||||
}
|
}
|
||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
@ -178,14 +125,14 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
|||||||
return callUrl
|
return callUrl
|
||||||
}
|
}
|
||||||
|
|
||||||
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
setEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
var usage Usage
|
var usage dto.Usage
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case xunfeiResponse := <-dataChan:
|
case xunfeiResponse := <-dataChan:
|
||||||
@ -208,13 +155,13 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
|
|||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
var usage Usage
|
var usage dto.Usage
|
||||||
var content string
|
var content string
|
||||||
var xunfeiResponse XunfeiChatResponse
|
var xunfeiResponse XunfeiChatResponse
|
||||||
stop := false
|
stop := false
|
||||||
@ -237,14 +184,14 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin
|
|||||||
response := responseXunfei2OpenAI(&xunfeiResponse)
|
response := responseXunfei2OpenAI(&xunfeiResponse)
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
_, _ = c.Writer.Write(jsonResponse)
|
_, _ = c.Writer.Write(jsonResponse)
|
||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
|
func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
|
||||||
d := websocket.Dialer{
|
d := websocket.Dialer{
|
||||||
HandshakeTimeout: 5 * time.Second,
|
HandshakeTimeout: 5 * time.Second,
|
||||||
}
|
}
|
61
relay/channel/zhipu/adaptor.go
Normal file
61
relay/channel/zhipu/adaptor.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
package zhipu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
relaychannel "one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
method := "invoke"
|
||||||
|
if info.IsStream {
|
||||||
|
method = "sse-invoke"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
|
relaychannel.SetupApiRequestHeader(info, c, req)
|
||||||
|
token := getZhipuToken(info.ApiKey)
|
||||||
|
req.Header.Set("Authorization", token)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
return requestOpenAI2Zhipu(*request), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
if info.IsStream {
|
||||||
|
err, usage = zhipuStreamHandler(c, resp)
|
||||||
|
} else {
|
||||||
|
err, usage = zhipuHandler(c, resp)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
7
relay/channel/zhipu/constants.go
Normal file
7
relay/channel/zhipu/constants.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package zhipu
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelName = "zhipu"
|
46
relay/channel/zhipu/dto.go
Normal file
46
relay/channel/zhipu/dto.go
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
package zhipu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/dto"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ZhipuMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZhipuRequest struct {
|
||||||
|
Prompt []ZhipuMessage `json:"prompt"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
RequestId string `json:"request_id,omitempty"`
|
||||||
|
Incremental bool `json:"incremental,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZhipuResponseData struct {
|
||||||
|
TaskId string `json:"task_id"`
|
||||||
|
RequestId string `json:"request_id"`
|
||||||
|
TaskStatus string `json:"task_status"`
|
||||||
|
Choices []ZhipuMessage `json:"choices"`
|
||||||
|
dto.Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZhipuResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Data ZhipuResponseData `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZhipuStreamMetaResponse struct {
|
||||||
|
RequestId string `json:"request_id"`
|
||||||
|
TaskId string `json:"task_id"`
|
||||||
|
TaskStatus string `json:"task_status"`
|
||||||
|
dto.Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type zhipuTokenData struct {
|
||||||
|
Token string
|
||||||
|
ExpiryTime time.Time
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package controller
|
package zhipu
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
@ -8,6 +8,9 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -18,46 +21,6 @@ import (
|
|||||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
|
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
|
||||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
|
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
|
||||||
|
|
||||||
type ZhipuMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ZhipuRequest struct {
|
|
||||||
Prompt []ZhipuMessage `json:"prompt"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
|
||||||
RequestId string `json:"request_id,omitempty"`
|
|
||||||
Incremental bool `json:"incremental,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ZhipuResponseData struct {
|
|
||||||
TaskId string `json:"task_id"`
|
|
||||||
RequestId string `json:"request_id"`
|
|
||||||
TaskStatus string `json:"task_status"`
|
|
||||||
Choices []ZhipuMessage `json:"choices"`
|
|
||||||
Usage `json:"usage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ZhipuResponse struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Msg string `json:"msg"`
|
|
||||||
Success bool `json:"success"`
|
|
||||||
Data ZhipuResponseData `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ZhipuStreamMetaResponse struct {
|
|
||||||
RequestId string `json:"request_id"`
|
|
||||||
TaskId string `json:"task_id"`
|
|
||||||
TaskStatus string `json:"task_status"`
|
|
||||||
Usage `json:"usage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type zhipuTokenData struct {
|
|
||||||
Token string
|
|
||||||
ExpiryTime time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
var zhipuTokens sync.Map
|
var zhipuTokens sync.Map
|
||||||
var expSeconds int64 = 24 * 3600
|
var expSeconds int64 = 24 * 3600
|
||||||
|
|
||||||
@ -108,7 +71,7 @@ func getZhipuToken(apikey string) string {
|
|||||||
return tokenString
|
return tokenString
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
|
func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *ZhipuRequest {
|
||||||
messages := make([]ZhipuMessage, 0, len(request.Messages))
|
messages := make([]ZhipuMessage, 0, len(request.Messages))
|
||||||
for _, message := range request.Messages {
|
for _, message := range request.Messages {
|
||||||
if message.Role == "system" {
|
if message.Role == "system" {
|
||||||
@ -135,19 +98,19 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
|
func responseZhipu2OpenAI(response *ZhipuResponse) *dto.OpenAITextResponse {
|
||||||
fullTextResponse := OpenAITextResponse{
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
Id: response.Data.TaskId,
|
Id: response.Data.TaskId,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)),
|
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Data.Choices)),
|
||||||
Usage: response.Data.Usage,
|
Usage: response.Data.Usage,
|
||||||
}
|
}
|
||||||
for i, choice := range response.Data.Choices {
|
for i, choice := range response.Data.Choices {
|
||||||
content, _ := json.Marshal(strings.Trim(choice.Content, "\""))
|
content, _ := json.Marshal(strings.Trim(choice.Content, "\""))
|
||||||
openaiChoice := OpenAITextResponseChoice{
|
openaiChoice := dto.OpenAITextResponseChoice{
|
||||||
Index: i,
|
Index: i,
|
||||||
Message: Message{
|
Message: dto.Message{
|
||||||
Role: choice.Role,
|
Role: choice.Role,
|
||||||
Content: content,
|
Content: content,
|
||||||
},
|
},
|
||||||
@ -161,34 +124,34 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
|
|||||||
return &fullTextResponse
|
return &fullTextResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse {
|
func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStreamResponse {
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.Content = zhipuResponse
|
choice.Delta.Content = zhipuResponse
|
||||||
response := ChatCompletionsStreamResponse{
|
response := dto.ChatCompletionsStreamResponse{
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Model: "chatglm",
|
Model: "chatglm",
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||||
}
|
}
|
||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) {
|
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.Content = ""
|
choice.Delta.Content = ""
|
||||||
choice.FinishReason = &stopFinishReason
|
choice.FinishReason = &relaycommon.StopFinishReason
|
||||||
response := ChatCompletionsStreamResponse{
|
response := dto.ChatCompletionsStreamResponse{
|
||||||
Id: zhipuResponse.RequestId,
|
Id: zhipuResponse.RequestId,
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Model: "chatglm",
|
Model: "chatglm",
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||||
}
|
}
|
||||||
return &response, &zhipuResponse.Usage
|
return &response, &zhipuResponse.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
var usage *Usage
|
var usage *dto.Usage
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
if atEOF && len(data) == 0 {
|
if atEOF && len(data) == 0 {
|
||||||
@ -225,7 +188,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
@ -260,28 +223,28 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
|
|||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
var zhipuResponse ZhipuResponse
|
var zhipuResponse ZhipuResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(responseBody, &zhipuResponse)
|
err = json.Unmarshal(responseBody, &zhipuResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
if !zhipuResponse.Success {
|
if !zhipuResponse.Success {
|
||||||
return &OpenAIErrorWithStatusCode{
|
return &dto.OpenAIErrorWithStatusCode{
|
||||||
OpenAIError: OpenAIError{
|
OpenAIError: dto.OpenAIError{
|
||||||
Message: zhipuResponse.Msg,
|
Message: zhipuResponse.Msg,
|
||||||
Type: "zhipu_error",
|
Type: "zhipu_error",
|
||||||
Param: "",
|
Param: "",
|
||||||
@ -293,7 +256,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
|
|||||||
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
|
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
71
relay/common/relay_info.go
Normal file
71
relay/common/relay_info.go
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/relay/constant"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RelayInfo struct {
|
||||||
|
ChannelType int
|
||||||
|
ChannelId int
|
||||||
|
TokenId int
|
||||||
|
UserId int
|
||||||
|
Group string
|
||||||
|
TokenUnlimited bool
|
||||||
|
StartTime time.Time
|
||||||
|
ApiType int
|
||||||
|
IsStream bool
|
||||||
|
RelayMode int
|
||||||
|
UpstreamModelName string
|
||||||
|
RequestURLPath string
|
||||||
|
ApiVersion string
|
||||||
|
PromptTokens int
|
||||||
|
ApiKey string
|
||||||
|
BaseUrl string
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||||
|
channelType := c.GetInt("channel")
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
tokenId := c.GetInt("token_id")
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
group := c.GetString("group")
|
||||||
|
tokenUnlimited := c.GetBool("token_unlimited_quota")
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
apiType := constant.ChannelType2APIType(channelType)
|
||||||
|
|
||||||
|
info := &RelayInfo{
|
||||||
|
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
|
||||||
|
BaseUrl: c.GetString("base_url"),
|
||||||
|
RequestURLPath: c.Request.URL.String(),
|
||||||
|
ChannelType: channelType,
|
||||||
|
ChannelId: channelId,
|
||||||
|
TokenId: tokenId,
|
||||||
|
UserId: userId,
|
||||||
|
Group: group,
|
||||||
|
TokenUnlimited: tokenUnlimited,
|
||||||
|
StartTime: startTime,
|
||||||
|
ApiType: apiType,
|
||||||
|
ApiVersion: c.GetString("api_version"),
|
||||||
|
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||||
|
}
|
||||||
|
if info.BaseUrl == "" {
|
||||||
|
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
||||||
|
}
|
||||||
|
//if info.ChannelType == common.ChannelTypeAzure {
|
||||||
|
// info.ApiVersion = GetAzureAPIVersion(c)
|
||||||
|
//}
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
func (info *RelayInfo) SetPromptTokens(promptTokens int) {
|
||||||
|
info.PromptTokens = promptTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
func (info *RelayInfo) SetIsStream(isStream bool) {
|
||||||
|
info.IsStream = isStream
|
||||||
|
}
|
68
relay/common/relay_utils.go
Normal file
68
relay/common/relay_utils.go
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
_ "image/gif"
|
||||||
|
_ "image/jpeg"
|
||||||
|
_ "image/png"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var StopFinishReason = "stop"
|
||||||
|
|
||||||
|
func RelayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
openAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
OpenAIError: dto.OpenAIError{
|
||||||
|
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
|
||||||
|
Type: "upstream_error",
|
||||||
|
Code: "bad_response_status_code",
|
||||||
|
Param: strconv.Itoa(resp.StatusCode),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var textResponse dto.TextResponse
|
||||||
|
err = json.Unmarshal(responseBody, &textResponse)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||||||
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
|
|
||||||
|
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||||
|
switch channelType {
|
||||||
|
case common.ChannelTypeOpenAI:
|
||||||
|
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
||||||
|
case common.ChannelTypeAzure:
|
||||||
|
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fullRequestURL
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAPIVersion(c *gin.Context) string {
|
||||||
|
query := c.Request.URL.Query()
|
||||||
|
apiVersion := query.Get("api-version")
|
||||||
|
if apiVersion == "" {
|
||||||
|
apiVersion = c.GetString("api_version")
|
||||||
|
}
|
||||||
|
return apiVersion
|
||||||
|
}
|
45
relay/constant/api_type.go
Normal file
45
relay/constant/api_type.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
APITypeOpenAI = iota
|
||||||
|
APITypeAnthropic
|
||||||
|
APITypePaLM
|
||||||
|
APITypeBaidu
|
||||||
|
APITypeZhipu
|
||||||
|
APITypeAli
|
||||||
|
APITypeXunfei
|
||||||
|
APITypeAIProxyLibrary
|
||||||
|
APITypeTencent
|
||||||
|
APITypeGemini
|
||||||
|
|
||||||
|
APITypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
)
|
||||||
|
|
||||||
|
func ChannelType2APIType(channelType int) int {
|
||||||
|
apiType := APITypeOpenAI
|
||||||
|
switch channelType {
|
||||||
|
case common.ChannelTypeAnthropic:
|
||||||
|
apiType = APITypeAnthropic
|
||||||
|
case common.ChannelTypeBaidu:
|
||||||
|
apiType = APITypeBaidu
|
||||||
|
case common.ChannelTypePaLM:
|
||||||
|
apiType = APITypePaLM
|
||||||
|
case common.ChannelTypeZhipu:
|
||||||
|
apiType = APITypeZhipu
|
||||||
|
case common.ChannelTypeAli:
|
||||||
|
apiType = APITypeAli
|
||||||
|
case common.ChannelTypeXunfei:
|
||||||
|
apiType = APITypeXunfei
|
||||||
|
case common.ChannelTypeAIProxyLibrary:
|
||||||
|
apiType = APITypeAIProxyLibrary
|
||||||
|
case common.ChannelTypeTencent:
|
||||||
|
apiType = APITypeTencent
|
||||||
|
case common.ChannelTypeGemini:
|
||||||
|
apiType = APITypeGemini
|
||||||
|
}
|
||||||
|
return apiType
|
||||||
|
}
|
50
relay/constant/relay_mode.go
Normal file
50
relay/constant/relay_mode.go
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
const (
|
||||||
|
RelayModeUnknown = iota
|
||||||
|
RelayModeChatCompletions
|
||||||
|
RelayModeCompletions
|
||||||
|
RelayModeEmbeddings
|
||||||
|
RelayModeModerations
|
||||||
|
RelayModeImagesGenerations
|
||||||
|
RelayModeEdits
|
||||||
|
RelayModeMidjourneyImagine
|
||||||
|
RelayModeMidjourneyDescribe
|
||||||
|
RelayModeMidjourneyBlend
|
||||||
|
RelayModeMidjourneyChange
|
||||||
|
RelayModeMidjourneySimpleChange
|
||||||
|
RelayModeMidjourneyNotify
|
||||||
|
RelayModeMidjourneyTaskFetch
|
||||||
|
RelayModeMidjourneyTaskFetchByCondition
|
||||||
|
RelayModeAudioSpeech
|
||||||
|
RelayModeAudioTranscription
|
||||||
|
RelayModeAudioTranslation
|
||||||
|
)
|
||||||
|
|
||||||
|
func Path2RelayMode(path string) int {
|
||||||
|
relayMode := RelayModeUnknown
|
||||||
|
if strings.HasPrefix(path, "/v1/chat/completions") {
|
||||||
|
relayMode = RelayModeChatCompletions
|
||||||
|
} else if strings.HasPrefix(path, "/v1/completions") {
|
||||||
|
relayMode = RelayModeCompletions
|
||||||
|
} else if strings.HasPrefix(path, "/v1/embeddings") {
|
||||||
|
relayMode = RelayModeEmbeddings
|
||||||
|
} else if strings.HasSuffix(path, "embeddings") {
|
||||||
|
relayMode = RelayModeEmbeddings
|
||||||
|
} else if strings.HasPrefix(path, "/v1/moderations") {
|
||||||
|
relayMode = RelayModeModerations
|
||||||
|
} else if strings.HasPrefix(path, "/v1/images/generations") {
|
||||||
|
relayMode = RelayModeImagesGenerations
|
||||||
|
} else if strings.HasPrefix(path, "/v1/edits") {
|
||||||
|
relayMode = RelayModeEdits
|
||||||
|
} else if strings.HasPrefix(path, "/v1/audio/speech") {
|
||||||
|
relayMode = RelayModeAudioSpeech
|
||||||
|
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
|
||||||
|
relayMode = RelayModeAudioTranscription
|
||||||
|
} else if strings.HasPrefix(path, "/v1/audio/translations") {
|
||||||
|
relayMode = RelayModeAudioTranslation
|
||||||
|
}
|
||||||
|
return relayMode
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package controller
|
package relay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -10,7 +10,10 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/controller"
|
||||||
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -24,7 +27,7 @@ var availableVoices = []string{
|
|||||||
"shimmer",
|
"shimmer",
|
||||||
}
|
}
|
||||||
|
|
||||||
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWithStatusCode {
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
@ -36,7 +39,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||||
err := common.UnmarshalBodyReusable(c, &audioRequest)
|
err := common.UnmarshalBodyReusable(c, &audioRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
audioRequest = AudioRequest{
|
audioRequest = AudioRequest{
|
||||||
@ -47,15 +50,15 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
|
|
||||||
// request validation
|
// request validation
|
||||||
if audioRequest.Model == "" {
|
if audioRequest.Model == "" {
|
||||||
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
|
return service.OpenAIErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
||||||
if audioRequest.Voice == "" {
|
if audioRequest.Voice == "" {
|
||||||
return errorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
|
return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
if !common.StringsContains(availableVoices, audioRequest.Voice) {
|
if !common.StringsContains(availableVoices, audioRequest.Voice) {
|
||||||
return errorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
|
return service.OpenAIErrorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,14 +69,14 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
userQuota, err := model.CacheGetUserQuota(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if userQuota-preConsumedQuota < 0 {
|
if userQuota-preConsumedQuota < 0 {
|
||||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if userQuota > 100*preConsumedQuota {
|
if userQuota > 100*preConsumedQuota {
|
||||||
// in this case, we do not pre-consume quota
|
// in this case, we do not pre-consume quota
|
||||||
@ -83,7 +86,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
if preConsumedQuota > 0 {
|
if preConsumedQuota > 0 {
|
||||||
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
return service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,7 +96,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
modelMap := make(map[string]string)
|
modelMap := make(map[string]string)
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if modelMap[audioRequest.Model] != "" {
|
if modelMap[audioRequest.Model] != "" {
|
||||||
audioRequest.Model = modelMap[audioRequest.Model]
|
audioRequest.Model = modelMap[audioRequest.Model]
|
||||||
@ -106,10 +109,10 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
baseURL = c.GetString("base_url")
|
baseURL = c.GetString("base_url")
|
||||||
}
|
}
|
||||||
|
|
||||||
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
fullRequestURL := common.getFullRequestURL(baseURL, requestURL, channelType)
|
||||||
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
||||||
apiVersion := GetAPIVersion(c)
|
apiVersion := common.GetAPIVersion(c)
|
||||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
|
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,7 +120,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
||||||
@ -133,25 +136,25 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
resp, err := controller.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = req.Body.Close()
|
err = req.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
err = c.Request.Body.Close()
|
err = c.Request.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return relayErrorHandler(resp)
|
return common.relayErrorHandler(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
var audioResponse AudioResponse
|
var audioResponse dto.AudioResponse
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
defer func(ctx context.Context) {
|
||||||
go func() {
|
go func() {
|
||||||
@ -159,10 +162,10 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
quota := 0
|
quota := 0
|
||||||
var promptTokens = 0
|
var promptTokens = 0
|
||||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
||||||
quota = countAudioToken(audioRequest.Input, audioRequest.Model)
|
quota = service.countAudioToken(audioRequest.Input, audioRequest.Model)
|
||||||
promptTokens = quota
|
promptTokens = quota
|
||||||
} else {
|
} else {
|
||||||
quota = countAudioToken(audioResponse.Text, audioRequest.Model)
|
quota = service.countAudioToken(audioResponse.Text, audioRequest.Model)
|
||||||
}
|
}
|
||||||
quota = int(float64(quota) * ratio)
|
quota = int(float64(quota) * ratio)
|
||||||
if ratio != 0 && quota <= 0 {
|
if ratio != 0 && quota <= 0 {
|
||||||
@ -191,18 +194,18 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
err = json.Unmarshal(responseBody, &audioResponse)
|
err = json.Unmarshal(responseBody, &audioResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,11 +218,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
_, err = io.Copy(c.Writer, resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package controller
|
package relay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -10,12 +10,15 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/controller"
|
||||||
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
"one-api/relay/common"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
@ -24,7 +27,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
var imageRequest ImageRequest
|
var imageRequest dto.ImageRequest
|
||||||
if consumeQuota {
|
if consumeQuota {
|
||||||
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -90,7 +93,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
||||||
if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations {
|
if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations {
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
||||||
apiVersion := GetAPIVersion(c)
|
apiVersion := common.GetAPIVersion(c)
|
||||||
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
|
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
|
||||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion)
|
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion)
|
||||||
}
|
}
|
||||||
@ -151,7 +154,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
resp, err := controller.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package controller
|
package relay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -9,6 +9,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/controller"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -104,7 +105,7 @@ func RelayMidjourneyImage(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
|
func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
|
||||||
var midjRequest Midjourney
|
var midjRequest Midjourney
|
||||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -167,7 +168,7 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
|
func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
var err error
|
var err error
|
||||||
var respBody []byte
|
var respBody []byte
|
||||||
@ -244,7 +245,7 @@ const (
|
|||||||
MJSubmitActionUpscale = "UPSCALE" // 放大
|
MJSubmitActionUpscale = "UPSCALE" // 放大
|
||||||
)
|
)
|
||||||
|
|
||||||
func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||||
imageModel := "midjourney"
|
imageModel := "midjourney"
|
||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
@ -427,21 +428,21 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
|||||||
Description: "create_request_failed",
|
Description: "create_request_failed",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
//req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey"))
|
||||||
|
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
//mjToken := ""
|
//mjToken := ""
|
||||||
//if c.Request.Header.Get("Authorization") != "" {
|
//if c.Request.Header.Get("ApiKey") != "" {
|
||||||
// mjToken = strings.Split(c.Request.Header.Get("Authorization"), " ")[1]
|
// mjToken = strings.Split(c.Request.Header.Get("ApiKey"), " ")[1]
|
||||||
//}
|
//}
|
||||||
//req.Header.Set("Authorization", "Bearer midjourney-proxy")
|
//req.Header.Set("ApiKey", "Bearer midjourney-proxy")
|
||||||
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
|
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
|
||||||
// print request header
|
// print request header
|
||||||
log.Printf("request header: %s", req.Header)
|
log.Printf("request header: %s", req.Header)
|
||||||
log.Printf("request body: %s", midjRequest.Prompt)
|
log.Printf("request body: %s", midjRequest.Prompt)
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
resp, err := controller.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
277
relay/relay-text.go
Normal file
277
relay/relay-text.go
Normal file
@ -0,0 +1,277 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/model"
|
||||||
|
relaychannel "one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
|
||||||
|
textRequest := &dto.GeneralOpenAIRequest{}
|
||||||
|
err := common.UnmarshalBodyReusable(c, textRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if relayInfo.RelayMode == relayconstant.RelayModeModerations && textRequest.Model == "" {
|
||||||
|
textRequest.Model = "text-moderation-latest"
|
||||||
|
}
|
||||||
|
if relayInfo.RelayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" {
|
||||||
|
textRequest.Model = c.Param("model")
|
||||||
|
}
|
||||||
|
|
||||||
|
if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
|
||||||
|
return nil, errors.New("max_tokens is invalid")
|
||||||
|
}
|
||||||
|
if textRequest.Model == "" {
|
||||||
|
return nil, errors.New("model is required")
|
||||||
|
}
|
||||||
|
switch relayInfo.RelayMode {
|
||||||
|
case relayconstant.RelayModeCompletions:
|
||||||
|
if textRequest.Prompt == "" {
|
||||||
|
return nil, errors.New("field prompt is required")
|
||||||
|
}
|
||||||
|
case relayconstant.RelayModeChatCompletions:
|
||||||
|
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
|
||||||
|
return nil, errors.New("field messages is required")
|
||||||
|
}
|
||||||
|
case relayconstant.RelayModeEmbeddings:
|
||||||
|
case relayconstant.RelayModeModerations:
|
||||||
|
if textRequest.Input == "" {
|
||||||
|
return nil, errors.New("field input is required")
|
||||||
|
}
|
||||||
|
case relayconstant.RelayModeEdits:
|
||||||
|
if textRequest.Instruction == "" {
|
||||||
|
return nil, errors.New("field instruction is required")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
relayInfo.IsStream = textRequest.Stream
|
||||||
|
return textRequest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||||
|
|
||||||
|
relayInfo := relaycommon.GenRelayInfo(c)
|
||||||
|
|
||||||
|
// get & validate textRequest 获取并验证文本请求
|
||||||
|
textRequest, err := getAndValidateTextRequest(c, relayInfo)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||||
|
return service.OpenAIErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// map model name
|
||||||
|
modelMapping := c.GetString("model_mapping")
|
||||||
|
isModelMapped := false
|
||||||
|
if modelMapping != "" && modelMapping != "{}" {
|
||||||
|
modelMap := make(map[string]string)
|
||||||
|
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if modelMap[textRequest.Model] != "" {
|
||||||
|
textRequest.Model = modelMap[textRequest.Model]
|
||||||
|
isModelMapped = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
modelPrice := common.GetModelPrice(textRequest.Model, false)
|
||||||
|
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||||
|
|
||||||
|
var preConsumedQuota int
|
||||||
|
var ratio float64
|
||||||
|
var modelRatio float64
|
||||||
|
promptTokens, err := getPromptTokens(textRequest, relayInfo)
|
||||||
|
|
||||||
|
// count messages token error 计算promptTokens错误
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelPrice == -1 {
|
||||||
|
preConsumedTokens := common.PreConsumedQuota
|
||||||
|
if textRequest.MaxTokens != 0 {
|
||||||
|
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
|
||||||
|
}
|
||||||
|
modelRatio = common.GetModelRatio(textRequest.Model)
|
||||||
|
ratio = modelRatio * groupRatio
|
||||||
|
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||||
|
} else {
|
||||||
|
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||||
|
}
|
||||||
|
|
||||||
|
// pre-consume quota 预消耗配额
|
||||||
|
userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||||
|
if err != nil {
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
|
||||||
|
adaptor := relaychannel.GetAdaptor(relayInfo.ApiType)
|
||||||
|
if adaptor == nil {
|
||||||
|
return service.OpenAIErrorWrapper(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
adaptor.Init(relayInfo, *textRequest)
|
||||||
|
var requestBody io.Reader
|
||||||
|
if relayInfo.ApiType == relayconstant.APITypeOpenAI {
|
||||||
|
if isModelMapped {
|
||||||
|
jsonStr, err := json.Marshal(textRequest)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
|
} else {
|
||||||
|
requestBody = c.Request.Body
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody = bytes.NewBuffer(jsonData)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
|
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
|
|
||||||
|
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||||
|
if openaiErr != nil {
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
|
||||||
|
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
|
||||||
|
var promptTokens int
|
||||||
|
var err error
|
||||||
|
|
||||||
|
switch info.RelayMode {
|
||||||
|
case relayconstant.RelayModeChatCompletions:
|
||||||
|
promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model)
|
||||||
|
case relayconstant.RelayModeCompletions:
|
||||||
|
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model), nil
|
||||||
|
case relayconstant.RelayModeModerations:
|
||||||
|
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil
|
||||||
|
default:
|
||||||
|
err = errors.New("unknown relay mode")
|
||||||
|
promptTokens = 0
|
||||||
|
}
|
||||||
|
info.PromptTokens = promptTokens
|
||||||
|
return promptTokens, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 预扣费并返回用户剩余配额
|
||||||
|
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return 0, service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if userQuota < 0 || userQuota-preConsumedQuota < 0 {
|
||||||
|
return 0, service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
|
}
|
||||||
|
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||||
|
if err != nil {
|
||||||
|
return 0, service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if userQuota > 100*preConsumedQuota {
|
||||||
|
// 用户额度充足,判断令牌额度是否充足
|
||||||
|
if !relayInfo.TokenUnlimited {
|
||||||
|
// 非无限令牌,判断令牌额度是否充足
|
||||||
|
tokenQuota := c.GetInt("token_quota")
|
||||||
|
if tokenQuota > 100*preConsumedQuota {
|
||||||
|
// 令牌额度充足,信任令牌
|
||||||
|
preConsumedQuota = 0
|
||||||
|
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// in this case, we do not pre-consume quota
|
||||||
|
// because the user has enough quota
|
||||||
|
preConsumedQuota = 0
|
||||||
|
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if preConsumedQuota > 0 {
|
||||||
|
userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
|
||||||
|
if err != nil {
|
||||||
|
return 0, service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return userQuota, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64) {
|
||||||
|
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||||
|
promptTokens := usage.PromptTokens
|
||||||
|
completionTokens := usage.CompletionTokens
|
||||||
|
|
||||||
|
tokenName := ctx.GetString("token_name")
|
||||||
|
|
||||||
|
quota := 0
|
||||||
|
if modelPrice == -1 {
|
||||||
|
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||||
|
quota = promptTokens + int(float64(completionTokens)*completionRatio)
|
||||||
|
quota = int(float64(quota) * ratio)
|
||||||
|
if ratio != 0 && quota <= 0 {
|
||||||
|
quota = 1
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||||
|
}
|
||||||
|
totalTokens := promptTokens + completionTokens
|
||||||
|
var logContent string
|
||||||
|
if modelPrice == -1 {
|
||||||
|
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
|
} else {
|
||||||
|
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
|
||||||
|
}
|
||||||
|
|
||||||
|
// record all the consume log even if quota is 0
|
||||||
|
if totalTokens == 0 {
|
||||||
|
// in this case, must be some error happened
|
||||||
|
// we cannot just return, because we may have to return the pre-consumed quota
|
||||||
|
quota = 0
|
||||||
|
logContent += fmt.Sprintf("(可能是上游超时)")
|
||||||
|
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota))
|
||||||
|
} else {
|
||||||
|
quotaDelta := quota - preConsumedQuota
|
||||||
|
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||||
|
}
|
||||||
|
err = model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||||
|
}
|
||||||
|
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||||
|
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
logModel := textRequest.Model
|
||||||
|
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||||
|
logModel = "gpt-4-gizmo-*"
|
||||||
|
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
|
||||||
|
}
|
||||||
|
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream)
|
||||||
|
|
||||||
|
//if quota != 0 {
|
||||||
|
//
|
||||||
|
//}
|
||||||
|
}
|
@ -1,10 +1,10 @@
|
|||||||
package router
|
package router
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"one-api/controller"
|
"one-api/controller"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
|
"one-api/relay"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetRelayRouter(router *gin.Engine) {
|
func SetRelayRouter(router *gin.Engine) {
|
||||||
@ -44,7 +44,7 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
relayV1Router.POST("/moderations", controller.Relay)
|
relayV1Router.POST("/moderations", controller.Relay)
|
||||||
}
|
}
|
||||||
relayMjRouter := router.Group("/mj")
|
relayMjRouter := router.Group("/mj")
|
||||||
relayMjRouter.GET("/image/:id", controller.RelayMidjourneyImage)
|
relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage)
|
||||||
relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
|
relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
|
||||||
{
|
{
|
||||||
relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
|
relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
|
||||||
|
53
service/channel.go
Normal file
53
service/channel.go
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
relaymodel "one-api/dto"
|
||||||
|
"one-api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// disable & notify
|
||||||
|
func DisableChannel(channelId int, channelName string, reason string) {
|
||||||
|
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
|
||||||
|
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
||||||
|
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
||||||
|
notifyRootUser(subject, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EnableChannel(channelId int, channelName string) {
|
||||||
|
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
|
||||||
|
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
||||||
|
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
||||||
|
notifyRootUser(subject, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool {
|
||||||
|
if !common.AutomaticDisableChannelEnabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if statusCode == http.StatusUnauthorized {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" || err.Code == "billing_not_active" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError) bool {
|
||||||
|
if !common.AutomaticEnableChannelEnabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if openAIErr != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
29
service/error.go
Normal file
29
service/error.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
|
||||||
|
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
|
||||||
|
text := err.Error()
|
||||||
|
// 定义一个正则表达式匹配URL
|
||||||
|
if strings.Contains(text, "Post") {
|
||||||
|
common.SysLog(fmt.Sprintf("error: %s", text))
|
||||||
|
text = "请求上游地址失败"
|
||||||
|
}
|
||||||
|
//避免暴露内部错误
|
||||||
|
|
||||||
|
openAIError := dto.OpenAIError{
|
||||||
|
Message: text,
|
||||||
|
Type: "new_api_error",
|
||||||
|
Code: code,
|
||||||
|
}
|
||||||
|
return &dto.OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: openAIError,
|
||||||
|
StatusCode: statusCode,
|
||||||
|
}
|
||||||
|
}
|
32
service/http_client.go
Normal file
32
service/http_client.go
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var httpClient *http.Client
|
||||||
|
var impatientHTTPClient *http.Client
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if common.RelayTimeout == 0 {
|
||||||
|
httpClient = &http.Client{}
|
||||||
|
} else {
|
||||||
|
httpClient = &http.Client{
|
||||||
|
Timeout: time.Duration(common.RelayTimeout) * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impatientHTTPClient = &http.Client{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetHttpClient() *http.Client {
|
||||||
|
return httpClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetImpatientHttpClient() *http.Client {
|
||||||
|
return impatientHTTPClient
|
||||||
|
}
|
11
service/sse.go
Normal file
11
service/sse.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
func SetEventStreamHeaders(c *gin.Context) {
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
}
|
@ -1,27 +1,19 @@
|
|||||||
package controller
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/pkoukk/tiktoken-go"
|
"github.com/pkoukk/tiktoken-go"
|
||||||
"image"
|
"image"
|
||||||
_ "image/gif"
|
|
||||||
_ "image/jpeg"
|
|
||||||
_ "image/png"
|
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"strconv"
|
"one-api/dto"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
|
||||||
var stopFinishReason = "stop"
|
|
||||||
|
|
||||||
// tokenEncoderMap won't grow after initialization
|
// tokenEncoderMap won't grow after initialization
|
||||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
||||||
var defaultTokenEncoder *tiktoken.Tiktoken
|
var defaultTokenEncoder *tiktoken.Tiktoken
|
||||||
@ -70,7 +62,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
|||||||
return len(tokenEncoder.Encode(text, nil, nil))
|
return len(tokenEncoder.Encode(text, nil, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func getImageToken(imageUrl *MessageImageUrl) (int, error) {
|
func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
|
||||||
if imageUrl.Detail == "low" {
|
if imageUrl.Detail == "low" {
|
||||||
return 85, nil
|
return 85, nil
|
||||||
}
|
}
|
||||||
@ -124,7 +116,7 @@ func getImageToken(imageUrl *MessageImageUrl) (int, error) {
|
|||||||
return tiles*170 + 85, nil
|
return tiles*170 + 85, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func countTokenMessages(messages []Message, model string) (int, error) {
|
func CountTokenMessages(messages []dto.Message, model string) (int, error) {
|
||||||
//recover when panic
|
//recover when panic
|
||||||
tokenEncoder := getTokenEncoder(model)
|
tokenEncoder := getTokenEncoder(model)
|
||||||
// Reference:
|
// Reference:
|
||||||
@ -146,7 +138,7 @@ func countTokenMessages(messages []Message, model string) (int, error) {
|
|||||||
tokenNum += tokensPerMessage
|
tokenNum += tokensPerMessage
|
||||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||||
if len(message.Content) > 0 {
|
if len(message.Content) > 0 {
|
||||||
var arrayContent []MediaMessage
|
var arrayContent []dto.MediaMessage
|
||||||
if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
|
if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
|
||||||
var stringContent string
|
var stringContent string
|
||||||
if err := json.Unmarshal(message.Content, &stringContent); err != nil {
|
if err := json.Unmarshal(message.Content, &stringContent); err != nil {
|
||||||
@ -163,7 +155,7 @@ func countTokenMessages(messages []Message, model string) (int, error) {
|
|||||||
if m.Type == "image_url" {
|
if m.Type == "image_url" {
|
||||||
var imageTokenNum int
|
var imageTokenNum int
|
||||||
if str, ok := m.ImageUrl.(string); ok {
|
if str, ok := m.ImageUrl.(string); ok {
|
||||||
imageTokenNum, err = getImageToken(&MessageImageUrl{Url: str, Detail: "auto"})
|
imageTokenNum, err = getImageToken(&dto.MessageImageUrl{Url: str, Detail: "auto"})
|
||||||
} else {
|
} else {
|
||||||
imageUrlMap := m.ImageUrl.(map[string]interface{})
|
imageUrlMap := m.ImageUrl.(map[string]interface{})
|
||||||
detail, ok := imageUrlMap["detail"]
|
detail, ok := imageUrlMap["detail"]
|
||||||
@ -172,7 +164,7 @@ func countTokenMessages(messages []Message, model string) (int, error) {
|
|||||||
} else {
|
} else {
|
||||||
imageUrlMap["detail"] = "auto"
|
imageUrlMap["detail"] = "auto"
|
||||||
}
|
}
|
||||||
imageUrl := MessageImageUrl{
|
imageUrl := dto.MessageImageUrl{
|
||||||
Url: imageUrlMap["url"].(string),
|
Url: imageUrlMap["url"].(string),
|
||||||
Detail: imageUrlMap["detail"].(string),
|
Detail: imageUrlMap["detail"].(string),
|
||||||
}
|
}
|
||||||
@ -195,16 +187,16 @@ func countTokenMessages(messages []Message, model string) (int, error) {
|
|||||||
return tokenNum, nil
|
return tokenNum, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func countTokenInput(input any, model string) int {
|
func CountTokenInput(input any, model string) int {
|
||||||
switch v := input.(type) {
|
switch v := input.(type) {
|
||||||
case string:
|
case string:
|
||||||
return countTokenText(v, model)
|
return CountTokenText(v, model)
|
||||||
case []string:
|
case []string:
|
||||||
text := ""
|
text := ""
|
||||||
for _, s := range v {
|
for _, s := range v {
|
||||||
text += s
|
text += s
|
||||||
}
|
}
|
||||||
return countTokenText(text, model)
|
return CountTokenText(text, model)
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
@ -213,118 +205,11 @@ func countAudioToken(text string, model string) int {
|
|||||||
if strings.HasPrefix(model, "tts") {
|
if strings.HasPrefix(model, "tts") {
|
||||||
return utf8.RuneCountInString(text)
|
return utf8.RuneCountInString(text)
|
||||||
} else {
|
} else {
|
||||||
return countTokenText(text, model)
|
return CountTokenText(text, model)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func countTokenText(text string, model string) int {
|
func CountTokenText(text string, model string) int {
|
||||||
tokenEncoder := getTokenEncoder(model)
|
tokenEncoder := getTokenEncoder(model)
|
||||||
return getTokenNum(tokenEncoder, text)
|
return getTokenNum(tokenEncoder, text)
|
||||||
}
|
}
|
||||||
|
|
||||||
func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
|
|
||||||
text := err.Error()
|
|
||||||
// 定义一个正则表达式匹配URL
|
|
||||||
if strings.Contains(text, "Post") {
|
|
||||||
common.SysLog(fmt.Sprintf("error: %s", text))
|
|
||||||
text = "请求上游地址失败"
|
|
||||||
}
|
|
||||||
//避免暴露内部错误
|
|
||||||
|
|
||||||
openAIError := OpenAIError{
|
|
||||||
Message: text,
|
|
||||||
Type: "new_api_error",
|
|
||||||
Code: code,
|
|
||||||
}
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: openAIError,
|
|
||||||
StatusCode: statusCode,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
|
|
||||||
if !common.AutomaticDisableChannelEnabled {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if statusCode == http.StatusUnauthorized {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" || err.Code == "billing_not_active" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldEnableChannel(err error, openAIErr *OpenAIError) bool {
|
|
||||||
if !common.AutomaticEnableChannelEnabled {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if openAIErr != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func setEventStreamHeaders(c *gin.Context) {
|
|
||||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
|
||||||
c.Writer.Header().Set("Connection", "keep-alive")
|
|
||||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
|
||||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
}
|
|
||||||
|
|
||||||
func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) {
|
|
||||||
openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
|
|
||||||
Type: "upstream_error",
|
|
||||||
Code: "bad_response_status_code",
|
|
||||||
Param: strconv.Itoa(resp.StatusCode),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var textResponse TextResponse
|
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
|
||||||
|
|
||||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
|
||||||
switch channelType {
|
|
||||||
case common.ChannelTypeOpenAI:
|
|
||||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
|
||||||
case common.ChannelTypeAzure:
|
|
||||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fullRequestURL
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetAPIVersion(c *gin.Context) string {
|
|
||||||
query := c.Request.URL.Query()
|
|
||||||
apiVersion := query.Get("api-version")
|
|
||||||
if apiVersion == "" {
|
|
||||||
apiVersion = c.GetString("api_version")
|
|
||||||
}
|
|
||||||
return apiVersion
|
|
||||||
}
|
|
27
service/usage_helpr.go
Normal file
27
service/usage_helpr.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/constant"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) {
|
||||||
|
switch relayMode {
|
||||||
|
case constant.RelayModeChatCompletions:
|
||||||
|
return CountTokenMessages(textRequest.Messages, textRequest.Model)
|
||||||
|
case constant.RelayModeCompletions:
|
||||||
|
return CountTokenInput(textRequest.Prompt, textRequest.Model), nil
|
||||||
|
case constant.RelayModeModerations:
|
||||||
|
return CountTokenInput(textRequest.Input, textRequest.Model), nil
|
||||||
|
}
|
||||||
|
return 0, errors.New("unknown relay mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
|
||||||
|
usage := &dto.Usage{}
|
||||||
|
usage.PromptTokens = promptTokens
|
||||||
|
usage.CompletionTokens = CountTokenText(responseText, modeName)
|
||||||
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
|
return usage
|
||||||
|
}
|
17
service/user_notify.go
Normal file
17
service/user_notify.go
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func notifyRootUser(subject string, content string) {
|
||||||
|
if common.RootUserEmail == "" {
|
||||||
|
common.RootUserEmail = model.GetRootUserEmail()
|
||||||
|
}
|
||||||
|
err := common.SendEmail(subject, common.RootUserEmail, content)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user