feat: 初步重构

This commit is contained in:
1808837298@qq.com 2024-02-29 01:08:18 +08:00
parent 9b421478c1
commit 5b18cd6b0a
67 changed files with 2646 additions and 2243 deletions

View File

@ -230,7 +230,7 @@ func StringsContains(strs []string, str string) bool {
return false return false
} }
// []byte only read, panic on append // StringToByteSlice []byte only read, panic on append
func StringToByteSlice(s string) []byte { func StringToByteSlice(s string) []byte {
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s)) tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]} tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}

View File

@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"one-api/service"
"strconv" "strconv"
"time" "time"
@ -92,7 +93,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
for k := range headers { for k := range headers {
req.Header.Add(k, headers.Get(k)) req.Header.Add(k, headers.Get(k))
} }
res, err := httpClient.Do(req) res, err := service.GetHttpClient().Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -310,7 +311,7 @@ func updateAllChannelsBalance() error {
} else { } else {
// err is nil & balance <= 0 means quota is used up // err is nil & balance <= 0 means quota is used up
if balance <= 0 { if balance <= 0 {
disableChannel(channel.Id, channel.Name, "余额不足") service.DisableChannel(channel.Id, channel.Name, "余额不足")
} }
} }
time.Sleep(common.RequestInterval) time.Sleep(common.RequestInterval)

View File

@ -5,9 +5,17 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/http/httptest"
"net/url"
"one-api/common" "one-api/common"
"one-api/dto"
"one-api/model" "one-api/model"
relaychannel "one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/service"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@ -15,89 +23,77 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, request.Model)) common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
switch channel.Type { w := httptest.NewRecorder()
case common.ChannelTypePaLM: c, _ := gin.CreateTestContext(w)
fallthrough c.Request = &http.Request{
case common.ChannelTypeAnthropic: Method: "POST",
fallthrough URL: &url.URL{Path: "/v1/chat/completions"},
case common.ChannelTypeBaidu: Body: nil,
fallthrough Header: make(http.Header),
case common.ChannelTypeZhipu:
fallthrough
case common.ChannelTypeAli:
fallthrough
case common.ChannelType360:
fallthrough
case common.ChannelTypeGemini:
fallthrough
case common.ChannelTypeXunfei:
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
case common.ChannelTypeAzure:
if request.Model == "" {
request.Model = "gpt-35-turbo"
}
defer func() {
if err != nil {
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
}
}()
default:
if request.Model == "" {
request.Model = "gpt-3.5-turbo"
}
} }
baseUrl := common.ChannelBaseURLs[channel.Type] c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
if channel.GetBaseURL() != "" { c.Request.Header.Set("Content-Type", "application/json")
baseUrl = channel.GetBaseURL() c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
meta := relaycommon.GenRelayInfo(c)
apiType := constant.ChannelType2APIType(channel.Type)
adaptor := relaychannel.GetAdaptor(apiType)
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
} }
requestURL := getFullRequestURL(baseUrl, "/v1/chat/completions", channel.Type) if testModel == "" {
testModel = adaptor.GetModelList()[0]
}
request := buildTestRequest()
if channel.Type == common.ChannelTypeAzure { adaptor.Init(meta, *request)
requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
}
jsonData, err := json.Marshal(request) request.Model = testModel
meta.UpstreamModelName = testModel
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
if err != nil { if err != nil {
return err, nil return err, nil
} }
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) jsonData, err := json.Marshal(convertedRequest)
if err != nil { if err != nil {
return err, nil return err, nil
} }
if channel.Type == common.ChannelTypeAzure { requestBody := bytes.NewBuffer(jsonData)
req.Header.Set("api-key", channel.Key) c.Request.Body = io.NopCloser(requestBody)
} else { resp, err := adaptor.DoRequest(c, meta, requestBody)
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
req.Header.Set("Content-Type", "application/json")
resp, err := httpClient.Do(req)
if err != nil { if err != nil {
return err, nil return err, nil
} }
defer resp.Body.Close() if resp.StatusCode != http.StatusOK {
var response TextResponse err := relaycommon.RelayErrorHandler(resp)
err = json.NewDecoder(resp.Body).Decode(&response) return fmt.Errorf("status code %d: %s", resp.StatusCode, err.OpenAIError.Message), &err.OpenAIError
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
return fmt.Errorf("%s", respErr.OpenAIError.Message), &respErr.OpenAIError
}
if usage == nil {
return errors.New("usage is nil"), nil
}
result := w.Result()
// print result.Body
respBody, err := io.ReadAll(result.Body)
if err != nil { if err != nil {
return err, nil return err, nil
} }
if response.Usage.CompletionTokens == 0 { common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
if response.Error.Message == "" {
response.Error.Message = "补全 tokens 非预期返回 0"
}
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
}
return nil, nil return nil, nil
} }
func buildTestRequest() *ChatRequest { func buildTestRequest() *dto.GeneralOpenAIRequest {
testRequest := &ChatRequest{ testRequest := &dto.GeneralOpenAIRequest{
Model: "", // this will be set later Model: "", // this will be set later
MaxTokens: 1, MaxTokens: 1,
} }
content, _ := json.Marshal("hi") content, _ := json.Marshal("hi")
testMessage := Message{ testMessage := dto.Message{
Role: "user", Role: "user",
Content: content, Content: content,
} }
@ -114,7 +110,6 @@ func TestChannel(c *gin.Context) {
}) })
return return
} }
testModel := c.Query("model")
channel, err := model.GetChannelById(id, true) channel, err := model.GetChannelById(id, true)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@ -123,12 +118,9 @@ func TestChannel(c *gin.Context) {
}) })
return return
} }
testRequest := buildTestRequest() testModel := c.Query("model")
if testModel != "" {
testRequest.Model = testModel
}
tik := time.Now() tik := time.Now()
err, _ = testChannel(channel, *testRequest) err, _ = testChannel(channel, testModel)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds) go channel.UpdateResponseTime(milliseconds)
@ -152,31 +144,6 @@ func TestChannel(c *gin.Context) {
var testAllChannelsLock sync.Mutex var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false var testAllChannelsRunning bool = false
// disable & notify
func disableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
notifyRootUser(subject, content)
}
func enableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
notifyRootUser(subject, content)
}
func notifyRootUser(subject string, content string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
func testAllChannels(notify bool) error { func testAllChannels(notify bool) error {
if common.RootUserEmail == "" { if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail() common.RootUserEmail = model.GetRootUserEmail()
@ -192,7 +159,6 @@ func testAllChannels(notify bool) error {
if err != nil { if err != nil {
return err return err
} }
testRequest := buildTestRequest()
var disableThreshold = int64(common.ChannelDisableThreshold * 1000) var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 { if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value disableThreshold = 10000000 // a impossible value
@ -201,7 +167,7 @@ func testAllChannels(notify bool) error {
for _, channel := range channels { for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now() tik := time.Now()
err, openaiErr := testChannel(channel, *testRequest) err, openaiErr := testChannel(channel, "")
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
@ -218,11 +184,11 @@ func testAllChannels(notify bool) error {
if channel.AutoBan != nil && *channel.AutoBan == 0 { if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false ban = false
} }
if isChannelEnabled && shouldDisableChannel(openaiErr, -1) && ban { if isChannelEnabled && service.ShouldDisableChannel(openaiErr, -1) && ban {
disableChannel(channel.Id, channel.Name, err.Error()) service.DisableChannel(channel.Id, channel.Name, err.Error())
} }
if !isChannelEnabled && shouldEnableChannel(err, openaiErr) { if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr) {
enableChannel(channel.Id, channel.Name) service.EnableChannel(channel.Id, channel.Name)
} }
channel.UpdateResponseTime(milliseconds) channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval) time.Sleep(common.RequestInterval)

View File

@ -10,7 +10,9 @@ import (
"log" "log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/controller/relay"
"one-api/model" "one-api/model"
relay2 "one-api/relay"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -63,7 +65,7 @@ import (
req = req.WithContext(ctx) req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
//req.Header.Set("Authorization", "Bearer midjourney-proxy") //req.Header.Set("ApiKey", "Bearer midjourney-proxy")
req.Header.Set("mj-api-secret", midjourneyChannel.Key) req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := httpClient.Do(req) resp, err := httpClient.Do(req)
if err != nil { if err != nil {
@ -221,7 +223,7 @@ func UpdateMidjourneyTaskBulk() {
req = req.WithContext(ctx) req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("mj-api-secret", midjourneyChannel.Key) req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := httpClient.Do(req) resp, err := relay.httpClient.Do(req)
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue continue
@ -231,7 +233,7 @@ func UpdateMidjourneyTaskBulk() {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue continue
} }
var responseItems []Midjourney var responseItems []relay2.Midjourney
err = json.Unmarshal(responseBody, &responseItems) err = json.Unmarshal(responseBody, &responseItems)
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
@ -284,7 +286,7 @@ func UpdateMidjourneyTaskBulk() {
} }
} }
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask Midjourney) bool { func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask relay2.Midjourney) bool {
if oldTask.Code != 1 { if oldTask.Code != 1 {
return true return true
} }

View File

@ -2,7 +2,6 @@ package controller
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@ -1,220 +0,0 @@
package controller
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strconv"
"strings"
)
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
type AIProxyLibraryRequest struct {
Model string `json:"model"`
Query string `json:"query"`
LibraryId string `json:"libraryId"`
Stream bool `json:"stream"`
}
type AIProxyLibraryError struct {
ErrCode int `json:"errCode"`
Message string `json:"message"`
}
type AIProxyLibraryDocument struct {
Title string `json:"title"`
URL string `json:"url"`
}
type AIProxyLibraryResponse struct {
Success bool `json:"success"`
Answer string `json:"answer"`
Documents []AIProxyLibraryDocument `json:"documents"`
AIProxyLibraryError
}
type AIProxyLibraryStreamResponse struct {
Content string `json:"content"`
Finish bool `json:"finish"`
Model string `json:"model"`
Documents []AIProxyLibraryDocument `json:"documents"`
}
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
query := ""
if len(request.Messages) != 0 {
query = string(request.Messages[len(request.Messages)-1].Content)
}
return &AIProxyLibraryRequest{
Model: request.Model,
Stream: request.Stream,
Query: query,
}
}
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
if len(documents) == 0 {
return ""
}
content := "\n\n参考文档\n"
for i, document := range documents {
content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
}
return content
}
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
content, _ := json.Marshal(response.Answer + aiProxyDocuments2Markdown(response.Documents))
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: content,
},
FinishReason: "stop",
}
fullTextResponse := OpenAITextResponse{
Id: common.GetUUID(),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
}
return &fullTextResponse
}
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
choice.FinishReason = &stopFinishReason
return &ChatCompletionsStreamResponse{
Id: common.GetUUID(),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
}
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content
return &ChatCompletionsStreamResponse{
Id: common.GetUUID(),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: response.Model,
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
}
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
var documents []AIProxyLibraryDocument
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var AIProxyLibraryResponse AIProxyLibraryStreamResponse
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if len(AIProxyLibraryResponse.Documents) != 0 {
documents = AIProxyLibraryResponse.Documents
}
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
response := documentsAIProxyLibrary(documents)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var AIProxyLibraryResponse AIProxyLibraryResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if AIProxyLibraryResponse.ErrCode != 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: AIProxyLibraryResponse.Message,
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
Code: AIProxyLibraryResponse.ErrCode,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@ -1,752 +0,0 @@
package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"time"
"github.com/gin-gonic/gin"
)
const (
APITypeOpenAI = iota
APITypeClaude
APITypePaLM
APITypeBaidu
APITypeZhipu
APITypeAli
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
)
var httpClient *http.Client
var impatientHTTPClient *http.Client
func init() {
if common.RelayTimeout == 0 {
httpClient = &http.Client{}
} else {
httpClient = &http.Client{
Timeout: time.Duration(common.RelayTimeout) * time.Second,
}
}
impatientHTTPClient = &http.Client{
Timeout: 5 * time.Second,
}
}
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
tokenUnlimited := c.GetBool("token_unlimited_quota")
startTime := time.Now()
var textRequest GeneralOpenAIRequest
err := common.UnmarshalBodyReusable(c, &textRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if relayMode == RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
// request validation
if textRequest.Model == "" {
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
}
switch relayMode {
case RelayModeCompletions:
if textRequest.Prompt == "" {
return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeChatCompletions:
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeEmbeddings:
case RelayModeModerations:
if textRequest.Input == "" {
return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeEdits:
if textRequest.Instruction == "" {
return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
}
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[textRequest.Model] != "" {
textRequest.Model = modelMap[textRequest.Model]
isModelMapped = true
}
}
apiType := APITypeOpenAI
switch channelType {
case common.ChannelTypeAnthropic:
apiType = APITypeClaude
case common.ChannelTypeBaidu:
apiType = APITypeBaidu
case common.ChannelTypePaLM:
apiType = APITypePaLM
case common.ChannelTypeZhipu:
apiType = APITypeZhipu
case common.ChannelTypeAli:
apiType = APITypeAli
case common.ChannelTypeXunfei:
apiType = APITypeXunfei
case common.ChannelTypeAIProxyLibrary:
apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent:
apiType = APITypeTencent
case common.ChannelTypeGemini:
apiType = APITypeGemini
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
requestURL := strings.Split(requestURL, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
baseURL = c.GetString("base_url")
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := textRequest.Model
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613")
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
}
case APITypeClaude:
fullRequestURL = "https://api.anthropic.com/v1/complete"
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
}
case APITypeBaidu:
switch textRequest.Model {
case "ERNIE-Bot":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "ERNIE-Bot-4":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
var err error
if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
}
fullRequestURL += "?access_token=" + apiKey
case APITypePaLM:
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeGemini:
requestBaseURL := "https://generativelanguage.googleapis.com"
if baseURL != "" {
requestBaseURL = baseURL
}
version := "v1beta"
if c.GetString("api_version") != "" {
version = c.GetString("api_version")
}
action := "generateContent"
if textRequest.Stream {
action = "streamGenerateContent"
}
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
//log.Println(fullRequestURL)
case APITypeZhipu:
method := "invoke"
if textRequest.Stream {
method = "sse-invoke"
}
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
case APITypeAli:
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
if relayMode == RelayModeEmbeddings {
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
}
case APITypeTencent:
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
case APITypeAIProxyLibrary:
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
}
var promptTokens int
var completionTokens int
switch relayMode {
case RelayModeChatCompletions:
promptTokens, err = countTokenMessages(textRequest.Messages, textRequest.Model)
if err != nil {
return errorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
}
case RelayModeCompletions:
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
case RelayModeModerations:
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
}
modelPrice := common.GetModelPrice(textRequest.Model, false)
groupRatio := common.GetGroupRatio(group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
if modelPrice == -1 {
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
}
modelRatio = common.GetModelRatio(textRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota < 0 || userQuota-preConsumedQuota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足
if !tokenUnlimited {
// 非无限令牌,判断令牌额度是否充足
tokenQuota := c.GetInt("token_quota")
if tokenQuota > 100*preConsumedQuota {
// 令牌额度充足,信任令牌
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", userId, userQuota, tokenId, tokenQuota))
}
} else {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
}
if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
switch apiType {
case APITypeClaude:
claudeRequest := requestOpenAI2Claude(textRequest)
jsonStr, err := json.Marshal(claudeRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeBaidu:
var jsonData []byte
var err error
switch relayMode {
case RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
jsonData, err = json.Marshal(baiduEmbeddingRequest)
default:
baiduRequest := requestOpenAI2Baidu(textRequest)
jsonData, err = json.Marshal(baiduRequest)
}
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
case APITypePaLM:
palmRequest := requestOpenAI2PaLM(textRequest)
jsonStr, err := json.Marshal(palmRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeGemini:
geminiChatRequest := requestOpenAI2Gemini(textRequest)
jsonStr, err := json.Marshal(geminiChatRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeZhipu:
zhipuRequest := requestOpenAI2Zhipu(textRequest)
jsonStr, err := json.Marshal(zhipuRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeAli:
var jsonStr []byte
var err error
switch relayMode {
case RelayModeEmbeddings:
aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
jsonStr, err = json.Marshal(aliEmbeddingRequest)
default:
aliRequest := requestOpenAI2Ali(textRequest)
jsonStr, err = json.Marshal(aliRequest)
}
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeTencent:
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
if err != nil {
return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
}
tencentRequest := requestOpenAI2Tencent(textRequest)
tencentRequest.AppId = appId
tencentRequest.SecretId = secretId
jsonStr, err := json.Marshal(tencentRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
sign := getTencentSign(*tencentRequest, secretKey)
c.Request.Header.Set("Authorization", sign)
requestBody = bytes.NewBuffer(jsonStr)
case APITypeAIProxyLibrary:
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
}
var req *http.Request
var resp *http.Response
isStream := textRequest.Stream
if apiType != APITypeXunfei { // cause xunfei use websocket
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
// 设置GetBody函数该函数返回一个新的io.ReadCloser该io.ReadCloser返回与原始请求体相同的数据
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(requestBody), nil
}
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
req.Header.Set("api-key", apiKey)
} else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
if c.Request.Header.Get("OpenAI-Organization") != "" {
req.Header.Set("OpenAI-Organization", c.Request.Header.Get("OpenAI-Organization"))
}
if channelType == common.ChannelTypeOpenRouter {
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
req.Header.Set("X-Title", "One API")
}
}
case APITypeClaude:
req.Header.Set("x-api-key", apiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
case APITypeZhipu:
token := getZhipuToken(apiKey)
req.Header.Set("Authorization", token)
case APITypeAli:
req.Header.Set("Authorization", "Bearer "+apiKey)
if textRequest.Stream {
req.Header.Set("X-DashScope-SSE", "enable")
}
case APITypeTencent:
req.Header.Set("Authorization", apiKey)
case APITypeGemini:
req.Header.Set("Content-Type", "application/json")
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
}
if apiType != APITypeGemini {
// 设置公共头部...
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if isStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
}
//req.HeaderBar.Set("Connection", c.Request.HeaderBar.Get("Connection"))
resp, err = httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
return relayErrorHandler(resp)
}
}
var textResponse TextResponse
tokenName := c.GetString("token_name")
defer func(ctx context.Context) {
// c.Writer.Flush()
go func() {
useTimeSeconds := time.Now().Unix() - startTime.Unix()
promptTokens = textResponse.Usage.PromptTokens
completionTokens = textResponse.Usage.CompletionTokens
quota := 0
if modelPrice == -1 {
completionRatio := common.GetCompletionRatio(textRequest.Model)
quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
}
} else {
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
totalTokens := promptTokens + completionTokens
var logContent string
if modelPrice == -1 {
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
}
// record all the consume log even if quota is 0
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(有疑问请联系管理员)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s pre-consumed quota %d", userId, channelId, tokenId, textRequest.Model, preConsumedQuota))
} else {
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
logModel := textRequest.Model
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
logModel = "gpt-4-gizmo-*"
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
}
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), isStream)
//if quota != 0 {
//
//}
}()
}(c.Request.Context())
switch apiType {
case APITypeOpenAI:
if isStream {
err, responseText := openaiStreamHandler(c, resp, relayMode)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeClaude:
if isStream {
err, responseText := claudeStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeBaidu:
if isStream {
err, usage := baiduStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = baiduEmbeddingHandler(c, resp)
default:
err, usage = baiduHandler(c, resp)
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypePaLM:
if textRequest.Stream { // PaLM2 API does not support stream
err, responseText := palmStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeGemini:
if textRequest.Stream {
err, responseText := geminiChatStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeZhipu:
if isStream {
err, usage := zhipuStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
// zhipu's API does not return prompt tokens & completion tokens
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
} else {
err, usage := zhipuHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
// zhipu's API does not return prompt tokens & completion tokens
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
}
case APITypeAli:
if isStream {
err, usage := aliStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
default:
err, usage = aliHandler(c, resp)
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeXunfei:
auth := c.Request.Header.Get("Authorization")
auth = strings.TrimPrefix(auth, "Bearer ")
splits := strings.Split(auth, "|")
if len(splits) != 3 {
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
var err *OpenAIErrorWithStatusCode
var usage *Usage
if isStream {
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
} else {
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
case APITypeAIProxyLibrary:
if isStream {
err, usage := aiProxyLibraryStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
err, usage := aiProxyLibraryHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeTencent:
if isStream {
err, responseText := tencentStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := tencentHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
default:
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
}
}

View File

@ -1,340 +1,34 @@
package controller package controller
import ( import (
"encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"log" "log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto"
"one-api/relay"
"one-api/relay/constant"
relayconstant "one-api/relay/constant"
"one-api/service"
"strconv" "strconv"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
type Message struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
Name *string `json:"name,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
}
type MediaMessage struct {
Type string `json:"type"`
Text string `json:"text"`
ImageUrl any `json:"image_url,omitempty"`
}
type MessageImageUrl struct {
Url string `json:"url"`
Detail string `json:"detail"`
}
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
)
func (m Message) StringContent() string {
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
return stringContent
}
return string(m.Content)
}
func (m Message) ParseContent() []MediaMessage {
var contentList []MediaMessage
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
contentList = append(contentList, MediaMessage{
Type: ContentTypeText,
Text: stringContent,
})
return contentList
}
var arrayContent []json.RawMessage
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
for _, contentItem := range arrayContent {
var contentMap map[string]any
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
continue
}
switch contentMap["type"] {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, MediaMessage{
Type: ContentTypeText,
Text: subStr,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
detail, ok := subObj["detail"]
if ok {
subObj["detail"] = detail.(string)
} else {
subObj["detail"] = "auto"
}
contentList = append(contentList, MediaMessage{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: subObj["url"].(string),
Detail: subObj["detail"].(string),
},
})
}
}
}
return contentList
}
return nil
}
const (
RelayModeUnknown = iota
RelayModeChatCompletions
RelayModeCompletions
RelayModeEmbeddings
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
RelayModeMidjourneyImagine
RelayModeMidjourneyDescribe
RelayModeMidjourneyBlend
RelayModeMidjourneyChange
RelayModeMidjourneySimpleChange
RelayModeMidjourneyNotify
RelayModeMidjourneyTaskFetch
RelayModeMidjourneyTaskFetchByCondition
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
)
// https://platform.openai.com/docs/api-reference/chat
type ResponseFormat struct {
Type string `json:"type,omitempty"`
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
}
type AudioRequest struct {
Model string `json:"model"`
Voice string `json:"voice"`
Input string `json:"input"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens uint `json:"max_tokens"`
}
type TextRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt string `json:"prompt"`
MaxTokens uint `json:"max_tokens"`
//Stream bool `json:"stream"`
}
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
}
type AudioResponse struct {
Text string `json:"text,omitempty"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code any `json:"code"`
}
type OpenAIErrorWithStatusCode struct {
OpenAIError
StatusCode int `json:"status_code"`
}
type TextResponse struct {
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type OpenAITextResponseChoice struct {
Index int `json:"index"`
Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
type OpenAITextResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
}
type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
}
type OpenAIEmbeddingResponse struct {
Object string `json:"object"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
}
type ImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
B64Json string `json:"b64_json"`
}
}
type ChatCompletionsStreamResponseChoice struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type ChatCompletionsStreamResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type ChatCompletionsStreamResponseSimple struct {
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type CompletionsStreamResponse struct {
Choices []struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
NotifyHook string `json:"notifyHook"`
Action string `json:"action"`
Index int `json:"index"`
State string `json:"state"`
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
Content string `json:"content"`
}
type MidjourneyResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Properties interface{} `json:"properties"`
Result string `json:"result"`
}
func Relay(c *gin.Context) { func Relay(c *gin.Context) {
relayMode := RelayModeUnknown relayMode := constant.Path2RelayMode(c.Request.URL.Path)
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { var err *dto.OpenAIErrorWithStatusCode
relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
relayMode = RelayModeCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
relayMode = RelayModeModerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
relayMode = RelayModeEdits
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
relayMode = RelayModeAudioSpeech
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
relayMode = RelayModeAudioTranscription
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
relayMode = RelayModeAudioTranslation
}
var err *OpenAIErrorWithStatusCode
switch relayMode { switch relayMode {
case RelayModeImagesGenerations: case relayconstant.RelayModeImagesGenerations:
err = relayImageHelper(c, relayMode) err = relay.RelayImageHelper(c, relayMode)
case RelayModeAudioSpeech: case relayconstant.RelayModeAudioSpeech:
fallthrough fallthrough
case RelayModeAudioTranslation: case relayconstant.RelayModeAudioTranslation:
fallthrough fallthrough
case RelayModeAudioTranscription: case relayconstant.RelayModeAudioTranscription:
err = relayAudioHelper(c, relayMode) err = relay.RelayAudioHelper(c, relayMode)
default: default:
err = relayTextHelper(c, relayMode) err = relay.TextHelper(c)
} }
if err != nil { if err != nil {
requestId := c.GetString(common.RequestIdKey) requestId := c.GetString(common.RequestIdKey)
@ -358,42 +52,42 @@ func Relay(c *gin.Context) {
autoBan := c.GetBool("auto_ban") autoBan := c.GetBool("auto_ban")
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors // https://platform.openai.com/docs/guides/error-codes/api-errors
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan { if service.ShouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan {
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name") channelName := c.GetString("channel_name")
disableChannel(channelId, channelName, err.Message) service.DisableChannel(channelId, channelName, err.Message)
} }
} }
} }
func RelayMidjourney(c *gin.Context) { func RelayMidjourney(c *gin.Context) {
relayMode := RelayModeUnknown relayMode := relayconstant.RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") { if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
relayMode = RelayModeMidjourneyImagine relayMode = relayconstant.RelayModeMidjourneyImagine
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") { } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
relayMode = RelayModeMidjourneyBlend relayMode = relayconstant.RelayModeMidjourneyBlend
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") { } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
relayMode = RelayModeMidjourneyDescribe relayMode = relayconstant.RelayModeMidjourneyDescribe
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") { } else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
relayMode = RelayModeMidjourneyNotify relayMode = relayconstant.RelayModeMidjourneyNotify
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") { } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
relayMode = RelayModeMidjourneyChange relayMode = relayconstant.RelayModeMidjourneyChange
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") { } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
relayMode = RelayModeMidjourneyChange relayMode = relayconstant.RelayModeMidjourneyChange
} else if strings.HasSuffix(c.Request.URL.Path, "/fetch") { } else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
relayMode = RelayModeMidjourneyTaskFetch relayMode = relayconstant.RelayModeMidjourneyTaskFetch
} else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") { } else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
relayMode = RelayModeMidjourneyTaskFetchByCondition relayMode = relayconstant.RelayModeMidjourneyTaskFetchByCondition
} }
var err *MidjourneyResponse var err *dto.MidjourneyResponse
switch relayMode { switch relayMode {
case RelayModeMidjourneyNotify: case relayconstant.RelayModeMidjourneyNotify:
err = relayMidjourneyNotify(c) err = relay.RelayMidjourneyNotify(c)
case RelayModeMidjourneyTaskFetch, RelayModeMidjourneyTaskFetchByCondition: case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
err = relayMidjourneyTask(c, relayMode) err = relay.RelayMidjourneyTask(c, relayMode)
default: default:
err = relayMidjourneySubmit(c, relayMode) err = relay.RelayMidjourneySubmit(c, relayMode)
} }
//err = relayMidjourneySubmit(c, relayMode) //err = relayMidjourneySubmit(c, relayMode)
log.Println(err) log.Println(err)
@ -425,7 +119,7 @@ func RelayMidjourney(c *gin.Context) {
} }
func RelayNotImplemented(c *gin.Context) { func RelayNotImplemented(c *gin.Context) {
err := OpenAIError{ err := dto.OpenAIError{
Message: "API not implemented", Message: "API not implemented",
Type: "new_api_error", Type: "new_api_error",
Param: "", Param: "",
@ -437,7 +131,7 @@ func RelayNotImplemented(c *gin.Context) {
} }
func RelayNotFound(c *gin.Context) { func RelayNotFound(c *gin.Context) {
err := OpenAIError{ err := dto.OpenAIError{
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
Type: "invalid_request_error", Type: "invalid_request_error",
Param: "", Param: "",

13
dto/error.go Normal file
View File

@ -0,0 +1,13 @@
package dto
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code any `json:"code"`
}
type OpenAIErrorWithStatusCode struct {
OpenAIError
StatusCode int `json:"status_code"`
}

137
dto/request.go Normal file
View File

@ -0,0 +1,137 @@
package dto
import "encoding/json"
type ResponseFormat struct {
Type string `json:"type,omitempty"`
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
}
type Message struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
Name *string `json:"name,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
}
type MediaMessage struct {
Type string `json:"type"`
Text string `json:"text"`
ImageUrl any `json:"image_url,omitempty"`
}
type MessageImageUrl struct {
Url string `json:"url"`
Detail string `json:"detail"`
}
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
)
func (m Message) StringContent() string {
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
return stringContent
}
return string(m.Content)
}
func (m Message) ParseContent() []MediaMessage {
var contentList []MediaMessage
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
contentList = append(contentList, MediaMessage{
Type: ContentTypeText,
Text: stringContent,
})
return contentList
}
var arrayContent []json.RawMessage
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
for _, contentItem := range arrayContent {
var contentMap map[string]any
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
continue
}
switch contentMap["type"] {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, MediaMessage{
Type: ContentTypeText,
Text: subStr,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
detail, ok := subObj["detail"]
if ok {
subObj["detail"] = detail.(string)
} else {
subObj["detail"] = "auto"
}
contentList = append(contentList, MediaMessage{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: subObj["url"].(string),
Detail: subObj["detail"].(string),
},
})
}
}
}
return contentList
}
return nil
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

86
dto/response.go Normal file
View File

@ -0,0 +1,86 @@
package dto
type TextResponse struct {
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type OpenAITextResponseChoice struct {
Index int `json:"index"`
Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
type OpenAITextResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
}
type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
}
type OpenAIEmbeddingResponse struct {
Object string `json:"object"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
}
type ImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
B64Json string `json:"b64_json"`
}
}
type ChatCompletionsStreamResponseChoice struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type ChatCompletionsStreamResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type ChatCompletionsStreamResponseSimple struct {
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type CompletionsStreamResponse struct {
Choices []struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
NotifyHook string `json:"notifyHook"`
Action string `json:"action"`
Index int `json:"index"`
State string `json:"state"`
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
Content string `json:"content"`
}
type MidjourneyResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Properties interface{} `json:"properties"`
Result string `json:"result"`
}

View File

@ -12,6 +12,7 @@ import (
"one-api/controller" "one-api/controller"
"one-api/middleware" "one-api/middleware"
"one-api/model" "one-api/model"
"one-api/relay/common"
"one-api/router" "one-api/router"
"os" "os"
"strconv" "strconv"
@ -105,7 +106,7 @@ func main() {
common.SysLog("pprof enabled") common.SysLog("pprof enabled")
} }
controller.InitTokenEncoders() common.InitTokenEncoders()
// Initialize HTTP server // Initialize HTTP server
server := gin.New() server := gin.New()

View File

@ -129,15 +129,18 @@ func Distribute() func(c *gin.Context) {
c.Set("model_mapping", channel.GetModelMapping()) c.Set("model_mapping", channel.GetModelMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
// TODO: api_version统一
switch channel.Type { switch channel.Type {
case common.ChannelTypeAzure: case common.ChannelTypeAzure:
c.Set("api_version", channel.Other) c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei: case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other) c.Set("api_version", channel.Other)
case common.ChannelTypeAIProxyLibrary: //case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other) // c.Set("library_id", channel.Other)
case common.ChannelTypeGemini: case common.ChannelTypeGemini:
c.Set("api_version", channel.Other) c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
} }
c.Next() c.Next()
} }

57
relay/channel/adapter.go Normal file
View File

@ -0,0 +1,57 @@
package channel
import (
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel/ali"
"one-api/relay/channel/baidu"
"one-api/relay/channel/claude"
"one-api/relay/channel/gemini"
"one-api/relay/channel/openai"
"one-api/relay/channel/palm"
"one-api/relay/channel/tencent"
"one-api/relay/channel/xunfei"
"one-api/relay/channel/zhipu"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor interface {
// Init IsStream bool
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
GetModelList() []string
GetChannelName() string
}
func GetAdaptor(apiType int) Adaptor {
switch apiType {
//case constant.APITypeAIProxyLibrary:
// return &aiproxy.Adaptor{}
case constant.APITypeAli:
return &ali.Adaptor{}
case constant.APITypeAnthropic:
return &claude.Adaptor{}
case constant.APITypeBaidu:
return &baidu.Adaptor{}
case constant.APITypeGemini:
return &gemini.Adaptor{}
case constant.APITypeOpenAI:
return &openai.Adaptor{}
case constant.APITypePaLM:
return &palm.Adaptor{}
case constant.APITypeTencent:
return &tencent.Adaptor{}
case constant.APITypeXunfei:
return &xunfei.Adaptor{}
case constant.APITypeZhipu:
return &zhipu.Adaptor{}
}
return nil
}

View File

@ -0,0 +1,80 @@
package ali
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
relaychannel "one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", info.BaseUrl)
if info.RelayMode == constant.RelayModeEmbeddings {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
}
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
relaychannel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
if info.IsStream {
req.Header.Set("X-DashScope-SSE", "enable")
}
if c.GetString("plugin") != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
}
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
return baiduEmbeddingRequest, nil
default:
baiduRequest := requestOpenAI2Ali(*request)
return baiduRequest, nil
}
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return relaychannel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = aliStreamHandler(c, resp)
} else {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
default:
err, usage = aliHandler(c, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@ -0,0 +1,8 @@
package ali
var ModelList = []string{
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
"text-embedding-v1",
}
var ChannelName = "ali"

70
relay/channel/ali/dto.go Normal file
View File

@ -0,0 +1,70 @@
package ali
type AliMessage struct {
User string `json:"user"`
Bot string `json:"bot"`
}
type AliInput struct {
Prompt string `json:"prompt"`
History []AliMessage `json:"history"`
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
}
type AliChatRequest struct {
Model string `json:"model"`
Input AliInput `json:"input"`
Parameters AliParameters `json:"parameters,omitempty"`
}
type AliEmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type AliEmbedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type AliEmbeddingResponse struct {
Output struct {
Embeddings []AliEmbedding `json:"embeddings"`
} `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
type AliError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type AliOutput struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type AliChatResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}

View File

@ -1,4 +1,4 @@
package controller package ali
import ( import (
"bufio" "bufio"
@ -7,81 +7,14 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto"
"one-api/service"
"strings" "strings"
) )
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
type AliMessage struct { func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest {
User string `json:"user"`
Bot string `json:"bot"`
}
type AliInput struct {
Prompt string `json:"prompt"`
History []AliMessage `json:"history"`
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
}
type AliChatRequest struct {
Model string `json:"model"`
Input AliInput `json:"input"`
Parameters AliParameters `json:"parameters,omitempty"`
}
type AliEmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type AliEmbedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type AliEmbeddingResponse struct {
Output struct {
Embeddings []AliEmbedding `json:"embeddings"`
} `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
type AliError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type AliOutput struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type AliChatResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages)) messages := make([]AliMessage, 0, len(request.Messages))
prompt := "" prompt := ""
for i := 0; i < len(request.Messages); i++ { for i := 0; i < len(request.Messages); i++ {
@ -119,7 +52,7 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
} }
} }
func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest { func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
return &AliEmbeddingRequest{ return &AliEmbeddingRequest{
Model: "text-embedding-v1", Model: "text-embedding-v1",
Input: struct { Input: struct {
@ -130,21 +63,21 @@ func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingReque
} }
} }
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var aliResponse AliEmbeddingResponse var aliResponse AliEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse) err := json.NewDecoder(resp.Body).Decode(&aliResponse)
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
if aliResponse.Code != "" { if aliResponse.Code != "" {
return &OpenAIErrorWithStatusCode{ return &dto.OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{ OpenAIError: dto.OpenAIError{
Message: aliResponse.Message, Message: aliResponse.Message,
Type: aliResponse.Code, Type: aliResponse.Code,
Param: aliResponse.RequestId, Param: aliResponse.RequestId,
@ -157,7 +90,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
@ -165,16 +98,16 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
return nil, &fullTextResponse.Usage return nil, &fullTextResponse.Usage
} }
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse { func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
openAIEmbeddingResponse := OpenAIEmbeddingResponse{ openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
Object: "list", Object: "list",
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)), Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: "text-embedding-v1", Model: "text-embedding-v1",
Usage: Usage{TotalTokens: response.Usage.TotalTokens}, Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
} }
for _, item := range response.Output.Embeddings { for _, item := range response.Output.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
Object: `embedding`, Object: `embedding`,
Index: item.TextIndex, Index: item.TextIndex,
Embedding: item.Embedding, Embedding: item.Embedding,
@ -183,22 +116,22 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddin
return &openAIEmbeddingResponse return &openAIEmbeddingResponse
} }
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
content, _ := json.Marshal(response.Output.Text) content, _ := json.Marshal(response.Output.Text)
choice := OpenAITextResponseChoice{ choice := dto.OpenAITextResponseChoice{
Index: 0, Index: 0,
Message: Message{ Message: dto.Message{
Role: "assistant", Role: "assistant",
Content: content, Content: content,
}, },
FinishReason: response.Output.FinishReason, FinishReason: response.Output.FinishReason,
} }
fullTextResponse := OpenAITextResponse{ fullTextResponse := dto.OpenAITextResponse{
Id: response.RequestId, Id: response.RequestId,
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice}, Choices: []dto.OpenAITextResponseChoice{choice},
Usage: Usage{ Usage: dto.Usage{
PromptTokens: response.Usage.InputTokens, PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens, CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
@ -207,25 +140,25 @@ func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
return &fullTextResponse return &fullTextResponse
} }
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse { func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = aliResponse.Output.Text choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" { if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason choice.FinishReason = &finishReason
} }
response := ChatCompletionsStreamResponse{ response := dto.ChatCompletionsStreamResponse{
Id: aliResponse.RequestId, Id: aliResponse.RequestId,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Model: "ernie-bot", Model: "ernie-bot",
Choices: []ChatCompletionsStreamResponseChoice{choice}, Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
} }
return &response return &response
} }
func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage Usage var usage dto.Usage
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
@ -255,7 +188,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
} }
stopChan <- true stopChan <- true
}() }()
setEventStreamHeaders(c) service.SetEventStreamHeaders(c)
lastResponseText := "" lastResponseText := ""
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
@ -288,28 +221,28 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
}) })
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
return nil, &usage return nil, &usage
} }
func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var aliResponse AliChatResponse var aliResponse AliChatResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
err = json.Unmarshal(responseBody, &aliResponse) err = json.Unmarshal(responseBody, &aliResponse)
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if aliResponse.Code != "" { if aliResponse.Code != "" {
return &OpenAIErrorWithStatusCode{ return &dto.OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{ OpenAIError: dto.OpenAIError{
Message: aliResponse.Message, Message: aliResponse.Message,
Type: aliResponse.Code, Type: aliResponse.Code,
Param: aliResponse.RequestId, Param: aliResponse.RequestId,
@ -321,7 +254,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode
fullTextResponse := responseAli2OpenAI(&aliResponse) fullTextResponse := responseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)

View File

@ -0,0 +1,52 @@
package channel
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
relaycommon "one-api/relay/common"
"one-api/service"
)
func SetupApiRequestHeader(info *relaycommon.RelayInfo, c *gin.Context, req *http.Request) {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if info.IsStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
}
func DoApiRequest(a Adaptor, c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(info)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
err = a.SetupRequestHeader(c, req, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(c, req)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := service.GetHttpClient().Do(req)
if err != nil {
return nil, err
}
if resp == nil {
return nil, errors.New("resp is nil")
}
_ = req.Body.Close()
_ = c.Request.Body.Close()
return resp, nil
}

View File

@ -0,0 +1,92 @@
package baidu
import (
"errors"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
relaychannel "one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
var fullRequestURL string
switch info.UpstreamModelName {
case "ERNIE-Bot-4":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "ERNIE-Bot-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
case "ERNIE-Bot":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Speed":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
}
var accessToken string
var err error
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
return "", err
}
fullRequestURL += "?access_token=" + accessToken
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
relaychannel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
return baiduEmbeddingRequest, nil
default:
baiduRequest := requestOpenAI2Baidu(*request)
return baiduRequest, nil
}
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return relaychannel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = baiduStreamHandler(c, resp)
} else {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
err, usage = baiduEmbeddingHandler(c, resp)
default:
err, usage = baiduHandler(c, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@ -0,0 +1,12 @@
package baidu
var ModelList = []string{
"ERNIE-Bot-4",
"ERNIE-Bot-8K",
"ERNIE-Bot",
"ERNIE-Speed",
"ERNIE-Bot-turbo",
"Embedding-V1",
}
var ChannelName = "baidu"

View File

@ -0,0 +1,71 @@
package baidu
import (
"one-api/dto"
"time"
)
type BaiduMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"`
Stream bool `json:"stream"`
UserId string `json:"user_id,omitempty"`
}
type Error struct {
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
}
type BaiduChatResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage dto.Usage `json:"usage"`
Error
}
type BaiduChatStreamResponse struct {
BaiduChatResponse
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
}
type BaiduEmbeddingRequest struct {
Input []string `json:"input"`
}
type BaiduEmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type BaiduEmbeddingResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Data []BaiduEmbeddingData `json:"data"`
Usage dto.Usage `json:"usage"`
Error
}
type BaiduAccessToken struct {
AccessToken string `json:"access_token"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ExpiresIn int64 `json:"expires_in,omitempty"`
ExpiresAt time.Time `json:"-"`
}
type BaiduTokenResponse struct {
ExpiresIn int `json:"expires_in"`
AccessToken string `json:"access_token"`
}

View File

@ -1,4 +1,4 @@
package controller package baidu
import ( import (
"bufio" "bufio"
@ -9,6 +9,9 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -16,74 +19,9 @@ import (
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
type BaiduTokenResponse struct {
ExpiresIn int `json:"expires_in"`
AccessToken string `json:"access_token"`
}
type BaiduMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"`
Stream bool `json:"stream"`
UserId string `json:"user_id,omitempty"`
}
type BaiduError struct {
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
}
type BaiduChatResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage Usage `json:"usage"`
BaiduError
}
type BaiduChatStreamResponse struct {
BaiduChatResponse
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
}
type BaiduEmbeddingRequest struct {
Input []string `json:"input"`
}
type BaiduEmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type BaiduEmbeddingResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Data []BaiduEmbeddingData `json:"data"`
Usage Usage `json:"usage"`
BaiduError
}
type BaiduAccessToken struct {
AccessToken string `json:"access_token"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ExpiresIn int64 `json:"expires_in,omitempty"`
ExpiresAt time.Time `json:"-"`
}
var baiduTokenStore sync.Map var baiduTokenStore sync.Map
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
messages := make([]BaiduMessage, 0, len(request.Messages)) messages := make([]BaiduMessage, 0, len(request.Messages))
for _, message := range request.Messages { for _, message := range request.Messages {
if message.Role == "system" { if message.Role == "system" {
@ -108,57 +46,57 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
} }
} }
func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
content, _ := json.Marshal(response.Result) content, _ := json.Marshal(response.Result)
choice := OpenAITextResponseChoice{ choice := dto.OpenAITextResponseChoice{
Index: 0, Index: 0,
Message: Message{ Message: dto.Message{
Role: "assistant", Role: "assistant",
Content: content, Content: content,
}, },
FinishReason: "stop", FinishReason: "stop",
} }
fullTextResponse := OpenAITextResponse{ fullTextResponse := dto.OpenAITextResponse{
Id: response.Id, Id: response.Id,
Object: "chat.completion", Object: "chat.completion",
Created: response.Created, Created: response.Created,
Choices: []OpenAITextResponseChoice{choice}, Choices: []dto.OpenAITextResponseChoice{choice},
Usage: response.Usage, Usage: response.Usage,
} }
return &fullTextResponse return &fullTextResponse
} }
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse { func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = baiduResponse.Result choice.Delta.Content = baiduResponse.Result
if baiduResponse.IsEnd { if baiduResponse.IsEnd {
choice.FinishReason = &stopFinishReason choice.FinishReason = &relaycommon.StopFinishReason
} }
response := ChatCompletionsStreamResponse{ response := dto.ChatCompletionsStreamResponse{
Id: baiduResponse.Id, Id: baiduResponse.Id,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: baiduResponse.Created, Created: baiduResponse.Created,
Model: "ernie-bot", Model: "ernie-bot",
Choices: []ChatCompletionsStreamResponseChoice{choice}, Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
} }
return &response return &response
} }
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { func embeddingRequestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduEmbeddingRequest {
return &BaiduEmbeddingRequest{ return &BaiduEmbeddingRequest{
Input: request.ParseInput(), Input: request.ParseInput(),
} }
} }
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
openAIEmbeddingResponse := OpenAIEmbeddingResponse{ openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
Object: "list", Object: "list",
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)), Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
Model: "baidu-embedding", Model: "baidu-embedding",
Usage: response.Usage, Usage: response.Usage,
} }
for _, item := range response.Data { for _, item := range response.Data {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
Object: item.Object, Object: item.Object,
Index: item.Index, Index: item.Index,
Embedding: item.Embedding, Embedding: item.Embedding,
@ -167,8 +105,8 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbe
return &openAIEmbeddingResponse return &openAIEmbeddingResponse
} }
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage Usage var usage dto.Usage
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
@ -195,7 +133,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
} }
stopChan <- true stopChan <- true
}() }()
setEventStreamHeaders(c) service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
@ -225,28 +163,28 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
}) })
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
return nil, &usage return nil, &usage
} }
func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var baiduResponse BaiduChatResponse var baiduResponse BaiduChatResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
err = json.Unmarshal(responseBody, &baiduResponse) err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if baiduResponse.ErrorMsg != "" { if baiduResponse.ErrorMsg != "" {
return &OpenAIErrorWithStatusCode{ return &dto.OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{ OpenAIError: dto.OpenAIError{
Message: baiduResponse.ErrorMsg, Message: baiduResponse.ErrorMsg,
Type: "baidu_error", Type: "baidu_error",
Param: "", Param: "",
@ -258,7 +196,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
fullTextResponse := responseBaidu2OpenAI(&baiduResponse) fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
@ -266,23 +204,23 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
return nil, &fullTextResponse.Usage return nil, &fullTextResponse.Usage
} }
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var baiduResponse BaiduEmbeddingResponse var baiduResponse BaiduEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
err = json.Unmarshal(responseBody, &baiduResponse) err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if baiduResponse.ErrorMsg != "" { if baiduResponse.ErrorMsg != "" {
return &OpenAIErrorWithStatusCode{ return &dto.OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{ OpenAIError: dto.OpenAIError{
Message: baiduResponse.ErrorMsg, Message: baiduResponse.ErrorMsg,
Type: "baidu_error", Type: "baidu_error",
Param: "", Param: "",
@ -294,7 +232,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
@ -337,7 +275,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
} }
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json") req.Header.Add("Accept", "application/json")
res, err := impatientHTTPClient.Do(req) res, err := service.GetImpatientHttpClient().Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -0,0 +1,65 @@
package claude
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
relaychannel "one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
relaychannel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-api-key", info.ApiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return relaychannel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = claudeStreamHandler(c, resp)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = claudeHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@ -0,0 +1,7 @@
package claude
var ModelList = []string{
"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1",
}
var ChannelName = "claude"

View File

@ -0,0 +1,29 @@
package claude
type ClaudeMetadata struct {
UserId string `json:"user_id"`
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokensToSample uint `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ClaudeError struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ClaudeResponse struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error ClaudeError `json:"error"`
}

View File

@ -1,4 +1,4 @@
package controller package claude
import ( import (
"bufio" "bufio"
@ -8,37 +8,11 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto"
"one-api/service"
"strings" "strings"
) )
type ClaudeMetadata struct {
UserId string `json:"user_id"`
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokensToSample uint `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ClaudeError struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ClaudeResponse struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error ClaudeError `json:"error"`
}
func stopReasonClaude2OpenAI(reason string) string { func stopReasonClaude2OpenAI(reason string) string {
switch reason { switch reason {
case "stop_sequence": case "stop_sequence":
@ -50,7 +24,7 @@ func stopReasonClaude2OpenAI(reason string) string {
} }
} }
func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
claudeRequest := ClaudeRequest{ claudeRequest := ClaudeRequest{
Model: textRequest.Model, Model: textRequest.Model,
Prompt: "", Prompt: "",
@ -78,41 +52,41 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
return &claudeRequest return &claudeRequest
} }
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse { func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = claudeResponse.Completion choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" { if finishReason != "null" {
choice.FinishReason = &finishReason choice.FinishReason = &finishReason
} }
var response ChatCompletionsStreamResponse var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model response.Model = claudeResponse.Model
response.Choices = []ChatCompletionsStreamResponseChoice{choice} response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
return &response return &response
} }
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " ")) content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
choice := OpenAITextResponseChoice{ choice := dto.OpenAITextResponseChoice{
Index: 0, Index: 0,
Message: Message{ Message: dto.Message{
Role: "assistant", Role: "assistant",
Content: content, Content: content,
Name: nil, Name: nil,
}, },
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
} }
fullTextResponse := OpenAITextResponse{ fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice}, Choices: []dto.OpenAITextResponseChoice{choice},
} }
return &fullTextResponse return &fullTextResponse
} }
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
responseText := "" responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp() createdTime := common.GetTimestamp()
@ -142,7 +116,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
} }
stopChan <- true stopChan <- true
}() }()
setEventStreamHeaders(c) service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
@ -172,28 +146,28 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
}) })
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
} }
return nil, responseText return nil, responseText
} }
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
var claudeResponse ClaudeResponse var claudeResponse ClaudeResponse
err = json.Unmarshal(responseBody, &claudeResponse) err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if claudeResponse.Error.Type != "" { if claudeResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{ return &dto.OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{ OpenAIError: dto.OpenAIError{
Message: claudeResponse.Error.Message, Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type, Type: claudeResponse.Error.Type,
Param: "", Param: "",
@ -203,8 +177,8 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
}, nil }, nil
} }
fullTextResponse := responseClaude2OpenAI(&claudeResponse) fullTextResponse := responseClaude2OpenAI(&claudeResponse)
completionTokens := countTokenText(claudeResponse.Completion, model) completionTokens := service.CountTokenText(claudeResponse.Completion, model)
usage := Usage{ usage := dto.Usage{
PromptTokens: promptTokens, PromptTokens: promptTokens,
CompletionTokens: completionTokens, CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens, TotalTokens: promptTokens + completionTokens,
@ -212,7 +186,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
fullTextResponse.Usage = usage fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)

View File

@ -0,0 +1,64 @@
package gemini
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
relaychannel "one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
version := "v1"
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent"
}
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
relaychannel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-goog-api-key", info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return CovertGemini2OpenAI(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return relaychannel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = geminiChatStreamHandler(c, resp)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@ -0,0 +1,12 @@
package gemini
const (
GeminiVisionMaxImageNum = 16
)
var ModelList = []string{
"gemini-pro",
"gemini-pro-vision",
}
var ChannelName = "google gemini"

View File

@ -0,0 +1,62 @@
package gemini
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []GeminiChatTools `json:"tools,omitempty"`
}
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
}
type GeminiChatContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
}
type GeminiChatSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type GeminiChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
type GeminiChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
type GeminiChatPromptFeedback struct {
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
}

View File

@ -1,4 +1,4 @@
package controller package gemini
import ( import (
"bufio" "bufio"
@ -7,57 +7,16 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
const (
GeminiVisionMaxImageNum = 16
)
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []GeminiChatTools `json:"tools,omitempty"`
}
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
}
type GeminiChatContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
}
type GeminiChatSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type GeminiChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
// Setting safety to the lowest possible values since Gemini is already powerless enough // Setting safety to the lowest possible values since Gemini is already powerless enough
func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatRequest {
geminiRequest := GeminiChatRequest{ geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
SafetySettings: []GeminiChatSafetySettings{ SafetySettings: []GeminiChatSafetySettings{
@ -106,16 +65,16 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
imageNum := 0 imageNum := 0
for _, part := range openaiContent { for _, part := range openaiContent {
if part.Type == ContentTypeText { if part.Type == dto.ContentTypeText {
parts = append(parts, GeminiPart{ parts = append(parts, GeminiPart{
Text: part.Text, Text: part.Text,
}) })
} else if part.Type == ContentTypeImageURL { } else if part.Type == dto.ContentTypeImageURL {
imageNum += 1 imageNum += 1
if imageNum > GeminiVisionMaxImageNum { if imageNum > GeminiVisionMaxImageNum {
continue continue
} }
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(MessageImageUrl).Url) mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{ parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{ InlineData: &GeminiInlineData{
MimeType: mimeType, MimeType: mimeType,
@ -154,11 +113,6 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
return &geminiRequest return &geminiRequest
} }
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
}
func (g *GeminiChatResponse) GetResponseText() string { func (g *GeminiChatResponse) GetResponseText() string {
if g == nil { if g == nil {
return "" return ""
@ -169,38 +123,22 @@ func (g *GeminiChatResponse) GetResponseText() string {
return "" return ""
} }
type GeminiChatCandidate struct { func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
Content GeminiChatContent `json:"content"` fullTextResponse := dto.OpenAITextResponse{
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
type GeminiChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
type GeminiChatPromptFeedback struct {
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
} }
content, _ := json.Marshal("") content, _ := json.Marshal("")
for i, candidate := range response.Candidates { for i, candidate := range response.Candidates {
choice := OpenAITextResponseChoice{ choice := dto.OpenAITextResponseChoice{
Index: i, Index: i,
Message: Message{ Message: dto.Message{
Role: "assistant", Role: "assistant",
Content: content, Content: content,
}, },
FinishReason: stopFinishReason, FinishReason: relaycommon.StopFinishReason,
} }
content, _ = json.Marshal(candidate.Content.Parts[0].Text) content, _ = json.Marshal(candidate.Content.Parts[0].Text)
if len(candidate.Content.Parts) > 0 { if len(candidate.Content.Parts) > 0 {
@ -211,18 +149,18 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse
return &fullTextResponse return &fullTextResponse
} }
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = geminiResponse.GetResponseText() choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &stopFinishReason choice.FinishReason = &relaycommon.StopFinishReason
var response ChatCompletionsStreamResponse var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Model = "gemini" response.Model = "gemini"
response.Choices = []ChatCompletionsStreamResponseChoice{choice} response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
return &response return &response
} }
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
responseText := "" responseText := ""
dataChan := make(chan string) dataChan := make(chan string)
stopChan := make(chan bool) stopChan := make(chan bool)
@ -252,7 +190,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
} }
stopChan <- true stopChan <- true
}() }()
setEventStreamHeaders(c) service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
@ -264,14 +202,14 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
var dummy dummyStruct var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy) err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content responseText += dummy.Content
var choice ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content choice.Delta.Content = dummy.Content
response := ChatCompletionsStreamResponse{ response := dto.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Model: "gemini-pro", Model: "gemini-pro",
Choices: []ChatCompletionsStreamResponseChoice{choice}, Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
} }
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
@ -287,28 +225,28 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
}) })
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
} }
return nil, responseText return nil, responseText
} }
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
var geminiResponse GeminiChatResponse var geminiResponse GeminiChatResponse
err = json.Unmarshal(responseBody, &geminiResponse) err = json.Unmarshal(responseBody, &geminiResponse)
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if len(geminiResponse.Candidates) == 0 { if len(geminiResponse.Candidates) == 0 {
return &OpenAIErrorWithStatusCode{ return &dto.OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{ OpenAIError: dto.OpenAIError{
Message: "No candidates returned", Message: "No candidates returned",
Type: "server_error", Type: "server_error",
Param: "", Param: "",
@ -318,8 +256,8 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
}, nil }, nil
} }
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
completionTokens := countTokenText(geminiResponse.GetResponseText(), model) completionTokens := service.CountTokenText(geminiResponse.GetResponseText(), model)
usage := Usage{ usage := dto.Usage{
PromptTokens: promptTokens, PromptTokens: promptTokens,
CompletionTokens: completionTokens, CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens, TotalTokens: promptTokens + completionTokens,
@ -327,7 +265,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
fullTextResponse.Usage = usage fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)

View File

@ -0,0 +1,7 @@
package moonshot
var ModelList = []string{
"moonshot-v1-8k",
"moonshot-v1-32k",
"moonshot-v1-128k",
}

View File

@ -0,0 +1,84 @@
package openai
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaychannel "one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.ChannelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
requestURL := strings.Split(info.RequestURLPath, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, info.ApiVersion)
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := info.UpstreamModelName
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613")
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
}
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
relaychannel.SetupApiRequestHeader(info, c, req)
if info.ChannelType == common.ChannelTypeAzure {
req.Header.Set("api-key", info.ApiKey)
return nil
}
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
if info.ChannelType == common.ChannelTypeOpenRouter {
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
req.Header.Set("X-Title", "One API")
}
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return relaychannel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = openaiStreamHandler(c, resp, info.RelayMode)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = openaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@ -0,0 +1,21 @@
package openai
var ModelList = []string{
"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-instruct",
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-4-turbo-preview",
"gpt-4-vision-preview",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
"text-moderation-latest", "text-moderation-stable",
"text-davinci-edit-001",
"davinci-002", "babbage-002",
"dall-e-2", "dall-e-3",
"whisper-1",
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
}
var ChannelName = "openai"

View File

@ -1,4 +1,4 @@
package controller package openai
import ( import (
"bufio" "bufio"
@ -8,12 +8,15 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto"
relayconstant "one-api/relay/constant"
"one-api/service"
"strings" "strings"
"sync" "sync"
"time" "time"
) )
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
var responseTextBuilder strings.Builder var responseTextBuilder strings.Builder
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@ -54,8 +57,8 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
} }
streamResp := "[" + strings.Join(streamItems, ",") + "]" streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch relayMode { switch relayMode {
case RelayModeChatCompletions: case relayconstant.RelayModeChatCompletions:
var streamResponses []ChatCompletionsStreamResponseSimple var streamResponses []dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
@ -66,8 +69,8 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
responseTextBuilder.WriteString(choice.Delta.Content) responseTextBuilder.WriteString(choice.Delta.Content)
} }
} }
case RelayModeCompletions: case relayconstant.RelayModeCompletions:
var streamResponses []CompletionsStreamResponse var streamResponses []dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
@ -85,7 +88,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
} }
common.SafeSend(stopChan, true) common.SafeSend(stopChan, true)
}() }()
setEventStreamHeaders(c) service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
@ -102,28 +105,28 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
}) })
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
} }
wg.Wait() wg.Wait()
return nil, responseTextBuilder.String() return nil, responseTextBuilder.String()
} }
func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var textResponse TextResponse var textResponse dto.TextResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
err = json.Unmarshal(responseBody, &textResponse) err = json.Unmarshal(responseBody, &textResponse)
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if textResponse.Error.Type != "" { if textResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{ return &dto.OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error, OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
}, nil }, nil
@ -140,19 +143,19 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body) _, err = io.Copy(c.Writer, resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
if textResponse.Usage.TotalTokens == 0 { if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0 completionTokens := 0
for _, choice := range textResponse.Choices { for _, choice := range textResponse.Choices {
completionTokens += countTokenText(string(choice.Message.Content), model) completionTokens += service.CountTokenText(string(choice.Message.Content), model)
} }
textResponse.Usage = Usage{ textResponse.Usage = dto.Usage{
PromptTokens: promptTokens, PromptTokens: promptTokens,
CompletionTokens: completionTokens, CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens, TotalTokens: promptTokens + completionTokens,

View File

@ -0,0 +1,59 @@
package palm
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
relaychannel "one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
relaychannel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-goog-api-key", info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return relaychannel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = palmStreamHandler(c, resp)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@ -0,0 +1,7 @@
package palm
var ModelList = []string{
"PaLM-2",
}
var ChannelName = "google palm"

38
relay/channel/palm/dto.go Normal file
View File

@ -0,0 +1,38 @@
package palm
import "one-api/dto"
type PaLMChatMessage struct {
Author string `json:"author"`
Content string `json:"content"`
}
type PaLMFilter struct {
Reason string `json:"reason"`
Message string `json:"message"`
}
type PaLMPrompt struct {
Messages []PaLMChatMessage `json:"messages"`
}
type PaLMChatRequest struct {
Prompt PaLMPrompt `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK uint `json:"topK,omitempty"`
}
type PaLMError struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
}
type PaLMChatResponse struct {
Candidates []PaLMChatMessage `json:"candidates"`
Messages []dto.Message `json:"messages"`
Filters []PaLMFilter `json:"filters"`
Error PaLMError `json:"error"`
}

View File

@ -1,4 +1,4 @@
package controller package palm
import ( import (
"encoding/json" "encoding/json"
@ -7,47 +7,15 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
) )
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
type PaLMChatMessage struct { func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
Author string `json:"author"`
Content string `json:"content"`
}
type PaLMFilter struct {
Reason string `json:"reason"`
Message string `json:"message"`
}
type PaLMPrompt struct {
Messages []PaLMChatMessage `json:"messages"`
}
type PaLMChatRequest struct {
Prompt PaLMPrompt `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK uint `json:"topK,omitempty"`
}
type PaLMError struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
}
type PaLMChatResponse struct {
Candidates []PaLMChatMessage `json:"candidates"`
Messages []Message `json:"messages"`
Filters []PaLMFilter `json:"filters"`
Error PaLMError `json:"error"`
}
func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
palmRequest := PaLMChatRequest{ palmRequest := PaLMChatRequest{
Prompt: PaLMPrompt{ Prompt: PaLMPrompt{
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
@ -71,15 +39,15 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
return &palmRequest return &palmRequest
} }
func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
fullTextResponse := OpenAITextResponse{ fullTextResponse := dto.OpenAITextResponse{
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
} }
for i, candidate := range response.Candidates { for i, candidate := range response.Candidates {
content, _ := json.Marshal(candidate.Content) content, _ := json.Marshal(candidate.Content)
choice := OpenAITextResponseChoice{ choice := dto.OpenAITextResponseChoice{
Index: i, Index: i,
Message: Message{ Message: dto.Message{
Role: "assistant", Role: "assistant",
Content: content, Content: content,
}, },
@ -90,20 +58,20 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
return &fullTextResponse return &fullTextResponse
} }
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse { func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
if len(palmResponse.Candidates) > 0 { if len(palmResponse.Candidates) > 0 {
choice.Delta.Content = palmResponse.Candidates[0].Content choice.Delta.Content = palmResponse.Candidates[0].Content
} }
choice.FinishReason = &stopFinishReason choice.FinishReason = &relaycommon.StopFinishReason
var response ChatCompletionsStreamResponse var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Model = "palm2" response.Model = "palm2"
response.Choices = []ChatCompletionsStreamResponseChoice{choice} response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
return &response return &response
} }
func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
responseText := "" responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp() createdTime := common.GetTimestamp()
@ -144,7 +112,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta
dataChan <- string(jsonResponse) dataChan <- string(jsonResponse)
stopChan <- true stopChan <- true
}() }()
setEventStreamHeaders(c) service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
@ -157,28 +125,28 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta
}) })
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
} }
return nil, responseText return nil, responseText
} }
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
var palmResponse PaLMChatResponse var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse) err = json.Unmarshal(responseBody, &palmResponse)
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
return &OpenAIErrorWithStatusCode{ return &dto.OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{ OpenAIError: dto.OpenAIError{
Message: palmResponse.Error.Message, Message: palmResponse.Error.Message,
Type: palmResponse.Error.Status, Type: palmResponse.Error.Status,
Param: "", Param: "",
@ -188,8 +156,8 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil }, nil
} }
fullTextResponse := responsePaLM2OpenAI(&palmResponse) fullTextResponse := responsePaLM2OpenAI(&palmResponse)
completionTokens := countTokenText(palmResponse.Candidates[0].Content, model) completionTokens := service.CountTokenText(palmResponse.Candidates[0].Content, model)
usage := Usage{ usage := dto.Usage{
PromptTokens: promptTokens, PromptTokens: promptTokens,
CompletionTokens: completionTokens, CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens, TotalTokens: promptTokens + completionTokens,
@ -197,7 +165,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
fullTextResponse.Usage = usage fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)

View File

@ -0,0 +1,73 @@
package tencent
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
relaychannel "one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
)
type Adaptor struct {
Sign string
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/hyllm/v1/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
relaychannel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", a.Sign)
req.Header.Set("X-TC-Action", info.UpstreamModelName)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
if err != nil {
return nil, err
}
tencentRequest := requestOpenAI2Tencent(*request)
tencentRequest.AppId = appId
tencentRequest.SecretId = secretId
// we have to calculate the sign here
a.Sign = getTencentSign(*tencentRequest, secretKey)
return tencentRequest, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return relaychannel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = tencentStreamHandler(c, resp)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = tencentHandler(c, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@ -0,0 +1,9 @@
package tencent
var ModelList = []string{
"ChatPro",
"ChatStd",
"hunyuan",
}
var ChannelName = "tencent"

View File

@ -0,0 +1,61 @@
package tencent
import "one-api/dto"
type TencentMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type TencentChatRequest struct {
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
SecretId string `json:"secret_id"` // 官网 SecretId
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
// 例如1529223702如果与当前时间相差过大会引起签名过期错误
Timestamp int64 `json:"timestamp"`
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
// 单位为秒Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
Expired int64 `json:"expired"`
QueryID string `json:"query_id"` //请求 Id用于问题排查
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
// 建议该参数和 top_p 只设置1个不要同时更改 top_p
Temperature float64 `json:"temperature"`
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
// 建议该参数和 temperature 只设置1个不要同时更改
TopP float64 `json:"top_p"`
// Stream 0同步1流式 默认协议SSE)
// 同步请求超时60s如果内容较长建议使用流式
Stream int `json:"stream"`
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
// 输入 content 总数最大支持 3000 token。
Messages []TencentMessage `json:"messages"`
}
type TencentError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type TencentUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type TencentResponseChoices struct {
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
}
type TencentChatResponse struct {
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id
Usage dto.Usage `json:"usage,omitempty"` // token 数量
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}

View File

@ -1,4 +1,4 @@
package controller package tencent
import ( import (
"bufio" "bufio"
@ -12,6 +12,9 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -19,65 +22,7 @@ import (
// https://cloud.tencent.com/document/product/1729/97732 // https://cloud.tencent.com/document/product/1729/97732
type TencentMessage struct { func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest {
Role string `json:"role"`
Content string `json:"content"`
}
type TencentChatRequest struct {
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
SecretId string `json:"secret_id"` // 官网 SecretId
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
// 例如1529223702如果与当前时间相差过大会引起签名过期错误
Timestamp int64 `json:"timestamp"`
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
// 单位为秒Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
Expired int64 `json:"expired"`
QueryID string `json:"query_id"` //请求 Id用于问题排查
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
// 建议该参数和 top_p 只设置1个不要同时更改 top_p
Temperature float64 `json:"temperature"`
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
// 建议该参数和 temperature 只设置1个不要同时更改
TopP float64 `json:"top_p"`
// Stream 0同步1流式 默认协议SSE)
// 同步请求超时60s如果内容较长建议使用流式
Stream int `json:"stream"`
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
// 输入 content 总数最大支持 3000 token。
Messages []TencentMessage `json:"messages"`
}
type TencentError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type TencentUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type TencentResponseChoices struct {
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
}
type TencentChatResponse struct {
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id
Usage Usage `json:"usage,omitempty"` // token 数量
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}
func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
messages := make([]TencentMessage, 0, len(request.Messages)) messages := make([]TencentMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ { for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i] message := request.Messages[i]
@ -112,17 +57,17 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
} }
} }
func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
fullTextResponse := OpenAITextResponse{ fullTextResponse := dto.OpenAITextResponse{
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Usage: response.Usage, Usage: response.Usage,
} }
if len(response.Choices) > 0 { if len(response.Choices) > 0 {
content, _ := json.Marshal(response.Choices[0].Messages.Content) content, _ := json.Marshal(response.Choices[0].Messages.Content)
choice := OpenAITextResponseChoice{ choice := dto.OpenAITextResponseChoice{
Index: 0, Index: 0,
Message: Message{ Message: dto.Message{
Role: "assistant", Role: "assistant",
Content: content, Content: content,
}, },
@ -133,24 +78,24 @@ func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
return &fullTextResponse return &fullTextResponse
} }
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse { func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.ChatCompletionsStreamResponse {
response := ChatCompletionsStreamResponse{ response := dto.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Model: "tencent-hunyuan", Model: "tencent-hunyuan",
} }
if len(TencentResponse.Choices) > 0 { if len(TencentResponse.Choices) > 0 {
var choice ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
if TencentResponse.Choices[0].FinishReason == "stop" { if TencentResponse.Choices[0].FinishReason == "stop" {
choice.FinishReason = &stopFinishReason choice.FinishReason = &relaycommon.StopFinishReason
} }
response.Choices = append(response.Choices, choice) response.Choices = append(response.Choices, choice)
} }
return &response return &response
} }
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
var responseText string var responseText string
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@ -181,7 +126,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith
} }
stopChan <- true stopChan <- true
}() }()
setEventStreamHeaders(c) service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
@ -209,28 +154,28 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith
}) })
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
} }
return nil, responseText return nil, responseText
} }
func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var TencentResponse TencentChatResponse var TencentResponse TencentChatResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
err = json.Unmarshal(responseBody, &TencentResponse) err = json.Unmarshal(responseBody, &TencentResponse)
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if TencentResponse.Error.Code != 0 { if TencentResponse.Error.Code != 0 {
return &OpenAIErrorWithStatusCode{ return &dto.OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{ OpenAIError: dto.OpenAIError{
Message: TencentResponse.Error.Message, Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code, Code: TencentResponse.Error.Code,
}, },
@ -240,7 +185,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatus
fullTextResponse := responseTencent2OpenAI(&TencentResponse) fullTextResponse := responseTencent2OpenAI(&TencentResponse)
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)

View File

@ -0,0 +1,68 @@
package xunfei
import (
"errors"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
relaychannel "one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
)
type Adaptor struct {
request *dto.GeneralOpenAIRequest
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return "", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
relaychannel.SetupApiRequestHeader(info, c, req)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
a.request = request
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
// xunfei's request is not http request, so we don't need to do anything here
dummyResp := &http.Response{}
dummyResp.StatusCode = http.StatusOK
return dummyResp, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
splits := strings.Split(info.ApiKey, "|")
if len(splits) != 3 {
return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
if a.request == nil {
return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
}
if info.IsStream {
err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
} else {
err, usage = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2])
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@ -0,0 +1,11 @@
package xunfei
var ModelList = []string{
"SparkDesk",
"SparkDesk-v1.1",
"SparkDesk-v2.1",
"SparkDesk-v3.1",
"SparkDesk-v3.5",
}
var ChannelName = "xunfei"

View File

@ -0,0 +1,59 @@
package xunfei
import "one-api/dto"
type XunfeiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type XunfeiChatRequest struct {
Header struct {
AppId string `json:"app_id"`
} `json:"header"`
Parameter struct {
Chat struct {
Domain string `json:"domain,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Auditing bool `json:"auditing,omitempty"`
} `json:"chat"`
} `json:"parameter"`
Payload struct {
Message struct {
Text []XunfeiMessage `json:"text"`
} `json:"message"`
} `json:"payload"`
}
type XunfeiChatResponseTextItem struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
}
type XunfeiChatResponse struct {
Header struct {
Code int `json:"code"`
Message string `json:"message"`
Sid string `json:"sid"`
Status int `json:"status"`
} `json:"header"`
Payload struct {
Choices struct {
Status int `json:"status"`
Seq int `json:"seq"`
Text []XunfeiChatResponseTextItem `json:"text"`
} `json:"choices"`
Usage struct {
//Text struct {
// QuestionTokens string `json:"question_tokens"`
// PromptTokens string `json:"prompt_tokens"`
// CompletionTokens string `json:"completion_tokens"`
// TotalTokens string `json:"total_tokens"`
//} `json:"text"`
Text dto.Usage `json:"text"`
} `json:"usage"`
} `json:"payload"`
}

View File

@ -1,4 +1,4 @@
package controller package xunfei
import ( import (
"crypto/hmac" "crypto/hmac"
@ -12,6 +12,9 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"one-api/common" "one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings" "strings"
"time" "time"
) )
@ -19,63 +22,7 @@ import (
// https://console.xfyun.cn/services/cbm // https://console.xfyun.cn/services/cbm
// https://www.xfyun.cn/doc/spark/Web.html // https://www.xfyun.cn/doc/spark/Web.html
type XunfeiMessage struct { func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
Role string `json:"role"`
Content string `json:"content"`
}
type XunfeiChatRequest struct {
Header struct {
AppId string `json:"app_id"`
} `json:"header"`
Parameter struct {
Chat struct {
Domain string `json:"domain,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Auditing bool `json:"auditing,omitempty"`
} `json:"chat"`
} `json:"parameter"`
Payload struct {
Message struct {
Text []XunfeiMessage `json:"text"`
} `json:"message"`
} `json:"payload"`
}
type XunfeiChatResponseTextItem struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
}
type XunfeiChatResponse struct {
Header struct {
Code int `json:"code"`
Message string `json:"message"`
Sid string `json:"sid"`
Status int `json:"status"`
} `json:"header"`
Payload struct {
Choices struct {
Status int `json:"status"`
Seq int `json:"seq"`
Text []XunfeiChatResponseTextItem `json:"text"`
} `json:"choices"`
Usage struct {
//Text struct {
// QuestionTokens string `json:"question_tokens"`
// PromptTokens string `json:"prompt_tokens"`
// CompletionTokens string `json:"completion_tokens"`
// TotalTokens string `json:"total_tokens"`
//} `json:"text"`
Text Usage `json:"text"`
} `json:"usage"`
} `json:"payload"`
}
func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
messages := make([]XunfeiMessage, 0, len(request.Messages)) messages := make([]XunfeiMessage, 0, len(request.Messages))
for _, message := range request.Messages { for _, message := range request.Messages {
if message.Role == "system" { if message.Role == "system" {
@ -104,7 +51,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
return &xunfeiRequest return &xunfeiRequest
} }
func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { func responseXunfei2OpenAI(response *XunfeiChatResponse) *dto.OpenAITextResponse {
if len(response.Payload.Choices.Text) == 0 { if len(response.Payload.Choices.Text) == 0 {
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
{ {
@ -113,24 +60,24 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
} }
} }
content, _ := json.Marshal(response.Payload.Choices.Text[0].Content) content, _ := json.Marshal(response.Payload.Choices.Text[0].Content)
choice := OpenAITextResponseChoice{ choice := dto.OpenAITextResponseChoice{
Index: 0, Index: 0,
Message: Message{ Message: dto.Message{
Role: "assistant", Role: "assistant",
Content: content, Content: content,
}, },
FinishReason: stopFinishReason, FinishReason: relaycommon.StopFinishReason,
} }
fullTextResponse := OpenAITextResponse{ fullTextResponse := dto.OpenAITextResponse{
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice}, Choices: []dto.OpenAITextResponseChoice{choice},
Usage: response.Payload.Usage.Text, Usage: response.Payload.Usage.Text,
} }
return &fullTextResponse return &fullTextResponse
} }
func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse { func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *dto.ChatCompletionsStreamResponse {
if len(xunfeiResponse.Payload.Choices.Text) == 0 { if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
{ {
@ -138,16 +85,16 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatComple
}, },
} }
} }
var choice ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
if xunfeiResponse.Payload.Choices.Status == 2 { if xunfeiResponse.Payload.Choices.Status == 2 {
choice.FinishReason = &stopFinishReason choice.FinishReason = &relaycommon.StopFinishReason
} }
response := ChatCompletionsStreamResponse{ response := dto.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Model: "SparkDesk", Model: "SparkDesk",
Choices: []ChatCompletionsStreamResponseChoice{choice}, Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
} }
return &response return &response
} }
@ -178,14 +125,14 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
return callUrl return callUrl
} }
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil { if err != nil {
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
} }
setEventStreamHeaders(c) service.SetEventStreamHeaders(c)
var usage Usage var usage dto.Usage
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case xunfeiResponse := <-dataChan: case xunfeiResponse := <-dataChan:
@ -208,13 +155,13 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
return nil, &usage return nil, &usage
} }
func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil { if err != nil {
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
} }
var usage Usage var usage dto.Usage
var content string var content string
var xunfeiResponse XunfeiChatResponse var xunfeiResponse XunfeiChatResponse
stop := false stop := false
@ -237,14 +184,14 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin
response := responseXunfei2OpenAI(&xunfeiResponse) response := responseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
_, _ = c.Writer.Write(jsonResponse) _, _ = c.Writer.Write(jsonResponse)
return nil, &usage return nil, &usage
} }
func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) { func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
d := websocket.Dialer{ d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second, HandshakeTimeout: 5 * time.Second,
} }

View File

@ -0,0 +1,61 @@
package zhipu
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
relaychannel "one-api/relay/channel"
relaycommon "one-api/relay/common"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
method := "invoke"
if info.IsStream {
method = "sse-invoke"
}
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
relaychannel.SetupApiRequestHeader(info, c, req)
token := getZhipuToken(info.ApiKey)
req.Header.Set("Authorization", token)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return requestOpenAI2Zhipu(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return relaychannel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = zhipuStreamHandler(c, resp)
} else {
err, usage = zhipuHandler(c, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@ -0,0 +1,7 @@
package zhipu
var ModelList = []string{
"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
}
var ChannelName = "zhipu"

View File

@ -0,0 +1,46 @@
package zhipu
import (
"one-api/dto"
"time"
)
type ZhipuMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ZhipuRequest struct {
Prompt []ZhipuMessage `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
RequestId string `json:"request_id,omitempty"`
Incremental bool `json:"incremental,omitempty"`
}
type ZhipuResponseData struct {
TaskId string `json:"task_id"`
RequestId string `json:"request_id"`
TaskStatus string `json:"task_status"`
Choices []ZhipuMessage `json:"choices"`
dto.Usage `json:"usage"`
}
type ZhipuResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
Success bool `json:"success"`
Data ZhipuResponseData `json:"data"`
}
type ZhipuStreamMetaResponse struct {
RequestId string `json:"request_id"`
TaskId string `json:"task_id"`
TaskStatus string `json:"task_status"`
dto.Usage `json:"usage"`
}
type zhipuTokenData struct {
Token string
ExpiryTime time.Time
}

View File

@ -1,4 +1,4 @@
package controller package zhipu
import ( import (
"bufio" "bufio"
@ -8,6 +8,9 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -18,46 +21,6 @@ import (
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
type ZhipuMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ZhipuRequest struct {
Prompt []ZhipuMessage `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
RequestId string `json:"request_id,omitempty"`
Incremental bool `json:"incremental,omitempty"`
}
type ZhipuResponseData struct {
TaskId string `json:"task_id"`
RequestId string `json:"request_id"`
TaskStatus string `json:"task_status"`
Choices []ZhipuMessage `json:"choices"`
Usage `json:"usage"`
}
type ZhipuResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
Success bool `json:"success"`
Data ZhipuResponseData `json:"data"`
}
type ZhipuStreamMetaResponse struct {
RequestId string `json:"request_id"`
TaskId string `json:"task_id"`
TaskStatus string `json:"task_status"`
Usage `json:"usage"`
}
type zhipuTokenData struct {
Token string
ExpiryTime time.Time
}
var zhipuTokens sync.Map var zhipuTokens sync.Map
var expSeconds int64 = 24 * 3600 var expSeconds int64 = 24 * 3600
@ -108,7 +71,7 @@ func getZhipuToken(apikey string) string {
return tokenString return tokenString
} }
func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *ZhipuRequest {
messages := make([]ZhipuMessage, 0, len(request.Messages)) messages := make([]ZhipuMessage, 0, len(request.Messages))
for _, message := range request.Messages { for _, message := range request.Messages {
if message.Role == "system" { if message.Role == "system" {
@ -135,19 +98,19 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
} }
} }
func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { func responseZhipu2OpenAI(response *ZhipuResponse) *dto.OpenAITextResponse {
fullTextResponse := OpenAITextResponse{ fullTextResponse := dto.OpenAITextResponse{
Id: response.Data.TaskId, Id: response.Data.TaskId,
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)), Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Data.Choices)),
Usage: response.Data.Usage, Usage: response.Data.Usage,
} }
for i, choice := range response.Data.Choices { for i, choice := range response.Data.Choices {
content, _ := json.Marshal(strings.Trim(choice.Content, "\"")) content, _ := json.Marshal(strings.Trim(choice.Content, "\""))
openaiChoice := OpenAITextResponseChoice{ openaiChoice := dto.OpenAITextResponseChoice{
Index: i, Index: i,
Message: Message{ Message: dto.Message{
Role: choice.Role, Role: choice.Role,
Content: content, Content: content,
}, },
@ -161,34 +124,34 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
return &fullTextResponse return &fullTextResponse
} }
func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse { func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = zhipuResponse choice.Delta.Content = zhipuResponse
response := ChatCompletionsStreamResponse{ response := dto.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Model: "chatglm", Model: "chatglm",
Choices: []ChatCompletionsStreamResponseChoice{choice}, Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
} }
return &response return &response
} }
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) { func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
var choice ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = "" choice.Delta.Content = ""
choice.FinishReason = &stopFinishReason choice.FinishReason = &relaycommon.StopFinishReason
response := ChatCompletionsStreamResponse{ response := dto.ChatCompletionsStreamResponse{
Id: zhipuResponse.RequestId, Id: zhipuResponse.RequestId,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Model: "chatglm", Model: "chatglm",
Choices: []ChatCompletionsStreamResponseChoice{choice}, Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
} }
return &response, &zhipuResponse.Usage return &response, &zhipuResponse.Usage
} }
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage *Usage var usage *dto.Usage
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
@ -225,7 +188,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
} }
stopChan <- true stopChan <- true
}() }()
setEventStreamHeaders(c) service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
@ -260,28 +223,28 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
}) })
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
return nil, usage return nil, usage
} }
func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var zhipuResponse ZhipuResponse var zhipuResponse ZhipuResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
err = json.Unmarshal(responseBody, &zhipuResponse) err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if !zhipuResponse.Success { if !zhipuResponse.Success {
return &OpenAIErrorWithStatusCode{ return &dto.OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{ OpenAIError: dto.OpenAIError{
Message: zhipuResponse.Msg, Message: zhipuResponse.Msg,
Type: "zhipu_error", Type: "zhipu_error",
Param: "", Param: "",
@ -293,7 +256,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse) fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)

View File

@ -0,0 +1,71 @@
package common
import (
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/relay/constant"
"strings"
"time"
)
type RelayInfo struct {
ChannelType int
ChannelId int
TokenId int
UserId int
Group string
TokenUnlimited bool
StartTime time.Time
ApiType int
IsStream bool
RelayMode int
UpstreamModelName string
RequestURLPath string
ApiVersion string
PromptTokens int
ApiKey string
BaseUrl string
}
func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
tokenUnlimited := c.GetBool("token_unlimited_quota")
startTime := time.Now()
apiType := constant.ChannelType2APIType(channelType)
info := &RelayInfo{
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,
ChannelId: channelId,
TokenId: tokenId,
UserId: userId,
Group: group,
TokenUnlimited: tokenUnlimited,
StartTime: startTime,
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
}
if info.BaseUrl == "" {
info.BaseUrl = common.ChannelBaseURLs[channelType]
}
//if info.ChannelType == common.ChannelTypeAzure {
// info.ApiVersion = GetAzureAPIVersion(c)
//}
return info
}
func (info *RelayInfo) SetPromptTokens(promptTokens int) {
info.PromptTokens = promptTokens
}
func (info *RelayInfo) SetIsStream(isStream bool) {
info.IsStream = isStream
}

View File

@ -0,0 +1,68 @@
package common
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"strconv"
"strings"
)
var StopFinishReason = "stop"
func RelayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
openAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
OpenAIError: dto.OpenAIError{
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
err = resp.Body.Close()
if err != nil {
return
}
var textResponse dto.TextResponse
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return
}
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
return
}
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case common.ChannelTypeOpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case common.ChannelTypeAzure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}
return fullRequestURL
}
func GetAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
return apiVersion
}

View File

@ -0,0 +1,45 @@
package constant
import (
"one-api/common"
)
const (
APITypeOpenAI = iota
APITypeAnthropic
APITypePaLM
APITypeBaidu
APITypeZhipu
APITypeAli
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
APITypeDummy // this one is only for count, do not add any channel after this
)
func ChannelType2APIType(channelType int) int {
apiType := APITypeOpenAI
switch channelType {
case common.ChannelTypeAnthropic:
apiType = APITypeAnthropic
case common.ChannelTypeBaidu:
apiType = APITypeBaidu
case common.ChannelTypePaLM:
apiType = APITypePaLM
case common.ChannelTypeZhipu:
apiType = APITypeZhipu
case common.ChannelTypeAli:
apiType = APITypeAli
case common.ChannelTypeXunfei:
apiType = APITypeXunfei
case common.ChannelTypeAIProxyLibrary:
apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent:
apiType = APITypeTencent
case common.ChannelTypeGemini:
apiType = APITypeGemini
}
return apiType
}

View File

@ -0,0 +1,50 @@
package constant
import "strings"
const (
RelayModeUnknown = iota
RelayModeChatCompletions
RelayModeCompletions
RelayModeEmbeddings
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
RelayModeMidjourneyImagine
RelayModeMidjourneyDescribe
RelayModeMidjourneyBlend
RelayModeMidjourneyChange
RelayModeMidjourneySimpleChange
RelayModeMidjourneyNotify
RelayModeMidjourneyTaskFetch
RelayModeMidjourneyTaskFetchByCondition
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
)
func Path2RelayMode(path string) int {
relayMode := RelayModeUnknown
if strings.HasPrefix(path, "/v1/chat/completions") {
relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(path, "/v1/completions") {
relayMode = RelayModeCompletions
} else if strings.HasPrefix(path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasSuffix(path, "embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(path, "/v1/moderations") {
relayMode = RelayModeModerations
} else if strings.HasPrefix(path, "/v1/images/generations") {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(path, "/v1/edits") {
relayMode = RelayModeEdits
} else if strings.HasPrefix(path, "/v1/audio/speech") {
relayMode = RelayModeAudioSpeech
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
relayMode = RelayModeAudioTranscription
} else if strings.HasPrefix(path, "/v1/audio/translations") {
relayMode = RelayModeAudioTranslation
}
return relayMode
}

View File

@ -1,4 +1,4 @@
package controller package relay
import ( import (
"bytes" "bytes"
@ -10,7 +10,10 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/controller"
"one-api/dto"
"one-api/model" "one-api/model"
"one-api/service"
"strings" "strings"
"time" "time"
) )
@ -24,7 +27,7 @@ var availableVoices = []string{
"shimmer", "shimmer",
} }
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWithStatusCode {
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
@ -36,7 +39,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err := common.UnmarshalBodyReusable(c, &audioRequest) err := common.UnmarshalBodyReusable(c, &audioRequest)
if err != nil { if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
} }
} else { } else {
audioRequest = AudioRequest{ audioRequest = AudioRequest{
@ -47,15 +50,15 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
// request validation // request validation
if audioRequest.Model == "" { if audioRequest.Model == "" {
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) return service.OpenAIErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
} }
if strings.HasPrefix(audioRequest.Model, "tts-1") { if strings.HasPrefix(audioRequest.Model, "tts-1") {
if audioRequest.Voice == "" { if audioRequest.Voice == "" {
return errorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest) return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
} }
if !common.StringsContains(availableVoices, audioRequest.Voice) { if !common.StringsContains(availableVoices, audioRequest.Voice) {
return errorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest) return service.OpenAIErrorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
} }
} }
@ -66,14 +69,14 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
preConsumedQuota := int(float64(preConsumedTokens) * ratio) preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(userId) userQuota, err := model.CacheGetUserQuota(userId)
if err != nil { if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
} }
if userQuota-preConsumedQuota < 0 { if userQuota-preConsumedQuota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
} }
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil { if err != nil {
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
} }
if userQuota > 100*preConsumedQuota { if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota // in this case, we do not pre-consume quota
@ -83,7 +86,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if preConsumedQuota > 0 { if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota) userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil { if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) return service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
} }
} }
@ -93,7 +96,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
modelMap := make(map[string]string) modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap) err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
} }
if modelMap[audioRequest.Model] != "" { if modelMap[audioRequest.Model] != "" {
audioRequest.Model = modelMap[audioRequest.Model] audioRequest.Model = modelMap[audioRequest.Model]
@ -106,10 +109,10 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
baseURL = c.GetString("base_url") baseURL = c.GetString("base_url")
} }
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) fullRequestURL := common.getFullRequestURL(baseURL, requestURL, channelType)
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiVersion := GetAPIVersion(c) apiVersion := common.GetAPIVersion(c)
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
} }
@ -117,7 +120,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil { if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
} }
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
@ -133,25 +136,25 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := httpClient.Do(req) resp, err := controller.httpClient.Do(req)
if err != nil { if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }
err = req.Body.Close() err = req.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
} }
err = c.Request.Body.Close() err = c.Request.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return relayErrorHandler(resp) return common.relayErrorHandler(resp)
} }
var audioResponse AudioResponse var audioResponse dto.AudioResponse
defer func(ctx context.Context) { defer func(ctx context.Context) {
go func() { go func() {
@ -159,10 +162,10 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
quota := 0 quota := 0
var promptTokens = 0 var promptTokens = 0
if strings.HasPrefix(audioRequest.Model, "tts-1") { if strings.HasPrefix(audioRequest.Model, "tts-1") {
quota = countAudioToken(audioRequest.Input, audioRequest.Model) quota = service.countAudioToken(audioRequest.Input, audioRequest.Model)
promptTokens = quota promptTokens = quota
} else { } else {
quota = countAudioToken(audioResponse.Text, audioRequest.Model) quota = service.countAudioToken(audioResponse.Text, audioRequest.Model)
} }
quota = int(float64(quota) * ratio) quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 { if ratio != 0 && quota <= 0 {
@ -191,18 +194,18 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
} }
if strings.HasPrefix(audioRequest.Model, "tts-1") { if strings.HasPrefix(audioRequest.Model, "tts-1") {
} else { } else {
err = json.Unmarshal(responseBody, &audioResponse) err = json.Unmarshal(responseBody, &audioResponse)
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
} }
} }
@ -215,11 +218,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
_, err = io.Copy(c.Writer, resp.Body) _, err = io.Copy(c.Writer, resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
} }
return nil return nil
} }

View File

@ -1,4 +1,4 @@
package controller package relay
import ( import (
"bytes" "bytes"
@ -10,12 +10,15 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/controller"
"one-api/dto"
"one-api/model" "one-api/model"
"one-api/relay/common"
"strings" "strings"
"time" "time"
) )
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
@ -24,7 +27,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
group := c.GetString("group") group := c.GetString("group")
startTime := time.Now() startTime := time.Now()
var imageRequest ImageRequest var imageRequest dto.ImageRequest
if consumeQuota { if consumeQuota {
err := common.UnmarshalBodyReusable(c, &imageRequest) err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil { if err != nil {
@ -90,7 +93,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations { if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
apiVersion := GetAPIVersion(c) apiVersion := common.GetAPIVersion(c)
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion)
} }
@ -151,7 +154,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := httpClient.Do(req) resp, err := controller.httpClient.Do(req)
if err != nil { if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }

View File

@ -1,4 +1,4 @@
package controller package relay
import ( import (
"bytes" "bytes"
@ -9,6 +9,7 @@ import (
"log" "log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/controller"
"one-api/model" "one-api/model"
"strconv" "strconv"
"strings" "strings"
@ -104,7 +105,7 @@ func RelayMidjourneyImage(c *gin.Context) {
return return
} }
func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse { func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
var midjRequest Midjourney var midjRequest Midjourney
err := common.UnmarshalBodyReusable(c, &midjRequest) err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil { if err != nil {
@ -167,7 +168,7 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo
return return
} }
func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse { func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
userId := c.GetInt("id") userId := c.GetInt("id")
var err error var err error
var respBody []byte var respBody []byte
@ -244,7 +245,7 @@ const (
MJSubmitActionUpscale = "UPSCALE" // 放大 MJSubmitActionUpscale = "UPSCALE" // 放大
) )
func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
imageModel := "midjourney" imageModel := "midjourney"
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
@ -427,21 +428,21 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
Description: "create_request_failed", Description: "create_request_failed",
} }
} }
//req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) //req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey"))
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
//mjToken := "" //mjToken := ""
//if c.Request.Header.Get("Authorization") != "" { //if c.Request.Header.Get("ApiKey") != "" {
// mjToken = strings.Split(c.Request.Header.Get("Authorization"), " ")[1] // mjToken = strings.Split(c.Request.Header.Get("ApiKey"), " ")[1]
//} //}
//req.Header.Set("Authorization", "Bearer midjourney-proxy") //req.Header.Set("ApiKey", "Bearer midjourney-proxy")
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1]) req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
// print request header // print request header
log.Printf("request header: %s", req.Header) log.Printf("request header: %s", req.Header)
log.Printf("request body: %s", midjRequest.Prompt) log.Printf("request body: %s", midjRequest.Prompt)
resp, err := httpClient.Do(req) resp, err := controller.httpClient.Do(req)
if err != nil { if err != nil {
return &MidjourneyResponse{ return &MidjourneyResponse{
Code: 4, Code: 4,

277
relay/relay-text.go Normal file
View File

@ -0,0 +1,277 @@
package relay
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/model"
relaychannel "one-api/relay/channel"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
textRequest := &dto.GeneralOpenAIRequest{}
err := common.UnmarshalBodyReusable(c, textRequest)
if err != nil {
return nil, err
}
if relayInfo.RelayMode == relayconstant.RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayInfo.RelayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
return nil, errors.New("max_tokens is invalid")
}
if textRequest.Model == "" {
return nil, errors.New("model is required")
}
switch relayInfo.RelayMode {
case relayconstant.RelayModeCompletions:
if textRequest.Prompt == "" {
return nil, errors.New("field prompt is required")
}
case relayconstant.RelayModeChatCompletions:
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return nil, errors.New("field messages is required")
}
case relayconstant.RelayModeEmbeddings:
case relayconstant.RelayModeModerations:
if textRequest.Input == "" {
return nil, errors.New("field input is required")
}
case relayconstant.RelayModeEdits:
if textRequest.Instruction == "" {
return nil, errors.New("field instruction is required")
}
}
relayInfo.IsStream = textRequest.Stream
return textRequest, nil
}
func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
relayInfo := relaycommon.GenRelayInfo(c)
// get & validate textRequest 获取并验证文本请求
textRequest, err := getAndValidateTextRequest(c, relayInfo)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[textRequest.Model] != "" {
textRequest.Model = modelMap[textRequest.Model]
isModelMapped = true
}
}
modelPrice := common.GetModelPrice(textRequest.Model, false)
groupRatio := common.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
promptTokens, err := getPromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
}
if modelPrice == -1 {
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
}
modelRatio = common.GetModelRatio(textRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
// pre-consume quota 预消耗配额
userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
if err != nil {
return openaiErr
}
adaptor := relaychannel.GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapper(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(relayInfo, *textRequest)
var requestBody io.Reader
if relayInfo.ApiType == relayconstant.APITypeOpenAI {
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
} else {
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
if err != nil {
return service.OpenAIErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
}
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil {
return openaiErr
}
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
return nil
}
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
var promptTokens int
var err error
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model)
case relayconstant.RelayModeCompletions:
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model), nil
case relayconstant.RelayModeModerations:
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil
default:
err = errors.New("unknown relay mode")
promptTokens = 0
}
info.PromptTokens = promptTokens
return promptTokens, err
}
// 预扣费并返回用户剩余配额
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *dto.OpenAIErrorWithStatusCode) {
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
if err != nil {
return 0, service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota < 0 || userQuota-preConsumedQuota < 0 {
return 0, service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil {
return 0, service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足
if !relayInfo.TokenUnlimited {
// 非无限令牌,判断令牌额度是否充足
tokenQuota := c.GetInt("token_quota")
if tokenQuota > 100*preConsumedQuota {
// 令牌额度充足,信任令牌
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
}
} else {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
}
}
if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
if err != nil {
return 0, service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
return userQuota, nil
}
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
tokenName := ctx.GetString("token_name")
quota := 0
if modelPrice == -1 {
completionRatio := common.GetCompletionRatio(textRequest.Model)
quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
}
} else {
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
totalTokens := promptTokens + completionTokens
var logContent string
if modelPrice == -1 {
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
}
// record all the consume log even if quota is 0
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota))
} else {
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(relayInfo.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
logModel := textRequest.Model
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
logModel = "gpt-4-gizmo-*"
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
}
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream)
//if quota != 0 {
//
//}
}

View File

@ -1,10 +1,10 @@
package router package router
import ( import (
"github.com/gin-gonic/gin"
"one-api/controller" "one-api/controller"
"one-api/middleware" "one-api/middleware"
"one-api/relay"
"github.com/gin-gonic/gin"
) )
func SetRelayRouter(router *gin.Engine) { func SetRelayRouter(router *gin.Engine) {
@ -44,7 +44,7 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.POST("/moderations", controller.Relay) relayV1Router.POST("/moderations", controller.Relay)
} }
relayMjRouter := router.Group("/mj") relayMjRouter := router.Group("/mj")
relayMjRouter.GET("/image/:id", controller.RelayMidjourneyImage) relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage)
relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute()) relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
{ {
relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney) relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)

53
service/channel.go Normal file
View File

@ -0,0 +1,53 @@
package service
import (
"fmt"
"net/http"
"one-api/common"
relaymodel "one-api/dto"
"one-api/model"
)
// disable & notify
func DisableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
notifyRootUser(subject, content)
}
func EnableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
notifyRootUser(subject, content)
}
func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool {
if !common.AutomaticDisableChannelEnabled {
return false
}
if err == nil {
return false
}
if statusCode == http.StatusUnauthorized {
return true
}
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" || err.Code == "billing_not_active" {
return true
}
return false
}
func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError) bool {
if !common.AutomaticEnableChannelEnabled {
return false
}
if err != nil {
return false
}
if openAIErr != nil {
return false
}
return true
}

29
service/error.go Normal file
View File

@ -0,0 +1,29 @@
package service
import (
"fmt"
"one-api/common"
"one-api/dto"
"strings"
)
// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
text := err.Error()
// 定义一个正则表达式匹配URL
if strings.Contains(text, "Post") {
common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败"
}
//避免暴露内部错误
openAIError := dto.OpenAIError{
Message: text,
Type: "new_api_error",
Code: code,
}
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: openAIError,
StatusCode: statusCode,
}
}

32
service/http_client.go Normal file
View File

@ -0,0 +1,32 @@
package service
import (
"net/http"
"one-api/common"
"time"
)
var httpClient *http.Client
var impatientHTTPClient *http.Client
func init() {
if common.RelayTimeout == 0 {
httpClient = &http.Client{}
} else {
httpClient = &http.Client{
Timeout: time.Duration(common.RelayTimeout) * time.Second,
}
}
impatientHTTPClient = &http.Client{
Timeout: 5 * time.Second,
}
}
func GetHttpClient() *http.Client {
return httpClient
}
func GetImpatientHttpClient() *http.Client {
return impatientHTTPClient
}

11
service/sse.go Normal file
View File

@ -0,0 +1,11 @@
package service
import "github.com/gin-gonic/gin"
func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}

View File

@ -1,27 +1,19 @@
package controller package service
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go" "github.com/pkoukk/tiktoken-go"
"image" "image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"io"
"log" "log"
"math" "math"
"net/http"
"one-api/common" "one-api/common"
"strconv" "one-api/dto"
"strings" "strings"
"unicode/utf8" "unicode/utf8"
) )
var stopFinishReason = "stop"
// tokenEncoderMap won't grow after initialization // tokenEncoderMap won't grow after initialization
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken var defaultTokenEncoder *tiktoken.Tiktoken
@ -70,7 +62,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
return len(tokenEncoder.Encode(text, nil, nil)) return len(tokenEncoder.Encode(text, nil, nil))
} }
func getImageToken(imageUrl *MessageImageUrl) (int, error) { func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
if imageUrl.Detail == "low" { if imageUrl.Detail == "low" {
return 85, nil return 85, nil
} }
@ -124,7 +116,7 @@ func getImageToken(imageUrl *MessageImageUrl) (int, error) {
return tiles*170 + 85, nil return tiles*170 + 85, nil
} }
func countTokenMessages(messages []Message, model string) (int, error) { func CountTokenMessages(messages []dto.Message, model string) (int, error) {
//recover when panic //recover when panic
tokenEncoder := getTokenEncoder(model) tokenEncoder := getTokenEncoder(model)
// Reference: // Reference:
@ -146,7 +138,7 @@ func countTokenMessages(messages []Message, model string) (int, error) {
tokenNum += tokensPerMessage tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Role) tokenNum += getTokenNum(tokenEncoder, message.Role)
if len(message.Content) > 0 { if len(message.Content) > 0 {
var arrayContent []MediaMessage var arrayContent []dto.MediaMessage
if err := json.Unmarshal(message.Content, &arrayContent); err != nil { if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
var stringContent string var stringContent string
if err := json.Unmarshal(message.Content, &stringContent); err != nil { if err := json.Unmarshal(message.Content, &stringContent); err != nil {
@ -163,7 +155,7 @@ func countTokenMessages(messages []Message, model string) (int, error) {
if m.Type == "image_url" { if m.Type == "image_url" {
var imageTokenNum int var imageTokenNum int
if str, ok := m.ImageUrl.(string); ok { if str, ok := m.ImageUrl.(string); ok {
imageTokenNum, err = getImageToken(&MessageImageUrl{Url: str, Detail: "auto"}) imageTokenNum, err = getImageToken(&dto.MessageImageUrl{Url: str, Detail: "auto"})
} else { } else {
imageUrlMap := m.ImageUrl.(map[string]interface{}) imageUrlMap := m.ImageUrl.(map[string]interface{})
detail, ok := imageUrlMap["detail"] detail, ok := imageUrlMap["detail"]
@ -172,7 +164,7 @@ func countTokenMessages(messages []Message, model string) (int, error) {
} else { } else {
imageUrlMap["detail"] = "auto" imageUrlMap["detail"] = "auto"
} }
imageUrl := MessageImageUrl{ imageUrl := dto.MessageImageUrl{
Url: imageUrlMap["url"].(string), Url: imageUrlMap["url"].(string),
Detail: imageUrlMap["detail"].(string), Detail: imageUrlMap["detail"].(string),
} }
@ -195,16 +187,16 @@ func countTokenMessages(messages []Message, model string) (int, error) {
return tokenNum, nil return tokenNum, nil
} }
func countTokenInput(input any, model string) int { func CountTokenInput(input any, model string) int {
switch v := input.(type) { switch v := input.(type) {
case string: case string:
return countTokenText(v, model) return CountTokenText(v, model)
case []string: case []string:
text := "" text := ""
for _, s := range v { for _, s := range v {
text += s text += s
} }
return countTokenText(text, model) return CountTokenText(text, model)
} }
return 0 return 0
} }
@ -213,118 +205,11 @@ func countAudioToken(text string, model string) int {
if strings.HasPrefix(model, "tts") { if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text) return utf8.RuneCountInString(text)
} else { } else {
return countTokenText(text, model) return CountTokenText(text, model)
} }
} }
func countTokenText(text string, model string) int { func CountTokenText(text string, model string) int {
tokenEncoder := getTokenEncoder(model) tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text) return getTokenNum(tokenEncoder, text)
} }
func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
text := err.Error()
// 定义一个正则表达式匹配URL
if strings.Contains(text, "Post") {
common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败"
}
//避免暴露内部错误
openAIError := OpenAIError{
Message: text,
Type: "new_api_error",
Code: code,
}
return &OpenAIErrorWithStatusCode{
OpenAIError: openAIError,
StatusCode: statusCode,
}
}
func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
if !common.AutomaticDisableChannelEnabled {
return false
}
if err == nil {
return false
}
if statusCode == http.StatusUnauthorized {
return true
}
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" || err.Code == "billing_not_active" {
return true
}
return false
}
func shouldEnableChannel(err error, openAIErr *OpenAIError) bool {
if !common.AutomaticEnableChannelEnabled {
return false
}
if err != nil {
return false
}
if openAIErr != nil {
return false
}
return true
}
func setEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}
func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) {
openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
OpenAIError: OpenAIError{
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
err = resp.Body.Close()
if err != nil {
return
}
var textResponse TextResponse
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return
}
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
return
}
func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case common.ChannelTypeOpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case common.ChannelTypeAzure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}
return fullRequestURL
}
func GetAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
return apiVersion
}

27
service/usage_helpr.go Normal file
View File

@ -0,0 +1,27 @@
package service
import (
"errors"
"one-api/dto"
"one-api/relay/constant"
)
func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) {
switch relayMode {
case constant.RelayModeChatCompletions:
return CountTokenMessages(textRequest.Messages, textRequest.Model)
case constant.RelayModeCompletions:
return CountTokenInput(textRequest.Prompt, textRequest.Model), nil
case constant.RelayModeModerations:
return CountTokenInput(textRequest.Input, textRequest.Model), nil
}
return 0, errors.New("unknown relay mode")
}
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
usage := &dto.Usage{}
usage.PromptTokens = promptTokens
usage.CompletionTokens = CountTokenText(responseText, modeName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage
}

17
service/user_notify.go Normal file
View File

@ -0,0 +1,17 @@
package service
import (
"fmt"
"one-api/common"
"one-api/model"
)
func notifyRootUser(subject string, content string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}