Refactor codebase, introduce relaymode package, update constants and improve consistency

- Refactor constant definitions and organization
- Clean up package level variables and functions
- Introduce new `relaymode` and `apitype` packages for constant definitions
- Refactor and simplify code in several packages including `openai`, `relay/channel/baidu`, `relay/util`, `relay/controller`, `relay/channeltype`
- Add helper functions in `relay/channeltype` package to convert channel type constants to corresponding API type constants
- Remove deprecated functions such as `ResponseText2Usage` from `relay/channel/openai/helper.go`
- Modify code in `relay/util/validation.go` and related files to use new `validator.ValidateTextRequest` function
- Rename `util` package to `relaymode` and update related imports in several packages
This commit is contained in:
Laisky.Cai
2024-04-06 05:18:04 +00:00
parent 7c59a97a6b
commit 50bab08496
157 changed files with 3217 additions and 19160 deletions

41
relay/adaptor.go Normal file
View File

@@ -0,0 +1,41 @@
package relay
import (
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/aiproxy"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
"github.com/songquanpeng/one-api/relay/adaptor/ollama"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/adaptor/palm"
"github.com/songquanpeng/one-api/relay/apitype"
)
func GetAdaptor(apiType int) adaptor.Adaptor {
switch apiType {
case apitype.AIProxyLibrary:
return &aiproxy.Adaptor{}
// case apitype.Ali:
// return &ali.Adaptor{}
case apitype.Anthropic:
return &anthropic.Adaptor{}
// case apitype.Baidu:
// return &baidu.Adaptor{}
case apitype.Gemini:
return &gemini.Adaptor{}
case apitype.OpenAI:
return &openai.Adaptor{}
case apitype.PaLM:
return &palm.Adaptor{}
// case apitype.Tencent:
// return &tencent.Adaptor{}
// case apitype.Xunfei:
// return &xunfei.Adaptor{}
// case apitype.Zhipu:
// return &zhipu.Adaptor{}
case apitype.Ollama:
return &ollama.Adaptor{}
}
return nil
}

View File

@@ -4,10 +4,10 @@ import (
"fmt"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
@@ -15,16 +15,16 @@ import (
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
@@ -34,15 +34,22 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil")
}
aiProxyLibraryRequest := ConvertRequest(*request)
aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID)
aiProxyLibraryRequest.LibraryId = c.GetString(config.KeyLibraryID)
return aiProxyLibraryRequest, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {

View File

@@ -1,6 +1,6 @@
package aiproxy
import "github.com/songquanpeng/one-api/relay/channel/openai"
import "github.com/songquanpeng/one-api/relay/adaptor/openai"
var ModelList = []string{""}

View File

@@ -13,7 +13,8 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
)
@@ -54,7 +55,7 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon
FinishReason: "stop",
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
@@ -67,7 +68,7 @@ func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletion
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
choice.FinishReason = &constant.StopFinishReason
return &openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "",
@@ -79,7 +80,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content
return &openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: response.Model,

View File

@@ -0,0 +1,105 @@
package ali
// import (
// "github.com/Laisky/errors/v2"
// "fmt"
// "github.com/gin-gonic/gin"
// "github.com/songquanpeng/one-api/common/config"
// "github.com/songquanpeng/one-api/relay/adaptor"
// "github.com/songquanpeng/one-api/relay/meta"
// "github.com/songquanpeng/one-api/relay/model"
// "github.com/songquanpeng/one-api/relay/relaymode"
// "io"
// "net/http"
// )
// // https://help.aliyun.com/zh/dashscope/developer-reference/api-details
// type Adaptor struct {
// }
// func (a *Adaptor) Init(meta *meta.Meta) {
// }
// func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
// fullRequestURL := ""
// switch meta.Mode {
// case relaymode.Embeddings:
// fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
// case relaymode.ImagesGenerations:
// fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL)
// default:
// fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
// }
// return fullRequestURL, nil
// }
// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
// adaptor.SetupCommonRequestHeader(c, req, meta)
// if meta.IsStream {
// req.Header.Set("Accept", "text/event-stream")
// req.Header.Set("X-DashScope-SSE", "enable")
// }
// req.Header.Set("Authorization", "Bearer "+meta.APIKey)
// if meta.Mode == relaymode.ImagesGenerations {
// req.Header.Set("X-DashScope-Async", "enable")
// }
// if c.GetString(config.KeyPlugin) != "" {
// req.Header.Set("X-DashScope-Plugin", c.GetString(config.KeyPlugin))
// }
// return nil
// }
// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
// if request == nil {
// return nil, errors.New("request is nil")
// }
// switch relayMode {
// case relaymode.Embeddings:
// aliEmbeddingRequest := ConvertEmbeddingRequest(*request)
// return aliEmbeddingRequest, nil
// default:
// aliRequest := ConvertRequest(*request)
// return aliRequest, nil
// }
// }
// func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
// if request == nil {
// return nil, errors.New("request is nil")
// }
// aliRequest := ConvertImageRequest(*request)
// return aliRequest, nil
// }
// func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
// return adaptor.DoRequestHelper(a, c, meta, requestBody)
// }
// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
// if meta.IsStream {
// err, usage = StreamHandler(c, resp)
// } else {
// switch meta.Mode {
// case relaymode.Embeddings:
// err, usage = EmbeddingHandler(c, resp)
// case relaymode.ImagesGenerations:
// err, usage = ImageHandler(c, resp)
// default:
// err, usage = Handler(c, resp)
// }
// }
// return
// }
// func (a *Adaptor) GetModelList() []string {
// return ModelList
// }
// func (a *Adaptor) GetChannelName() string {
// return "ali"
// }

View File

@@ -3,4 +3,5 @@ package ali
var ModelList = []string{
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
"text-embedding-v1",
"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1",
}

192
relay/adaptor/ali/image.go Normal file
View File

@@ -0,0 +1,192 @@
package ali
import (
"encoding/base64"
"encoding/json"
"fmt"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
"time"
)
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
responseFormat := c.GetString("response_format")
var aliTaskResponse TaskResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliTaskResponse.Message != "" {
logger.SysError("aliAsyncTask err: " + string(responseBody))
return openai.ErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
}
aliResponse, _, err := asyncTaskWait(aliTaskResponse.Output.TaskId, apiKey)
if err != nil {
return openai.ErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
}
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: aliResponse.Output.Message,
Type: "ali_error",
Param: "",
Code: aliResponse.Output.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAIImage(aliResponse, responseFormat)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, nil
}
func asyncTask(taskID string, key string) (*TaskResponse, error, []byte) {
url := fmt.Sprintf("https://dashscope.aliyuncs.com/api/v1/tasks/%s", taskID)
var aliResponse TaskResponse
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return &aliResponse, err, nil
}
req.Header.Set("Authorization", "Bearer "+key)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
logger.SysError("aliAsyncTask client.Do err: " + err.Error())
return &aliResponse, err, nil
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
var response TaskResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
logger.SysError("aliAsyncTask NewDecoder err: " + err.Error())
return &aliResponse, err, nil
}
return &response, nil, responseBody
}
func asyncTaskWait(taskID string, key string) (*TaskResponse, []byte, error) {
waitSeconds := 2
step := 0
maxStep := 20
var taskResponse TaskResponse
var responseBody []byte
for {
step++
rsp, err, body := asyncTask(taskID, key)
responseBody = body
if err != nil {
return &taskResponse, responseBody, err
}
if rsp.Output.TaskStatus == "" {
return &taskResponse, responseBody, nil
}
switch rsp.Output.TaskStatus {
case "FAILED":
fallthrough
case "CANCELED":
fallthrough
case "SUCCEEDED":
fallthrough
case "UNKNOWN":
return rsp, responseBody, nil
}
if step >= maxStep {
break
}
time.Sleep(time.Duration(waitSeconds) * time.Second)
}
return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
}
func responseAli2OpenAIImage(response *TaskResponse, responseFormat string) *openai.ImageResponse {
imageResponse := openai.ImageResponse{
Created: helper.GetTimestamp(),
}
for _, data := range response.Output.Results {
var b64Json string
if responseFormat == "b64_json" {
// 读取 data.Url 的图片数据并转存到 b64Json
imageData, err := getImageData(data.Url)
if err != nil {
// 处理获取图片数据失败的情况
logger.SysError("getImageData Error getting image data: " + err.Error())
continue
}
// 将图片数据转为 Base64 编码的字符串
b64Json = Base64Encode(imageData)
} else {
// 如果 responseFormat 不是 "b64_json",则直接使用 data.B64Image
b64Json = data.B64Image
}
imageResponse.Data = append(imageResponse.Data, openai.ImageData{
Url: data.Url,
B64Json: b64Json,
RevisedPrompt: "",
})
}
return &imageResponse
}
func getImageData(url string) ([]byte, error) {
response, err := http.Get(url)
if err != nil {
return nil, err
}
defer response.Body.Close()
imageData, err := io.ReadAll(response.Body)
if err != nil {
return nil, err
}
return imageData, nil
}
func Base64Encode(data []byte) string {
b64Json := base64.StdEncoding.EncodeToString(data)
return b64Json
}

154
relay/adaptor/ali/model.go Normal file
View File

@@ -0,0 +1,154 @@
package ali
import (
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
)
type Message struct {
Content string `json:"content"`
Role string `json:"role"`
}
type Input struct {
//Prompt string `json:"prompt"`
Messages []Message `json:"messages"`
}
type Parameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
ResultFormat string `json:"result_format,omitempty"`
Tools []model.Tool `json:"tools,omitempty"`
}
type ChatRequest struct {
Model string `json:"model"`
Input Input `json:"input"`
Parameters Parameters `json:"parameters,omitempty"`
}
type ImageRequest struct {
Model string `json:"model"`
Input struct {
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt,omitempty"`
} `json:"input"`
Parameters struct {
Size string `json:"size,omitempty"`
N int `json:"n,omitempty"`
Steps string `json:"steps,omitempty"`
Scale string `json:"scale,omitempty"`
} `json:"parameters,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}
type TaskResponse struct {
StatusCode int `json:"status_code,omitempty"`
RequestId string `json:"request_id,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Output struct {
TaskId string `json:"task_id,omitempty"`
TaskStatus string `json:"task_status,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Results []struct {
B64Image string `json:"b64_image,omitempty"`
Url string `json:"url,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
} `json:"results,omitempty"`
TaskMetrics struct {
Total int `json:"TOTAL,omitempty"`
Succeeded int `json:"SUCCEEDED,omitempty"`
Failed int `json:"FAILED,omitempty"`
} `json:"task_metrics,omitempty"`
} `json:"output,omitempty"`
Usage Usage `json:"usage"`
}
type Header struct {
Action string `json:"action,omitempty"`
Streaming string `json:"streaming,omitempty"`
TaskID string `json:"task_id,omitempty"`
Event string `json:"event,omitempty"`
ErrorCode string `json:"error_code,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
Attributes any `json:"attributes,omitempty"`
}
type Payload struct {
Model string `json:"model,omitempty"`
Task string `json:"task,omitempty"`
TaskGroup string `json:"task_group,omitempty"`
Function string `json:"function,omitempty"`
Parameters struct {
SampleRate int `json:"sample_rate,omitempty"`
Rate float64 `json:"rate,omitempty"`
Format string `json:"format,omitempty"`
} `json:"parameters,omitempty"`
Input struct {
Text string `json:"text,omitempty"`
} `json:"input,omitempty"`
Usage struct {
Characters int `json:"characters,omitempty"`
} `json:"usage,omitempty"`
}
type WSSMessage struct {
Header Header `json:"header,omitempty"`
Payload Payload `json:"payload,omitempty"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type Embedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type EmbeddingResponse struct {
Output struct {
Embeddings []Embedding `json:"embeddings"`
} `json:"output"`
Usage Usage `json:"usage"`
Error
}
type Error struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Output struct {
//Text string `json:"text"`
//FinishReason string `json:"finish_reason"`
Choices []openai.TextResponseChoice `json:"choices"`
}
type ChatResponse struct {
Output Output `json:"output"`
Usage Usage `json:"usage"`
Error
}

View File

@@ -2,31 +2,31 @@ package anthropic
import (
"fmt"
"github.com/Laisky/errors/v2"
"io"
"net/http"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
// https://docs.anthropic.com/claude/reference/messages_post
// anthopic migrate to Message API
// https://docs.anthropic.com/claude/reference/messages_post
// anthopic migrate to Message API
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-api-key", meta.APIKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
@@ -46,11 +46,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {

View File

@@ -13,7 +13,7 @@ import (
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
)

View File

@@ -0,0 +1,15 @@
package azure
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
)
func GetAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString(config.KeyAPIVersion)
}
return apiVersion
}

View File

@@ -0,0 +1,20 @@
package baidu
var ModelList = []string{
"ERNIE-4.0-8K",
"ERNIE-3.5-8K",
"ERNIE-3.5-8K-0205",
"ERNIE-3.5-8K-1222",
"ERNIE-Bot-8K",
"ERNIE-3.5-4K-0205",
"ERNIE-Speed-8K",
"ERNIE-Speed-128K",
"ERNIE-Lite-8K-0922",
"ERNIE-Lite-8K-0308",
"ERNIE-Tiny-8K",
"BLOOMZ-7B",
"Embedding-V1",
"bge-large-zh",
"bge-large-en",
"tao-8k",
}

View File

@@ -1,4 +1,4 @@
package channel
package adaptor
import (
"io"
@@ -6,10 +6,11 @@ import (
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/util"
"github.com/songquanpeng/one-api/relay/client"
"github.com/songquanpeng/one-api/relay/meta"
)
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) {
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if meta.IsStream && c.Request.Header.Get("Accept") == "" {
@@ -17,7 +18,7 @@ func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.Rela
}
}
func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(meta)
if err != nil {
return nil, errors.Wrap(err, "get request url failed")
@@ -43,7 +44,7 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBod
}
func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := util.HTTPClient.Do(req)
resp, err := client.HTTPClient.Do(req)
if err != nil {
return nil, err
}

View File

@@ -8,21 +8,21 @@ import (
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
channelhelper "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
version := helper.AssignOrDefault(meta.APIVersion, "v1beta")
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
version := helper.AssignOrDefault(meta.APIVersion, "v1")
action := "generateContent"
if meta.IsStream {
action = "streamGenerateContent"
@@ -30,7 +30,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/%s/models/%s:%s?key=%s", meta.BaseURL, version, meta.ActualModelName, action, meta.APIKey), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channelhelper.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-goog-api-key", meta.APIKey)
req.URL.Query().Add("key", meta.APIKey)
@@ -44,11 +44,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText = StreamHandler(c, resp)

View File

@@ -1,21 +1,24 @@
package gemini
import (
"context"
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/gin-gonic/gin"
)
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
@@ -82,13 +85,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
if imageNum > VisionMaxImageNum {
continue
}
mimeType, data, err := image.GetImageFromUrl(part.ImageURL.Url)
if err != nil {
logger.Warn(context.TODO(),
fmt.Sprintf("get image from url %s got %+v", part.ImageURL.Url, err))
continue
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
parts = append(parts, Part{
InlineData: &InlineData{
MimeType: mimeType,
@@ -97,9 +94,6 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
})
}
}
logger.Info(context.TODO(),
fmt.Sprintf("send %d messages to gemini with %d images", len(parts), imageNum))
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
@@ -163,7 +157,7 @@ type ChatPromptFeedback struct {
func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
@@ -196,182 +190,73 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC
return &response
}
// [{
// "candidates": [
// {
// "content": {
// "parts": [
// {
// "text": "```go \n\n// Package ratelimit implements tokens bucket algorithm.\npackage rate"
// }
// ],
// "role": "model"
// },
// "finishReason": "STOP",
// "index": 0,
// "safetyRatings": [
// {
// "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
// "probability": "NEGLIGIBLE"
// },
// {
// "category": "HARM_CATEGORY_HATE_SPEECH",
// "probability": "NEGLIGIBLE"
// },
// {
// "category": "HARM_CATEGORY_HARASSMENT",
// "probability": "NEGLIGIBLE"
// },
// {
// "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
// "probability": "NEGLIGIBLE"
// }
// ]
// }
// ],
// "promptFeedback": {
// "safetyRatings": [
// {
// "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
// "probability": "NEGLIGIBLE"
// },
// {
// "category": "HARM_CATEGORY_HATE_SPEECH",
// "probability": "NEGLIGIBLE"
// },
// {
// "category": "HARM_CATEGORY_HARASSMENT",
// "probability": "NEGLIGIBLE"
// },
// {
// "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
// "probability": "NEGLIGIBLE"
// }
// ]
// }
// }]
type GeminiStreamResp struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
Role string `json:"role"`
} `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
} `json:"candidates"`
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read upstream's body", http.StatusInternalServerError), responseText
}
var respData []GeminiStreamResp
if err = json.Unmarshal(respBody, &respData); err != nil {
return openai.ErrorWrapper(err, "unmarshal upstream's body", http.StatusInternalServerError), responseText
}
for _, chunk := range respData {
for _, cad := range chunk.Candidates {
for _, part := range cad.Content.Parts {
responseText += part.Text
}
dataChan := make(chan string)
stopChan := make(chan bool)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
}
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = responseText
resp2cli, err := json.Marshal(&openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "gemini-pro",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
go func() {
for scanner.Scan() {
data := scanner.Text()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "\"text\": \"") {
continue
}
data = strings.TrimPrefix(data, "\"text\": \"")
data = strings.TrimSuffix(data, "\"")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// this is used to prevent annoying \ related format bug
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
type dummyStruct struct {
Content string `json:"content"`
}
var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content
response := openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "gemini-pro",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "marshal upstream's body", http.StatusInternalServerError), responseText
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(resp2cli)})
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
// dataChan := make(chan string)
// stopChan := make(chan bool)
// scanner := bufio.NewScanner(resp.Body)
// scanner.Split(bufio.ScanLines)
// // scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
// // if atEOF && len(data) == 0 {
// // return 0, nil, nil
// // }
// // if i := strings.Index(string(data), "\n"); i >= 0 {
// // return i + 1, data[0:i], nil
// // }
// // if atEOF {
// // return len(data), data, nil
// // }
// // return 0, nil, nil
// // })
// go func() {
// var content string
// for scanner.Scan() {
// line := strings.TrimSpace(scanner.Text())
// fmt.Printf("> gemini got line: %s\n", line)
// content += line
// // if !strings.HasPrefix(data, "\"text\": \"") {
// // continue
// // }
// // data = strings.TrimPrefix(data, "\"text\": \"")
// // data = strings.TrimSuffix(data, "\"")
// // dataChan <- data
// }
// dataChan <- content
// stopChan <- true
// }()
// common.SetEventStreamHeaders(c)
// c.Stream(func(w io.Writer) bool {
// select {
// case data := <-dataChan:
// // this is used to prevent annoying \ related format bug
// data = fmt.Sprintf("{\"content\": \"%s\"}", data)
// type dummyStruct struct {
// Content string `json:"content"`
// }
// var dummy dummyStruct
// err := json.Unmarshal([]byte(data), &dummy)
// responseText += dummy.Content
// var choice openai.ChatCompletionsStreamResponseChoice
// choice.Delta.Content = dummy.Content
// response := openai.ChatCompletionsStreamResponse{
// Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
// Object: "chat.completion.chunk",
// Created: helper.GetTimestamp(),
// Model: "gemini-pro",
// Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
// }
// jsonResponse, err := json.Marshal(response)
// if err != nil {
// logger.SysError("error marshalling stream response: " + err.Error())
// return true
// }
// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
// return true
// case <-stopChan:
// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
// return false
// }
// })
if err := resp.Body.Close(); err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}

View File

@@ -0,0 +1,21 @@
package adaptor
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
type Adaptor interface {
Init(meta *meta.Meta)
GetRequestURL(meta *meta.Meta) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
ConvertImageRequest(request *model.ImageRequest) (any, error)
DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
GetModelList() []string
GetChannelName() string
}

View File

@@ -0,0 +1,14 @@
package minimax
import (
"fmt"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/relaymode"
)
func GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Mode == relaymode.ChatCompletions {
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil
}
return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode)
}

View File

@@ -1,36 +1,36 @@
package ollama
import (
"errors"
"fmt"
"io"
"net/http"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"github.com/songquanpeng/one-api/relay/relaymode"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
// https://github.com/ollama/ollama/blob/main/docs/api.md
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
if meta.Mode == constant.RelayModeEmbeddings {
if meta.Mode == relaymode.Embeddings {
fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL)
}
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
@@ -40,7 +40,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
case relaymode.Embeddings:
ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request)
return ollamaEmbeddingRequest, nil
default:
@@ -48,16 +48,23 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
}
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
switch meta.Mode {
case constant.RelayModeEmbeddings:
case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)

View File

@@ -5,15 +5,16 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/random"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
)
@@ -51,7 +52,7 @@ func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse {
choice.FinishReason = "stop"
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
@@ -72,7 +73,7 @@ func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompl
choice.FinishReason = &constant.StopFinishReason
}
response := openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: ollamaResponse.Model,

View File

@@ -4,11 +4,12 @@ import (
"fmt"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/minimax"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
@@ -18,13 +19,20 @@ type Adaptor struct {
ChannelType int
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) Init(meta *meta.Meta) {
a.ChannelType = meta.ChannelType
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
switch meta.ChannelType {
case common.ChannelTypeAzure:
case channeltype.Azure:
if meta.Mode == relaymode.ImagesGenerations {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.APIVersion)
return fullRequestURL, nil
}
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
requestURL := strings.Split(meta.RequestURLPath, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion)
@@ -34,22 +42,22 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
//https://github.com/songquanpeng/one-api/issues/1191
// {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version}
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
case common.ChannelTypeMinimax:
return GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
case channeltype.Minimax:
return minimax.GetRequestURL(meta)
default:
return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
if meta.ChannelType == common.ChannelTypeAzure {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
if meta.ChannelType == channeltype.Azure {
req.Header.Set("api-key", meta.APIKey)
return nil
}
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
if meta.ChannelType == common.ChannelTypeOpenRouter {
if meta.ChannelType == channeltype.OpenRouter {
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
req.Header.Set("X-Title", "One API")
}
@@ -63,11 +71,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText, usage = StreamHandler(c, resp, meta.Mode)
@@ -75,7 +90,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
}
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
switch meta.Mode {
case relaymode.ImagesGenerations:
err, _ = ImageHandler(c, resp)
default:
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
}
return
}

View File

@@ -0,0 +1,50 @@
package openai
import (
"github.com/songquanpeng/one-api/relay/adaptor/ai360"
"github.com/songquanpeng/one-api/relay/adaptor/baichuan"
"github.com/songquanpeng/one-api/relay/adaptor/groq"
"github.com/songquanpeng/one-api/relay/adaptor/lingyiwanwu"
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
"github.com/songquanpeng/one-api/relay/adaptor/mistral"
"github.com/songquanpeng/one-api/relay/adaptor/moonshot"
"github.com/songquanpeng/one-api/relay/adaptor/stepfun"
"github.com/songquanpeng/one-api/relay/channeltype"
)
var CompatibleChannels = []int{
channeltype.Azure,
channeltype.AI360,
channeltype.Moonshot,
channeltype.Baichuan,
channeltype.Minimax,
channeltype.Mistral,
channeltype.Groq,
channeltype.LingYiWanWu,
channeltype.StepFun,
}
func GetCompatibleChannelMeta(channelType int) (string, []string) {
switch channelType {
case channeltype.Azure:
return "azure", ModelList
case channeltype.AI360:
return "360", ai360.ModelList
case channeltype.Moonshot:
return "moonshot", moonshot.ModelList
case channeltype.Baichuan:
return "baichuan", baichuan.ModelList
case channeltype.Minimax:
return "minimax", minimax.ModelList
case channeltype.Mistral:
return "mistralai", mistral.ModelList
case channeltype.Groq:
return "groq", groq.ModelList
case channeltype.LingYiWanWu:
return "lingyiwanwu", lingyiwanwu.ModelList
case channeltype.StepFun:
return "stepfun", stepfun.ModelList
default:
return "openai", ModelList
}
}

View File

@@ -0,0 +1,30 @@
package openai
import (
"fmt"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/model"
"strings"
)
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
usage := &model.Usage{}
usage.PromptTokens = promptTokens
usage.CompletionTokens = CountTokenText(responseText, modeName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage
}
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case channeltype.OpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case channeltype.Azure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}
return fullRequestURL
}

View File

@@ -0,0 +1,44 @@
package openai
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var imageResponse ImageResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &imageResponse)
if err != nil {
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, nil
}

View File

@@ -8,8 +8,8 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
@@ -46,7 +46,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
switch relayMode {
case constant.RelayModeChatCompletions:
case relaymode.ChatCompletions:
var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
@@ -59,7 +59,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
if streamResponse.Usage != nil {
usage = streamResponse.Usage
}
case constant.RelayModeCompletions:
case relaymode.Completions:
var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {

View File

@@ -110,11 +110,16 @@ type EmbeddingResponse struct {
model.Usage `json:"usage"`
}
type ImageData struct {
Url string `json:"url,omitempty"`
B64Json string `json:"b64_json,omitempty"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
}
type ImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
}
Created int64 `json:"created"`
Data []ImageData `json:"data"`
//model.Usage `json:"usage"`
}
type ChatCompletionsStreamResponseChoice struct {

View File

@@ -4,10 +4,10 @@ import (
"fmt"
"github.com/Laisky/errors/v2"
"github.com/pkoukk/tiktoken-go"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/model"
"math"
"strings"
@@ -28,7 +28,7 @@ func InitTokenEncoders() {
if err != nil {
logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
for model := range common.ModelRatio {
for model := range billingratio.ModelRatio {
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") {

View File

@@ -4,10 +4,10 @@ import (
"fmt"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
@@ -15,16 +15,16 @@ import (
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-goog-api-key", meta.APIKey)
return nil
}
@@ -36,11 +36,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText = StreamHandler(c, resp)

View File

@@ -7,7 +7,8 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
@@ -74,7 +75,7 @@ func streamResponsePaLM2OpenAI(palmResponse *ChatResponse) *openai.ChatCompletio
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID())
createdTime := helper.GetTimestamp()
dataChan := make(chan string)
stopChan := make(chan bool)

View File

@@ -0,0 +1,7 @@
package stepfun
var ModelList = []string{
"step-1-32k",
"step-1v-32k",
"step-1-200k",
}

View File

@@ -0,0 +1,238 @@
package tencent
// import (
// "bufio"
// "crypto/hmac"
// "crypto/sha1"
// "encoding/base64"
// "encoding/json"
// "github.com/Laisky/errors/v2"
// "fmt"
// "github.com/gin-gonic/gin"
// "github.com/songquanpeng/one-api/common"
// "github.com/songquanpeng/one-api/common/helper"
// "github.com/songquanpeng/one-api/common/logger"
// "github.com/songquanpeng/one-api/relay/channel/openai"
// "github.com/songquanpeng/one-api/relay/constant"
// "github.com/songquanpeng/one-api/relay/model"
// "io"
// "net/http"
// "sort"
// "strconv"
// "strings"
// )
// // https://cloud.tencent.com/document/product/1729/97732
// func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
// messages := make([]Message, 0, len(request.Messages))
// for i := 0; i < len(request.Messages); i++ {
// message := request.Messages[i]
// if message.Role == "system" {
// messages = append(messages, Message{
// Role: "user",
// Content: message.StringContent(),
// })
// messages = append(messages, Message{
// Role: "assistant",
// Content: "Okay",
// })
// continue
// }
// messages = append(messages, Message{
// Content: message.StringContent(),
// Role: message.Role,
// })
// }
// stream := 0
// if request.Stream {
// stream = 1
// }
// return &ChatRequest{
// Timestamp: helper.GetTimestamp(),
// Expired: helper.GetTimestamp() + 24*60*60,
// QueryID: helper.GetUUID(),
// Temperature: request.Temperature,
// TopP: request.TopP,
// Stream: stream,
// Messages: messages,
// }
// }
// func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
// fullTextResponse := openai.TextResponse{
// Object: "chat.completion",
// Created: helper.GetTimestamp(),
// Usage: response.Usage,
// }
// if len(response.Choices) > 0 {
// choice := openai.TextResponseChoice{
// Index: 0,
// Message: model.Message{
// Role: "assistant",
// Content: response.Choices[0].Messages.Content,
// },
// FinishReason: response.Choices[0].FinishReason,
// }
// fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
// }
// return &fullTextResponse
// }
// func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
// response := openai.ChatCompletionsStreamResponse{
// Object: "chat.completion.chunk",
// Created: helper.GetTimestamp(),
// Model: "tencent-hunyuan",
// }
// if len(TencentResponse.Choices) > 0 {
// var choice openai.ChatCompletionsStreamResponseChoice
// choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
// if TencentResponse.Choices[0].FinishReason == "stop" {
// choice.FinishReason = &constant.StopFinishReason
// }
// response.Choices = append(response.Choices, choice)
// }
// return &response
// }
// func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
// var responseText string
// scanner := bufio.NewScanner(resp.Body)
// scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
// if atEOF && len(data) == 0 {
// return 0, nil, nil
// }
// if i := strings.Index(string(data), "\n"); i >= 0 {
// return i + 1, data[0:i], nil
// }
// if atEOF {
// return len(data), data, nil
// }
// return 0, nil, nil
// })
// dataChan := make(chan string)
// stopChan := make(chan bool)
// go func() {
// for scanner.Scan() {
// data := scanner.Text()
// if len(data) < 5 { // ignore blank line or wrong format
// continue
// }
// if data[:5] != "data:" {
// continue
// }
// data = data[5:]
// dataChan <- data
// }
// stopChan <- true
// }()
// common.SetEventStreamHeaders(c)
// c.Stream(func(w io.Writer) bool {
// select {
// case data := <-dataChan:
// var TencentResponse ChatResponse
// err := json.Unmarshal([]byte(data), &TencentResponse)
// if err != nil {
// logger.SysError("error unmarshalling stream response: " + err.Error())
// return true
// }
// response := streamResponseTencent2OpenAI(&TencentResponse)
// if len(response.Choices) != 0 {
// responseText += response.Choices[0].Delta.Content
// }
// jsonResponse, err := json.Marshal(response)
// if err != nil {
// logger.SysError("error marshalling stream response: " + err.Error())
// return true
// }
// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
// return true
// case <-stopChan:
// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
// return false
// }
// })
// err := resp.Body.Close()
// if err != nil {
// return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
// }
// return nil, responseText
// }
// func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
// var TencentResponse ChatResponse
// responseBody, err := io.ReadAll(resp.Body)
// if err != nil {
// return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
// }
// err = resp.Body.Close()
// if err != nil {
// return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
// }
// err = json.Unmarshal(responseBody, &TencentResponse)
// if err != nil {
// return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
// }
// if TencentResponse.Error.Code != 0 {
// return &model.ErrorWithStatusCode{
// Error: model.Error{
// Message: TencentResponse.Error.Message,
// Code: TencentResponse.Error.Code,
// },
// StatusCode: resp.StatusCode,
// }, nil
// }
// fullTextResponse := responseTencent2OpenAI(&TencentResponse)
// fullTextResponse.Model = "hunyuan"
// jsonResponse, err := json.Marshal(fullTextResponse)
// if err != nil {
// return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
// }
// c.Writer.Header().Set("Content-Type", "application/json")
// c.Writer.WriteHeader(resp.StatusCode)
// _, err = c.Writer.Write(jsonResponse)
// if err != nil {
// return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
// }
// return nil, &fullTextResponse.Usage
// }
// func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) {
// parts := strings.Split(config, "|")
// if len(parts) != 3 {
// err = errors.New("invalid tencent config")
// return
// }
// appId, err = strconv.ParseInt(parts[0], 10, 64)
// secretId = parts[1]
// secretKey = parts[2]
// return
// }
// func GetSign(req ChatRequest, secretKey string) string {
// params := make([]string, 0)
// params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
// params = append(params, "secret_id="+req.SecretId)
// params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
// params = append(params, "query_id="+req.QueryID)
// params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
// params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
// params = append(params, "stream="+strconv.Itoa(req.Stream))
// params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
// var messageStr string
// for _, msg := range req.Messages {
// messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
// }
// messageStr = strings.TrimSuffix(messageStr, ",")
// params = append(params, "messages=["+messageStr+"]")
// sort.Strings(params)
// url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
// mac := hmac.New(sha1.New, []byte(secretKey))
// signURL := url
// mac.Write([]byte(signURL))
// sign := mac.Sum([]byte(nil))
// return base64.StdEncoding.EncodeToString(sign)
// }

View File

@@ -0,0 +1,145 @@
package zhipu
// import (
// "github.com/Laisky/errors/v2"
// "fmt"
// "github.com/gin-gonic/gin"
// "github.com/songquanpeng/one-api/relay/adaptor"
// "github.com/songquanpeng/one-api/relay/adaptor/openai"
// "github.com/songquanpeng/one-api/relay/meta"
// "github.com/songquanpeng/one-api/relay/model"
// "github.com/songquanpeng/one-api/relay/relaymode"
// "io"
// "math"
// "net/http"
// "strings"
// )
// type Adaptor struct {
// APIVersion string
// }
// func (a *Adaptor) Init(meta *meta.Meta) {
// }
// func (a *Adaptor) SetVersionByModeName(modelName string) {
// if strings.HasPrefix(modelName, "glm-") {
// a.APIVersion = "v4"
// } else {
// a.APIVersion = "v3"
// }
// }
// func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
// switch meta.Mode {
// case relaymode.ImagesGenerations:
// return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil
// case relaymode.Embeddings:
// return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
// }
// a.SetVersionByModeName(meta.ActualModelName)
// if a.APIVersion == "v4" {
// return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil
// }
// method := "invoke"
// if meta.IsStream {
// method = "sse-invoke"
// }
// return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil
// }
// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
// adaptor.SetupCommonRequestHeader(c, req, meta)
// token := GetToken(meta.APIKey)
// req.Header.Set("Authorization", token)
// return nil
// }
// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
// if request == nil {
// return nil, errors.New("request is nil")
// }
// switch relayMode {
// case relaymode.Embeddings:
// baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
// return baiduEmbeddingRequest, nil
// default:
// // TopP (0.0, 1.0)
// request.TopP = math.Min(0.99, request.TopP)
// request.TopP = math.Max(0.01, request.TopP)
// // Temperature (0.0, 1.0)
// request.Temperature = math.Min(0.99, request.Temperature)
// request.Temperature = math.Max(0.01, request.Temperature)
// a.SetVersionByModeName(request.Model)
// if a.APIVersion == "v4" {
// return request, nil
// }
// return ConvertRequest(*request), nil
// }
// }
// func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
// if request == nil {
// return nil, errors.New("request is nil")
// }
// newRequest := ImageRequest{
// Model: request.Model,
// Prompt: request.Prompt,
// UserId: request.User,
// }
// return newRequest, nil
// }
// func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
// return adaptor.DoRequestHelper(a, c, meta, requestBody)
// }
// func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
// if meta.IsStream {
// err, _, usage = openai.StreamHandler(c, resp, meta.Mode)
// } else {
// err, usage = openai.Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
// }
// return
// }
// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
// switch meta.Mode {
// case relaymode.Embeddings:
// err, usage = EmbeddingsHandler(c, resp)
// return
// case relaymode.ImagesGenerations:
// err, usage = openai.ImageHandler(c, resp)
// return
// }
// if a.APIVersion == "v4" {
// return a.DoResponseV4(c, resp, meta)
// }
// if meta.IsStream {
// err, usage = StreamHandler(c, resp)
// } else {
// if meta.Mode == relaymode.Embeddings {
// err, usage = EmbeddingsHandler(c, resp)
// } else {
// err, usage = Handler(c, resp)
// }
// }
// return
// }
// func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
// return &EmbeddingRequest{
// Model: "embedding-2",
// Input: request.Input.(string),
// }
// }
// func (a *Adaptor) GetModelList() []string {
// return ModelList
// }
// func (a *Adaptor) GetChannelName() string {
// return "zhipu"
// }

17
relay/apitype/define.go Normal file
View File

@@ -0,0 +1,17 @@
package apitype
const (
OpenAI = iota
Anthropic
PaLM
Baidu
Zhipu
Ali
Xunfei
AIProxyLibrary
Tencent
Gemini
Ollama
Dummy // this one is only for count, do not add any channel after this
)

42
relay/billing/billing.go Normal file
View File

@@ -0,0 +1,42 @@
package billing
import (
"context"
"fmt"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
)
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
}
}(ctx)
}
}
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
// quotaDelta is remaining quota to be consumed
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
logger.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(ctx, userId)
if err != nil {
logger.SysError("error update user quota cache: " + err.Error())
}
// totalQuota is total quota consumed
if totalQuota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
model.UpdateChannelUsedQuota(channelId, totalQuota)
}
if totalQuota <= 0 {
logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
}
}

View File

@@ -0,0 +1,34 @@
package ratio
import (
"encoding/json"
"github.com/songquanpeng/one-api/common/logger"
)
var GroupRatio = map[string]float64{
"default": 1,
"vip": 1,
"svip": 1,
}
func GroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(GroupRatio)
if err != nil {
logger.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateGroupRatioByJSONString(jsonStr string) error {
GroupRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &GroupRatio)
}
func GetGroupRatio(name string) float64 {
ratio, ok := GroupRatio[name]
if !ok {
logger.SysError("group ratio not found: " + name)
return 1
}
return ratio
}

View File

@@ -0,0 +1,51 @@
package ratio
var ImageSizeRatios = map[string]map[string]float64{
"dall-e-2": {
"256x256": 1,
"512x512": 1.125,
"1024x1024": 1.25,
},
"dall-e-3": {
"1024x1024": 1,
"1024x1792": 2,
"1792x1024": 2,
},
"ali-stable-diffusion-xl": {
"512x1024": 1,
"1024x768": 1,
"1024x1024": 1,
"576x1024": 1,
"1024x576": 1,
},
"ali-stable-diffusion-v1.5": {
"512x1024": 1,
"1024x768": 1,
"1024x1024": 1,
"576x1024": 1,
"1024x576": 1,
},
"wanx-v1": {
"1024x1024": 1,
"720x1280": 1,
"1280x720": 1,
},
}
var ImageGenerationAmounts = map[string][2]int{
"dall-e-2": {1, 10},
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
"ali-stable-diffusion-xl": {1, 4}, // Ali
"ali-stable-diffusion-v1.5": {1, 4}, // Ali
"wanx-v1": {1, 4}, // Ali
"cogview-3": {1, 1},
}
var ImagePromptLengthLimitations = map[string]int{
"dall-e-2": 1000,
"dall-e-3": 4000,
"ali-stable-diffusion-xl": 4000,
"ali-stable-diffusion-v1.5": 4000,
"wanx-v1": 4000,
"cogview-3": 833,
}

View File

@@ -0,0 +1,281 @@
package ratio
import (
"encoding/json"
"strings"
"github.com/songquanpeng/one-api/common/logger"
)
const (
USD2RMB = 7
USD = 500 // $0.002 = 1 -> $1 = 500
RMB = USD / USD2RMB
)
// ModelRatio
// https://platform.openai.com/docs/models/model-endpoint-compatibility
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
// https://openai.com/pricing
// 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens
var ModelRatio = map[string]float64{
// https://openai.com/pricing
"gpt-4": 15,
"gpt-4-0314": 15,
"gpt-4-0613": 15,
"gpt-4-32k": 30,
"gpt-4-32k-0314": 30,
"gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens
"gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5,
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens
"davinci-002": 1, // $0.002 / 1K tokens
"babbage-002": 0.2, // $0.0004 / 1K tokens
"text-ada-001": 0.2,
"text-babbage-001": 0.25,
"text-curie-001": 1,
"text-davinci-002": 10,
"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // $0.015 / 1K characters
"tts-1-1106": 7.5,
"tts-1-hd": 15, // $0.030 / 1K characters
"tts-1-hd-1106": 15,
"davinci": 10,
"curie": 10,
"babbage": 10,
"ada": 10,
"text-embedding-ada-002": 0.05,
"text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image
"dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image
// https://www.anthropic.com/api#pricing
"claude-instant-1.2": 0.8 / 1000 * USD,
"claude-2.0": 8.0 / 1000 * USD,
"claude-2.1": 8.0 / 1000 * USD,
"claude-3-haiku-20240307": 0.25 / 1000 * USD,
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
"claude-3-opus-20240229": 15.0 / 1000 * USD,
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
"ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB,
"ERNIE-3.5-8K-1222": 0.012 * RMB,
"ERNIE-Bot-8K": 0.024 * RMB,
"ERNIE-3.5-4K-0205": 0.012 * RMB,
"ERNIE-Speed-8K": 0.004 * RMB,
"ERNIE-Speed-128K": 0.004 * RMB,
"ERNIE-Lite-8K-0922": 0.008 * RMB,
"ERNIE-Lite-8K-0308": 0.003 * RMB,
"ERNIE-Tiny-8K": 0.001 * RMB,
"BLOOMZ-7B": 0.004 * RMB,
"Embedding-V1": 0.002 * RMB,
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
// https://ai.google.dev/pricing
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro": 1,
// https://open.bigmodel.cn/pricing
"glm-4": 0.1 * RMB,
"glm-4v": 0.1 * RMB,
"glm-3-turbo": 0.005 * RMB,
"embedding-2": 0.0005 * RMB,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"cogview-3": 0.25 * RMB,
// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens
"qwen-max": 1.4286, // ¥0.02 / 1k tokens
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"ali-stable-diffusion-xl": 8,
"ali-stable-diffusion-v1.5": 8,
"wanx-v1": 8,
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"ChatStd": 0.01 * RMB,
"ChatPro": 0.1 * RMB,
// https://platform.moonshot.cn/pricing
"moonshot-v1-8k": 0.012 * RMB,
"moonshot-v1-32k": 0.024 * RMB,
"moonshot-v1-128k": 0.06 * RMB,
// https://platform.baichuan-ai.com/price
"Baichuan2-Turbo": 0.008 * RMB,
"Baichuan2-Turbo-192k": 0.016 * RMB,
"Baichuan2-53B": 0.02 * RMB,
// https://api.minimax.chat/document/price
"abab6-chat": 0.1 * RMB,
"abab5.5-chat": 0.015 * RMB,
"abab5.5s-chat": 0.005 * RMB,
// https://docs.mistral.ai/platform/pricing/
"open-mistral-7b": 0.25 / 1000 * USD,
"open-mixtral-8x7b": 0.7 / 1000 * USD,
"mistral-small-latest": 2.0 / 1000 * USD,
"mistral-medium-latest": 2.7 / 1000 * USD,
"mistral-large-latest": 8.0 / 1000 * USD,
"mistral-embed": 0.1 / 1000 * USD,
// https://wow.groq.com/
"llama2-70b-4096": 0.7 / 1000 * USD,
"llama2-7b-2048": 0.1 / 1000 * USD,
"mixtral-8x7b-32768": 0.27 / 1000 * USD,
"gemma-7b-it": 0.1 / 1000 * USD,
// https://platform.lingyiwanwu.com/docs#-计费单元
"yi-34b-chat-0205": 2.5 / 1000 * RMB,
"yi-34b-chat-200k": 12.0 / 1000 * RMB,
"yi-vl-plus": 6.0 / 1000 * RMB,
}
var CompletionRatio = map[string]float64{}
var DefaultModelRatio map[string]float64
var DefaultCompletionRatio map[string]float64
func init() {
DefaultModelRatio = make(map[string]float64)
for k, v := range ModelRatio {
DefaultModelRatio[k] = v
}
DefaultCompletionRatio = make(map[string]float64)
for k, v := range CompletionRatio {
DefaultCompletionRatio[k] = v
}
}
func AddNewMissingRatio(oldRatio string) string {
newRatio := make(map[string]float64)
err := json.Unmarshal([]byte(oldRatio), &newRatio)
if err != nil {
logger.SysError("error unmarshalling old ratio: " + err.Error())
return oldRatio
}
for k, v := range DefaultModelRatio {
if _, ok := newRatio[k]; !ok {
newRatio[k] = v
}
}
jsonBytes, err := json.Marshal(newRatio)
if err != nil {
logger.SysError("error marshalling new ratio: " + err.Error())
return oldRatio
}
return string(jsonBytes)
}
func ModelRatio2JSONString() string {
jsonBytes, err := json.Marshal(ModelRatio)
if err != nil {
logger.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateModelRatioByJSONString(jsonStr string) error {
ModelRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &ModelRatio)
}
func GetModelRatio(name string) float64 {
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet")
}
ratio, ok := ModelRatio[name]
if !ok {
ratio, ok = DefaultModelRatio[name]
}
if !ok {
logger.SysError("model ratio not found: " + name)
return 30
}
return ratio
}
func CompletionRatio2JSONString() string {
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
logger.SysError("error marshalling completion ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
}
// GetCompletionRatio returns the completion ratio of a model
//
// completion ratio is the ratio comparing to the ratio of prompt
func GetCompletionRatio(name string) float64 {
if ratio, ok := CompletionRatio[name]; ok {
return ratio
}
if ratio, ok := DefaultCompletionRatio[name]; ok {
return ratio
}
if strings.HasPrefix(name, "gpt-3.5") {
if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") {
// https://openai.com/blog/new-embedding-models-and-api-updates
// Updated GPT-3.5 Turbo model and lower pricing
return 3
}
if strings.HasSuffix(name, "1106") {
return 2
}
return 4.0 / 3.0
}
if strings.HasPrefix(name, "gpt-4") {
if strings.HasSuffix(name, "preview") {
return 3
}
return 2
}
if strings.HasPrefix(name, "claude-3") {
return 5
}
if strings.HasPrefix(name, "claude-") {
return 3
}
if strings.HasPrefix(name, "mistral-") {
return 3
}
if strings.HasPrefix(name, "gemini-") {
return 3
}
switch name {
case "llama2-70b-4096":
return 0.8 / 0.7
}
return 1
}

View File

@@ -1,83 +0,0 @@
package ali
// import (
// "github.com/Laisky/errors/v2"
// "fmt"
// "github.com/gin-gonic/gin"
// "github.com/songquanpeng/one-api/common"
// "github.com/songquanpeng/one-api/relay/channel"
// "github.com/songquanpeng/one-api/relay/constant"
// "github.com/songquanpeng/one-api/relay/model"
// "github.com/songquanpeng/one-api/relay/util"
// "io"
// "net/http"
// )
// // https://help.aliyun.com/zh/dashscope/developer-reference/api-details
// type Adaptor struct {
// }
// func (a *Adaptor) Init(meta *util.RelayMeta) {
// }
// func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
// fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
// if meta.Mode == constant.RelayModeEmbeddings {
// fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
// }
// return fullRequestURL, nil
// }
// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
// channel.SetupCommonRequestHeader(c, req, meta)
// req.Header.Set("Authorization", "Bearer "+meta.APIKey)
// if meta.IsStream {
// req.Header.Set("X-DashScope-SSE", "enable")
// }
// if c.GetString(common.ConfigKeyPlugin) != "" {
// req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin))
// }
// return nil
// }
// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
// if request == nil {
// return nil, errors.New("request is nil")
// }
// switch relayMode {
// case constant.RelayModeEmbeddings:
// baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
// return baiduEmbeddingRequest, nil
// default:
// baiduRequest := ConvertRequest(*request)
// return baiduRequest, nil
// }
// }
// func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
// return channel.DoRequestHelper(a, c, meta, requestBody)
// }
// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
// if meta.IsStream {
// err, usage = StreamHandler(c, resp)
// } else {
// switch meta.Mode {
// case constant.RelayModeEmbeddings:
// err, usage = EmbeddingHandler(c, resp)
// default:
// err, usage = Handler(c, resp)
// }
// }
// return
// }
// func (a *Adaptor) GetModelList() []string {
// return ModelList
// }
// func (a *Adaptor) GetChannelName() string {
// return "ali"
// }

View File

@@ -1,71 +0,0 @@
package ali
// type Message struct {
// Content string `json:"content"`
// Role string `json:"role"`
// }
// type Input struct {
// //Prompt string `json:"prompt"`
// Messages []Message `json:"messages"`
// }
// type Parameters struct {
// TopP float64 `json:"top_p,omitempty"`
// TopK int `json:"top_k,omitempty"`
// Seed uint64 `json:"seed,omitempty"`
// EnableSearch bool `json:"enable_search,omitempty"`
// IncrementalOutput bool `json:"incremental_output,omitempty"`
// }
// type ChatRequest struct {
// Model string `json:"model"`
// Input Input `json:"input"`
// Parameters Parameters `json:"parameters,omitempty"`
// }
// type EmbeddingRequest struct {
// Model string `json:"model"`
// Input struct {
// Texts []string `json:"texts"`
// } `json:"input"`
// Parameters *struct {
// TextType string `json:"text_type,omitempty"`
// } `json:"parameters,omitempty"`
// }
// type Embedding struct {
// Embedding []float64 `json:"embedding"`
// TextIndex int `json:"text_index"`
// }
// type EmbeddingResponse struct {
// Output struct {
// Embeddings []Embedding `json:"embeddings"`
// } `json:"output"`
// Usage Usage `json:"usage"`
// Error
// }
// type Error struct {
// Code string `json:"code"`
// Message string `json:"message"`
// RequestId string `json:"request_id"`
// }
// type Usage struct {
// InputTokens int `json:"input_tokens"`
// OutputTokens int `json:"output_tokens"`
// TotalTokens int `json:"total_tokens"`
// }
// type Output struct {
// Text string `json:"text"`
// FinishReason string `json:"finish_reason"`
// }
// type ChatResponse struct {
// Output Output `json:"output"`
// Usage Usage `json:"usage"`
// Error
// }

View File

@@ -1,13 +0,0 @@
package baidu
var ModelList = []string{
"ERNIE-Bot-4",
"ERNIE-Bot-8K",
"ERNIE-Bot",
"ERNIE-Speed",
"ERNIE-Bot-turbo",
"Embedding-V1",
"bge-large-zh",
"bge-large-en",
"tao-8k",
}

View File

@@ -1,20 +0,0 @@
package channel
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor interface {
Init(meta *util.RelayMeta)
GetRequestURL(meta *util.RelayMeta) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode)
GetModelList() []string
GetChannelName() string
}

View File

@@ -1,16 +0,0 @@
package minimax
import (
"fmt"
"github.com/Laisky/errors/v2"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/util"
)
func GetRequestURL(meta *util.RelayMeta) (string, error) {
if meta.Mode == constant.RelayModeChatCompletions {
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil
}
return "", errors.Errorf("unsupported relay mode %d for minimax", meta.Mode)
}

View File

@@ -1,46 +0,0 @@
package openai
import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel/ai360"
"github.com/songquanpeng/one-api/relay/channel/baichuan"
"github.com/songquanpeng/one-api/relay/channel/groq"
"github.com/songquanpeng/one-api/relay/channel/lingyiwanwu"
"github.com/songquanpeng/one-api/relay/channel/minimax"
"github.com/songquanpeng/one-api/relay/channel/mistral"
"github.com/songquanpeng/one-api/relay/channel/moonshot"
)
var CompatibleChannels = []int{
common.ChannelTypeAzure,
common.ChannelType360,
common.ChannelTypeMoonshot,
common.ChannelTypeBaichuan,
common.ChannelTypeMinimax,
common.ChannelTypeMistral,
common.ChannelTypeGroq,
common.ChannelTypeLingYiWanWu,
}
func GetCompatibleChannelMeta(channelType int) (string, []string) {
switch channelType {
case common.ChannelTypeAzure:
return "azure", ModelList
case common.ChannelType360:
return "360", ai360.ModelList
case common.ChannelTypeMoonshot:
return "moonshot", moonshot.ModelList
case common.ChannelTypeBaichuan:
return "baichuan", baichuan.ModelList
case common.ChannelTypeMinimax:
return "minimax", minimax.ModelList
case common.ChannelTypeMistral:
return "mistralai", mistral.ModelList
case common.ChannelTypeGroq:
return "groq", groq.ModelList
case common.ChannelTypeLingYiWanWu:
return "lingyiwanwu", lingyiwanwu.ModelList
default:
return "openai", ModelList
}
}

View File

@@ -1,11 +0,0 @@
package openai
import "github.com/songquanpeng/one-api/relay/model"
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
usage := &model.Usage{}
usage.PromptTokens = promptTokens
usage.CompletionTokens = CountTokenText(responseText, modeName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage
}

View File

@@ -1,62 +0,0 @@
package zhipu
// import (
// "github.com/Laisky/errors/v2"
// "fmt"
// "github.com/gin-gonic/gin"
// "github.com/songquanpeng/one-api/relay/channel"
// "github.com/songquanpeng/one-api/relay/model"
// "github.com/songquanpeng/one-api/relay/util"
// "io"
// "net/http"
// )
// type Adaptor struct {
// }
// func (a *Adaptor) Init(meta *util.RelayMeta) {
// }
// func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
// method := "invoke"
// if meta.IsStream {
// method = "sse-invoke"
// }
// return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil
// }
// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
// channel.SetupCommonRequestHeader(c, req, meta)
// token := GetToken(meta.APIKey)
// req.Header.Set("Authorization", token)
// return nil
// }
// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
// if request == nil {
// return nil, errors.New("request is nil")
// }
// return ConvertRequest(*request), nil
// }
// func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
// return channel.DoRequestHelper(a, c, meta, requestBody)
// }
// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
// if meta.IsStream {
// err, usage = StreamHandler(c, resp)
// } else {
// err, usage = Handler(c, resp)
// }
// return
// }
// func (a *Adaptor) GetModelList() []string {
// return ModelList
// }
// func (a *Adaptor) GetChannelName() string {
// return "zhipu"
// }

View File

@@ -0,0 +1,39 @@
package channeltype
const (
Unknown = iota
OpenAI
API2D
Azure
CloseAI
OpenAISB
OpenAIMax
OhMyGPT
Custom
Ails
AIProxy
PaLM
API2GPT
AIGC2D
Anthropic
Baidu
Zhipu
Ali
Xunfei
AI360
OpenRouter
AIProxyLibrary
FastGPT
Tencent
Gemini
Moonshot
Baichuan
Minimax
Mistral
Groq
Ollama
LingYiWanWu
StepFun
Dummy
)

View File

@@ -0,0 +1,30 @@
package channeltype
import "github.com/songquanpeng/one-api/relay/apitype"
func ToAPIType(channelType int) int {
apiType := apitype.OpenAI
switch channelType {
case Anthropic:
apiType = apitype.Anthropic
case Baidu:
apiType = apitype.Baidu
case PaLM:
apiType = apitype.PaLM
case Zhipu:
apiType = apitype.Zhipu
case Ali:
apiType = apitype.Ali
case Xunfei:
apiType = apitype.Xunfei
case AIProxyLibrary:
apiType = apitype.AIProxyLibrary
case Tencent:
apiType = apitype.Tencent
case Gemini:
apiType = apitype.Gemini
case Ollama:
apiType = apitype.Ollama
}
return apiType
}

43
relay/channeltype/url.go Normal file
View File

@@ -0,0 +1,43 @@
package channeltype
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"https://generativelanguage.googleapis.com", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", // 23
"https://generativelanguage.googleapis.com", // 24
"https://api.moonshot.cn", // 25
"https://api.baichuan-ai.com", // 26
"https://api.minimax.chat", // 27
"https://api.mistral.ai", // 28
"https://api.groq.com/openai", // 29
"http://localhost:11434", // 30
"https://api.lingyiwanwu.com", // 31
"https://api.stepfun.com", // 32
}
func init() {
if len(ChannelBaseURLs) != Dummy {
panic("channel base urls length not match")
}
}

View File

@@ -1,4 +1,4 @@
package util
package client
import (
"net/http"

View File

@@ -1,48 +0,0 @@
package constant
import (
"github.com/songquanpeng/one-api/common"
)
const (
APITypeOpenAI = iota
APITypeAnthropic
APITypePaLM
APITypeBaidu
APITypeZhipu
APITypeAli
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
APITypeOllama
APITypeDummy // this one is only for count, do not add any channel after this
)
func ChannelType2APIType(channelType int) int {
apiType := APITypeOpenAI
switch channelType {
case common.ChannelTypeAnthropic:
apiType = APITypeAnthropic
case common.ChannelTypeBaidu:
apiType = APITypeBaidu
case common.ChannelTypePaLM:
apiType = APITypePaLM
case common.ChannelTypeZhipu:
apiType = APITypeZhipu
case common.ChannelTypeAli:
apiType = APITypeAli
case common.ChannelTypeXunfei:
apiType = APITypeXunfei
case common.ChannelTypeAIProxyLibrary:
apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent:
apiType = APITypeTencent
case common.ChannelTypeGemini:
apiType = APITypeGemini
case common.ChannelTypeOllama:
apiType = APITypeOllama
}
return apiType
}

View File

@@ -1,24 +0,0 @@
package constant
var DalleSizeRatios = map[string]map[string]float64{
"dall-e-2": {
"256x256": 1,
"512x512": 1.125,
"1024x1024": 1.25,
},
"dall-e-3": {
"1024x1024": 1,
"1024x1792": 2,
"1792x1024": 2,
},
}
var DalleGenerationImageAmounts = map[string][2]int{
"dall-e-2": {1, 10},
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
}
var DalleImagePromptLengthLimitations = map[string]int{
"dall-e-2": 1000,
"dall-e-3": 4000,
}

View File

@@ -1,42 +0,0 @@
package constant
import "strings"
const (
RelayModeUnknown = iota
RelayModeChatCompletions
RelayModeCompletions
RelayModeEmbeddings
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
)
func Path2RelayMode(path string) int {
relayMode := RelayModeUnknown
if strings.HasPrefix(path, "/v1/chat/completions") {
relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(path, "/v1/completions") {
relayMode = RelayModeCompletions
} else if strings.HasPrefix(path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasSuffix(path, "embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(path, "/v1/moderations") {
relayMode = RelayModeModerations
} else if strings.HasPrefix(path, "/v1/images/generations") {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(path, "/v1/edits") {
relayMode = RelayModeEdits
} else if strings.HasPrefix(path, "/v1/audio/speech") {
relayMode = RelayModeAudioSpeech
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
relayMode = RelayModeAudioTranscription
} else if strings.HasPrefix(path, "/v1/audio/translations") {
relayMode = RelayModeAudioTranslation
}
return relayMode
}

View File

@@ -6,20 +6,23 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/adaptor/azure"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/client"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
)
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
@@ -34,7 +37,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
tokenName := c.GetString("token_name")
var ttsRequest openai.TextToSpeechRequest
if relayMode == constant.RelayModeAudioSpeech {
if relayMode == relaymode.AudioSpeech {
// Read JSON
err := common.UnmarshalBodyReusable(c, &ttsRequest)
// Check if JSON is valid
@@ -48,14 +51,15 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
}
modelRatio := common.GetModelRatio(audioModel)
// groupRatio := common.GetGroupRatio(group)
groupRatio := c.GetFloat64("channel_ratio")
modelRatio := billingratio.GetModelRatio(audioModel)
// groupRatio := billingratio.GetGroupRatio(group)
groupRatio := c.GetFloat64("channel_ratio") // get minimal ratio from multiple groups
ratio := modelRatio * groupRatio
var quota int64
var preConsumedQuota int64
switch relayMode {
case constant.RelayModeAudioSpeech:
case relaymode.AudioSpeech:
preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio)
quota = preConsumedQuota
default:
@@ -117,19 +121,19 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
}
baseURL := common.ChannelBaseURLs[channelType]
baseURL := channeltype.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure {
apiVersion := util.GetAzureAPIVersion(c)
if relayMode == constant.RelayModeAudioTranscription {
fullRequestURL := openai.GetFullRequestURL(baseURL, requestURL, channelType)
if channelType == channeltype.Azure {
apiVersion := azure.GetAPIVersion(c)
if relayMode == relaymode.AudioTranscription {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
} else if relayMode == constant.RelayModeAudioSpeech {
} else if relayMode == relaymode.AudioSpeech {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, audioModel, apiVersion)
}
@@ -148,7 +152,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
if (relayMode == constant.RelayModeAudioTranscription || relayMode == constant.RelayModeAudioSpeech) && channelType == common.ChannelTypeAzure {
if (relayMode == relaymode.AudioTranscription || relayMode == relaymode.AudioSpeech) && channelType == channeltype.Azure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
@@ -160,7 +164,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := util.HTTPClient.Do(req)
resp, err := client.HTTPClient.Do(req)
if err != nil {
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
@@ -174,7 +178,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
if relayMode != constant.RelayModeAudioSpeech {
if relayMode != relaymode.AudioSpeech {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
@@ -213,12 +217,12 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
if resp.StatusCode != http.StatusOK {
return util.RelayErrorHandler(resp)
return RelayErrorHandler(resp)
}
succeed = true
quotaDelta := quota - preConsumedQuota
defer func(ctx context.Context) {
go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
}(c.Request.Context())
for k, v := range resp.Header {

91
relay/controller/error.go Normal file
View File

@@ -0,0 +1,91 @@
package controller
import (
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strconv"
)
type GeneralErrorResponse struct {
Error model.Error `json:"error"`
Message string `json:"message"`
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
Header struct {
Message string `json:"message"`
} `json:"header"`
Response struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
} `json:"response"`
}
func (e GeneralErrorResponse) ToMessage() string {
if e.Error.Message != "" {
return e.Error.Message
}
if e.Message != "" {
return e.Message
}
if e.Msg != "" {
return e.Msg
}
if e.Err != "" {
return e.Err
}
if e.ErrorMsg != "" {
return e.ErrorMsg
}
if e.Header.Message != "" {
return e.Header.Message
}
if e.Response.Error.Message != "" {
return e.Response.Error.Message
}
return ""
}
func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *model.ErrorWithStatusCode) {
ErrorWithStatusCode = &model.ErrorWithStatusCode{
StatusCode: resp.StatusCode,
Error: model.Error{
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
if config.DebugEnabled {
logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody)))
}
err = resp.Body.Close()
if err != nil {
return
}
var errResponse GeneralErrorResponse
err = json.Unmarshal(responseBody, &errResponse)
if err != nil {
return
}
if errResponse.Error.Message != "" {
// OpenAI format error, so we override the default one
ErrorWithStatusCode.Error = errResponse.Error
} else {
ErrorWithStatusCode.Error.Message = errResponse.ToMessage()
}
if ErrorWithStatusCode.Error.Message == "" {
ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
}
return
}

View File

@@ -9,10 +9,13 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/controller/validator"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"github.com/songquanpeng/one-api/relay/relaymode"
"math"
"net/http"
)
@@ -23,21 +26,21 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
if err != nil {
return nil, err
}
if relayMode == constant.RelayModeModerations && textRequest.Model == "" {
if relayMode == relaymode.Moderations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" {
if relayMode == relaymode.Embeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
err = util.ValidateTextRequest(textRequest, relayMode)
err = validator.ValidateTextRequest(textRequest, relayMode)
if err != nil {
return nil, err
}
return textRequest, nil
}
func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error) {
imageRequest := &openai.ImageRequest{}
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
imageRequest := &relaymodel.ImageRequest{}
err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
return nil, err
@@ -54,9 +57,25 @@ func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error
return imageRequest, nil
}
func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode {
func isValidImageSize(model string, size string) bool {
if model == "cogview-3" {
return true
}
_, ok := billingratio.ImageSizeRatios[model][size]
return ok
}
func getImageSizeRatio(model string, size string) float64 {
ratio, ok := billingratio.ImageSizeRatios[model][size]
if !ok {
return 1
}
return ratio
}
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
// model validation
_, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size]
hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size)
if !hasValidSize {
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
@@ -64,27 +83,24 @@ func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMet
if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
if len(imageRequest.Prompt) > constant.DalleImagePromptLengthLimitations[imageRequest.Model] {
if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] {
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// Number of generated images validation
if !isWithinRange(imageRequest.Model, imageRequest.N) {
// channel not azure
if meta.ChannelType != common.ChannelTypeAzure {
if meta.ChannelType != channeltype.Azure {
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
}
return nil
}
func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) {
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
if imageRequest == nil {
return 0, errors.New("imageRequest is nil")
}
imageCostRatio, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size]
if !hasValidSize {
return 0, errors.Errorf("size not supported for this image model: %s", imageRequest.Size)
}
imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
if imageRequest.Size == "1024x1024" {
imageCostRatio *= 2
@@ -97,11 +113,11 @@ func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) {
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
switch relayMode {
case constant.RelayModeChatCompletions:
case relaymode.ChatCompletions:
return openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
case constant.RelayModeCompletions:
case relaymode.Completions:
return openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
case constant.RelayModeModerations:
case relaymode.Moderations:
return openai.CountTokenInput(textRequest.Input, textRequest.Model)
}
return 0
@@ -115,7 +131,7 @@ func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTok
return int64(float64(preConsumedTokens) * ratio)
}
func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int64, *relaymodel.ErrorWithStatusCode) {
func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *meta.Meta) (int64, *relaymodel.ErrorWithStatusCode) {
preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio)
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
@@ -144,13 +160,13 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
return preConsumedQuota, nil
}
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) {
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) {
if usage == nil {
logger.Error(ctx, "usage is nil, which is unexpected")
return
}
var quota int64
completionRatio := common.GetCompletionRatio(textRequest.Model)
completionRatio := billingratio.GetCompletionRatio(textRequest.Model)
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
@@ -178,3 +194,14 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
}
func getMappedModelName(modelName string, mapping map[string]string) (string, bool) {
if mapping == nil {
return modelName, false
}
mappedModelName := mapping[modelName]
if mappedModelName != "" {
return mappedModelName, true
}
return modelName, false
}

View File

@@ -5,35 +5,33 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/Laisky/errors/v2"
"io"
"net/http"
"strings"
"github.com/songquanpeng/one-api/common"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"github.com/gin-gonic/gin"
)
func isWithinRange(element string, value int) bool {
if _, ok := constant.DalleGenerationImageAmounts[element]; !ok {
if _, ok := billingratio.ImageGenerationAmounts[element]; !ok {
return false
}
min := constant.DalleGenerationImageAmounts[element][0]
max := constant.DalleGenerationImageAmounts[element][1]
min := billingratio.ImageGenerationAmounts[element][0]
max := billingratio.ImageGenerationAmounts[element][1]
return value >= min && value <= max
}
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
ctx := c.Request.Context()
meta := util.GetRelayMeta(c)
meta := meta.GetByContext(c)
imageRequest, err := getImageRequest(c, meta.Mode)
if err != nil {
logger.Errorf(ctx, "getImageRequest failed: %s", err.Error())
@@ -43,7 +41,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
// map model name
var isModelMapped bool
meta.OriginModelName = imageRequest.Model
imageRequest.Model, isModelMapped = util.GetMappedModelName(imageRequest.Model, meta.ModelMapping)
imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping)
meta.ActualModelName = imageRequest.Model
// model validation
@@ -57,17 +55,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError)
}
requestURL := c.Request.URL.String()
fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
if meta.ChannelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
apiVersion := util.GetAzureAPIVersion(c)
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, imageRequest.Model, apiVersion)
}
var requestBody io.Reader
if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body
if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body
jsonStr, err := json.Marshal(imageRequest)
if err != nil {
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
@@ -77,9 +66,32 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
requestBody = c.Request.Body
}
modelRatio := common.GetModelRatio(imageRequest.Model)
// groupRatio := common.GetGroupRatio(meta.Group)
adaptor := relay.GetAdaptor(meta.APIType)
if adaptor == nil {
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
}
switch meta.ChannelType {
case channeltype.Ali:
fallthrough
case channeltype.Baidu:
fallthrough
case channeltype.Zhipu:
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
if err != nil {
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
}
jsonStr, err := json.Marshal(finalRequest)
if err != nil {
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
}
modelRatio := billingratio.GetModelRatio(imageRequest.Model)
// groupRatio := billingratio.GetGroupRatio(meta.Group)
groupRatio := c.GetFloat64("channel_ratio") // pre-selected cheapest channel ratio
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
@@ -89,36 +101,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
token := c.Request.Header.Get("Authorization")
if meta.ChannelType == common.ChannelTypeAzure { // Azure authentication
token = strings.TrimPrefix(token, "Bearer ")
req.Header.Set("api-key", token)
} else {
req.Header.Set("Authorization", token)
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := util.HTTPClient.Do(req)
// do request
resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil {
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
var imageResponse openai.ImageResponse
defer func(ctx context.Context) {
if resp.StatusCode != http.StatusOK {
return
@@ -141,34 +130,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
}(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &imageResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
// do response
_, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
return respErr
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
}

View File

@@ -10,18 +10,20 @@ import (
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
)
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
ctx := c.Request.Context()
meta := util.GetRelayMeta(c)
meta := meta.GetByContext(c)
// get & validate textRequest
textRequest, err := getAndValidateTextRequest(c, meta.Mode)
if err != nil {
@@ -33,12 +35,13 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
// map model name
var isModelMapped bool
meta.OriginModelName = textRequest.Model
textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model
// get model ratio & group ratio
modelRatio := common.GetModelRatio(textRequest.Model)
// groupRatio := common.GetGroupRatio(meta.Group)
modelRatio := billingratio.GetModelRatio(textRequest.Model)
// groupRatio := billingratio.GetGroupRatio(meta.Group)
groupRatio := meta.ChannelRatio
ratio := modelRatio * groupRatio
// pre-consume quota
promptTokens := getPromptTokens(textRequest, meta.Mode)
@@ -49,16 +52,16 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
return bizErr
}
adaptor := helper.GetAdaptor(meta.APIType)
adaptor := relay.GetAdaptor(meta.APIType)
if adaptor == nil {
return openai.ErrorWrapper(errors.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
}
// get request body
var requestBody io.Reader
if meta.APIType == constant.APITypeOpenAI {
if meta.APIType == apitype.OpenAI {
// no need to convert request for openai
shouldResetRequestBody := isModelMapped || meta.ChannelType == common.ChannelTypeBaichuan // frequency_penalty 0 is not acceptable for baichuan
shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan
if shouldResetRequestBody {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
@@ -93,10 +96,10 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
}
errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json")
if errorHappened {
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
logger.Error(ctx, fmt.Sprintf("relay text [%d] <- %q %q",
resp.StatusCode, resp.Request.URL.String(), string(requestBodyBytes)))
return util.RelayErrorHandler(resp)
return RelayErrorHandler(resp)
}
meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
@@ -104,7 +107,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return respErr
}
// post-consume quota

View File

@@ -1,10 +1,11 @@
package util
package validator
import (
"github.com/Laisky/errors/v2"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"math"
"github.com/Laisky/errors/v2"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
)
func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int) error {
@@ -15,20 +16,20 @@ func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int)
return errors.New("model is required")
}
switch relayMode {
case constant.RelayModeCompletions:
case relaymode.Completions:
if textRequest.Prompt == "" {
return errors.New("field prompt is required")
}
case constant.RelayModeChatCompletions:
case relaymode.ChatCompletions:
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return errors.New("field messages is required")
}
case constant.RelayModeEmbeddings:
case constant.RelayModeModerations:
case relaymode.Embeddings:
case relaymode.Moderations:
if textRequest.Input == "" {
return errors.New("field input is required")
}
case constant.RelayModeEdits:
case relaymode.Edits:
if textRequest.Instruction == "" {
return errors.New("field instruction is required")
}

View File

@@ -1,40 +0,0 @@
package helper
import (
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/aiproxy"
"github.com/songquanpeng/one-api/relay/channel/anthropic"
"github.com/songquanpeng/one-api/relay/channel/gemini"
"github.com/songquanpeng/one-api/relay/channel/ollama"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channel/palm"
"github.com/songquanpeng/one-api/relay/constant"
)
func GetAdaptor(apiType int) channel.Adaptor {
switch apiType {
case constant.APITypeAIProxyLibrary:
return &aiproxy.Adaptor{}
// case constant.APITypeAli:
// return &ali.Adaptor{}
case constant.APITypeAnthropic:
return &anthropic.Adaptor{}
// case constant.APITypeBaidu:
// return &baidu.Adaptor{}
case constant.APITypeGemini:
return &gemini.Adaptor{}
case constant.APITypeOpenAI:
return &openai.Adaptor{}
case constant.APITypePaLM:
return &palm.Adaptor{}
// case constant.APITypeTencent:
// return &tencent.Adaptor{}
// case constant.APITypeXunfei:
// return &xunfei.Adaptor{}
// case constant.APITypeZhipu:
// return &zhipu.Adaptor{}
case constant.APITypeOllama:
return &ollama.Adaptor{}
}
return nil
}

View File

@@ -1,13 +1,15 @@
package util
package meta
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/adaptor/azure"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/relaymode"
"strings"
)
type RelayMeta struct {
type Meta struct {
Mode int
ChannelType int
ChannelId int
@@ -29,9 +31,9 @@ type RelayMeta struct {
ChannelRatio float64
}
func GetRelayMeta(c *gin.Context) *RelayMeta {
meta := RelayMeta{
Mode: constant.Path2RelayMode(c.Request.URL.Path),
func GetByContext(c *gin.Context) *Meta {
meta := Meta{
Mode: relaymode.GetByPath(c.Request.URL.Path),
ChannelType: c.GetInt("channel"),
ChannelId: c.GetInt("channel_id"),
TokenId: c.GetInt("token_id"),
@@ -40,18 +42,18 @@ func GetRelayMeta(c *gin.Context) *RelayMeta {
Group: c.GetString("group"),
ModelMapping: c.GetStringMapString("model_mapping"),
BaseURL: c.GetString("base_url"),
APIVersion: c.GetString(common.ConfigKeyAPIVersion),
APIVersion: c.GetString(config.KeyAPIVersion),
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Config: nil,
RequestURLPath: c.Request.URL.String(),
ChannelRatio: c.GetFloat64("channel_ratio"),
}
if meta.ChannelType == common.ChannelTypeAzure {
meta.APIVersion = GetAzureAPIVersion(c)
if meta.ChannelType == channeltype.Azure {
meta.APIVersion = azure.GetAPIVersion(c)
}
if meta.BaseURL == "" {
meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType]
meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType]
}
meta.APIType = constant.ChannelType2APIType(meta.ChannelType)
meta.APIType = channeltype.ToAPIType(meta.ChannelType)
return &meta
}

12
relay/model/image.go Normal file
View File

@@ -0,0 +1,12 @@
package model
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
User string `json:"user,omitempty"`
}

14
relay/relaymode/define.go Normal file
View File

@@ -0,0 +1,14 @@
package relaymode
const (
Unknown = iota
ChatCompletions
Completions
Embeddings
Moderations
ImagesGenerations
Edits
AudioSpeech
AudioTranscription
AudioTranslation
)

29
relay/relaymode/helper.go Normal file
View File

@@ -0,0 +1,29 @@
package relaymode
import "strings"
func GetByPath(path string) int {
relayMode := Unknown
if strings.HasPrefix(path, "/v1/chat/completions") {
relayMode = ChatCompletions
} else if strings.HasPrefix(path, "/v1/completions") {
relayMode = Completions
} else if strings.HasPrefix(path, "/v1/embeddings") {
relayMode = Embeddings
} else if strings.HasSuffix(path, "embeddings") {
relayMode = Embeddings
} else if strings.HasPrefix(path, "/v1/moderations") {
relayMode = Moderations
} else if strings.HasPrefix(path, "/v1/images/generations") {
relayMode = ImagesGenerations
} else if strings.HasPrefix(path, "/v1/edits") {
relayMode = Edits
} else if strings.HasPrefix(path, "/v1/audio/speech") {
relayMode = AudioSpeech
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
relayMode = AudioTranscription
} else if strings.HasPrefix(path, "/v1/audio/translations") {
relayMode = AudioTranslation
}
return relayMode
}

View File

@@ -1,19 +0,0 @@
package util
import (
"context"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
)
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
}
}(ctx)
}
}

View File

@@ -1,188 +0,0 @@
package util
import (
"context"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool {
if !config.AutomaticDisableChannelEnabled {
return false
}
if err == nil {
return false
}
if statusCode == http.StatusUnauthorized {
return true
}
switch err.Type {
case "insufficient_quota":
return true
// https://docs.anthropic.com/claude/reference/errors
case "authentication_error":
return true
case "permission_error":
return true
case "forbidden":
return true
}
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
return true
}
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
return true
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
return true
}
return false
}
func ShouldEnableChannel(err error, openAIErr *relaymodel.Error) bool {
if !config.AutomaticEnableChannelEnabled {
return false
}
if err != nil {
return false
}
if openAIErr != nil {
return false
}
return true
}
type GeneralErrorResponse struct {
Error relaymodel.Error `json:"error"`
Message string `json:"message"`
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
Header struct {
Message string `json:"message"`
} `json:"header"`
Response struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
} `json:"response"`
}
func (e GeneralErrorResponse) ToMessage() string {
if e.Error.Message != "" {
return e.Error.Message
}
if e.Message != "" {
return e.Message
}
if e.Msg != "" {
return e.Msg
}
if e.Err != "" {
return e.Err
}
if e.ErrorMsg != "" {
return e.ErrorMsg
}
if e.Header.Message != "" {
return e.Header.Message
}
if e.Response.Error.Message != "" {
return e.Response.Error.Message
}
return ""
}
func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *relaymodel.ErrorWithStatusCode) {
ErrorWithStatusCode = &relaymodel.ErrorWithStatusCode{
StatusCode: resp.StatusCode,
Error: relaymodel.Error{
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
if config.DebugEnabled {
logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody)))
}
err = resp.Body.Close()
if err != nil {
return
}
var errResponse GeneralErrorResponse
err = json.Unmarshal(responseBody, &errResponse)
if err != nil {
return
}
if errResponse.Error.Message != "" {
// OpenAI format error, so we override the default one
ErrorWithStatusCode.Error = errResponse.Error
} else {
ErrorWithStatusCode.Error.Message = errResponse.ToMessage()
}
if ErrorWithStatusCode.Error.Message == "" {
ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
}
return
}
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case common.ChannelTypeOpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case common.ChannelTypeAzure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}
return fullRequestURL
}
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
// quotaDelta is remaining quota to be consumed
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
logger.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(ctx, userId)
if err != nil {
logger.SysError("error update user quota cache: " + err.Error())
}
// totalQuota is total quota consumed
if totalQuota >= 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
model.UpdateChannelUsedQuota(channelId, totalQuota)
}
if totalQuota < 0 {
logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
}
}
func GetAzureAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString(common.ConfigKeyAPIVersion)
}
return apiVersion
}

View File

@@ -1,12 +0,0 @@
package util
func GetMappedModelName(modelName string, mapping map[string]string) (string, bool) {
if mapping == nil {
return modelName, false
}
mappedModelName := mapping[modelName]
if mappedModelName != "" {
return mappedModelName, true
}
return modelName, false
}