mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-12 03:13:41 +08:00
Refactor codebase, introduce relaymode package, update constants and improve consistency
- Refactor constant definitions and organization - Clean up package level variables and functions - Introduce new `relaymode` and `apitype` packages for constant definitions - Refactor and simplify code in several packages including `openai`, `relay/channel/baidu`, `relay/util`, `relay/controller`, `relay/channeltype` - Add helper functions in `relay/channeltype` package to convert channel type constants to corresponding API type constants - Remove deprecated functions such as `ResponseText2Usage` from `relay/channel/openai/helper.go` - Modify code in `relay/util/validation.go` and related files to use new `validator.ValidateTextRequest` function - Rename `util` package to `relaymode` and update related imports in several packages
This commit is contained in:
41
relay/adaptor.go
Normal file
41
relay/adaptor.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/aiproxy"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/ollama"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/palm"
|
||||
"github.com/songquanpeng/one-api/relay/apitype"
|
||||
)
|
||||
|
||||
func GetAdaptor(apiType int) adaptor.Adaptor {
|
||||
switch apiType {
|
||||
case apitype.AIProxyLibrary:
|
||||
return &aiproxy.Adaptor{}
|
||||
// case apitype.Ali:
|
||||
// return &ali.Adaptor{}
|
||||
case apitype.Anthropic:
|
||||
return &anthropic.Adaptor{}
|
||||
// case apitype.Baidu:
|
||||
// return &baidu.Adaptor{}
|
||||
case apitype.Gemini:
|
||||
return &gemini.Adaptor{}
|
||||
case apitype.OpenAI:
|
||||
return &openai.Adaptor{}
|
||||
case apitype.PaLM:
|
||||
return &palm.Adaptor{}
|
||||
// case apitype.Tencent:
|
||||
// return &tencent.Adaptor{}
|
||||
// case apitype.Xunfei:
|
||||
// return &xunfei.Adaptor{}
|
||||
// case apitype.Zhipu:
|
||||
// return &zhipu.Adaptor{}
|
||||
case apitype.Ollama:
|
||||
return &ollama.Adaptor{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"fmt"
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
@@ -15,16 +15,16 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
@@ -34,15 +34,22 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
aiProxyLibraryRequest := ConvertRequest(*request)
|
||||
aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID)
|
||||
aiProxyLibraryRequest.LibraryId = c.GetString(config.KeyLibraryID)
|
||||
return aiProxyLibraryRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
@@ -1,6 +1,6 @@
|
||||
package aiproxy
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
import "github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
|
||||
var ModelList = []string{""}
|
||||
|
||||
@@ -13,7 +13,8 @@ import (
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
@@ -54,7 +55,7 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
@@ -67,7 +68,7 @@ func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletion
|
||||
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
return &openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "",
|
||||
@@ -79,7 +80,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = response.Content
|
||||
return &openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: response.Model,
|
||||
105
relay/adaptor/ali/adaptor.go
Normal file
105
relay/adaptor/ali/adaptor.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package ali
|
||||
|
||||
// import (
|
||||
// "github.com/Laisky/errors/v2"
|
||||
// "fmt"
|
||||
// "github.com/gin-gonic/gin"
|
||||
// "github.com/songquanpeng/one-api/common/config"
|
||||
// "github.com/songquanpeng/one-api/relay/adaptor"
|
||||
// "github.com/songquanpeng/one-api/relay/meta"
|
||||
// "github.com/songquanpeng/one-api/relay/model"
|
||||
// "github.com/songquanpeng/one-api/relay/relaymode"
|
||||
// "io"
|
||||
// "net/http"
|
||||
// )
|
||||
|
||||
// // https://help.aliyun.com/zh/dashscope/developer-reference/api-details
|
||||
|
||||
// type Adaptor struct {
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
// fullRequestURL := ""
|
||||
// switch meta.Mode {
|
||||
// case relaymode.Embeddings:
|
||||
// fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
|
||||
// case relaymode.ImagesGenerations:
|
||||
// fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL)
|
||||
// default:
|
||||
// fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
|
||||
// }
|
||||
|
||||
// return fullRequestURL, nil
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
// adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
// if meta.IsStream {
|
||||
// req.Header.Set("Accept", "text/event-stream")
|
||||
// req.Header.Set("X-DashScope-SSE", "enable")
|
||||
// }
|
||||
// req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
|
||||
// if meta.Mode == relaymode.ImagesGenerations {
|
||||
// req.Header.Set("X-DashScope-Async", "enable")
|
||||
// }
|
||||
// if c.GetString(config.KeyPlugin) != "" {
|
||||
// req.Header.Set("X-DashScope-Plugin", c.GetString(config.KeyPlugin))
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
// if request == nil {
|
||||
// return nil, errors.New("request is nil")
|
||||
// }
|
||||
// switch relayMode {
|
||||
// case relaymode.Embeddings:
|
||||
// aliEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
// return aliEmbeddingRequest, nil
|
||||
// default:
|
||||
// aliRequest := ConvertRequest(*request)
|
||||
// return aliRequest, nil
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
// if request == nil {
|
||||
// return nil, errors.New("request is nil")
|
||||
// }
|
||||
|
||||
// aliRequest := ConvertImageRequest(*request)
|
||||
// return aliRequest, nil
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
// return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
// if meta.IsStream {
|
||||
// err, usage = StreamHandler(c, resp)
|
||||
// } else {
|
||||
// switch meta.Mode {
|
||||
// case relaymode.Embeddings:
|
||||
// err, usage = EmbeddingHandler(c, resp)
|
||||
// case relaymode.ImagesGenerations:
|
||||
// err, usage = ImageHandler(c, resp)
|
||||
// default:
|
||||
// err, usage = Handler(c, resp)
|
||||
// }
|
||||
// }
|
||||
// return
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) GetModelList() []string {
|
||||
// return ModelList
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) GetChannelName() string {
|
||||
// return "ali"
|
||||
// }
|
||||
@@ -3,4 +3,5 @@ package ali
|
||||
var ModelList = []string{
|
||||
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
|
||||
"text-embedding-v1",
|
||||
"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1",
|
||||
}
|
||||
192
relay/adaptor/ali/image.go
Normal file
192
relay/adaptor/ali/image.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
responseFormat := c.GetString("response_format")
|
||||
|
||||
var aliTaskResponse TaskResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &aliTaskResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliTaskResponse.Message != "" {
|
||||
logger.SysError("aliAsyncTask err: " + string(responseBody))
|
||||
return openai.ErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
aliResponse, _, err := asyncTaskWait(aliTaskResponse.Output.TaskId, apiKey)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: aliResponse.Output.Message,
|
||||
Type: "ali_error",
|
||||
Param: "",
|
||||
Code: aliResponse.Output.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
fullTextResponse := responseAli2OpenAIImage(aliResponse, responseFormat)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func asyncTask(taskID string, key string) (*TaskResponse, error, []byte) {
|
||||
url := fmt.Sprintf("https://dashscope.aliyuncs.com/api/v1/tasks/%s", taskID)
|
||||
|
||||
var aliResponse TaskResponse
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return &aliResponse, err, nil
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.SysError("aliAsyncTask client.Do err: " + err.Error())
|
||||
return &aliResponse, err, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
var response TaskResponse
|
||||
err = json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
logger.SysError("aliAsyncTask NewDecoder err: " + err.Error())
|
||||
return &aliResponse, err, nil
|
||||
}
|
||||
|
||||
return &response, nil, responseBody
|
||||
}
|
||||
|
||||
func asyncTaskWait(taskID string, key string) (*TaskResponse, []byte, error) {
|
||||
waitSeconds := 2
|
||||
step := 0
|
||||
maxStep := 20
|
||||
|
||||
var taskResponse TaskResponse
|
||||
var responseBody []byte
|
||||
|
||||
for {
|
||||
step++
|
||||
rsp, err, body := asyncTask(taskID, key)
|
||||
responseBody = body
|
||||
if err != nil {
|
||||
return &taskResponse, responseBody, err
|
||||
}
|
||||
|
||||
if rsp.Output.TaskStatus == "" {
|
||||
return &taskResponse, responseBody, nil
|
||||
}
|
||||
|
||||
switch rsp.Output.TaskStatus {
|
||||
case "FAILED":
|
||||
fallthrough
|
||||
case "CANCELED":
|
||||
fallthrough
|
||||
case "SUCCEEDED":
|
||||
fallthrough
|
||||
case "UNKNOWN":
|
||||
return rsp, responseBody, nil
|
||||
}
|
||||
if step >= maxStep {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Duration(waitSeconds) * time.Second)
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
|
||||
}
|
||||
|
||||
func responseAli2OpenAIImage(response *TaskResponse, responseFormat string) *openai.ImageResponse {
|
||||
imageResponse := openai.ImageResponse{
|
||||
Created: helper.GetTimestamp(),
|
||||
}
|
||||
|
||||
for _, data := range response.Output.Results {
|
||||
var b64Json string
|
||||
if responseFormat == "b64_json" {
|
||||
// 读取 data.Url 的图片数据并转存到 b64Json
|
||||
imageData, err := getImageData(data.Url)
|
||||
if err != nil {
|
||||
// 处理获取图片数据失败的情况
|
||||
logger.SysError("getImageData Error getting image data: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
// 将图片数据转为 Base64 编码的字符串
|
||||
b64Json = Base64Encode(imageData)
|
||||
} else {
|
||||
// 如果 responseFormat 不是 "b64_json",则直接使用 data.B64Image
|
||||
b64Json = data.B64Image
|
||||
}
|
||||
|
||||
imageResponse.Data = append(imageResponse.Data, openai.ImageData{
|
||||
Url: data.Url,
|
||||
B64Json: b64Json,
|
||||
RevisedPrompt: "",
|
||||
})
|
||||
}
|
||||
return &imageResponse
|
||||
}
|
||||
|
||||
func getImageData(url string) ([]byte, error) {
|
||||
response, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
imageData, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return imageData, nil
|
||||
}
|
||||
|
||||
func Base64Encode(data []byte) string {
|
||||
b64Json := base64.StdEncoding.EncodeToString(data)
|
||||
return b64Json
|
||||
}
|
||||
154
relay/adaptor/ali/model.go
Normal file
154
relay/adaptor/ali/model.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type Input struct {
|
||||
//Prompt string `json:"prompt"`
|
||||
Messages []Message `json:"messages"`
|
||||
}
|
||||
|
||||
type Parameters struct {
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
EnableSearch bool `json:"enable_search,omitempty"`
|
||||
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
ResultFormat string `json:"result_format,omitempty"`
|
||||
Tools []model.Tool `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input Input `json:"input"`
|
||||
Parameters Parameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Prompt string `json:"prompt"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||
} `json:"input"`
|
||||
Parameters struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Steps string `json:"steps,omitempty"`
|
||||
Scale string `json:"scale,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
type TaskResponse struct {
|
||||
StatusCode int `json:"status_code,omitempty"`
|
||||
RequestId string `json:"request_id,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Output struct {
|
||||
TaskId string `json:"task_id,omitempty"`
|
||||
TaskStatus string `json:"task_status,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Results []struct {
|
||||
B64Image string `json:"b64_image,omitempty"`
|
||||
Url string `json:"url,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
} `json:"results,omitempty"`
|
||||
TaskMetrics struct {
|
||||
Total int `json:"TOTAL,omitempty"`
|
||||
Succeeded int `json:"SUCCEEDED,omitempty"`
|
||||
Failed int `json:"FAILED,omitempty"`
|
||||
} `json:"task_metrics,omitempty"`
|
||||
} `json:"output,omitempty"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type Header struct {
|
||||
Action string `json:"action,omitempty"`
|
||||
Streaming string `json:"streaming,omitempty"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
Event string `json:"event,omitempty"`
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
Attributes any `json:"attributes,omitempty"`
|
||||
}
|
||||
|
||||
type Payload struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Task string `json:"task,omitempty"`
|
||||
TaskGroup string `json:"task_group,omitempty"`
|
||||
Function string `json:"function,omitempty"`
|
||||
Parameters struct {
|
||||
SampleRate int `json:"sample_rate,omitempty"`
|
||||
Rate float64 `json:"rate,omitempty"`
|
||||
Format string `json:"format,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
Input struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
} `json:"input,omitempty"`
|
||||
Usage struct {
|
||||
Characters int `json:"characters,omitempty"`
|
||||
} `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type WSSMessage struct {
|
||||
Header Header `json:"header,omitempty"`
|
||||
Payload Payload `json:"payload,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
} `json:"input"`
|
||||
Parameters *struct {
|
||||
TextType string `json:"text_type,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
TextIndex int `json:"text_index"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Output struct {
|
||||
Embeddings []Embedding `json:"embeddings"`
|
||||
} `json:"output"`
|
||||
Usage Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type Output struct {
|
||||
//Text string `json:"text"`
|
||||
//FinishReason string `json:"finish_reason"`
|
||||
Choices []openai.TextResponseChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Output Output `json:"output"`
|
||||
Usage Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
@@ -2,31 +2,31 @@ package anthropic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/Laisky/errors/v2"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
// https://docs.anthropic.com/claude/reference/messages_post
|
||||
// anthopic migrate to Message API
|
||||
// https://docs.anthropic.com/claude/reference/messages_post
|
||||
// anthopic migrate to Message API
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("x-api-key", meta.APIKey)
|
||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
@@ -46,11 +46,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/image"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
15
relay/adaptor/azure/helper.go
Normal file
15
relay/adaptor/azure/helper.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
)
|
||||
|
||||
func GetAPIVersion(c *gin.Context) string {
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion == "" {
|
||||
apiVersion = c.GetString(config.KeyAPIVersion)
|
||||
}
|
||||
return apiVersion
|
||||
}
|
||||
20
relay/adaptor/baidu/constants.go
Normal file
20
relay/adaptor/baidu/constants.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package baidu
|
||||
|
||||
var ModelList = []string{
|
||||
"ERNIE-4.0-8K",
|
||||
"ERNIE-3.5-8K",
|
||||
"ERNIE-3.5-8K-0205",
|
||||
"ERNIE-3.5-8K-1222",
|
||||
"ERNIE-Bot-8K",
|
||||
"ERNIE-3.5-4K-0205",
|
||||
"ERNIE-Speed-8K",
|
||||
"ERNIE-Speed-128K",
|
||||
"ERNIE-Lite-8K-0922",
|
||||
"ERNIE-Lite-8K-0308",
|
||||
"ERNIE-Tiny-8K",
|
||||
"BLOOMZ-7B",
|
||||
"Embedding-V1",
|
||||
"bge-large-zh",
|
||||
"bge-large-en",
|
||||
"tao-8k",
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package channel
|
||||
package adaptor
|
||||
|
||||
import (
|
||||
"io"
|
||||
@@ -6,10 +6,11 @@ import (
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"github.com/songquanpeng/one-api/relay/client"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
)
|
||||
|
||||
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) {
|
||||
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) {
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if meta.IsStream && c.Request.Header.Get("Accept") == "" {
|
||||
@@ -17,7 +18,7 @@ func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.Rela
|
||||
}
|
||||
}
|
||||
|
||||
func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
fullRequestURL, err := a.GetRequestURL(meta)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get request url failed")
|
||||
@@ -43,7 +44,7 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBod
|
||||
}
|
||||
|
||||
func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
||||
resp, err := util.HTTPClient.Do(req)
|
||||
resp, err := client.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -8,21 +8,21 @@ import (
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
channelhelper "github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
version := helper.AssignOrDefault(meta.APIVersion, "v1beta")
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
version := helper.AssignOrDefault(meta.APIVersion, "v1")
|
||||
action := "generateContent"
|
||||
if meta.IsStream {
|
||||
action = "streamGenerateContent"
|
||||
@@ -30,7 +30,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s?key=%s", meta.BaseURL, version, meta.ActualModelName, action, meta.APIKey), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
channelhelper.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("x-goog-api-key", meta.APIKey)
|
||||
req.URL.Query().Add("key", meta.APIKey)
|
||||
@@ -44,11 +44,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = StreamHandler(c, resp)
|
||||
@@ -1,21 +1,24 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/image"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
|
||||
@@ -82,13 +85,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
if imageNum > VisionMaxImageNum {
|
||||
continue
|
||||
}
|
||||
mimeType, data, err := image.GetImageFromUrl(part.ImageURL.Url)
|
||||
if err != nil {
|
||||
logger.Warn(context.TODO(),
|
||||
fmt.Sprintf("get image from url %s got %+v", part.ImageURL.Url, err))
|
||||
continue
|
||||
}
|
||||
|
||||
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
|
||||
parts = append(parts, Part{
|
||||
InlineData: &InlineData{
|
||||
MimeType: mimeType,
|
||||
@@ -97,9 +94,6 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info(context.TODO(),
|
||||
fmt.Sprintf("send %d messages to gemini with %d images", len(parts), imageNum))
|
||||
content.Parts = parts
|
||||
|
||||
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
||||
@@ -163,7 +157,7 @@ type ChatPromptFeedback struct {
|
||||
|
||||
func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
|
||||
@@ -196,182 +190,73 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC
|
||||
return &response
|
||||
}
|
||||
|
||||
// [{
|
||||
// "candidates": [
|
||||
// {
|
||||
// "content": {
|
||||
// "parts": [
|
||||
// {
|
||||
// "text": "```go \n\n// Package ratelimit implements tokens bucket algorithm.\npackage rate"
|
||||
// }
|
||||
// ],
|
||||
// "role": "model"
|
||||
// },
|
||||
// "finishReason": "STOP",
|
||||
// "index": 0,
|
||||
// "safetyRatings": [
|
||||
// {
|
||||
// "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
// "probability": "NEGLIGIBLE"
|
||||
// },
|
||||
// {
|
||||
// "category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
// "probability": "NEGLIGIBLE"
|
||||
// },
|
||||
// {
|
||||
// "category": "HARM_CATEGORY_HARASSMENT",
|
||||
// "probability": "NEGLIGIBLE"
|
||||
// },
|
||||
// {
|
||||
// "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
// "probability": "NEGLIGIBLE"
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// ],
|
||||
// "promptFeedback": {
|
||||
// "safetyRatings": [
|
||||
// {
|
||||
// "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
// "probability": "NEGLIGIBLE"
|
||||
// },
|
||||
// {
|
||||
// "category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
// "probability": "NEGLIGIBLE"
|
||||
// },
|
||||
// {
|
||||
// "category": "HARM_CATEGORY_HARASSMENT",
|
||||
// "probability": "NEGLIGIBLE"
|
||||
// },
|
||||
// {
|
||||
// "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
// "probability": "NEGLIGIBLE"
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// }]
|
||||
type GeminiStreamResp struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
Role string `json:"role"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int64 `json:"index"`
|
||||
} `json:"candidates"`
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read upstream's body", http.StatusInternalServerError), responseText
|
||||
}
|
||||
|
||||
var respData []GeminiStreamResp
|
||||
if err = json.Unmarshal(respBody, &respData); err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal upstream's body", http.StatusInternalServerError), responseText
|
||||
}
|
||||
|
||||
for _, chunk := range respData {
|
||||
for _, cad := range chunk.Candidates {
|
||||
for _, part := range cad.Content.Parts {
|
||||
responseText += part.Text
|
||||
}
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = responseText
|
||||
resp2cli, err := json.Marshal(&openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "gemini-pro",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
data = strings.TrimSpace(data)
|
||||
if !strings.HasPrefix(data, "\"text\": \"") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "\"text\": \"")
|
||||
data = strings.TrimSuffix(data, "\"")
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
// this is used to prevent annoying \ related format bug
|
||||
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
||||
type dummyStruct struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
var dummy dummyStruct
|
||||
err := json.Unmarshal([]byte(data), &dummy)
|
||||
responseText += dummy.Content
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = dummy.Content
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "gemini-pro",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal upstream's body", http.StatusInternalServerError), responseText
|
||||
}
|
||||
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(resp2cli)})
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
|
||||
// dataChan := make(chan string)
|
||||
// stopChan := make(chan bool)
|
||||
// scanner := bufio.NewScanner(resp.Body)
|
||||
// scanner.Split(bufio.ScanLines)
|
||||
// // scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
// // if atEOF && len(data) == 0 {
|
||||
// // return 0, nil, nil
|
||||
// // }
|
||||
// // if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
// // return i + 1, data[0:i], nil
|
||||
// // }
|
||||
// // if atEOF {
|
||||
// // return len(data), data, nil
|
||||
// // }
|
||||
// // return 0, nil, nil
|
||||
// // })
|
||||
// go func() {
|
||||
// var content string
|
||||
// for scanner.Scan() {
|
||||
// line := strings.TrimSpace(scanner.Text())
|
||||
// fmt.Printf("> gemini got line: %s\n", line)
|
||||
// content += line
|
||||
// // if !strings.HasPrefix(data, "\"text\": \"") {
|
||||
// // continue
|
||||
// // }
|
||||
|
||||
// // data = strings.TrimPrefix(data, "\"text\": \"")
|
||||
// // data = strings.TrimSuffix(data, "\"")
|
||||
// // dataChan <- data
|
||||
// }
|
||||
|
||||
// dataChan <- content
|
||||
// stopChan <- true
|
||||
// }()
|
||||
// common.SetEventStreamHeaders(c)
|
||||
// c.Stream(func(w io.Writer) bool {
|
||||
// select {
|
||||
// case data := <-dataChan:
|
||||
// // this is used to prevent annoying \ related format bug
|
||||
// data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
||||
// type dummyStruct struct {
|
||||
// Content string `json:"content"`
|
||||
// }
|
||||
// var dummy dummyStruct
|
||||
// err := json.Unmarshal([]byte(data), &dummy)
|
||||
// responseText += dummy.Content
|
||||
// var choice openai.ChatCompletionsStreamResponseChoice
|
||||
// choice.Delta.Content = dummy.Content
|
||||
// response := openai.ChatCompletionsStreamResponse{
|
||||
// Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
|
||||
// Object: "chat.completion.chunk",
|
||||
// Created: helper.GetTimestamp(),
|
||||
// Model: "gemini-pro",
|
||||
// Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
// }
|
||||
// jsonResponse, err := json.Marshal(response)
|
||||
// if err != nil {
|
||||
// logger.SysError("error marshalling stream response: " + err.Error())
|
||||
// return true
|
||||
// }
|
||||
// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
// return true
|
||||
// case <-stopChan:
|
||||
// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
// return false
|
||||
// }
|
||||
// })
|
||||
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
21
relay/adaptor/interface.go
Normal file
21
relay/adaptor/interface.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package adaptor
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor interface {
|
||||
Init(meta *meta.Meta)
|
||||
GetRequestURL(meta *meta.Meta) (string, error)
|
||||
SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error
|
||||
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
|
||||
ConvertImageRequest(request *model.ImageRequest) (any, error)
|
||||
DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
}
|
||||
14
relay/adaptor/minimax/main.go
Normal file
14
relay/adaptor/minimax/main.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package minimax
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
func GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
if meta.Mode == relaymode.ChatCompletions {
|
||||
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil
|
||||
}
|
||||
return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode)
|
||||
}
|
||||
@@ -1,36 +1,36 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
// https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
|
||||
if meta.Mode == constant.RelayModeEmbeddings {
|
||||
if meta.Mode == relaymode.Embeddings {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL)
|
||||
}
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
@@ -40,7 +40,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
case relaymode.Embeddings:
|
||||
ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
return ollamaEmbeddingRequest, nil
|
||||
default:
|
||||
@@ -48,16 +48,23 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
switch meta.Mode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
case relaymode.Embeddings:
|
||||
err, usage = EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp)
|
||||
@@ -5,15 +5,16 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
@@ -51,7 +52,7 @@ func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
choice.FinishReason = "stop"
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
@@ -72,7 +73,7 @@ func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompl
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
}
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: ollamaResponse.Model,
|
||||
@@ -4,11 +4,12 @@ import (
|
||||
"fmt"
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/minimax"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -18,13 +19,20 @@ type Adaptor struct {
|
||||
ChannelType int
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
a.ChannelType = meta.ChannelType
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
switch meta.ChannelType {
|
||||
case common.ChannelTypeAzure:
|
||||
case channeltype.Azure:
|
||||
if meta.Mode == relaymode.ImagesGenerations {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
||||
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
|
||||
fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.APIVersion)
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||
requestURL := strings.Split(meta.RequestURLPath, "?")[0]
|
||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion)
|
||||
@@ -34,22 +42,22 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
//https://github.com/songquanpeng/one-api/issues/1191
|
||||
// {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version}
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
|
||||
case common.ChannelTypeMinimax:
|
||||
return GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
|
||||
case channeltype.Minimax:
|
||||
return minimax.GetRequestURL(meta)
|
||||
default:
|
||||
return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
|
||||
return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
if meta.ChannelType == common.ChannelTypeAzure {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
if meta.ChannelType == channeltype.Azure {
|
||||
req.Header.Set("api-key", meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
if meta.ChannelType == common.ChannelTypeOpenRouter {
|
||||
if meta.ChannelType == channeltype.OpenRouter {
|
||||
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||||
req.Header.Set("X-Title", "One API")
|
||||
}
|
||||
@@ -63,11 +71,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText, usage = StreamHandler(c, resp, meta.Mode)
|
||||
@@ -75,7 +90,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
|
||||
usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
}
|
||||
} else {
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
switch meta.Mode {
|
||||
case relaymode.ImagesGenerations:
|
||||
err, _ = ImageHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
50
relay/adaptor/openai/compatible.go
Normal file
50
relay/adaptor/openai/compatible.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/ai360"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/baichuan"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/groq"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/lingyiwanwu"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/mistral"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/moonshot"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/stepfun"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
)
|
||||
|
||||
var CompatibleChannels = []int{
|
||||
channeltype.Azure,
|
||||
channeltype.AI360,
|
||||
channeltype.Moonshot,
|
||||
channeltype.Baichuan,
|
||||
channeltype.Minimax,
|
||||
channeltype.Mistral,
|
||||
channeltype.Groq,
|
||||
channeltype.LingYiWanWu,
|
||||
channeltype.StepFun,
|
||||
}
|
||||
|
||||
func GetCompatibleChannelMeta(channelType int) (string, []string) {
|
||||
switch channelType {
|
||||
case channeltype.Azure:
|
||||
return "azure", ModelList
|
||||
case channeltype.AI360:
|
||||
return "360", ai360.ModelList
|
||||
case channeltype.Moonshot:
|
||||
return "moonshot", moonshot.ModelList
|
||||
case channeltype.Baichuan:
|
||||
return "baichuan", baichuan.ModelList
|
||||
case channeltype.Minimax:
|
||||
return "minimax", minimax.ModelList
|
||||
case channeltype.Mistral:
|
||||
return "mistralai", mistral.ModelList
|
||||
case channeltype.Groq:
|
||||
return "groq", groq.ModelList
|
||||
case channeltype.LingYiWanWu:
|
||||
return "lingyiwanwu", lingyiwanwu.ModelList
|
||||
case channeltype.StepFun:
|
||||
return "stepfun", stepfun.ModelList
|
||||
default:
|
||||
return "openai", ModelList
|
||||
}
|
||||
}
|
||||
30
relay/adaptor/openai/helper.go
Normal file
30
relay/adaptor/openai/helper.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
|
||||
usage := &model.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
usage.CompletionTokens = CountTokenText(responseText, modeName)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return usage
|
||||
}
|
||||
|
||||
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
|
||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||
switch channelType {
|
||||
case channeltype.OpenAI:
|
||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
||||
case channeltype.Azure:
|
||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
||||
}
|
||||
}
|
||||
return fullRequestURL
|
||||
}
|
||||
44
relay/adaptor/openai/image.go
Normal file
44
relay/adaptor/openai/image.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var imageResponse ImageResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &imageResponse)
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/conv"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -46,7 +46,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
switch relayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
case relaymode.ChatCompletions:
|
||||
var streamResponse ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
@@ -59,7 +59,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
|
||||
if streamResponse.Usage != nil {
|
||||
usage = streamResponse.Usage
|
||||
}
|
||||
case constant.RelayModeCompletions:
|
||||
case relaymode.Completions:
|
||||
var streamResponse CompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
@@ -110,11 +110,16 @@ type EmbeddingResponse struct {
|
||||
model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
Url string `json:"url,omitempty"`
|
||||
B64Json string `json:"b64_json,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
Created int `json:"created"`
|
||||
Data []struct {
|
||||
Url string `json:"url"`
|
||||
}
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageData `json:"data"`
|
||||
//model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseChoice struct {
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"fmt"
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/image"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"math"
|
||||
"strings"
|
||||
@@ -28,7 +28,7 @@ func InitTokenEncoders() {
|
||||
if err != nil {
|
||||
logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
||||
}
|
||||
for model := range common.ModelRatio {
|
||||
for model := range billingratio.ModelRatio {
|
||||
if strings.HasPrefix(model, "gpt-3.5") {
|
||||
tokenEncoderMap[model] = gpt35TokenEncoder
|
||||
} else if strings.HasPrefix(model, "gpt-4") {
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"fmt"
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
@@ -15,16 +15,16 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("x-goog-api-key", meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
@@ -36,11 +36,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = StreamHandler(c, resp)
|
||||
@@ -7,7 +7,8 @@ import (
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
@@ -74,7 +75,7 @@ func streamResponsePaLM2OpenAI(palmResponse *ChatResponse) *openai.ChatCompletio
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID())
|
||||
createdTime := helper.GetTimestamp()
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
7
relay/adaptor/stepfun/constants.go
Normal file
7
relay/adaptor/stepfun/constants.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package stepfun
|
||||
|
||||
var ModelList = []string{
|
||||
"step-1-32k",
|
||||
"step-1v-32k",
|
||||
"step-1-200k",
|
||||
}
|
||||
238
relay/adaptor/tencent/main.go
Normal file
238
relay/adaptor/tencent/main.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package tencent
|
||||
|
||||
// import (
|
||||
// "bufio"
|
||||
// "crypto/hmac"
|
||||
// "crypto/sha1"
|
||||
// "encoding/base64"
|
||||
// "encoding/json"
|
||||
// "github.com/Laisky/errors/v2"
|
||||
// "fmt"
|
||||
// "github.com/gin-gonic/gin"
|
||||
// "github.com/songquanpeng/one-api/common"
|
||||
// "github.com/songquanpeng/one-api/common/helper"
|
||||
// "github.com/songquanpeng/one-api/common/logger"
|
||||
// "github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
// "github.com/songquanpeng/one-api/relay/constant"
|
||||
// "github.com/songquanpeng/one-api/relay/model"
|
||||
// "io"
|
||||
// "net/http"
|
||||
// "sort"
|
||||
// "strconv"
|
||||
// "strings"
|
||||
// )
|
||||
|
||||
// // https://cloud.tencent.com/document/product/1729/97732
|
||||
|
||||
// func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
// messages := make([]Message, 0, len(request.Messages))
|
||||
// for i := 0; i < len(request.Messages); i++ {
|
||||
// message := request.Messages[i]
|
||||
// if message.Role == "system" {
|
||||
// messages = append(messages, Message{
|
||||
// Role: "user",
|
||||
// Content: message.StringContent(),
|
||||
// })
|
||||
// messages = append(messages, Message{
|
||||
// Role: "assistant",
|
||||
// Content: "Okay",
|
||||
// })
|
||||
// continue
|
||||
// }
|
||||
// messages = append(messages, Message{
|
||||
// Content: message.StringContent(),
|
||||
// Role: message.Role,
|
||||
// })
|
||||
// }
|
||||
// stream := 0
|
||||
// if request.Stream {
|
||||
// stream = 1
|
||||
// }
|
||||
// return &ChatRequest{
|
||||
// Timestamp: helper.GetTimestamp(),
|
||||
// Expired: helper.GetTimestamp() + 24*60*60,
|
||||
// QueryID: helper.GetUUID(),
|
||||
// Temperature: request.Temperature,
|
||||
// TopP: request.TopP,
|
||||
// Stream: stream,
|
||||
// Messages: messages,
|
||||
// }
|
||||
// }
|
||||
|
||||
// func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
// fullTextResponse := openai.TextResponse{
|
||||
// Object: "chat.completion",
|
||||
// Created: helper.GetTimestamp(),
|
||||
// Usage: response.Usage,
|
||||
// }
|
||||
// if len(response.Choices) > 0 {
|
||||
// choice := openai.TextResponseChoice{
|
||||
// Index: 0,
|
||||
// Message: model.Message{
|
||||
// Role: "assistant",
|
||||
// Content: response.Choices[0].Messages.Content,
|
||||
// },
|
||||
// FinishReason: response.Choices[0].FinishReason,
|
||||
// }
|
||||
// fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
// }
|
||||
// return &fullTextResponse
|
||||
// }
|
||||
|
||||
// func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
// response := openai.ChatCompletionsStreamResponse{
|
||||
// Object: "chat.completion.chunk",
|
||||
// Created: helper.GetTimestamp(),
|
||||
// Model: "tencent-hunyuan",
|
||||
// }
|
||||
// if len(TencentResponse.Choices) > 0 {
|
||||
// var choice openai.ChatCompletionsStreamResponseChoice
|
||||
// choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
||||
// if TencentResponse.Choices[0].FinishReason == "stop" {
|
||||
// choice.FinishReason = &constant.StopFinishReason
|
||||
// }
|
||||
// response.Choices = append(response.Choices, choice)
|
||||
// }
|
||||
// return &response
|
||||
// }
|
||||
|
||||
// func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||
// var responseText string
|
||||
// scanner := bufio.NewScanner(resp.Body)
|
||||
// scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
// if atEOF && len(data) == 0 {
|
||||
// return 0, nil, nil
|
||||
// }
|
||||
// if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
// return i + 1, data[0:i], nil
|
||||
// }
|
||||
// if atEOF {
|
||||
// return len(data), data, nil
|
||||
// }
|
||||
// return 0, nil, nil
|
||||
// })
|
||||
// dataChan := make(chan string)
|
||||
// stopChan := make(chan bool)
|
||||
// go func() {
|
||||
// for scanner.Scan() {
|
||||
// data := scanner.Text()
|
||||
// if len(data) < 5 { // ignore blank line or wrong format
|
||||
// continue
|
||||
// }
|
||||
// if data[:5] != "data:" {
|
||||
// continue
|
||||
// }
|
||||
// data = data[5:]
|
||||
// dataChan <- data
|
||||
// }
|
||||
// stopChan <- true
|
||||
// }()
|
||||
// common.SetEventStreamHeaders(c)
|
||||
// c.Stream(func(w io.Writer) bool {
|
||||
// select {
|
||||
// case data := <-dataChan:
|
||||
// var TencentResponse ChatResponse
|
||||
// err := json.Unmarshal([]byte(data), &TencentResponse)
|
||||
// if err != nil {
|
||||
// logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
// return true
|
||||
// }
|
||||
// response := streamResponseTencent2OpenAI(&TencentResponse)
|
||||
// if len(response.Choices) != 0 {
|
||||
// responseText += response.Choices[0].Delta.Content
|
||||
// }
|
||||
// jsonResponse, err := json.Marshal(response)
|
||||
// if err != nil {
|
||||
// logger.SysError("error marshalling stream response: " + err.Error())
|
||||
// return true
|
||||
// }
|
||||
// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
// return true
|
||||
// case <-stopChan:
|
||||
// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
// return false
|
||||
// }
|
||||
// })
|
||||
// err := resp.Body.Close()
|
||||
// if err != nil {
|
||||
// return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
// }
|
||||
// return nil, responseText
|
||||
// }
|
||||
|
||||
// func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
// var TencentResponse ChatResponse
|
||||
// responseBody, err := io.ReadAll(resp.Body)
|
||||
// if err != nil {
|
||||
// return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// err = resp.Body.Close()
|
||||
// if err != nil {
|
||||
// return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// err = json.Unmarshal(responseBody, &TencentResponse)
|
||||
// if err != nil {
|
||||
// return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// if TencentResponse.Error.Code != 0 {
|
||||
// return &model.ErrorWithStatusCode{
|
||||
// Error: model.Error{
|
||||
// Message: TencentResponse.Error.Message,
|
||||
// Code: TencentResponse.Error.Code,
|
||||
// },
|
||||
// StatusCode: resp.StatusCode,
|
||||
// }, nil
|
||||
// }
|
||||
// fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
||||
// fullTextResponse.Model = "hunyuan"
|
||||
// jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
// if err != nil {
|
||||
// return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// c.Writer.Header().Set("Content-Type", "application/json")
|
||||
// c.Writer.WriteHeader(resp.StatusCode)
|
||||
// _, err = c.Writer.Write(jsonResponse)
|
||||
// if err != nil {
|
||||
// return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// return nil, &fullTextResponse.Usage
|
||||
// }
|
||||
|
||||
// func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) {
|
||||
// parts := strings.Split(config, "|")
|
||||
// if len(parts) != 3 {
|
||||
// err = errors.New("invalid tencent config")
|
||||
// return
|
||||
// }
|
||||
// appId, err = strconv.ParseInt(parts[0], 10, 64)
|
||||
// secretId = parts[1]
|
||||
// secretKey = parts[2]
|
||||
// return
|
||||
// }
|
||||
|
||||
// func GetSign(req ChatRequest, secretKey string) string {
|
||||
// params := make([]string, 0)
|
||||
// params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
|
||||
// params = append(params, "secret_id="+req.SecretId)
|
||||
// params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
|
||||
// params = append(params, "query_id="+req.QueryID)
|
||||
// params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
|
||||
// params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
|
||||
// params = append(params, "stream="+strconv.Itoa(req.Stream))
|
||||
// params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
|
||||
|
||||
// var messageStr string
|
||||
// for _, msg := range req.Messages {
|
||||
// messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
|
||||
// }
|
||||
// messageStr = strings.TrimSuffix(messageStr, ",")
|
||||
// params = append(params, "messages=["+messageStr+"]")
|
||||
|
||||
// sort.Strings(params)
|
||||
// url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
|
||||
// mac := hmac.New(sha1.New, []byte(secretKey))
|
||||
// signURL := url
|
||||
// mac.Write([]byte(signURL))
|
||||
// sign := mac.Sum([]byte(nil))
|
||||
// return base64.StdEncoding.EncodeToString(sign)
|
||||
// }
|
||||
145
relay/adaptor/zhipu/adaptor.go
Normal file
145
relay/adaptor/zhipu/adaptor.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package zhipu
|
||||
|
||||
// import (
|
||||
// "github.com/Laisky/errors/v2"
|
||||
// "fmt"
|
||||
// "github.com/gin-gonic/gin"
|
||||
// "github.com/songquanpeng/one-api/relay/adaptor"
|
||||
// "github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
// "github.com/songquanpeng/one-api/relay/meta"
|
||||
// "github.com/songquanpeng/one-api/relay/model"
|
||||
// "github.com/songquanpeng/one-api/relay/relaymode"
|
||||
// "io"
|
||||
// "math"
|
||||
// "net/http"
|
||||
// "strings"
|
||||
// )
|
||||
|
||||
// type Adaptor struct {
|
||||
// APIVersion string
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) SetVersionByModeName(modelName string) {
|
||||
// if strings.HasPrefix(modelName, "glm-") {
|
||||
// a.APIVersion = "v4"
|
||||
// } else {
|
||||
// a.APIVersion = "v3"
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
// switch meta.Mode {
|
||||
// case relaymode.ImagesGenerations:
|
||||
// return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil
|
||||
// case relaymode.Embeddings:
|
||||
// return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
|
||||
// }
|
||||
// a.SetVersionByModeName(meta.ActualModelName)
|
||||
// if a.APIVersion == "v4" {
|
||||
// return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil
|
||||
// }
|
||||
// method := "invoke"
|
||||
// if meta.IsStream {
|
||||
// method = "sse-invoke"
|
||||
// }
|
||||
// return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
// adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
// token := GetToken(meta.APIKey)
|
||||
// req.Header.Set("Authorization", token)
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
// if request == nil {
|
||||
// return nil, errors.New("request is nil")
|
||||
// }
|
||||
// switch relayMode {
|
||||
// case relaymode.Embeddings:
|
||||
// baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
// return baiduEmbeddingRequest, nil
|
||||
// default:
|
||||
// // TopP (0.0, 1.0)
|
||||
// request.TopP = math.Min(0.99, request.TopP)
|
||||
// request.TopP = math.Max(0.01, request.TopP)
|
||||
|
||||
// // Temperature (0.0, 1.0)
|
||||
// request.Temperature = math.Min(0.99, request.Temperature)
|
||||
// request.Temperature = math.Max(0.01, request.Temperature)
|
||||
// a.SetVersionByModeName(request.Model)
|
||||
// if a.APIVersion == "v4" {
|
||||
// return request, nil
|
||||
// }
|
||||
// return ConvertRequest(*request), nil
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
// if request == nil {
|
||||
// return nil, errors.New("request is nil")
|
||||
// }
|
||||
// newRequest := ImageRequest{
|
||||
// Model: request.Model,
|
||||
// Prompt: request.Prompt,
|
||||
// UserId: request.User,
|
||||
// }
|
||||
// return newRequest, nil
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
// return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
// if meta.IsStream {
|
||||
// err, _, usage = openai.StreamHandler(c, resp, meta.Mode)
|
||||
// } else {
|
||||
// err, usage = openai.Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
// }
|
||||
// return
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
// switch meta.Mode {
|
||||
// case relaymode.Embeddings:
|
||||
// err, usage = EmbeddingsHandler(c, resp)
|
||||
// return
|
||||
// case relaymode.ImagesGenerations:
|
||||
// err, usage = openai.ImageHandler(c, resp)
|
||||
// return
|
||||
// }
|
||||
// if a.APIVersion == "v4" {
|
||||
// return a.DoResponseV4(c, resp, meta)
|
||||
// }
|
||||
// if meta.IsStream {
|
||||
// err, usage = StreamHandler(c, resp)
|
||||
// } else {
|
||||
// if meta.Mode == relaymode.Embeddings {
|
||||
// err, usage = EmbeddingsHandler(c, resp)
|
||||
// } else {
|
||||
// err, usage = Handler(c, resp)
|
||||
// }
|
||||
// }
|
||||
// return
|
||||
// }
|
||||
|
||||
// func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
// return &EmbeddingRequest{
|
||||
// Model: "embedding-2",
|
||||
// Input: request.Input.(string),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) GetModelList() []string {
|
||||
// return ModelList
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) GetChannelName() string {
|
||||
// return "zhipu"
|
||||
// }
|
||||
17
relay/apitype/define.go
Normal file
17
relay/apitype/define.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package apitype
|
||||
|
||||
const (
|
||||
OpenAI = iota
|
||||
Anthropic
|
||||
PaLM
|
||||
Baidu
|
||||
Zhipu
|
||||
Ali
|
||||
Xunfei
|
||||
AIProxyLibrary
|
||||
Tencent
|
||||
Gemini
|
||||
Ollama
|
||||
|
||||
Dummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
42
relay/billing/billing.go
Normal file
42
relay/billing/billing.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
)
|
||||
|
||||
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
|
||||
if preConsumedQuota != 0 {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
||||
// quotaDelta is remaining quota to be consumed
|
||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
logger.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(ctx, userId)
|
||||
if err != nil {
|
||||
logger.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
// totalQuota is total quota consumed
|
||||
if totalQuota != 0 {
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
|
||||
model.UpdateChannelUsedQuota(channelId, totalQuota)
|
||||
}
|
||||
if totalQuota <= 0 {
|
||||
logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
|
||||
}
|
||||
}
|
||||
34
relay/billing/ratio/group.go
Normal file
34
relay/billing/ratio/group.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package ratio
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
)
|
||||
|
||||
var GroupRatio = map[string]float64{
|
||||
"default": 1,
|
||||
"vip": 1,
|
||||
"svip": 1,
|
||||
}
|
||||
|
||||
func GroupRatio2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(GroupRatio)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling model ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateGroupRatioByJSONString(jsonStr string) error {
|
||||
GroupRatio = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &GroupRatio)
|
||||
}
|
||||
|
||||
func GetGroupRatio(name string) float64 {
|
||||
ratio, ok := GroupRatio[name]
|
||||
if !ok {
|
||||
logger.SysError("group ratio not found: " + name)
|
||||
return 1
|
||||
}
|
||||
return ratio
|
||||
}
|
||||
51
relay/billing/ratio/image.go
Normal file
51
relay/billing/ratio/image.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package ratio
|
||||
|
||||
var ImageSizeRatios = map[string]map[string]float64{
|
||||
"dall-e-2": {
|
||||
"256x256": 1,
|
||||
"512x512": 1.125,
|
||||
"1024x1024": 1.25,
|
||||
},
|
||||
"dall-e-3": {
|
||||
"1024x1024": 1,
|
||||
"1024x1792": 2,
|
||||
"1792x1024": 2,
|
||||
},
|
||||
"ali-stable-diffusion-xl": {
|
||||
"512x1024": 1,
|
||||
"1024x768": 1,
|
||||
"1024x1024": 1,
|
||||
"576x1024": 1,
|
||||
"1024x576": 1,
|
||||
},
|
||||
"ali-stable-diffusion-v1.5": {
|
||||
"512x1024": 1,
|
||||
"1024x768": 1,
|
||||
"1024x1024": 1,
|
||||
"576x1024": 1,
|
||||
"1024x576": 1,
|
||||
},
|
||||
"wanx-v1": {
|
||||
"1024x1024": 1,
|
||||
"720x1280": 1,
|
||||
"1280x720": 1,
|
||||
},
|
||||
}
|
||||
|
||||
var ImageGenerationAmounts = map[string][2]int{
|
||||
"dall-e-2": {1, 10},
|
||||
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
|
||||
"ali-stable-diffusion-xl": {1, 4}, // Ali
|
||||
"ali-stable-diffusion-v1.5": {1, 4}, // Ali
|
||||
"wanx-v1": {1, 4}, // Ali
|
||||
"cogview-3": {1, 1},
|
||||
}
|
||||
|
||||
var ImagePromptLengthLimitations = map[string]int{
|
||||
"dall-e-2": 1000,
|
||||
"dall-e-3": 4000,
|
||||
"ali-stable-diffusion-xl": 4000,
|
||||
"ali-stable-diffusion-v1.5": 4000,
|
||||
"wanx-v1": 4000,
|
||||
"cogview-3": 833,
|
||||
}
|
||||
281
relay/billing/ratio/model.go
Normal file
281
relay/billing/ratio/model.go
Normal file
@@ -0,0 +1,281 @@
|
||||
package ratio
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
USD2RMB = 7
|
||||
USD = 500 // $0.002 = 1 -> $1 = 500
|
||||
RMB = USD / USD2RMB
|
||||
)
|
||||
|
||||
// ModelRatio
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
|
||||
// https://openai.com/pricing
|
||||
// 1 === $0.002 / 1K tokens
|
||||
// 1 === ¥0.014 / 1k tokens
|
||||
var ModelRatio = map[string]float64{
|
||||
// https://openai.com/pricing
|
||||
"gpt-4": 15,
|
||||
"gpt-4-0314": 15,
|
||||
"gpt-4-0613": 15,
|
||||
"gpt-4-32k": 30,
|
||||
"gpt-4-32k-0314": 30,
|
||||
"gpt-4-32k-0613": 30,
|
||||
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens
|
||||
"gpt-3.5-turbo-0301": 0.75,
|
||||
"gpt-3.5-turbo-0613": 0.75,
|
||||
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
|
||||
"gpt-3.5-turbo-16k-0613": 1.5,
|
||||
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
|
||||
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
|
||||
"gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens
|
||||
"davinci-002": 1, // $0.002 / 1K tokens
|
||||
"babbage-002": 0.2, // $0.0004 / 1K tokens
|
||||
"text-ada-001": 0.2,
|
||||
"text-babbage-001": 0.25,
|
||||
"text-curie-001": 1,
|
||||
"text-davinci-002": 10,
|
||||
"text-davinci-003": 10,
|
||||
"text-davinci-edit-001": 10,
|
||||
"code-davinci-edit-001": 10,
|
||||
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
|
||||
"tts-1": 7.5, // $0.015 / 1K characters
|
||||
"tts-1-1106": 7.5,
|
||||
"tts-1-hd": 15, // $0.030 / 1K characters
|
||||
"tts-1-hd-1106": 15,
|
||||
"davinci": 10,
|
||||
"curie": 10,
|
||||
"babbage": 10,
|
||||
"ada": 10,
|
||||
"text-embedding-ada-002": 0.05,
|
||||
"text-embedding-3-small": 0.01,
|
||||
"text-embedding-3-large": 0.065,
|
||||
"text-search-ada-doc-001": 10,
|
||||
"text-moderation-stable": 0.1,
|
||||
"text-moderation-latest": 0.1,
|
||||
"dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image
|
||||
"dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image
|
||||
// https://www.anthropic.com/api#pricing
|
||||
"claude-instant-1.2": 0.8 / 1000 * USD,
|
||||
"claude-2.0": 8.0 / 1000 * USD,
|
||||
"claude-2.1": 8.0 / 1000 * USD,
|
||||
"claude-3-haiku-20240307": 0.25 / 1000 * USD,
|
||||
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
|
||||
"claude-3-opus-20240229": 15.0 / 1000 * USD,
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
|
||||
"ERNIE-4.0-8K": 0.120 * RMB,
|
||||
"ERNIE-3.5-8K": 0.012 * RMB,
|
||||
"ERNIE-3.5-8K-0205": 0.024 * RMB,
|
||||
"ERNIE-3.5-8K-1222": 0.012 * RMB,
|
||||
"ERNIE-Bot-8K": 0.024 * RMB,
|
||||
"ERNIE-3.5-4K-0205": 0.012 * RMB,
|
||||
"ERNIE-Speed-8K": 0.004 * RMB,
|
||||
"ERNIE-Speed-128K": 0.004 * RMB,
|
||||
"ERNIE-Lite-8K-0922": 0.008 * RMB,
|
||||
"ERNIE-Lite-8K-0308": 0.003 * RMB,
|
||||
"ERNIE-Tiny-8K": 0.001 * RMB,
|
||||
"BLOOMZ-7B": 0.004 * RMB,
|
||||
"Embedding-V1": 0.002 * RMB,
|
||||
"bge-large-zh": 0.002 * RMB,
|
||||
"bge-large-en": 0.002 * RMB,
|
||||
"tao-8k": 0.002 * RMB,
|
||||
// https://ai.google.dev/pricing
|
||||
"PaLM-2": 1,
|
||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-1.0-pro-vision-001": 1,
|
||||
"gemini-1.0-pro-001": 1,
|
||||
"gemini-1.5-pro": 1,
|
||||
// https://open.bigmodel.cn/pricing
|
||||
"glm-4": 0.1 * RMB,
|
||||
"glm-4v": 0.1 * RMB,
|
||||
"glm-3-turbo": 0.005 * RMB,
|
||||
"embedding-2": 0.0005 * RMB,
|
||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
||||
"cogview-3": 0.25 * RMB,
|
||||
// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
|
||||
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens
|
||||
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens
|
||||
"qwen-max": 1.4286, // ¥0.02 / 1k tokens
|
||||
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
|
||||
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
|
||||
"ali-stable-diffusion-xl": 8,
|
||||
"ali-stable-diffusion-v1.5": 8,
|
||||
"wanx-v1": 8,
|
||||
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
|
||||
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
|
||||
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
|
||||
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
|
||||
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
|
||||
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
||||
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
|
||||
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
||||
"ChatStd": 0.01 * RMB,
|
||||
"ChatPro": 0.1 * RMB,
|
||||
// https://platform.moonshot.cn/pricing
|
||||
"moonshot-v1-8k": 0.012 * RMB,
|
||||
"moonshot-v1-32k": 0.024 * RMB,
|
||||
"moonshot-v1-128k": 0.06 * RMB,
|
||||
// https://platform.baichuan-ai.com/price
|
||||
"Baichuan2-Turbo": 0.008 * RMB,
|
||||
"Baichuan2-Turbo-192k": 0.016 * RMB,
|
||||
"Baichuan2-53B": 0.02 * RMB,
|
||||
// https://api.minimax.chat/document/price
|
||||
"abab6-chat": 0.1 * RMB,
|
||||
"abab5.5-chat": 0.015 * RMB,
|
||||
"abab5.5s-chat": 0.005 * RMB,
|
||||
// https://docs.mistral.ai/platform/pricing/
|
||||
"open-mistral-7b": 0.25 / 1000 * USD,
|
||||
"open-mixtral-8x7b": 0.7 / 1000 * USD,
|
||||
"mistral-small-latest": 2.0 / 1000 * USD,
|
||||
"mistral-medium-latest": 2.7 / 1000 * USD,
|
||||
"mistral-large-latest": 8.0 / 1000 * USD,
|
||||
"mistral-embed": 0.1 / 1000 * USD,
|
||||
// https://wow.groq.com/
|
||||
"llama2-70b-4096": 0.7 / 1000 * USD,
|
||||
"llama2-7b-2048": 0.1 / 1000 * USD,
|
||||
"mixtral-8x7b-32768": 0.27 / 1000 * USD,
|
||||
"gemma-7b-it": 0.1 / 1000 * USD,
|
||||
// https://platform.lingyiwanwu.com/docs#-计费单元
|
||||
"yi-34b-chat-0205": 2.5 / 1000 * RMB,
|
||||
"yi-34b-chat-200k": 12.0 / 1000 * RMB,
|
||||
"yi-vl-plus": 6.0 / 1000 * RMB,
|
||||
}
|
||||
|
||||
var CompletionRatio = map[string]float64{}
|
||||
|
||||
var DefaultModelRatio map[string]float64
|
||||
var DefaultCompletionRatio map[string]float64
|
||||
|
||||
func init() {
|
||||
DefaultModelRatio = make(map[string]float64)
|
||||
for k, v := range ModelRatio {
|
||||
DefaultModelRatio[k] = v
|
||||
}
|
||||
DefaultCompletionRatio = make(map[string]float64)
|
||||
for k, v := range CompletionRatio {
|
||||
DefaultCompletionRatio[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
func AddNewMissingRatio(oldRatio string) string {
|
||||
newRatio := make(map[string]float64)
|
||||
err := json.Unmarshal([]byte(oldRatio), &newRatio)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling old ratio: " + err.Error())
|
||||
return oldRatio
|
||||
}
|
||||
for k, v := range DefaultModelRatio {
|
||||
if _, ok := newRatio[k]; !ok {
|
||||
newRatio[k] = v
|
||||
}
|
||||
}
|
||||
jsonBytes, err := json.Marshal(newRatio)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling new ratio: " + err.Error())
|
||||
return oldRatio
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func ModelRatio2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(ModelRatio)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling model ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateModelRatioByJSONString(jsonStr string) error {
|
||||
ModelRatio = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &ModelRatio)
|
||||
}
|
||||
|
||||
func GetModelRatio(name string) float64 {
|
||||
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
|
||||
name = strings.TrimSuffix(name, "-internet")
|
||||
}
|
||||
ratio, ok := ModelRatio[name]
|
||||
if !ok {
|
||||
ratio, ok = DefaultModelRatio[name]
|
||||
}
|
||||
if !ok {
|
||||
logger.SysError("model ratio not found: " + name)
|
||||
return 30
|
||||
}
|
||||
return ratio
|
||||
}
|
||||
|
||||
func CompletionRatio2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(CompletionRatio)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling completion ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateCompletionRatioByJSONString(jsonStr string) error {
|
||||
CompletionRatio = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
|
||||
}
|
||||
|
||||
// GetCompletionRatio returns the completion ratio of a model
|
||||
//
|
||||
// completion ratio is the ratio comparing to the ratio of prompt
|
||||
func GetCompletionRatio(name string) float64 {
|
||||
if ratio, ok := CompletionRatio[name]; ok {
|
||||
return ratio
|
||||
}
|
||||
if ratio, ok := DefaultCompletionRatio[name]; ok {
|
||||
return ratio
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-3.5") {
|
||||
if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") {
|
||||
// https://openai.com/blog/new-embedding-models-and-api-updates
|
||||
// Updated GPT-3.5 Turbo model and lower pricing
|
||||
return 3
|
||||
}
|
||||
if strings.HasSuffix(name, "1106") {
|
||||
return 2
|
||||
}
|
||||
return 4.0 / 3.0
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4") {
|
||||
if strings.HasSuffix(name, "preview") {
|
||||
return 3
|
||||
}
|
||||
return 2
|
||||
}
|
||||
if strings.HasPrefix(name, "claude-3") {
|
||||
return 5
|
||||
}
|
||||
if strings.HasPrefix(name, "claude-") {
|
||||
return 3
|
||||
}
|
||||
if strings.HasPrefix(name, "mistral-") {
|
||||
return 3
|
||||
}
|
||||
if strings.HasPrefix(name, "gemini-") {
|
||||
return 3
|
||||
}
|
||||
switch name {
|
||||
case "llama2-70b-4096":
|
||||
return 0.8 / 0.7
|
||||
}
|
||||
return 1
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
package ali
|
||||
|
||||
// import (
|
||||
// "github.com/Laisky/errors/v2"
|
||||
// "fmt"
|
||||
// "github.com/gin-gonic/gin"
|
||||
// "github.com/songquanpeng/one-api/common"
|
||||
// "github.com/songquanpeng/one-api/relay/channel"
|
||||
// "github.com/songquanpeng/one-api/relay/constant"
|
||||
// "github.com/songquanpeng/one-api/relay/model"
|
||||
// "github.com/songquanpeng/one-api/relay/util"
|
||||
// "io"
|
||||
// "net/http"
|
||||
// )
|
||||
|
||||
// // https://help.aliyun.com/zh/dashscope/developer-reference/api-details
|
||||
|
||||
// type Adaptor struct {
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
// fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
|
||||
// if meta.Mode == constant.RelayModeEmbeddings {
|
||||
// fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
|
||||
// }
|
||||
// return fullRequestURL, nil
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
// channel.SetupCommonRequestHeader(c, req, meta)
|
||||
// req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
// if meta.IsStream {
|
||||
// req.Header.Set("X-DashScope-SSE", "enable")
|
||||
// }
|
||||
// if c.GetString(common.ConfigKeyPlugin) != "" {
|
||||
// req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin))
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
// if request == nil {
|
||||
// return nil, errors.New("request is nil")
|
||||
// }
|
||||
// switch relayMode {
|
||||
// case constant.RelayModeEmbeddings:
|
||||
// baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
// return baiduEmbeddingRequest, nil
|
||||
// default:
|
||||
// baiduRequest := ConvertRequest(*request)
|
||||
// return baiduRequest, nil
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
// return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
// if meta.IsStream {
|
||||
// err, usage = StreamHandler(c, resp)
|
||||
// } else {
|
||||
// switch meta.Mode {
|
||||
// case constant.RelayModeEmbeddings:
|
||||
// err, usage = EmbeddingHandler(c, resp)
|
||||
// default:
|
||||
// err, usage = Handler(c, resp)
|
||||
// }
|
||||
// }
|
||||
// return
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) GetModelList() []string {
|
||||
// return ModelList
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) GetChannelName() string {
|
||||
// return "ali"
|
||||
// }
|
||||
@@ -1,71 +0,0 @@
|
||||
package ali
|
||||
|
||||
// type Message struct {
|
||||
// Content string `json:"content"`
|
||||
// Role string `json:"role"`
|
||||
// }
|
||||
|
||||
// type Input struct {
|
||||
// //Prompt string `json:"prompt"`
|
||||
// Messages []Message `json:"messages"`
|
||||
// }
|
||||
|
||||
// type Parameters struct {
|
||||
// TopP float64 `json:"top_p,omitempty"`
|
||||
// TopK int `json:"top_k,omitempty"`
|
||||
// Seed uint64 `json:"seed,omitempty"`
|
||||
// EnableSearch bool `json:"enable_search,omitempty"`
|
||||
// IncrementalOutput bool `json:"incremental_output,omitempty"`
|
||||
// }
|
||||
|
||||
// type ChatRequest struct {
|
||||
// Model string `json:"model"`
|
||||
// Input Input `json:"input"`
|
||||
// Parameters Parameters `json:"parameters,omitempty"`
|
||||
// }
|
||||
|
||||
// type EmbeddingRequest struct {
|
||||
// Model string `json:"model"`
|
||||
// Input struct {
|
||||
// Texts []string `json:"texts"`
|
||||
// } `json:"input"`
|
||||
// Parameters *struct {
|
||||
// TextType string `json:"text_type,omitempty"`
|
||||
// } `json:"parameters,omitempty"`
|
||||
// }
|
||||
|
||||
// type Embedding struct {
|
||||
// Embedding []float64 `json:"embedding"`
|
||||
// TextIndex int `json:"text_index"`
|
||||
// }
|
||||
|
||||
// type EmbeddingResponse struct {
|
||||
// Output struct {
|
||||
// Embeddings []Embedding `json:"embeddings"`
|
||||
// } `json:"output"`
|
||||
// Usage Usage `json:"usage"`
|
||||
// Error
|
||||
// }
|
||||
|
||||
// type Error struct {
|
||||
// Code string `json:"code"`
|
||||
// Message string `json:"message"`
|
||||
// RequestId string `json:"request_id"`
|
||||
// }
|
||||
|
||||
// type Usage struct {
|
||||
// InputTokens int `json:"input_tokens"`
|
||||
// OutputTokens int `json:"output_tokens"`
|
||||
// TotalTokens int `json:"total_tokens"`
|
||||
// }
|
||||
|
||||
// type Output struct {
|
||||
// Text string `json:"text"`
|
||||
// FinishReason string `json:"finish_reason"`
|
||||
// }
|
||||
|
||||
// type ChatResponse struct {
|
||||
// Output Output `json:"output"`
|
||||
// Usage Usage `json:"usage"`
|
||||
// Error
|
||||
// }
|
||||
@@ -1,13 +0,0 @@
|
||||
package baidu
|
||||
|
||||
var ModelList = []string{
|
||||
"ERNIE-Bot-4",
|
||||
"ERNIE-Bot-8K",
|
||||
"ERNIE-Bot",
|
||||
"ERNIE-Speed",
|
||||
"ERNIE-Bot-turbo",
|
||||
"Embedding-V1",
|
||||
"bge-large-zh",
|
||||
"bge-large-en",
|
||||
"tao-8k",
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor interface {
|
||||
Init(meta *util.RelayMeta)
|
||||
GetRequestURL(meta *util.RelayMeta) (string, error)
|
||||
SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error
|
||||
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
|
||||
DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode)
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
package minimax
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
)
|
||||
|
||||
func GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
if meta.Mode == constant.RelayModeChatCompletions {
|
||||
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil
|
||||
}
|
||||
return "", errors.Errorf("unsupported relay mode %d for minimax", meta.Mode)
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/relay/channel/ai360"
|
||||
"github.com/songquanpeng/one-api/relay/channel/baichuan"
|
||||
"github.com/songquanpeng/one-api/relay/channel/groq"
|
||||
"github.com/songquanpeng/one-api/relay/channel/lingyiwanwu"
|
||||
"github.com/songquanpeng/one-api/relay/channel/minimax"
|
||||
"github.com/songquanpeng/one-api/relay/channel/mistral"
|
||||
"github.com/songquanpeng/one-api/relay/channel/moonshot"
|
||||
)
|
||||
|
||||
var CompatibleChannels = []int{
|
||||
common.ChannelTypeAzure,
|
||||
common.ChannelType360,
|
||||
common.ChannelTypeMoonshot,
|
||||
common.ChannelTypeBaichuan,
|
||||
common.ChannelTypeMinimax,
|
||||
common.ChannelTypeMistral,
|
||||
common.ChannelTypeGroq,
|
||||
common.ChannelTypeLingYiWanWu,
|
||||
}
|
||||
|
||||
func GetCompatibleChannelMeta(channelType int) (string, []string) {
|
||||
switch channelType {
|
||||
case common.ChannelTypeAzure:
|
||||
return "azure", ModelList
|
||||
case common.ChannelType360:
|
||||
return "360", ai360.ModelList
|
||||
case common.ChannelTypeMoonshot:
|
||||
return "moonshot", moonshot.ModelList
|
||||
case common.ChannelTypeBaichuan:
|
||||
return "baichuan", baichuan.ModelList
|
||||
case common.ChannelTypeMinimax:
|
||||
return "minimax", minimax.ModelList
|
||||
case common.ChannelTypeMistral:
|
||||
return "mistralai", mistral.ModelList
|
||||
case common.ChannelTypeGroq:
|
||||
return "groq", groq.ModelList
|
||||
case common.ChannelTypeLingYiWanWu:
|
||||
return "lingyiwanwu", lingyiwanwu.ModelList
|
||||
default:
|
||||
return "openai", ModelList
|
||||
}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package openai
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/model"
|
||||
|
||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
|
||||
usage := &model.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
usage.CompletionTokens = CountTokenText(responseText, modeName)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return usage
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
package zhipu
|
||||
|
||||
// import (
|
||||
// "github.com/Laisky/errors/v2"
|
||||
// "fmt"
|
||||
// "github.com/gin-gonic/gin"
|
||||
// "github.com/songquanpeng/one-api/relay/channel"
|
||||
// "github.com/songquanpeng/one-api/relay/model"
|
||||
// "github.com/songquanpeng/one-api/relay/util"
|
||||
// "io"
|
||||
// "net/http"
|
||||
// )
|
||||
|
||||
// type Adaptor struct {
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
// method := "invoke"
|
||||
// if meta.IsStream {
|
||||
// method = "sse-invoke"
|
||||
// }
|
||||
// return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
// channel.SetupCommonRequestHeader(c, req, meta)
|
||||
// token := GetToken(meta.APIKey)
|
||||
// req.Header.Set("Authorization", token)
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
// if request == nil {
|
||||
// return nil, errors.New("request is nil")
|
||||
// }
|
||||
// return ConvertRequest(*request), nil
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
// return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
// if meta.IsStream {
|
||||
// err, usage = StreamHandler(c, resp)
|
||||
// } else {
|
||||
// err, usage = Handler(c, resp)
|
||||
// }
|
||||
// return
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) GetModelList() []string {
|
||||
// return ModelList
|
||||
// }
|
||||
|
||||
// func (a *Adaptor) GetChannelName() string {
|
||||
// return "zhipu"
|
||||
// }
|
||||
39
relay/channeltype/define.go
Normal file
39
relay/channeltype/define.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package channeltype
|
||||
|
||||
const (
|
||||
Unknown = iota
|
||||
OpenAI
|
||||
API2D
|
||||
Azure
|
||||
CloseAI
|
||||
OpenAISB
|
||||
OpenAIMax
|
||||
OhMyGPT
|
||||
Custom
|
||||
Ails
|
||||
AIProxy
|
||||
PaLM
|
||||
API2GPT
|
||||
AIGC2D
|
||||
Anthropic
|
||||
Baidu
|
||||
Zhipu
|
||||
Ali
|
||||
Xunfei
|
||||
AI360
|
||||
OpenRouter
|
||||
AIProxyLibrary
|
||||
FastGPT
|
||||
Tencent
|
||||
Gemini
|
||||
Moonshot
|
||||
Baichuan
|
||||
Minimax
|
||||
Mistral
|
||||
Groq
|
||||
Ollama
|
||||
LingYiWanWu
|
||||
StepFun
|
||||
|
||||
Dummy
|
||||
)
|
||||
30
relay/channeltype/helper.go
Normal file
30
relay/channeltype/helper.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package channeltype
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/apitype"
|
||||
|
||||
func ToAPIType(channelType int) int {
|
||||
apiType := apitype.OpenAI
|
||||
switch channelType {
|
||||
case Anthropic:
|
||||
apiType = apitype.Anthropic
|
||||
case Baidu:
|
||||
apiType = apitype.Baidu
|
||||
case PaLM:
|
||||
apiType = apitype.PaLM
|
||||
case Zhipu:
|
||||
apiType = apitype.Zhipu
|
||||
case Ali:
|
||||
apiType = apitype.Ali
|
||||
case Xunfei:
|
||||
apiType = apitype.Xunfei
|
||||
case AIProxyLibrary:
|
||||
apiType = apitype.AIProxyLibrary
|
||||
case Tencent:
|
||||
apiType = apitype.Tencent
|
||||
case Gemini:
|
||||
apiType = apitype.Gemini
|
||||
case Ollama:
|
||||
apiType = apitype.Ollama
|
||||
}
|
||||
return apiType
|
||||
}
|
||||
43
relay/channeltype/url.go
Normal file
43
relay/channeltype/url.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package channeltype
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"https://api.closeai-proxy.xyz", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"https://generativelanguage.googleapis.com", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://ai.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.cloud.tencent.com", // 23
|
||||
"https://generativelanguage.googleapis.com", // 24
|
||||
"https://api.moonshot.cn", // 25
|
||||
"https://api.baichuan-ai.com", // 26
|
||||
"https://api.minimax.chat", // 27
|
||||
"https://api.mistral.ai", // 28
|
||||
"https://api.groq.com/openai", // 29
|
||||
"http://localhost:11434", // 30
|
||||
"https://api.lingyiwanwu.com", // 31
|
||||
"https://api.stepfun.com", // 32
|
||||
}
|
||||
|
||||
func init() {
|
||||
if len(ChannelBaseURLs) != Dummy {
|
||||
panic("channel base urls length not match")
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package util
|
||||
package client
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@@ -1,48 +0,0 @@
|
||||
package constant
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
)
|
||||
|
||||
const (
|
||||
APITypeOpenAI = iota
|
||||
APITypeAnthropic
|
||||
APITypePaLM
|
||||
APITypeBaidu
|
||||
APITypeZhipu
|
||||
APITypeAli
|
||||
APITypeXunfei
|
||||
APITypeAIProxyLibrary
|
||||
APITypeTencent
|
||||
APITypeGemini
|
||||
APITypeOllama
|
||||
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
|
||||
func ChannelType2APIType(channelType int) int {
|
||||
apiType := APITypeOpenAI
|
||||
switch channelType {
|
||||
case common.ChannelTypeAnthropic:
|
||||
apiType = APITypeAnthropic
|
||||
case common.ChannelTypeBaidu:
|
||||
apiType = APITypeBaidu
|
||||
case common.ChannelTypePaLM:
|
||||
apiType = APITypePaLM
|
||||
case common.ChannelTypeZhipu:
|
||||
apiType = APITypeZhipu
|
||||
case common.ChannelTypeAli:
|
||||
apiType = APITypeAli
|
||||
case common.ChannelTypeXunfei:
|
||||
apiType = APITypeXunfei
|
||||
case common.ChannelTypeAIProxyLibrary:
|
||||
apiType = APITypeAIProxyLibrary
|
||||
case common.ChannelTypeTencent:
|
||||
apiType = APITypeTencent
|
||||
case common.ChannelTypeGemini:
|
||||
apiType = APITypeGemini
|
||||
case common.ChannelTypeOllama:
|
||||
apiType = APITypeOllama
|
||||
}
|
||||
return apiType
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
package constant
|
||||
|
||||
var DalleSizeRatios = map[string]map[string]float64{
|
||||
"dall-e-2": {
|
||||
"256x256": 1,
|
||||
"512x512": 1.125,
|
||||
"1024x1024": 1.25,
|
||||
},
|
||||
"dall-e-3": {
|
||||
"1024x1024": 1,
|
||||
"1024x1792": 2,
|
||||
"1792x1024": 2,
|
||||
},
|
||||
}
|
||||
|
||||
var DalleGenerationImageAmounts = map[string][2]int{
|
||||
"dall-e-2": {1, 10},
|
||||
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
|
||||
}
|
||||
|
||||
var DalleImagePromptLengthLimitations = map[string]int{
|
||||
"dall-e-2": 1000,
|
||||
"dall-e-3": 4000,
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
package constant
|
||||
|
||||
import "strings"
|
||||
|
||||
const (
|
||||
RelayModeUnknown = iota
|
||||
RelayModeChatCompletions
|
||||
RelayModeCompletions
|
||||
RelayModeEmbeddings
|
||||
RelayModeModerations
|
||||
RelayModeImagesGenerations
|
||||
RelayModeEdits
|
||||
RelayModeAudioSpeech
|
||||
RelayModeAudioTranscription
|
||||
RelayModeAudioTranslation
|
||||
)
|
||||
|
||||
func Path2RelayMode(path string) int {
|
||||
relayMode := RelayModeUnknown
|
||||
if strings.HasPrefix(path, "/v1/chat/completions") {
|
||||
relayMode = RelayModeChatCompletions
|
||||
} else if strings.HasPrefix(path, "/v1/completions") {
|
||||
relayMode = RelayModeCompletions
|
||||
} else if strings.HasPrefix(path, "/v1/embeddings") {
|
||||
relayMode = RelayModeEmbeddings
|
||||
} else if strings.HasSuffix(path, "embeddings") {
|
||||
relayMode = RelayModeEmbeddings
|
||||
} else if strings.HasPrefix(path, "/v1/moderations") {
|
||||
relayMode = RelayModeModerations
|
||||
} else if strings.HasPrefix(path, "/v1/images/generations") {
|
||||
relayMode = RelayModeImagesGenerations
|
||||
} else if strings.HasPrefix(path, "/v1/edits") {
|
||||
relayMode = RelayModeEdits
|
||||
} else if strings.HasPrefix(path, "/v1/audio/speech") {
|
||||
relayMode = RelayModeAudioSpeech
|
||||
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
|
||||
relayMode = RelayModeAudioTranscription
|
||||
} else if strings.HasPrefix(path, "/v1/audio/translations") {
|
||||
relayMode = RelayModeAudioTranslation
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
@@ -6,20 +6,23 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/azure"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/client"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
||||
@@ -34,7 +37,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
tokenName := c.GetString("token_name")
|
||||
|
||||
var ttsRequest openai.TextToSpeechRequest
|
||||
if relayMode == constant.RelayModeAudioSpeech {
|
||||
if relayMode == relaymode.AudioSpeech {
|
||||
// Read JSON
|
||||
err := common.UnmarshalBodyReusable(c, &ttsRequest)
|
||||
// Check if JSON is valid
|
||||
@@ -48,14 +51,15 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
}
|
||||
}
|
||||
|
||||
modelRatio := common.GetModelRatio(audioModel)
|
||||
// groupRatio := common.GetGroupRatio(group)
|
||||
groupRatio := c.GetFloat64("channel_ratio")
|
||||
modelRatio := billingratio.GetModelRatio(audioModel)
|
||||
// groupRatio := billingratio.GetGroupRatio(group)
|
||||
groupRatio := c.GetFloat64("channel_ratio") // get minimal ratio from multiple groups
|
||||
|
||||
ratio := modelRatio * groupRatio
|
||||
var quota int64
|
||||
var preConsumedQuota int64
|
||||
switch relayMode {
|
||||
case constant.RelayModeAudioSpeech:
|
||||
case relaymode.AudioSpeech:
|
||||
preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio)
|
||||
quota = preConsumedQuota
|
||||
default:
|
||||
@@ -117,19 +121,19 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
baseURL := channeltype.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
|
||||
fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
|
||||
if channelType == common.ChannelTypeAzure {
|
||||
apiVersion := util.GetAzureAPIVersion(c)
|
||||
if relayMode == constant.RelayModeAudioTranscription {
|
||||
fullRequestURL := openai.GetFullRequestURL(baseURL, requestURL, channelType)
|
||||
if channelType == channeltype.Azure {
|
||||
apiVersion := azure.GetAPIVersion(c)
|
||||
if relayMode == relaymode.AudioTranscription {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
|
||||
} else if relayMode == constant.RelayModeAudioSpeech {
|
||||
} else if relayMode == relaymode.AudioSpeech {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api
|
||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, audioModel, apiVersion)
|
||||
}
|
||||
@@ -148,7 +152,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if (relayMode == constant.RelayModeAudioTranscription || relayMode == constant.RelayModeAudioSpeech) && channelType == common.ChannelTypeAzure {
|
||||
if (relayMode == relaymode.AudioTranscription || relayMode == relaymode.AudioSpeech) && channelType == channeltype.Azure {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
@@ -160,7 +164,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
|
||||
resp, err := util.HTTPClient.Do(req)
|
||||
resp, err := client.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
@@ -174,7 +178,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if relayMode != constant.RelayModeAudioSpeech {
|
||||
if relayMode != relaymode.AudioSpeech {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
@@ -213,12 +217,12 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return util.RelayErrorHandler(resp)
|
||||
return RelayErrorHandler(resp)
|
||||
}
|
||||
succeed = true
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
defer func(ctx context.Context) {
|
||||
go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
||||
go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
||||
}(c.Request.Context())
|
||||
|
||||
for k, v := range resp.Header {
|
||||
|
||||
91
relay/controller/error.go
Normal file
91
relay/controller/error.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type GeneralErrorResponse struct {
|
||||
Error model.Error `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
Header struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"header"`
|
||||
Response struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
} `json:"response"`
|
||||
}
|
||||
|
||||
func (e GeneralErrorResponse) ToMessage() string {
|
||||
if e.Error.Message != "" {
|
||||
return e.Error.Message
|
||||
}
|
||||
if e.Message != "" {
|
||||
return e.Message
|
||||
}
|
||||
if e.Msg != "" {
|
||||
return e.Msg
|
||||
}
|
||||
if e.Err != "" {
|
||||
return e.Err
|
||||
}
|
||||
if e.ErrorMsg != "" {
|
||||
return e.ErrorMsg
|
||||
}
|
||||
if e.Header.Message != "" {
|
||||
return e.Header.Message
|
||||
}
|
||||
if e.Response.Error.Message != "" {
|
||||
return e.Response.Error.Message
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *model.ErrorWithStatusCode) {
|
||||
ErrorWithStatusCode = &model.ErrorWithStatusCode{
|
||||
StatusCode: resp.StatusCode,
|
||||
Error: model.Error{
|
||||
Message: "",
|
||||
Type: "upstream_error",
|
||||
Code: "bad_response_status_code",
|
||||
Param: strconv.Itoa(resp.StatusCode),
|
||||
},
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if config.DebugEnabled {
|
||||
logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody)))
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var errResponse GeneralErrorResponse
|
||||
err = json.Unmarshal(responseBody, &errResponse)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if errResponse.Error.Message != "" {
|
||||
// OpenAI format error, so we override the default one
|
||||
ErrorWithStatusCode.Error = errResponse.Error
|
||||
} else {
|
||||
ErrorWithStatusCode.Error.Message = errResponse.ToMessage()
|
||||
}
|
||||
if ErrorWithStatusCode.Error.Message == "" {
|
||||
ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -9,10 +9,13 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/controller/validator"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"math"
|
||||
"net/http"
|
||||
)
|
||||
@@ -23,21 +26,21 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if relayMode == constant.RelayModeModerations && textRequest.Model == "" {
|
||||
if relayMode == relaymode.Moderations && textRequest.Model == "" {
|
||||
textRequest.Model = "text-moderation-latest"
|
||||
}
|
||||
if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" {
|
||||
if relayMode == relaymode.Embeddings && textRequest.Model == "" {
|
||||
textRequest.Model = c.Param("model")
|
||||
}
|
||||
err = util.ValidateTextRequest(textRequest, relayMode)
|
||||
err = validator.ValidateTextRequest(textRequest, relayMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return textRequest, nil
|
||||
}
|
||||
|
||||
func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error) {
|
||||
imageRequest := &openai.ImageRequest{}
|
||||
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
|
||||
imageRequest := &relaymodel.ImageRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -54,9 +57,25 @@ func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error
|
||||
return imageRequest, nil
|
||||
}
|
||||
|
||||
func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode {
|
||||
func isValidImageSize(model string, size string) bool {
|
||||
if model == "cogview-3" {
|
||||
return true
|
||||
}
|
||||
_, ok := billingratio.ImageSizeRatios[model][size]
|
||||
return ok
|
||||
}
|
||||
|
||||
func getImageSizeRatio(model string, size string) float64 {
|
||||
ratio, ok := billingratio.ImageSizeRatios[model][size]
|
||||
if !ok {
|
||||
return 1
|
||||
}
|
||||
return ratio
|
||||
}
|
||||
|
||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
||||
// model validation
|
||||
_, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size]
|
||||
hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size)
|
||||
if !hasValidSize {
|
||||
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
|
||||
}
|
||||
@@ -64,27 +83,24 @@ func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMet
|
||||
if imageRequest.Prompt == "" {
|
||||
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
||||
}
|
||||
if len(imageRequest.Prompt) > constant.DalleImagePromptLengthLimitations[imageRequest.Model] {
|
||||
if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] {
|
||||
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
|
||||
}
|
||||
// Number of generated images validation
|
||||
if !isWithinRange(imageRequest.Model, imageRequest.N) {
|
||||
// channel not azure
|
||||
if meta.ChannelType != common.ChannelTypeAzure {
|
||||
if meta.ChannelType != channeltype.Azure {
|
||||
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) {
|
||||
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
|
||||
if imageRequest == nil {
|
||||
return 0, errors.New("imageRequest is nil")
|
||||
}
|
||||
imageCostRatio, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size]
|
||||
if !hasValidSize {
|
||||
return 0, errors.Errorf("size not supported for this image model: %s", imageRequest.Size)
|
||||
}
|
||||
imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
|
||||
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
|
||||
if imageRequest.Size == "1024x1024" {
|
||||
imageCostRatio *= 2
|
||||
@@ -97,11 +113,11 @@ func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) {
|
||||
|
||||
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
|
||||
switch relayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
case relaymode.ChatCompletions:
|
||||
return openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
|
||||
case constant.RelayModeCompletions:
|
||||
case relaymode.Completions:
|
||||
return openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||
case constant.RelayModeModerations:
|
||||
case relaymode.Moderations:
|
||||
return openai.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||
}
|
||||
return 0
|
||||
@@ -115,7 +131,7 @@ func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTok
|
||||
return int64(float64(preConsumedTokens) * ratio)
|
||||
}
|
||||
|
||||
func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int64, *relaymodel.ErrorWithStatusCode) {
|
||||
func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *meta.Meta) (int64, *relaymodel.ErrorWithStatusCode) {
|
||||
preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio)
|
||||
|
||||
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
||||
@@ -144,13 +160,13 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
|
||||
return preConsumedQuota, nil
|
||||
}
|
||||
|
||||
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) {
|
||||
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) {
|
||||
if usage == nil {
|
||||
logger.Error(ctx, "usage is nil, which is unexpected")
|
||||
return
|
||||
}
|
||||
var quota int64
|
||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||
completionRatio := billingratio.GetCompletionRatio(textRequest.Model)
|
||||
promptTokens := usage.PromptTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
||||
@@ -178,3 +194,14 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R
|
||||
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
|
||||
}
|
||||
|
||||
func getMappedModelName(modelName string, mapping map[string]string) (string, bool) {
|
||||
if mapping == nil {
|
||||
return modelName, false
|
||||
}
|
||||
mappedModelName := mapping[modelName]
|
||||
if mappedModelName != "" {
|
||||
return mappedModelName, true
|
||||
}
|
||||
return modelName, false
|
||||
}
|
||||
|
||||
@@ -5,35 +5,33 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/Laisky/errors/v2"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func isWithinRange(element string, value int) bool {
|
||||
if _, ok := constant.DalleGenerationImageAmounts[element]; !ok {
|
||||
if _, ok := billingratio.ImageGenerationAmounts[element]; !ok {
|
||||
return false
|
||||
}
|
||||
min := constant.DalleGenerationImageAmounts[element][0]
|
||||
max := constant.DalleGenerationImageAmounts[element][1]
|
||||
|
||||
min := billingratio.ImageGenerationAmounts[element][0]
|
||||
max := billingratio.ImageGenerationAmounts[element][1]
|
||||
return value >= min && value <= max
|
||||
}
|
||||
|
||||
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
||||
ctx := c.Request.Context()
|
||||
meta := util.GetRelayMeta(c)
|
||||
meta := meta.GetByContext(c)
|
||||
imageRequest, err := getImageRequest(c, meta.Mode)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "getImageRequest failed: %s", err.Error())
|
||||
@@ -43,7 +41,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
// map model name
|
||||
var isModelMapped bool
|
||||
meta.OriginModelName = imageRequest.Model
|
||||
imageRequest.Model, isModelMapped = util.GetMappedModelName(imageRequest.Model, meta.ModelMapping)
|
||||
imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping)
|
||||
meta.ActualModelName = imageRequest.Model
|
||||
|
||||
// model validation
|
||||
@@ -57,17 +55,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
requestURL := c.Request.URL.String()
|
||||
fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
|
||||
if meta.ChannelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
||||
apiVersion := util.GetAzureAPIVersion(c)
|
||||
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
|
||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, imageRequest.Model, apiVersion)
|
||||
}
|
||||
|
||||
var requestBody io.Reader
|
||||
if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body
|
||||
if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body
|
||||
jsonStr, err := json.Marshal(imageRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
|
||||
@@ -77,9 +66,32 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
|
||||
modelRatio := common.GetModelRatio(imageRequest.Model)
|
||||
// groupRatio := common.GetGroupRatio(meta.Group)
|
||||
adaptor := relay.GetAdaptor(meta.APIType)
|
||||
if adaptor == nil {
|
||||
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
switch meta.ChannelType {
|
||||
case channeltype.Ali:
|
||||
fallthrough
|
||||
case channeltype.Baidu:
|
||||
fallthrough
|
||||
case channeltype.Zhipu:
|
||||
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
jsonStr, err := json.Marshal(finalRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
}
|
||||
|
||||
modelRatio := billingratio.GetModelRatio(imageRequest.Model)
|
||||
// groupRatio := billingratio.GetGroupRatio(meta.Group)
|
||||
groupRatio := c.GetFloat64("channel_ratio") // pre-selected cheapest channel ratio
|
||||
|
||||
ratio := modelRatio * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
||||
|
||||
@@ -89,36 +101,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
token := c.Request.Header.Get("Authorization")
|
||||
if meta.ChannelType == common.ChannelTypeAzure { // Azure authentication
|
||||
token = strings.TrimPrefix(token, "Bearer ")
|
||||
req.Header.Set("api-key", token)
|
||||
} else {
|
||||
req.Header.Set("Authorization", token)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
|
||||
resp, err := util.HTTPClient.Do(req)
|
||||
// do request
|
||||
resp, err := adaptor.DoRequest(c, meta, requestBody)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
|
||||
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
var imageResponse openai.ImageResponse
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return
|
||||
@@ -141,34 +130,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
}
|
||||
}(c.Request.Context())
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &imageResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
// do response
|
||||
_, respErr := adaptor.DoResponse(c, resp, meta)
|
||||
if respErr != nil {
|
||||
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
|
||||
return respErr
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -10,18 +10,20 @@ import (
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/helper"
|
||||
"github.com/songquanpeng/one-api/relay"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/apitype"
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
)
|
||||
|
||||
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
ctx := c.Request.Context()
|
||||
meta := util.GetRelayMeta(c)
|
||||
meta := meta.GetByContext(c)
|
||||
// get & validate textRequest
|
||||
textRequest, err := getAndValidateTextRequest(c, meta.Mode)
|
||||
if err != nil {
|
||||
@@ -33,12 +35,13 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
// map model name
|
||||
var isModelMapped bool
|
||||
meta.OriginModelName = textRequest.Model
|
||||
textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
|
||||
textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping)
|
||||
meta.ActualModelName = textRequest.Model
|
||||
// get model ratio & group ratio
|
||||
modelRatio := common.GetModelRatio(textRequest.Model)
|
||||
// groupRatio := common.GetGroupRatio(meta.Group)
|
||||
modelRatio := billingratio.GetModelRatio(textRequest.Model)
|
||||
// groupRatio := billingratio.GetGroupRatio(meta.Group)
|
||||
groupRatio := meta.ChannelRatio
|
||||
|
||||
ratio := modelRatio * groupRatio
|
||||
// pre-consume quota
|
||||
promptTokens := getPromptTokens(textRequest, meta.Mode)
|
||||
@@ -49,16 +52,16 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
return bizErr
|
||||
}
|
||||
|
||||
adaptor := helper.GetAdaptor(meta.APIType)
|
||||
adaptor := relay.GetAdaptor(meta.APIType)
|
||||
if adaptor == nil {
|
||||
return openai.ErrorWrapper(errors.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// get request body
|
||||
var requestBody io.Reader
|
||||
if meta.APIType == constant.APITypeOpenAI {
|
||||
if meta.APIType == apitype.OpenAI {
|
||||
// no need to convert request for openai
|
||||
shouldResetRequestBody := isModelMapped || meta.ChannelType == common.ChannelTypeBaichuan // frequency_penalty 0 is not acceptable for baichuan
|
||||
shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan
|
||||
if shouldResetRequestBody {
|
||||
jsonStr, err := json.Marshal(textRequest)
|
||||
if err != nil {
|
||||
@@ -93,10 +96,10 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
}
|
||||
errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json")
|
||||
if errorHappened {
|
||||
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||
logger.Error(ctx, fmt.Sprintf("relay text [%d] <- %q %q",
|
||||
resp.StatusCode, resp.Request.URL.String(), string(requestBodyBytes)))
|
||||
return util.RelayErrorHandler(resp)
|
||||
return RelayErrorHandler(resp)
|
||||
}
|
||||
meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
|
||||
@@ -104,7 +107,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||
if respErr != nil {
|
||||
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
|
||||
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||
return respErr
|
||||
}
|
||||
// post-consume quota
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package util
|
||||
package validator
|
||||
|
||||
import (
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"math"
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int) error {
|
||||
@@ -15,20 +16,20 @@ func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int)
|
||||
return errors.New("model is required")
|
||||
}
|
||||
switch relayMode {
|
||||
case constant.RelayModeCompletions:
|
||||
case relaymode.Completions:
|
||||
if textRequest.Prompt == "" {
|
||||
return errors.New("field prompt is required")
|
||||
}
|
||||
case constant.RelayModeChatCompletions:
|
||||
case relaymode.ChatCompletions:
|
||||
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
|
||||
return errors.New("field messages is required")
|
||||
}
|
||||
case constant.RelayModeEmbeddings:
|
||||
case constant.RelayModeModerations:
|
||||
case relaymode.Embeddings:
|
||||
case relaymode.Moderations:
|
||||
if textRequest.Input == "" {
|
||||
return errors.New("field input is required")
|
||||
}
|
||||
case constant.RelayModeEdits:
|
||||
case relaymode.Edits:
|
||||
if textRequest.Instruction == "" {
|
||||
return errors.New("field instruction is required")
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/aiproxy"
|
||||
"github.com/songquanpeng/one-api/relay/channel/anthropic"
|
||||
"github.com/songquanpeng/one-api/relay/channel/gemini"
|
||||
"github.com/songquanpeng/one-api/relay/channel/ollama"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/channel/palm"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
)
|
||||
|
||||
func GetAdaptor(apiType int) channel.Adaptor {
|
||||
switch apiType {
|
||||
case constant.APITypeAIProxyLibrary:
|
||||
return &aiproxy.Adaptor{}
|
||||
// case constant.APITypeAli:
|
||||
// return &ali.Adaptor{}
|
||||
case constant.APITypeAnthropic:
|
||||
return &anthropic.Adaptor{}
|
||||
// case constant.APITypeBaidu:
|
||||
// return &baidu.Adaptor{}
|
||||
case constant.APITypeGemini:
|
||||
return &gemini.Adaptor{}
|
||||
case constant.APITypeOpenAI:
|
||||
return &openai.Adaptor{}
|
||||
case constant.APITypePaLM:
|
||||
return &palm.Adaptor{}
|
||||
// case constant.APITypeTencent:
|
||||
// return &tencent.Adaptor{}
|
||||
// case constant.APITypeXunfei:
|
||||
// return &xunfei.Adaptor{}
|
||||
// case constant.APITypeZhipu:
|
||||
// return &zhipu.Adaptor{}
|
||||
case constant.APITypeOllama:
|
||||
return &ollama.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,13 +1,15 @@
|
||||
package util
|
||||
package meta
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/azure"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type RelayMeta struct {
|
||||
type Meta struct {
|
||||
Mode int
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
@@ -29,9 +31,9 @@ type RelayMeta struct {
|
||||
ChannelRatio float64
|
||||
}
|
||||
|
||||
func GetRelayMeta(c *gin.Context) *RelayMeta {
|
||||
meta := RelayMeta{
|
||||
Mode: constant.Path2RelayMode(c.Request.URL.Path),
|
||||
func GetByContext(c *gin.Context) *Meta {
|
||||
meta := Meta{
|
||||
Mode: relaymode.GetByPath(c.Request.URL.Path),
|
||||
ChannelType: c.GetInt("channel"),
|
||||
ChannelId: c.GetInt("channel_id"),
|
||||
TokenId: c.GetInt("token_id"),
|
||||
@@ -40,18 +42,18 @@ func GetRelayMeta(c *gin.Context) *RelayMeta {
|
||||
Group: c.GetString("group"),
|
||||
ModelMapping: c.GetStringMapString("model_mapping"),
|
||||
BaseURL: c.GetString("base_url"),
|
||||
APIVersion: c.GetString(common.ConfigKeyAPIVersion),
|
||||
APIVersion: c.GetString(config.KeyAPIVersion),
|
||||
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||
Config: nil,
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
ChannelRatio: c.GetFloat64("channel_ratio"),
|
||||
}
|
||||
if meta.ChannelType == common.ChannelTypeAzure {
|
||||
meta.APIVersion = GetAzureAPIVersion(c)
|
||||
if meta.ChannelType == channeltype.Azure {
|
||||
meta.APIVersion = azure.GetAPIVersion(c)
|
||||
}
|
||||
if meta.BaseURL == "" {
|
||||
meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType]
|
||||
meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType]
|
||||
}
|
||||
meta.APIType = constant.ChannelType2APIType(meta.ChannelType)
|
||||
meta.APIType = channeltype.ToAPIType(meta.ChannelType)
|
||||
return &meta
|
||||
}
|
||||
12
relay/model/image.go
Normal file
12
relay/model/image.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package model
|
||||
|
||||
type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
14
relay/relaymode/define.go
Normal file
14
relay/relaymode/define.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package relaymode
|
||||
|
||||
const (
|
||||
Unknown = iota
|
||||
ChatCompletions
|
||||
Completions
|
||||
Embeddings
|
||||
Moderations
|
||||
ImagesGenerations
|
||||
Edits
|
||||
AudioSpeech
|
||||
AudioTranscription
|
||||
AudioTranslation
|
||||
)
|
||||
29
relay/relaymode/helper.go
Normal file
29
relay/relaymode/helper.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package relaymode
|
||||
|
||||
import "strings"
|
||||
|
||||
func GetByPath(path string) int {
|
||||
relayMode := Unknown
|
||||
if strings.HasPrefix(path, "/v1/chat/completions") {
|
||||
relayMode = ChatCompletions
|
||||
} else if strings.HasPrefix(path, "/v1/completions") {
|
||||
relayMode = Completions
|
||||
} else if strings.HasPrefix(path, "/v1/embeddings") {
|
||||
relayMode = Embeddings
|
||||
} else if strings.HasSuffix(path, "embeddings") {
|
||||
relayMode = Embeddings
|
||||
} else if strings.HasPrefix(path, "/v1/moderations") {
|
||||
relayMode = Moderations
|
||||
} else if strings.HasPrefix(path, "/v1/images/generations") {
|
||||
relayMode = ImagesGenerations
|
||||
} else if strings.HasPrefix(path, "/v1/edits") {
|
||||
relayMode = Edits
|
||||
} else if strings.HasPrefix(path, "/v1/audio/speech") {
|
||||
relayMode = AudioSpeech
|
||||
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
|
||||
relayMode = AudioTranscription
|
||||
} else if strings.HasPrefix(path, "/v1/audio/translations") {
|
||||
relayMode = AudioTranslation
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
)
|
||||
|
||||
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
|
||||
if preConsumedQuota != 0 {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(ctx)
|
||||
}
|
||||
}
|
||||
@@ -1,188 +0,0 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool {
|
||||
if !config.AutomaticDisableChannelEnabled {
|
||||
return false
|
||||
}
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if statusCode == http.StatusUnauthorized {
|
||||
return true
|
||||
}
|
||||
switch err.Type {
|
||||
case "insufficient_quota":
|
||||
return true
|
||||
// https://docs.anthropic.com/claude/reference/errors
|
||||
case "authentication_error":
|
||||
return true
|
||||
case "permission_error":
|
||||
return true
|
||||
case "forbidden":
|
||||
return true
|
||||
}
|
||||
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
|
||||
return true
|
||||
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ShouldEnableChannel(err error, openAIErr *relaymodel.Error) bool {
|
||||
if !config.AutomaticEnableChannelEnabled {
|
||||
return false
|
||||
}
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if openAIErr != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type GeneralErrorResponse struct {
|
||||
Error relaymodel.Error `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
Header struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"header"`
|
||||
Response struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
} `json:"response"`
|
||||
}
|
||||
|
||||
func (e GeneralErrorResponse) ToMessage() string {
|
||||
if e.Error.Message != "" {
|
||||
return e.Error.Message
|
||||
}
|
||||
if e.Message != "" {
|
||||
return e.Message
|
||||
}
|
||||
if e.Msg != "" {
|
||||
return e.Msg
|
||||
}
|
||||
if e.Err != "" {
|
||||
return e.Err
|
||||
}
|
||||
if e.ErrorMsg != "" {
|
||||
return e.ErrorMsg
|
||||
}
|
||||
if e.Header.Message != "" {
|
||||
return e.Header.Message
|
||||
}
|
||||
if e.Response.Error.Message != "" {
|
||||
return e.Response.Error.Message
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *relaymodel.ErrorWithStatusCode) {
|
||||
ErrorWithStatusCode = &relaymodel.ErrorWithStatusCode{
|
||||
StatusCode: resp.StatusCode,
|
||||
Error: relaymodel.Error{
|
||||
Message: "",
|
||||
Type: "upstream_error",
|
||||
Code: "bad_response_status_code",
|
||||
Param: strconv.Itoa(resp.StatusCode),
|
||||
},
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if config.DebugEnabled {
|
||||
logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody)))
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var errResponse GeneralErrorResponse
|
||||
err = json.Unmarshal(responseBody, &errResponse)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if errResponse.Error.Message != "" {
|
||||
// OpenAI format error, so we override the default one
|
||||
ErrorWithStatusCode.Error = errResponse.Error
|
||||
} else {
|
||||
ErrorWithStatusCode.Error.Message = errResponse.ToMessage()
|
||||
}
|
||||
if ErrorWithStatusCode.Error.Message == "" {
|
||||
ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
|
||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||
switch channelType {
|
||||
case common.ChannelTypeOpenAI:
|
||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
||||
case common.ChannelTypeAzure:
|
||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
||||
}
|
||||
}
|
||||
return fullRequestURL
|
||||
}
|
||||
|
||||
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
||||
// quotaDelta is remaining quota to be consumed
|
||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
logger.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(ctx, userId)
|
||||
if err != nil {
|
||||
logger.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
// totalQuota is total quota consumed
|
||||
if totalQuota >= 0 {
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
|
||||
model.UpdateChannelUsedQuota(channelId, totalQuota)
|
||||
}
|
||||
|
||||
if totalQuota < 0 {
|
||||
logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
|
||||
}
|
||||
}
|
||||
|
||||
func GetAzureAPIVersion(c *gin.Context) string {
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion == "" {
|
||||
apiVersion = c.GetString(common.ConfigKeyAPIVersion)
|
||||
}
|
||||
return apiVersion
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package util
|
||||
|
||||
func GetMappedModelName(modelName string, mapping map[string]string) (string, bool) {
|
||||
if mapping == nil {
|
||||
return modelName, false
|
||||
}
|
||||
mappedModelName := mapping[modelName]
|
||||
if mappedModelName != "" {
|
||||
return mappedModelName, true
|
||||
}
|
||||
return modelName, false
|
||||
}
|
||||
Reference in New Issue
Block a user