mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-08 17:53:41 +08:00
Merge branch 'main' into pr/Laisky/25
This commit is contained in:
70
relay/adaptor.go
Normal file
70
relay/adaptor.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/aiproxy"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/ali"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/aws"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/baidu"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/cloudflare"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/cohere"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/coze"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/deepl"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/ollama"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/palm"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/proxy"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/replicate"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/zhipu"
|
||||
"github.com/songquanpeng/one-api/relay/apitype"
|
||||
)
|
||||
|
||||
func GetAdaptor(apiType int) adaptor.Adaptor {
|
||||
switch apiType {
|
||||
case apitype.AIProxyLibrary:
|
||||
return &aiproxy.Adaptor{}
|
||||
case apitype.Ali:
|
||||
return &ali.Adaptor{}
|
||||
case apitype.Anthropic:
|
||||
return &anthropic.Adaptor{}
|
||||
case apitype.AwsClaude:
|
||||
return &aws.Adaptor{}
|
||||
case apitype.Baidu:
|
||||
return &baidu.Adaptor{}
|
||||
case apitype.Gemini:
|
||||
return &gemini.Adaptor{}
|
||||
case apitype.OpenAI:
|
||||
return &openai.Adaptor{}
|
||||
case apitype.PaLM:
|
||||
return &palm.Adaptor{}
|
||||
case apitype.Tencent:
|
||||
return &tencent.Adaptor{}
|
||||
case apitype.Xunfei:
|
||||
return &xunfei.Adaptor{}
|
||||
case apitype.Zhipu:
|
||||
return &zhipu.Adaptor{}
|
||||
case apitype.Ollama:
|
||||
return &ollama.Adaptor{}
|
||||
case apitype.Coze:
|
||||
return &coze.Adaptor{}
|
||||
case apitype.Cohere:
|
||||
return &cohere.Adaptor{}
|
||||
case apitype.Cloudflare:
|
||||
return &cloudflare.Adaptor{}
|
||||
case apitype.DeepL:
|
||||
return &deepl.Adaptor{}
|
||||
case apitype.VertexAI:
|
||||
return &vertexai.Adaptor{}
|
||||
case apitype.Proxy:
|
||||
return &proxy.Adaptor{}
|
||||
case apitype.Replicate:
|
||||
return &replicate.Adaptor{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
8
relay/adaptor/ai360/constants.go
Normal file
8
relay/adaptor/ai360/constants.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package ai360
|
||||
|
||||
var ModelList = []string{
|
||||
"360GPT_S2_V9",
|
||||
"embedding-bert-512-v1",
|
||||
"embedding_s1_v1",
|
||||
"semantic_similarity_s1_v1",
|
||||
}
|
||||
68
relay/adaptor/aiproxy/adaptor.go
Normal file
68
relay/adaptor/aiproxy/adaptor.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package aiproxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
meta *meta.Meta
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
a.meta = meta
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
aiProxyLibraryRequest := ConvertRequest(*request)
|
||||
aiProxyLibraryRequest.LibraryId = a.meta.Config.LibraryID
|
||||
return aiProxyLibraryRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "aiproxy"
|
||||
}
|
||||
9
relay/adaptor/aiproxy/constants.go
Normal file
9
relay/adaptor/aiproxy/constants.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package aiproxy
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
|
||||
var ModelList = []string{""}
|
||||
|
||||
func init() {
|
||||
ModelList = openai.ModelList
|
||||
}
|
||||
189
relay/adaptor/aiproxy/main.go
Normal file
189
relay/adaptor/aiproxy/main.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package aiproxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/common/render"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
|
||||
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *LibraryRequest {
|
||||
query := ""
|
||||
if len(request.Messages) != 0 {
|
||||
query = request.Messages[len(request.Messages)-1].StringContent()
|
||||
}
|
||||
return &LibraryRequest{
|
||||
Model: request.Model,
|
||||
Stream: request.Stream,
|
||||
Query: query,
|
||||
}
|
||||
}
|
||||
|
||||
func aiProxyDocuments2Markdown(documents []LibraryDocument) string {
|
||||
if len(documents) == 0 {
|
||||
return ""
|
||||
}
|
||||
content := "\n\n参考文档:\n"
|
||||
for i, document := range documents {
|
||||
content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextResponse {
|
||||
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
return &openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
}
|
||||
|
||||
func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = response.Content
|
||||
return &openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: response.Model,
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var usage model.Usage
|
||||
var documents []LibraryDocument
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
|
||||
common.SetEventStreamHeaders(c)
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 || data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
|
||||
var AIProxyLibraryResponse LibraryStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
if len(AIProxyLibraryResponse.Documents) != 0 {
|
||||
documents = AIProxyLibraryResponse.Documents
|
||||
}
|
||||
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
||||
err = render.ObjectData(c, response)
|
||||
if err != nil {
|
||||
logger.SysError(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
|
||||
response := documentsAIProxyLibrary(documents)
|
||||
err := render.ObjectData(c, response)
|
||||
if err != nil {
|
||||
logger.SysError(err.Error())
|
||||
}
|
||||
render.Done(c)
|
||||
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var AIProxyLibraryResponse LibraryResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if AIProxyLibraryResponse.ErrCode != 0 {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: AIProxyLibraryResponse.Message,
|
||||
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
|
||||
Code: AIProxyLibraryResponse.ErrCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
32
relay/adaptor/aiproxy/model.go
Normal file
32
relay/adaptor/aiproxy/model.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package aiproxy
|
||||
|
||||
type LibraryRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
LibraryId string `json:"libraryId"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type LibraryError struct {
|
||||
ErrCode int `json:"errCode"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type LibraryDocument struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type LibraryResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Answer string `json:"answer"`
|
||||
Documents []LibraryDocument `json:"documents"`
|
||||
LibraryError
|
||||
}
|
||||
|
||||
type LibraryStreamResponse struct {
|
||||
Content string `json:"content"`
|
||||
Finish bool `json:"finish"`
|
||||
Model string `json:"model"`
|
||||
Documents []LibraryDocument `json:"documents"`
|
||||
}
|
||||
106
relay/adaptor/ali/adaptor.go
Normal file
106
relay/adaptor/ali/adaptor.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
// https://help.aliyun.com/zh/dashscope/developer-reference/api-details
|
||||
|
||||
type Adaptor struct {
|
||||
meta *meta.Meta
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
a.meta = meta
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
fullRequestURL := ""
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
|
||||
case relaymode.ImagesGenerations:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL)
|
||||
default:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
|
||||
}
|
||||
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
if meta.IsStream {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("X-DashScope-SSE", "enable")
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
|
||||
if meta.Mode == relaymode.ImagesGenerations {
|
||||
req.Header.Set("X-DashScope-Async", "enable")
|
||||
}
|
||||
if a.meta.Config.Plugin != "" {
|
||||
req.Header.Set("X-DashScope-Plugin", a.meta.Config.Plugin)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case relaymode.Embeddings:
|
||||
aliEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
return aliEmbeddingRequest, nil
|
||||
default:
|
||||
aliRequest := ConvertRequest(*request)
|
||||
return aliRequest, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
aliRequest := ConvertImageRequest(*request)
|
||||
return aliRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
err, usage = EmbeddingHandler(c, resp)
|
||||
case relaymode.ImagesGenerations:
|
||||
err, usage = ImageHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "ali"
|
||||
}
|
||||
23
relay/adaptor/ali/constants.go
Normal file
23
relay/adaptor/ali/constants.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package ali
|
||||
|
||||
var ModelList = []string{
|
||||
"qwen-turbo", "qwen-turbo-latest",
|
||||
"qwen-plus", "qwen-plus-latest",
|
||||
"qwen-max", "qwen-max-latest",
|
||||
"qwen-max-longcontext",
|
||||
"qwen-vl-max", "qwen-vl-max-latest", "qwen-vl-plus", "qwen-vl-plus-latest",
|
||||
"qwen-vl-ocr", "qwen-vl-ocr-latest",
|
||||
"qwen-audio-turbo",
|
||||
"qwen-math-plus", "qwen-math-plus-latest", "qwen-math-turbo", "qwen-math-turbo-latest",
|
||||
"qwen-coder-plus", "qwen-coder-plus-latest", "qwen-coder-turbo", "qwen-coder-turbo-latest",
|
||||
"qwq-32b-preview", "qwen2.5-72b-instruct", "qwen2.5-32b-instruct", "qwen2.5-14b-instruct", "qwen2.5-7b-instruct", "qwen2.5-3b-instruct", "qwen2.5-1.5b-instruct", "qwen2.5-0.5b-instruct",
|
||||
"qwen2-72b-instruct", "qwen2-57b-a14b-instruct", "qwen2-7b-instruct", "qwen2-1.5b-instruct", "qwen2-0.5b-instruct",
|
||||
"qwen1.5-110b-chat", "qwen1.5-72b-chat", "qwen1.5-32b-chat", "qwen1.5-14b-chat", "qwen1.5-7b-chat", "qwen1.5-1.8b-chat", "qwen1.5-0.5b-chat",
|
||||
"qwen-72b-chat", "qwen-14b-chat", "qwen-7b-chat", "qwen-1.8b-chat", "qwen-1.8b-longcontext-chat",
|
||||
"qwen2-vl-7b-instruct", "qwen2-vl-2b-instruct", "qwen-vl-v1", "qwen-vl-chat-v1",
|
||||
"qwen2-audio-instruct", "qwen-audio-chat",
|
||||
"qwen2.5-math-72b-instruct", "qwen2.5-math-7b-instruct", "qwen2.5-math-1.5b-instruct", "qwen2-math-72b-instruct", "qwen2-math-7b-instruct", "qwen2-math-1.5b-instruct",
|
||||
"qwen2.5-coder-32b-instruct", "qwen2.5-coder-14b-instruct", "qwen2.5-coder-7b-instruct", "qwen2.5-coder-3b-instruct", "qwen2.5-coder-1.5b-instruct", "qwen2.5-coder-0.5b-instruct",
|
||||
"text-embedding-v1", "text-embedding-v3", "text-embedding-v2", "text-embedding-async-v2", "text-embedding-async-v1",
|
||||
"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1",
|
||||
}
|
||||
193
relay/adaptor/ali/image.go
Normal file
193
relay/adaptor/ali/image.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
responseFormat := c.GetString("response_format")
|
||||
|
||||
var aliTaskResponse TaskResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &aliTaskResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliTaskResponse.Message != "" {
|
||||
logger.SysError("aliAsyncTask err: " + string(responseBody))
|
||||
return openai.ErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
aliResponse, _, err := asyncTaskWait(aliTaskResponse.Output.TaskId, apiKey)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: aliResponse.Output.Message,
|
||||
Type: "ali_error",
|
||||
Param: "",
|
||||
Code: aliResponse.Output.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
fullTextResponse := responseAli2OpenAIImage(aliResponse, responseFormat)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func asyncTask(taskID string, key string) (*TaskResponse, error, []byte) {
|
||||
url := fmt.Sprintf("https://dashscope.aliyuncs.com/api/v1/tasks/%s", taskID)
|
||||
|
||||
var aliResponse TaskResponse
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return &aliResponse, err, nil
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.SysError("aliAsyncTask client.Do err: " + err.Error())
|
||||
return &aliResponse, err, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
var response TaskResponse
|
||||
err = json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
logger.SysError("aliAsyncTask NewDecoder err: " + err.Error())
|
||||
return &aliResponse, err, nil
|
||||
}
|
||||
|
||||
return &response, nil, responseBody
|
||||
}
|
||||
|
||||
func asyncTaskWait(taskID string, key string) (*TaskResponse, []byte, error) {
|
||||
waitSeconds := 2
|
||||
step := 0
|
||||
maxStep := 20
|
||||
|
||||
var taskResponse TaskResponse
|
||||
var responseBody []byte
|
||||
|
||||
for {
|
||||
step++
|
||||
rsp, err, body := asyncTask(taskID, key)
|
||||
responseBody = body
|
||||
if err != nil {
|
||||
return &taskResponse, responseBody, err
|
||||
}
|
||||
|
||||
if rsp.Output.TaskStatus == "" {
|
||||
return &taskResponse, responseBody, nil
|
||||
}
|
||||
|
||||
switch rsp.Output.TaskStatus {
|
||||
case "FAILED":
|
||||
fallthrough
|
||||
case "CANCELED":
|
||||
fallthrough
|
||||
case "SUCCEEDED":
|
||||
fallthrough
|
||||
case "UNKNOWN":
|
||||
return rsp, responseBody, nil
|
||||
}
|
||||
if step >= maxStep {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Duration(waitSeconds) * time.Second)
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
|
||||
}
|
||||
|
||||
func responseAli2OpenAIImage(response *TaskResponse, responseFormat string) *openai.ImageResponse {
|
||||
imageResponse := openai.ImageResponse{
|
||||
Created: helper.GetTimestamp(),
|
||||
}
|
||||
|
||||
for _, data := range response.Output.Results {
|
||||
var b64Json string
|
||||
if responseFormat == "b64_json" {
|
||||
// 读取 data.Url 的图片数据并转存到 b64Json
|
||||
imageData, err := getImageData(data.Url)
|
||||
if err != nil {
|
||||
// 处理获取图片数据失败的情况
|
||||
logger.SysError("getImageData Error getting image data: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
// 将图片数据转为 Base64 编码的字符串
|
||||
b64Json = Base64Encode(imageData)
|
||||
} else {
|
||||
// 如果 responseFormat 不是 "b64_json",则直接使用 data.B64Image
|
||||
b64Json = data.B64Image
|
||||
}
|
||||
|
||||
imageResponse.Data = append(imageResponse.Data, openai.ImageData{
|
||||
Url: data.Url,
|
||||
B64Json: b64Json,
|
||||
RevisedPrompt: "",
|
||||
})
|
||||
}
|
||||
return &imageResponse
|
||||
}
|
||||
|
||||
func getImageData(url string) ([]byte, error) {
|
||||
response, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
imageData, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return imageData, nil
|
||||
}
|
||||
|
||||
func Base64Encode(data []byte) string {
|
||||
b64Json := base64.StdEncoding.EncodeToString(data)
|
||||
return b64Json
|
||||
}
|
||||
267
relay/adaptor/ali/main.go
Normal file
267
relay/adaptor/ali/main.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/render"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
|
||||
|
||||
const EnableSearchModelSuffix = "-internet"
|
||||
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
messages = append(messages, Message{
|
||||
Content: message.StringContent(),
|
||||
Role: strings.ToLower(message.Role),
|
||||
})
|
||||
}
|
||||
enableSearch := false
|
||||
aliModel := request.Model
|
||||
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
|
||||
enableSearch = true
|
||||
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
|
||||
}
|
||||
request.TopP = helper.Float64PtrMax(request.TopP, 0.9999)
|
||||
return &ChatRequest{
|
||||
Model: aliModel,
|
||||
Input: Input{
|
||||
Messages: messages,
|
||||
},
|
||||
Parameters: Parameters{
|
||||
EnableSearch: enableSearch,
|
||||
IncrementalOutput: request.Stream,
|
||||
Seed: uint64(request.Seed),
|
||||
MaxTokens: request.MaxTokens,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
TopK: request.TopK,
|
||||
ResultFormat: "message",
|
||||
Tools: request.Tools,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
return &EmbeddingRequest{
|
||||
Model: request.Model,
|
||||
Input: struct {
|
||||
Texts []string `json:"texts"`
|
||||
}{
|
||||
Texts: request.ParseInput(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertImageRequest(request model.ImageRequest) *ImageRequest {
|
||||
var imageRequest ImageRequest
|
||||
imageRequest.Input.Prompt = request.Prompt
|
||||
imageRequest.Model = request.Model
|
||||
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
|
||||
imageRequest.Parameters.N = request.N
|
||||
imageRequest.ResponseFormat = request.ResponseFormat
|
||||
|
||||
return &imageRequest
|
||||
}
|
||||
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var aliResponse EmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
requestModel := c.GetString(ctxkey.RequestModel)
|
||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
||||
fullTextResponse.Model = requestModel
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
|
||||
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
||||
Model: "text-embedding-v1",
|
||||
Usage: model.Usage{TotalTokens: response.Usage.TotalTokens},
|
||||
}
|
||||
|
||||
for _, item := range response.Output.Embeddings {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||
Object: `embedding`,
|
||||
Index: item.TextIndex,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: response.RequestId,
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: response.Output.Choices,
|
||||
Usage: model.Usage{
|
||||
PromptTokens: response.Usage.InputTokens,
|
||||
CompletionTokens: response.Usage.OutputTokens,
|
||||
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
|
||||
},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
if len(aliResponse.Output.Choices) == 0 {
|
||||
return nil
|
||||
}
|
||||
aliChoice := aliResponse.Output.Choices[0]
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta = aliChoice.Message
|
||||
if aliChoice.FinishReason != "null" {
|
||||
finishReason := aliChoice.FinishReason
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: aliResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "qwen",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var usage model.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
|
||||
common.SetEventStreamHeaders(c)
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 || data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
|
||||
var aliResponse ChatResponse
|
||||
err := json.Unmarshal([]byte(data), &aliResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
if aliResponse.Usage.OutputTokens != 0 {
|
||||
usage.PromptTokens = aliResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||
}
|
||||
response := streamResponseAli2OpenAI(&aliResponse)
|
||||
if response == nil {
|
||||
continue
|
||||
}
|
||||
err = render.ObjectData(c, response)
|
||||
if err != nil {
|
||||
logger.SysError(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
|
||||
render.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
ctx := c.Request.Context()
|
||||
var aliResponse ChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
logger.Debugf(ctx, "response body: %s\n", responseBody)
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if aliResponse.Code != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseAli2OpenAI(&aliResponse)
|
||||
fullTextResponse.Model = "qwen"
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
154
relay/adaptor/ali/model.go
Normal file
154
relay/adaptor/ali/model.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type Input struct {
|
||||
//Prompt string `json:"prompt"`
|
||||
Messages []Message `json:"messages"`
|
||||
}
|
||||
|
||||
type Parameters struct {
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
EnableSearch bool `json:"enable_search,omitempty"`
|
||||
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
ResultFormat string `json:"result_format,omitempty"`
|
||||
Tools []model.Tool `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input Input `json:"input"`
|
||||
Parameters Parameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Prompt string `json:"prompt"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||
} `json:"input"`
|
||||
Parameters struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Steps string `json:"steps,omitempty"`
|
||||
Scale string `json:"scale,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
type TaskResponse struct {
|
||||
StatusCode int `json:"status_code,omitempty"`
|
||||
RequestId string `json:"request_id,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Output struct {
|
||||
TaskId string `json:"task_id,omitempty"`
|
||||
TaskStatus string `json:"task_status,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Results []struct {
|
||||
B64Image string `json:"b64_image,omitempty"`
|
||||
Url string `json:"url,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
} `json:"results,omitempty"`
|
||||
TaskMetrics struct {
|
||||
Total int `json:"TOTAL,omitempty"`
|
||||
Succeeded int `json:"SUCCEEDED,omitempty"`
|
||||
Failed int `json:"FAILED,omitempty"`
|
||||
} `json:"task_metrics,omitempty"`
|
||||
} `json:"output,omitempty"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type Header struct {
|
||||
Action string `json:"action,omitempty"`
|
||||
Streaming string `json:"streaming,omitempty"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
Event string `json:"event,omitempty"`
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
Attributes any `json:"attributes,omitempty"`
|
||||
}
|
||||
|
||||
type Payload struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Task string `json:"task,omitempty"`
|
||||
TaskGroup string `json:"task_group,omitempty"`
|
||||
Function string `json:"function,omitempty"`
|
||||
Parameters struct {
|
||||
SampleRate int `json:"sample_rate,omitempty"`
|
||||
Rate float64 `json:"rate,omitempty"`
|
||||
Format string `json:"format,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
Input struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
} `json:"input,omitempty"`
|
||||
Usage struct {
|
||||
Characters int `json:"characters,omitempty"`
|
||||
} `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type WSSMessage struct {
|
||||
Header Header `json:"header,omitempty"`
|
||||
Payload Payload `json:"payload,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
} `json:"input"`
|
||||
Parameters *struct {
|
||||
TextType string `json:"text_type,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
TextIndex int `json:"text_index"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Output struct {
|
||||
Embeddings []Embedding `json:"embeddings"`
|
||||
} `json:"output"`
|
||||
Usage Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type Output struct {
|
||||
//Text string `json:"text"`
|
||||
//FinishReason string `json:"finish_reason"`
|
||||
Choices []openai.TextResponseChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Output Output `json:"output"`
|
||||
Usage Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
83
relay/adaptor/anthropic/adaptor.go
Normal file
83
relay/adaptor/anthropic/adaptor.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
// https://docs.anthropic.com/claude/reference/messages_post
|
||||
// anthopic migrate to Message API
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("x-api-key", meta.APIKey)
|
||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
anthropicVersion = "2023-06-01"
|
||||
}
|
||||
req.Header.Set("anthropic-version", anthropicVersion)
|
||||
req.Header.Set("anthropic-beta", "messages-2023-12-15")
|
||||
|
||||
// https://x.com/alexalbert__/status/1812921642143900036
|
||||
// claude-3-5-sonnet can support 8k context
|
||||
if strings.HasPrefix(meta.ActualModelName, "claude-3-5-sonnet") {
|
||||
req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
c.Set("claude_model", request.Model)
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "anthropic"
|
||||
}
|
||||
12
relay/adaptor/anthropic/constants.go
Normal file
12
relay/adaptor/anthropic/constants.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package anthropic
|
||||
|
||||
var ModelList = []string{
|
||||
"claude-instant-1.2", "claude-2.0", "claude-2.1",
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-5-sonnet-latest",
|
||||
}
|
||||
378
relay/adaptor/anthropic/main.go
Normal file
378
relay/adaptor/anthropic/main.go
Normal file
@@ -0,0 +1,378 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/render"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/image"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func stopReasonClaude2OpenAI(reason *string) string {
|
||||
if reason == nil {
|
||||
return ""
|
||||
}
|
||||
switch *reason {
|
||||
case "end_turn":
|
||||
return "stop"
|
||||
case "stop_sequence":
|
||||
return "stop"
|
||||
case "max_tokens":
|
||||
return "length"
|
||||
case "tool_use":
|
||||
return "tool_calls"
|
||||
default:
|
||||
return *reason
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
||||
claudeTools := make([]Tool, 0, len(textRequest.Tools))
|
||||
|
||||
for _, tool := range textRequest.Tools {
|
||||
if params, ok := tool.Function.Parameters.(map[string]any); ok {
|
||||
claudeTools = append(claudeTools, Tool{
|
||||
Name: tool.Function.Name,
|
||||
Description: tool.Function.Description,
|
||||
InputSchema: InputSchema{
|
||||
Type: params["type"].(string),
|
||||
Properties: params["properties"],
|
||||
Required: params["required"],
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
claudeRequest := Request{
|
||||
Model: textRequest.Model,
|
||||
MaxTokens: textRequest.MaxTokens,
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
TopK: textRequest.TopK,
|
||||
Stream: textRequest.Stream,
|
||||
Tools: claudeTools,
|
||||
}
|
||||
if len(claudeTools) > 0 {
|
||||
claudeToolChoice := struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}{Type: "auto"} // default value https://docs.anthropic.com/en/docs/build-with-claude/tool-use#controlling-claudes-output
|
||||
if choice, ok := textRequest.ToolChoice.(map[string]any); ok {
|
||||
if function, ok := choice["function"].(map[string]any); ok {
|
||||
claudeToolChoice.Type = "tool"
|
||||
claudeToolChoice.Name = function["name"].(string)
|
||||
}
|
||||
} else if toolChoiceType, ok := textRequest.ToolChoice.(string); ok {
|
||||
if toolChoiceType == "any" {
|
||||
claudeToolChoice.Type = toolChoiceType
|
||||
}
|
||||
}
|
||||
claudeRequest.ToolChoice = claudeToolChoice
|
||||
}
|
||||
if claudeRequest.MaxTokens == 0 {
|
||||
claudeRequest.MaxTokens = 4096
|
||||
}
|
||||
// legacy model name mapping
|
||||
if claudeRequest.Model == "claude-instant-1" {
|
||||
claudeRequest.Model = "claude-instant-1.1"
|
||||
} else if claudeRequest.Model == "claude-2" {
|
||||
claudeRequest.Model = "claude-2.1"
|
||||
}
|
||||
for _, message := range textRequest.Messages {
|
||||
if message.Role == "system" && claudeRequest.System == "" {
|
||||
claudeRequest.System = message.StringContent()
|
||||
continue
|
||||
}
|
||||
claudeMessage := Message{
|
||||
Role: message.Role,
|
||||
}
|
||||
var content Content
|
||||
if message.IsStringContent() {
|
||||
content.Type = "text"
|
||||
content.Text = message.StringContent()
|
||||
if message.Role == "tool" {
|
||||
claudeMessage.Role = "user"
|
||||
content.Type = "tool_result"
|
||||
content.Content = content.Text
|
||||
content.Text = ""
|
||||
content.ToolUseId = message.ToolCallId
|
||||
}
|
||||
claudeMessage.Content = append(claudeMessage.Content, content)
|
||||
for i := range message.ToolCalls {
|
||||
inputParam := make(map[string]any)
|
||||
_ = json.Unmarshal([]byte(message.ToolCalls[i].Function.Arguments.(string)), &inputParam)
|
||||
claudeMessage.Content = append(claudeMessage.Content, Content{
|
||||
Type: "tool_use",
|
||||
Id: message.ToolCalls[i].Id,
|
||||
Name: message.ToolCalls[i].Function.Name,
|
||||
Input: inputParam,
|
||||
})
|
||||
}
|
||||
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
||||
continue
|
||||
}
|
||||
var contents []Content
|
||||
openaiContent := message.ParseContent()
|
||||
for _, part := range openaiContent {
|
||||
var content Content
|
||||
if part.Type == model.ContentTypeText {
|
||||
content.Type = "text"
|
||||
content.Text = part.Text
|
||||
} else if part.Type == model.ContentTypeImageURL {
|
||||
content.Type = "image"
|
||||
content.Source = &ImageSource{
|
||||
Type: "base64",
|
||||
}
|
||||
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
|
||||
content.Source.MediaType = mimeType
|
||||
content.Source.Data = data
|
||||
}
|
||||
contents = append(contents, content)
|
||||
}
|
||||
claudeMessage.Content = contents
|
||||
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
||||
}
|
||||
return &claudeRequest
|
||||
}
|
||||
|
||||
// https://docs.anthropic.com/claude/reference/messages-streaming
|
||||
func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
|
||||
var response *Response
|
||||
var responseText string
|
||||
var stopReason string
|
||||
tools := make([]model.Tool, 0)
|
||||
|
||||
switch claudeResponse.Type {
|
||||
case "message_start":
|
||||
return nil, claudeResponse.Message
|
||||
case "content_block_start":
|
||||
if claudeResponse.ContentBlock != nil {
|
||||
responseText = claudeResponse.ContentBlock.Text
|
||||
if claudeResponse.ContentBlock.Type == "tool_use" {
|
||||
tools = append(tools, model.Tool{
|
||||
Id: claudeResponse.ContentBlock.Id,
|
||||
Type: "function",
|
||||
Function: model.Function{
|
||||
Name: claudeResponse.ContentBlock.Name,
|
||||
Arguments: "",
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
case "content_block_delta":
|
||||
if claudeResponse.Delta != nil {
|
||||
responseText = claudeResponse.Delta.Text
|
||||
if claudeResponse.Delta.Type == "input_json_delta" {
|
||||
tools = append(tools, model.Tool{
|
||||
Function: model.Function{
|
||||
Arguments: claudeResponse.Delta.PartialJson,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
case "message_delta":
|
||||
if claudeResponse.Usage != nil {
|
||||
response = &Response{
|
||||
Usage: *claudeResponse.Usage,
|
||||
}
|
||||
}
|
||||
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
|
||||
stopReason = *claudeResponse.Delta.StopReason
|
||||
}
|
||||
}
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = responseText
|
||||
if len(tools) > 0 {
|
||||
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
|
||||
choice.Delta.ToolCalls = tools
|
||||
}
|
||||
choice.Delta.Role = "assistant"
|
||||
finishReason := stopReasonClaude2OpenAI(&stopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
var openaiResponse openai.ChatCompletionsStreamResponse
|
||||
openaiResponse.Object = "chat.completion.chunk"
|
||||
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
|
||||
return &openaiResponse, response
|
||||
}
|
||||
|
||||
func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
|
||||
var responseText string
|
||||
if len(claudeResponse.Content) > 0 {
|
||||
responseText = claudeResponse.Content[0].Text
|
||||
}
|
||||
tools := make([]model.Tool, 0)
|
||||
for _, v := range claudeResponse.Content {
|
||||
if v.Type == "tool_use" {
|
||||
args, _ := json.Marshal(v.Input)
|
||||
tools = append(tools, model.Tool{
|
||||
Id: v.Id,
|
||||
Type: "function", // compatible with other OpenAI derivative applications
|
||||
Function: model.Function{
|
||||
Name: v.Name,
|
||||
Arguments: string(args),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: responseText,
|
||||
Name: nil,
|
||||
ToolCalls: tools,
|
||||
},
|
||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", claudeResponse.Id),
|
||||
Model: claudeResponse.Model,
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
createdTime := helper.GetTimestamp()
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
common.SetEventStreamHeaders(c)
|
||||
|
||||
var usage model.Usage
|
||||
var modelName string
|
||||
var id string
|
||||
var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 || !strings.HasPrefix(data, "data:") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "data:")
|
||||
data = strings.TrimSpace(data)
|
||||
|
||||
var claudeResponse StreamResponse
|
||||
err := json.Unmarshal([]byte(data), &claudeResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
response, meta := StreamResponseClaude2OpenAI(&claudeResponse)
|
||||
if meta != nil {
|
||||
usage.PromptTokens += meta.Usage.InputTokens
|
||||
usage.CompletionTokens += meta.Usage.OutputTokens
|
||||
if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
|
||||
modelName = meta.Model
|
||||
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
|
||||
continue
|
||||
} else { // finish_reason case
|
||||
if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
|
||||
lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
|
||||
if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
|
||||
lastArgs.Arguments = "{}"
|
||||
response.Choices[len(response.Choices)-1].Delta.Content = nil
|
||||
response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if response == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
response.Id = id
|
||||
response.Model = modelName
|
||||
response.Created = createdTime
|
||||
|
||||
for _, choice := range response.Choices {
|
||||
if len(choice.Delta.ToolCalls) > 0 {
|
||||
lastToolCallChoice = choice
|
||||
}
|
||||
}
|
||||
err = render.ObjectData(c, response)
|
||||
if err != nil {
|
||||
logger.SysError(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
|
||||
render.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var claudeResponse Response
|
||||
err = json.Unmarshal(responseBody, &claudeResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if claudeResponse.Error.Type != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: claudeResponse.Error.Message,
|
||||
Type: claudeResponse.Error.Type,
|
||||
Param: "",
|
||||
Code: claudeResponse.Error.Type,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := ResponseClaude2OpenAI(&claudeResponse)
|
||||
fullTextResponse.Model = modelName
|
||||
usage := model.Usage{
|
||||
PromptTokens: claudeResponse.Usage.InputTokens,
|
||||
CompletionTokens: claudeResponse.Usage.OutputTokens,
|
||||
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
110
relay/adaptor/anthropic/model.go
Normal file
110
relay/adaptor/anthropic/model.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package anthropic
|
||||
|
||||
// https://docs.anthropic.com/claude/reference/messages_post
|
||||
|
||||
type Metadata struct {
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
type ImageSource struct {
|
||||
Type string `json:"type"`
|
||||
MediaType string `json:"media_type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Source *ImageSource `json:"source,omitempty"`
|
||||
// tool_calls
|
||||
Id string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolUseId string `json:"tool_use_id,omitempty"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content []Content `json:"content"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema InputSchema `json:"input_schema"`
|
||||
}
|
||||
|
||||
type InputSchema struct {
|
||||
Type string `json:"type"`
|
||||
Properties any `json:"properties,omitempty"`
|
||||
Required any `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
//Metadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ResponseType string
|
||||
|
||||
const (
|
||||
TypeError ResponseType = "error"
|
||||
TypeStart ResponseType = "message_start"
|
||||
TypeContentStart ResponseType = "content_block_start"
|
||||
TypeContent ResponseType = "content_block_delta"
|
||||
TypePing ResponseType = "ping"
|
||||
TypeContentStop ResponseType = "content_block_stop"
|
||||
TypeMessageDelta ResponseType = "message_delta"
|
||||
TypeMessageStop ResponseType = "message_stop"
|
||||
)
|
||||
|
||||
// https://docs.anthropic.com/claude/reference/messages-streaming
|
||||
type Response struct {
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Content []Content `json:"content"`
|
||||
Model string `json:"model"`
|
||||
StopReason *string `json:"stop_reason"`
|
||||
StopSequence *string `json:"stop_sequence"`
|
||||
Usage Usage `json:"usage"`
|
||||
Error Error `json:"error"`
|
||||
}
|
||||
|
||||
type Delta struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
PartialJson string `json:"partial_json,omitempty"`
|
||||
StopReason *string `json:"stop_reason"`
|
||||
StopSequence *string `json:"stop_sequence"`
|
||||
}
|
||||
|
||||
type StreamResponse struct {
|
||||
Type string `json:"type"`
|
||||
Message *Response `json:"message"`
|
||||
Index int `json:"index"`
|
||||
ContentBlock *Content `json:"content_block"`
|
||||
Delta *Delta `json:"delta"`
|
||||
Usage *Usage `json:"usage"`
|
||||
}
|
||||
84
relay/adaptor/aws/adaptor.go
Normal file
84
relay/adaptor/aws/adaptor.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
var _ adaptor.Adaptor = new(Adaptor)
|
||||
|
||||
type Adaptor struct {
|
||||
awsAdapter utils.AwsAdapter
|
||||
|
||||
Meta *meta.Meta
|
||||
AwsClient *bedrockruntime.Client
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
a.Meta = meta
|
||||
a.AwsClient = bedrockruntime.New(bedrockruntime.Options{
|
||||
Region: meta.Config.Region,
|
||||
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")),
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
adaptor := GetAdaptor(request.Model)
|
||||
if adaptor == nil {
|
||||
return nil, errors.New("adaptor not found")
|
||||
}
|
||||
|
||||
a.awsAdapter = adaptor
|
||||
return adaptor.ConvertRequest(c, relayMode, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if a.awsAdapter == nil {
|
||||
return nil, utils.WrapErr(errors.New("awsAdapter is nil"))
|
||||
}
|
||||
return a.awsAdapter.DoResponse(c, a.AwsClient, meta)
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() (models []string) {
|
||||
for model := range adaptors {
|
||||
models = append(models, model)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "aws"
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
37
relay/adaptor/aws/claude/adapter.go
Normal file
37
relay/adaptor/aws/claude/adapter.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
var _ utils.AwsAdapter = new(Adaptor)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
claudeReq := anthropic.ConvertRequest(*request)
|
||||
c.Set(ctxkey.RequestModel, request.Model)
|
||||
c.Set(ctxkey.ConvertedRequest, claudeReq)
|
||||
return claudeReq, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, awsCli)
|
||||
} else {
|
||||
err, usage = Handler(c, awsCli, meta.ActualModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
207
relay/adaptor/aws/claude/main.go
Normal file
207
relay/adaptor/aws/claude/main.go
Normal file
@@ -0,0 +1,207 @@
|
||||
// Package aws provides the AWS adaptor for the relay service.
|
||||
package aws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/copier"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
|
||||
var AwsModelIDMap = map[string]string{
|
||||
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
||||
"claude-2.0": "anthropic.claude-v2",
|
||||
"claude-2.1": "anthropic.claude-v2:1",
|
||||
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
|
||||
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"claude-3-5-sonnet-latest": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
}
|
||||
|
||||
func awsModelID(requestModel string) (string, error) {
|
||||
if awsModelID, ok := AwsModelIDMap[requestModel]; ok {
|
||||
return awsModelID, nil
|
||||
}
|
||||
|
||||
return "", errors.Errorf("model %s not found", requestModel)
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
||||
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
|
||||
if err != nil {
|
||||
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||
}
|
||||
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
|
||||
if !ok {
|
||||
return utils.WrapErr(errors.New("request not found")), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*anthropic.Request)
|
||||
awsClaudeReq := &Request{
|
||||
AnthropicVersion: "bedrock-2023-05-31",
|
||||
}
|
||||
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
|
||||
return utils.WrapErr(errors.Wrap(err, "copy request")), nil
|
||||
}
|
||||
|
||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil
|
||||
}
|
||||
|
||||
claudeResponse := new(anthropic.Response)
|
||||
err = json.Unmarshal(awsResp.Body, claudeResponse)
|
||||
if err != nil {
|
||||
return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil
|
||||
}
|
||||
|
||||
openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse)
|
||||
openaiResp.Model = modelName
|
||||
usage := relaymodel.Usage{
|
||||
PromptTokens: claudeResponse.Usage.InputTokens,
|
||||
CompletionTokens: claudeResponse.Usage.OutputTokens,
|
||||
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
|
||||
}
|
||||
openaiResp.Usage = usage
|
||||
|
||||
c.JSON(http.StatusOK, openaiResp)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
||||
createdTime := helper.GetTimestamp()
|
||||
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
|
||||
if err != nil {
|
||||
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||
}
|
||||
|
||||
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
|
||||
if !ok {
|
||||
return utils.WrapErr(errors.New("request not found")), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*anthropic.Request)
|
||||
|
||||
awsClaudeReq := &Request{
|
||||
AnthropicVersion: "bedrock-2023-05-31",
|
||||
}
|
||||
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
|
||||
return utils.WrapErr(errors.Wrap(err, "copy request")), nil
|
||||
}
|
||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
|
||||
}
|
||||
stream := awsResp.GetStream()
|
||||
defer stream.Close()
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
var usage relaymodel.Usage
|
||||
var id string
|
||||
var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
|
||||
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
event, ok := <-stream.Events()
|
||||
if !ok {
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
|
||||
switch v := event.(type) {
|
||||
case *types.ResponseStreamMemberChunk:
|
||||
claudeResp := new(anthropic.StreamResponse)
|
||||
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
response, meta := anthropic.StreamResponseClaude2OpenAI(claudeResp)
|
||||
if meta != nil {
|
||||
usage.PromptTokens += meta.Usage.InputTokens
|
||||
usage.CompletionTokens += meta.Usage.OutputTokens
|
||||
if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
|
||||
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
|
||||
return true
|
||||
} else { // finish_reason case
|
||||
if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
|
||||
lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
|
||||
if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
|
||||
lastArgs.Arguments = "{}"
|
||||
response.Choices[len(response.Choices)-1].Delta.Content = nil
|
||||
response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if response == nil {
|
||||
return true
|
||||
}
|
||||
response.Id = id
|
||||
response.Model = c.GetString(ctxkey.OriginalModel)
|
||||
response.Created = createdTime
|
||||
|
||||
for _, choice := range response.Choices {
|
||||
if len(choice.Delta.ToolCalls) > 0 {
|
||||
lastToolCallChoice = choice
|
||||
}
|
||||
}
|
||||
jsonStr, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||
return true
|
||||
case *types.UnknownUnionMember:
|
||||
fmt.Println("unknown tag:", v.Tag)
|
||||
return false
|
||||
default:
|
||||
fmt.Println("union is nil or unknown type")
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
return nil, &usage
|
||||
}
|
||||
20
relay/adaptor/aws/claude/model.go
Normal file
20
relay/adaptor/aws/claude/model.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package aws
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
||||
|
||||
// Request is the request to AWS Claude
|
||||
//
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
type Request struct {
|
||||
// AnthropicVersion should be "bedrock-2023-05-31"
|
||||
AnthropicVersion string `json:"anthropic_version"`
|
||||
Messages []anthropic.Message `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools []anthropic.Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
}
|
||||
37
relay/adaptor/aws/llama3/adapter.go
Normal file
37
relay/adaptor/aws/llama3/adapter.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
var _ utils.AwsAdapter = new(Adaptor)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
llamaReq := ConvertRequest(*request)
|
||||
c.Set(ctxkey.RequestModel, request.Model)
|
||||
c.Set(ctxkey.ConvertedRequest, llamaReq)
|
||||
return llamaReq, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, awsCli)
|
||||
} else {
|
||||
err, usage = Handler(c, awsCli, meta.ActualModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
231
relay/adaptor/aws/llama3/main.go
Normal file
231
relay/adaptor/aws/llama3/main.go
Normal file
@@ -0,0 +1,231 @@
|
||||
// Package aws provides the AWS adaptor for the relay service.
|
||||
package aws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"text/template"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
// Only support llama-3-8b and llama-3-70b instruction models
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
|
||||
var AwsModelIDMap = map[string]string{
|
||||
"llama3-8b-8192": "meta.llama3-8b-instruct-v1:0",
|
||||
"llama3-70b-8192": "meta.llama3-70b-instruct-v1:0",
|
||||
}
|
||||
|
||||
func awsModelID(requestModel string) (string, error) {
|
||||
if awsModelID, ok := AwsModelIDMap[requestModel]; ok {
|
||||
return awsModelID, nil
|
||||
}
|
||||
|
||||
return "", errors.Errorf("model %s not found", requestModel)
|
||||
}
|
||||
|
||||
// promptTemplate with range
|
||||
const promptTemplate = `<|begin_of_text|>{{range .Messages}}<|start_header_id|>{{.Role}}<|end_header_id|>{{.StringContent}}<|eot_id|>{{end}}<|start_header_id|>assistant<|end_header_id|>
|
||||
`
|
||||
|
||||
var promptTpl = template.Must(template.New("llama3-chat").Parse(promptTemplate))
|
||||
|
||||
func RenderPrompt(messages []relaymodel.Message) string {
|
||||
var buf bytes.Buffer
|
||||
err := promptTpl.Execute(&buf, struct{ Messages []relaymodel.Message }{messages})
|
||||
if err != nil {
|
||||
logger.SysError("error rendering prompt messages: " + err.Error())
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request {
|
||||
llamaRequest := Request{
|
||||
MaxGenLen: textRequest.MaxTokens,
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
}
|
||||
if llamaRequest.MaxGenLen == 0 {
|
||||
llamaRequest.MaxGenLen = 2048
|
||||
}
|
||||
prompt := RenderPrompt(textRequest.Messages)
|
||||
llamaRequest.Prompt = prompt
|
||||
return &llamaRequest
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
||||
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
|
||||
if err != nil {
|
||||
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||
}
|
||||
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
llamaReq, ok := c.Get(ctxkey.ConvertedRequest)
|
||||
if !ok {
|
||||
return utils.WrapErr(errors.New("request not found")), nil
|
||||
}
|
||||
|
||||
awsReq.Body, err = json.Marshal(llamaReq)
|
||||
if err != nil {
|
||||
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil
|
||||
}
|
||||
|
||||
var llamaResponse Response
|
||||
err = json.Unmarshal(awsResp.Body, &llamaResponse)
|
||||
if err != nil {
|
||||
return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil
|
||||
}
|
||||
|
||||
openaiResp := ResponseLlama2OpenAI(&llamaResponse)
|
||||
openaiResp.Model = modelName
|
||||
usage := relaymodel.Usage{
|
||||
PromptTokens: llamaResponse.PromptTokenCount,
|
||||
CompletionTokens: llamaResponse.GenerationTokenCount,
|
||||
TotalTokens: llamaResponse.PromptTokenCount + llamaResponse.GenerationTokenCount,
|
||||
}
|
||||
openaiResp.Usage = usage
|
||||
|
||||
c.JSON(http.StatusOK, openaiResp)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func ResponseLlama2OpenAI(llamaResponse *Response) *openai.TextResponse {
|
||||
var responseText string
|
||||
if len(llamaResponse.Generation) > 0 {
|
||||
responseText = llamaResponse.Generation
|
||||
}
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: relaymodel.Message{
|
||||
Role: "assistant",
|
||||
Content: responseText,
|
||||
Name: nil,
|
||||
},
|
||||
FinishReason: llamaResponse.StopReason,
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
||||
createdTime := helper.GetTimestamp()
|
||||
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
|
||||
if err != nil {
|
||||
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||
}
|
||||
|
||||
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
llamaReq, ok := c.Get(ctxkey.ConvertedRequest)
|
||||
if !ok {
|
||||
return utils.WrapErr(errors.New("request not found")), nil
|
||||
}
|
||||
|
||||
awsReq.Body, err = json.Marshal(llamaReq)
|
||||
if err != nil {
|
||||
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
|
||||
}
|
||||
stream := awsResp.GetStream()
|
||||
defer stream.Close()
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
var usage relaymodel.Usage
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
event, ok := <-stream.Events()
|
||||
if !ok {
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
|
||||
switch v := event.(type) {
|
||||
case *types.ResponseStreamMemberChunk:
|
||||
var llamaResp StreamResponse
|
||||
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&llamaResp)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
if llamaResp.PromptTokenCount > 0 {
|
||||
usage.PromptTokens = llamaResp.PromptTokenCount
|
||||
}
|
||||
if llamaResp.StopReason == "stop" {
|
||||
usage.CompletionTokens = llamaResp.GenerationTokenCount
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
response := StreamResponseLlama2OpenAI(&llamaResp)
|
||||
response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID())
|
||||
response.Model = c.GetString(ctxkey.OriginalModel)
|
||||
response.Created = createdTime
|
||||
jsonStr, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||
return true
|
||||
case *types.UnknownUnionMember:
|
||||
fmt.Println("unknown tag:", v.Tag)
|
||||
return false
|
||||
default:
|
||||
fmt.Println("union is nil or unknown type")
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func StreamResponseLlama2OpenAI(llamaResponse *StreamResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = llamaResponse.Generation
|
||||
choice.Delta.Role = "assistant"
|
||||
finishReason := llamaResponse.StopReason
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
var openaiResponse openai.ChatCompletionsStreamResponse
|
||||
openaiResponse.Object = "chat.completion.chunk"
|
||||
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
|
||||
return &openaiResponse
|
||||
}
|
||||
45
relay/adaptor/aws/llama3/main_test.go
Normal file
45
relay/adaptor/aws/llama3/main_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package aws_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
aws "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRenderPrompt(t *testing.T) {
|
||||
messages := []relaymodel.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's your name?",
|
||||
},
|
||||
}
|
||||
prompt := aws.RenderPrompt(messages)
|
||||
expected := `<|begin_of_text|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
`
|
||||
assert.Equal(t, expected, prompt)
|
||||
|
||||
messages = []relaymodel.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "Your name is Kat. You are a detective.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's your name?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Kat",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's your job?",
|
||||
},
|
||||
}
|
||||
prompt = aws.RenderPrompt(messages)
|
||||
expected = `<|begin_of_text|><|start_header_id|>system<|end_header_id|>Your name is Kat. You are a detective.<|eot_id|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>Kat<|eot_id|><|start_header_id|>user<|end_header_id|>What's your job?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
`
|
||||
assert.Equal(t, expected, prompt)
|
||||
}
|
||||
29
relay/adaptor/aws/llama3/model.go
Normal file
29
relay/adaptor/aws/llama3/model.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package aws
|
||||
|
||||
// Request is the request to AWS Llama3
|
||||
//
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
|
||||
type Request struct {
|
||||
Prompt string `json:"prompt"`
|
||||
MaxGenLen int `json:"max_gen_len,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
}
|
||||
|
||||
// Response is the response from AWS Llama3
|
||||
//
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
|
||||
type Response struct {
|
||||
Generation string `json:"generation"`
|
||||
PromptTokenCount int `json:"prompt_token_count"`
|
||||
GenerationTokenCount int `json:"generation_token_count"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
}
|
||||
|
||||
// {'generation': 'Hi', 'prompt_token_count': 15, 'generation_token_count': 1, 'stop_reason': None}
|
||||
type StreamResponse struct {
|
||||
Generation string `json:"generation"`
|
||||
PromptTokenCount int `json:"prompt_token_count"`
|
||||
GenerationTokenCount int `json:"generation_token_count"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
}
|
||||
39
relay/adaptor/aws/registry.go
Normal file
39
relay/adaptor/aws/registry.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
claude "github.com/songquanpeng/one-api/relay/adaptor/aws/claude"
|
||||
llama3 "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
|
||||
)
|
||||
|
||||
type AwsModelType int
|
||||
|
||||
const (
|
||||
AwsClaude AwsModelType = iota + 1
|
||||
AwsLlama3
|
||||
)
|
||||
|
||||
var (
|
||||
adaptors = map[string]AwsModelType{}
|
||||
)
|
||||
|
||||
func init() {
|
||||
for model := range claude.AwsModelIDMap {
|
||||
adaptors[model] = AwsClaude
|
||||
}
|
||||
for model := range llama3.AwsModelIDMap {
|
||||
adaptors[model] = AwsLlama3
|
||||
}
|
||||
}
|
||||
|
||||
func GetAdaptor(model string) utils.AwsAdapter {
|
||||
adaptorType := adaptors[model]
|
||||
switch adaptorType {
|
||||
case AwsClaude:
|
||||
return &claude.Adaptor{}
|
||||
case AwsLlama3:
|
||||
return &llama3.Adaptor{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
51
relay/adaptor/aws/utils/adaptor.go
Normal file
51
relay/adaptor/aws/utils/adaptor.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type AwsAdapter interface {
|
||||
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
|
||||
DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
|
||||
}
|
||||
|
||||
type Adaptor struct {
|
||||
Meta *meta.Meta
|
||||
AwsClient *bedrockruntime.Client
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
a.Meta = meta
|
||||
a.AwsClient = bedrockruntime.New(bedrockruntime.Options{
|
||||
Region: meta.Config.Region,
|
||||
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")),
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
16
relay/adaptor/aws/utils/utils.go
Normal file
16
relay/adaptor/aws/utils/utils.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func WrapErr(err error) *relaymodel.ErrorWithStatusCode {
|
||||
return &relaymodel.ErrorWithStatusCode{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: relaymodel.Error{
|
||||
Message: err.Error(),
|
||||
},
|
||||
}
|
||||
}
|
||||
7
relay/adaptor/baichuan/constants.go
Normal file
7
relay/adaptor/baichuan/constants.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package baichuan
|
||||
|
||||
var ModelList = []string{
|
||||
"Baichuan2-Turbo",
|
||||
"Baichuan2-Turbo-192k",
|
||||
"Baichuan-Text-Embedding",
|
||||
}
|
||||
143
relay/adaptor/baidu/adaptor.go
Normal file
143
relay/adaptor/baidu/adaptor.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
|
||||
suffix := "chat/"
|
||||
if strings.HasPrefix(meta.ActualModelName, "Embedding") {
|
||||
suffix = "embeddings/"
|
||||
}
|
||||
if strings.HasPrefix(meta.ActualModelName, "bge-large") {
|
||||
suffix = "embeddings/"
|
||||
}
|
||||
if strings.HasPrefix(meta.ActualModelName, "tao-8k") {
|
||||
suffix = "embeddings/"
|
||||
}
|
||||
switch meta.ActualModelName {
|
||||
case "ERNIE-4.0":
|
||||
suffix += "completions_pro"
|
||||
case "ERNIE-Bot-4":
|
||||
suffix += "completions_pro"
|
||||
case "ERNIE-Bot":
|
||||
suffix += "completions"
|
||||
case "ERNIE-Bot-turbo":
|
||||
suffix += "eb-instant"
|
||||
case "ERNIE-Speed":
|
||||
suffix += "ernie_speed"
|
||||
case "ERNIE-4.0-8K":
|
||||
suffix += "completions_pro"
|
||||
case "ERNIE-3.5-8K":
|
||||
suffix += "completions"
|
||||
case "ERNIE-3.5-8K-0205":
|
||||
suffix += "ernie-3.5-8k-0205"
|
||||
case "ERNIE-3.5-8K-1222":
|
||||
suffix += "ernie-3.5-8k-1222"
|
||||
case "ERNIE-Bot-8K":
|
||||
suffix += "ernie_bot_8k"
|
||||
case "ERNIE-3.5-4K-0205":
|
||||
suffix += "ernie-3.5-4k-0205"
|
||||
case "ERNIE-Speed-8K":
|
||||
suffix += "ernie_speed"
|
||||
case "ERNIE-Speed-128K":
|
||||
suffix += "ernie-speed-128k"
|
||||
case "ERNIE-Lite-8K-0922":
|
||||
suffix += "eb-instant"
|
||||
case "ERNIE-Lite-8K-0308":
|
||||
suffix += "ernie-lite-8k"
|
||||
case "ERNIE-Tiny-8K":
|
||||
suffix += "ernie-tiny-8k"
|
||||
case "BLOOMZ-7B":
|
||||
suffix += "bloomz_7b1"
|
||||
case "Embedding-V1":
|
||||
suffix += "embedding-v1"
|
||||
case "bge-large-zh":
|
||||
suffix += "bge_large_zh"
|
||||
case "bge-large-en":
|
||||
suffix += "bge_large_en"
|
||||
case "tao-8k":
|
||||
suffix += "tao_8k"
|
||||
default:
|
||||
suffix += strings.ToLower(meta.ActualModelName)
|
||||
}
|
||||
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix)
|
||||
var accessToken string
|
||||
var err error
|
||||
if accessToken, err = GetAccessToken(meta.APIKey); err != nil {
|
||||
return "", err
|
||||
}
|
||||
fullRequestURL += "?access_token=" + accessToken
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case relaymode.Embeddings:
|
||||
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
return baiduEmbeddingRequest, nil
|
||||
default:
|
||||
baiduRequest := ConvertRequest(*request)
|
||||
return baiduRequest, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
err, usage = EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "baidu"
|
||||
}
|
||||
20
relay/adaptor/baidu/constants.go
Normal file
20
relay/adaptor/baidu/constants.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package baidu
|
||||
|
||||
var ModelList = []string{
|
||||
"ERNIE-4.0-8K",
|
||||
"ERNIE-3.5-8K",
|
||||
"ERNIE-3.5-8K-0205",
|
||||
"ERNIE-3.5-8K-1222",
|
||||
"ERNIE-Bot-8K",
|
||||
"ERNIE-3.5-4K-0205",
|
||||
"ERNIE-Speed-8K",
|
||||
"ERNIE-Speed-128K",
|
||||
"ERNIE-Lite-8K-0922",
|
||||
"ERNIE-Lite-8K-0308",
|
||||
"ERNIE-Tiny-8K",
|
||||
"BLOOMZ-7B",
|
||||
"Embedding-V1",
|
||||
"bge-large-zh",
|
||||
"bge-large-en",
|
||||
"tao-8k",
|
||||
}
|
||||
312
relay/adaptor/baidu/main.go
Normal file
312
relay/adaptor/baidu/main.go
Normal file
@@ -0,0 +1,312 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/client"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/render"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
||||
|
||||
type TokenResponse struct {
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Messages []Message `json:"messages"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
PenaltyScore *float64 `json:"penalty_score,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
DisableSearch bool `json:"disable_search,omitempty"`
|
||||
EnableCitation bool `json:"enable_citation,omitempty"`
|
||||
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
ErrorCode int `json:"error_code"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
}
|
||||
|
||||
var baiduTokenStore sync.Map
|
||||
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
baiduRequest := ChatRequest{
|
||||
Messages: make([]Message, 0, len(request.Messages)),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
PenaltyScore: request.FrequencyPenalty,
|
||||
Stream: request.Stream,
|
||||
DisableSearch: false,
|
||||
EnableCitation: false,
|
||||
MaxOutputTokens: request.MaxTokens,
|
||||
UserId: request.User,
|
||||
}
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
baiduRequest.System = message.StringContent()
|
||||
} else {
|
||||
baiduRequest.Messages = append(baiduRequest.Messages, Message{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return &baiduRequest
|
||||
}
|
||||
|
||||
func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Result,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: response.Id,
|
||||
Object: "chat.completion",
|
||||
Created: response.Created,
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
Usage: response.Usage,
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = baiduResponse.Result
|
||||
if baiduResponse.IsEnd {
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
}
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: baiduResponse.Id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: baiduResponse.Created,
|
||||
Model: "ernie-bot",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
return &EmbeddingRequest{
|
||||
Input: request.ParseInput(),
|
||||
}
|
||||
}
|
||||
|
||||
func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
|
||||
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)),
|
||||
Model: "baidu-embedding",
|
||||
Usage: response.Usage,
|
||||
}
|
||||
for _, item := range response.Data {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||
Object: item.Object,
|
||||
Index: item.Index,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var usage model.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
common.SetEventStreamHeaders(c)
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 {
|
||||
continue
|
||||
}
|
||||
data = data[6:]
|
||||
|
||||
var baiduResponse ChatStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
if baiduResponse.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||
}
|
||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
err = render.ObjectData(c, response)
|
||||
if err != nil {
|
||||
logger.SysError(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
|
||||
render.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var baiduResponse ChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
||||
fullTextResponse.Model = "ernie-bot"
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var baiduResponse EmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func GetAccessToken(apiKey string) (string, error) {
|
||||
if val, ok := baiduTokenStore.Load(apiKey); ok {
|
||||
var accessToken AccessToken
|
||||
if accessToken, ok = val.(AccessToken); ok {
|
||||
// soon this will expire
|
||||
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
|
||||
go func() {
|
||||
_, _ = getBaiduAccessTokenHelper(apiKey)
|
||||
}()
|
||||
}
|
||||
return accessToken.AccessToken, nil
|
||||
}
|
||||
}
|
||||
accessToken, err := getBaiduAccessTokenHelper(apiKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if accessToken == nil {
|
||||
return "", errors.New("GetAccessToken return a nil token")
|
||||
}
|
||||
return (*accessToken).AccessToken, nil
|
||||
}
|
||||
|
||||
func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) {
|
||||
parts := strings.Split(apiKey, "|")
|
||||
if len(parts) != 2 {
|
||||
return nil, errors.New("invalid baidu apikey")
|
||||
}
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
|
||||
parts[0], parts[1]), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Accept", "application/json")
|
||||
res, err := client.ImpatientHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
var accessToken AccessToken
|
||||
err = json.NewDecoder(res.Body).Decode(&accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accessToken.Error != "" {
|
||||
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
|
||||
}
|
||||
if accessToken.AccessToken == "" {
|
||||
return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
|
||||
}
|
||||
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
|
||||
baiduTokenStore.Store(apiKey, accessToken)
|
||||
return &accessToken, nil
|
||||
}
|
||||
51
relay/adaptor/baidu/model.go
Normal file
51
relay/adaptor/baidu/model.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type ChatResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Result string `json:"result"`
|
||||
IsTruncated bool `json:"is_truncated"`
|
||||
NeedClearHistory bool `json:"need_clear_history"`
|
||||
Usage model.Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
|
||||
type ChatStreamResponse struct {
|
||||
ChatResponse
|
||||
SentenceId int `json:"sentence_id"`
|
||||
IsEnd bool `json:"is_end"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
type EmbeddingData struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Data []EmbeddingData `json:"data"`
|
||||
Usage model.Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
|
||||
type AccessToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Error string `json:"error,omitempty"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
ExpiresAt time.Time `json:"-"`
|
||||
}
|
||||
100
relay/adaptor/cloudflare/adaptor.go
Normal file
100
relay/adaptor/cloudflare/adaptor.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
meta *meta.Meta
|
||||
}
|
||||
|
||||
// ConvertImageRequest implements adaptor.Adaptor.
|
||||
func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// ConvertImageRequest implements adaptor.Adaptor.
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
a.meta = meta
|
||||
}
|
||||
|
||||
// WorkerAI cannot be used across accounts with AIGateWay
|
||||
// https://developers.cloudflare.com/ai-gateway/providers/workersai/#openai-compatible-endpoints
|
||||
// https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/workers-ai
|
||||
func (a *Adaptor) isAIGateWay(baseURL string) bool {
|
||||
return strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") && strings.HasSuffix(baseURL, "/workers-ai")
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
isAIGateWay := a.isAIGateWay(meta.BaseURL)
|
||||
var urlPrefix string
|
||||
if isAIGateWay {
|
||||
urlPrefix = meta.BaseURL
|
||||
} else {
|
||||
urlPrefix = fmt.Sprintf("%s/client/v4/accounts/%s/ai", meta.BaseURL, meta.Config.UserID)
|
||||
}
|
||||
|
||||
switch meta.Mode {
|
||||
case relaymode.ChatCompletions:
|
||||
return fmt.Sprintf("%s/v1/chat/completions", urlPrefix), nil
|
||||
case relaymode.Embeddings:
|
||||
return fmt.Sprintf("%s/v1/embeddings", urlPrefix), nil
|
||||
default:
|
||||
if isAIGateWay {
|
||||
return fmt.Sprintf("%s/%s", urlPrefix, meta.ActualModelName), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/run/%s", urlPrefix, meta.ActualModelName), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case relaymode.Completions:
|
||||
return ConvertCompletionsRequest(*request), nil
|
||||
case relaymode.ChatCompletions, relaymode.Embeddings:
|
||||
return request, nil
|
||||
default:
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
} else {
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "cloudflare"
|
||||
}
|
||||
37
relay/adaptor/cloudflare/constant.go
Normal file
37
relay/adaptor/cloudflare/constant.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package cloudflare
|
||||
|
||||
var ModelList = []string{
|
||||
"@cf/meta/llama-3.1-8b-instruct",
|
||||
"@cf/meta/llama-2-7b-chat-fp16",
|
||||
"@cf/meta/llama-2-7b-chat-int8",
|
||||
"@cf/mistral/mistral-7b-instruct-v0.1",
|
||||
"@hf/thebloke/deepseek-coder-6.7b-base-awq",
|
||||
"@hf/thebloke/deepseek-coder-6.7b-instruct-awq",
|
||||
"@cf/deepseek-ai/deepseek-math-7b-base",
|
||||
"@cf/deepseek-ai/deepseek-math-7b-instruct",
|
||||
"@cf/thebloke/discolm-german-7b-v1-awq",
|
||||
"@cf/tiiuae/falcon-7b-instruct",
|
||||
"@cf/google/gemma-2b-it-lora",
|
||||
"@hf/google/gemma-7b-it",
|
||||
"@cf/google/gemma-7b-it-lora",
|
||||
"@hf/nousresearch/hermes-2-pro-mistral-7b",
|
||||
"@hf/thebloke/llama-2-13b-chat-awq",
|
||||
"@cf/meta-llama/llama-2-7b-chat-hf-lora",
|
||||
"@cf/meta/llama-3-8b-instruct",
|
||||
"@hf/thebloke/llamaguard-7b-awq",
|
||||
"@hf/thebloke/mistral-7b-instruct-v0.1-awq",
|
||||
"@hf/mistralai/mistral-7b-instruct-v0.2",
|
||||
"@cf/mistral/mistral-7b-instruct-v0.2-lora",
|
||||
"@hf/thebloke/neural-chat-7b-v3-1-awq",
|
||||
"@cf/openchat/openchat-3.5-0106",
|
||||
"@hf/thebloke/openhermes-2.5-mistral-7b-awq",
|
||||
"@cf/microsoft/phi-2",
|
||||
"@cf/qwen/qwen1.5-0.5b-chat",
|
||||
"@cf/qwen/qwen1.5-1.8b-chat",
|
||||
"@cf/qwen/qwen1.5-14b-chat-awq",
|
||||
"@cf/qwen/qwen1.5-7b-chat-awq",
|
||||
"@cf/defog/sqlcoder-7b-2",
|
||||
"@hf/nexusflow/starling-lm-7b-beta",
|
||||
"@cf/tinyllama/tinyllama-1.1b-chat-v1.0",
|
||||
"@hf/thebloke/zephyr-7b-beta-awq",
|
||||
}
|
||||
115
relay/adaptor/cloudflare/main.go
Normal file
115
relay/adaptor/cloudflare/main.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/render"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func ConvertCompletionsRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
||||
p, _ := textRequest.Prompt.(string)
|
||||
return &Request{
|
||||
Prompt: p,
|
||||
MaxTokens: textRequest.MaxTokens,
|
||||
Stream: textRequest.Stream,
|
||||
Temperature: textRequest.Temperature,
|
||||
}
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
common.SetEventStreamHeaders(c)
|
||||
id := helper.GetResponseID(c)
|
||||
responseModel := c.GetString(ctxkey.OriginalModel)
|
||||
var responseText string
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < len("data: ") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "data: ")
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
var response openai.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &response)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
for _, v := range response.Choices {
|
||||
v.Delta.Role = "assistant"
|
||||
responseText += v.Delta.StringContent()
|
||||
}
|
||||
response.Id = id
|
||||
response.Model = modelName
|
||||
err = render.ObjectData(c, response)
|
||||
if err != nil {
|
||||
logger.SysError(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
|
||||
render.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
usage := openai.ResponseText2Usage(responseText, responseModel, promptTokens)
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var response openai.TextResponse
|
||||
err = json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
response.Model = modelName
|
||||
var responseText string
|
||||
for _, v := range response.Choices {
|
||||
responseText += v.Message.Content.(string)
|
||||
}
|
||||
usage := openai.ResponseText2Usage(responseText, modelName, promptTokens)
|
||||
response.Usage = *usage
|
||||
response.Id = helper.GetResponseID(c)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
return nil, usage
|
||||
}
|
||||
13
relay/adaptor/cloudflare/model.go
Normal file
13
relay/adaptor/cloudflare/model.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package cloudflare
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/model"
|
||||
|
||||
type Request struct {
|
||||
Messages []model.Message `json:"messages,omitempty"`
|
||||
Lora string `json:"lora,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
}
|
||||
64
relay/adaptor/cohere/adaptor.go
Normal file
64
relay/adaptor/cohere/adaptor.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type Adaptor struct{}
|
||||
|
||||
// ConvertImageRequest implements adaptor.Adaptor.
|
||||
func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// ConvertImageRequest implements adaptor.Adaptor.
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/chat", meta.BaseURL), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "Cohere"
|
||||
}
|
||||
14
relay/adaptor/cohere/constant.go
Normal file
14
relay/adaptor/cohere/constant.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package cohere
|
||||
|
||||
var ModelList = []string{
|
||||
"command", "command-nightly",
|
||||
"command-light", "command-light-nightly",
|
||||
"command-r", "command-r-plus",
|
||||
}
|
||||
|
||||
func init() {
|
||||
num := len(ModelList)
|
||||
for i := 0; i < num; i++ {
|
||||
ModelList = append(ModelList, ModelList[i]+"-internet")
|
||||
}
|
||||
}
|
||||
228
relay/adaptor/cohere/main.go
Normal file
228
relay/adaptor/cohere/main.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/render"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
var (
|
||||
WebSearchConnector = Connector{ID: "web-search"}
|
||||
)
|
||||
|
||||
func stopReasonCohere2OpenAI(reason *string) string {
|
||||
if reason == nil {
|
||||
return ""
|
||||
}
|
||||
switch *reason {
|
||||
case "COMPLETE":
|
||||
return "stop"
|
||||
default:
|
||||
return *reason
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
||||
cohereRequest := Request{
|
||||
Model: textRequest.Model,
|
||||
Message: "",
|
||||
MaxTokens: textRequest.MaxTokens,
|
||||
Temperature: textRequest.Temperature,
|
||||
P: textRequest.TopP,
|
||||
K: textRequest.TopK,
|
||||
Stream: textRequest.Stream,
|
||||
FrequencyPenalty: textRequest.FrequencyPenalty,
|
||||
PresencePenalty: textRequest.PresencePenalty,
|
||||
Seed: int(textRequest.Seed),
|
||||
}
|
||||
if cohereRequest.Model == "" {
|
||||
cohereRequest.Model = "command-r"
|
||||
}
|
||||
if strings.HasSuffix(cohereRequest.Model, "-internet") {
|
||||
cohereRequest.Model = strings.TrimSuffix(cohereRequest.Model, "-internet")
|
||||
cohereRequest.Connectors = append(cohereRequest.Connectors, WebSearchConnector)
|
||||
}
|
||||
for _, message := range textRequest.Messages {
|
||||
if message.Role == "user" {
|
||||
cohereRequest.Message = message.Content.(string)
|
||||
} else {
|
||||
var role string
|
||||
if message.Role == "assistant" {
|
||||
role = "CHATBOT"
|
||||
} else if message.Role == "system" {
|
||||
role = "SYSTEM"
|
||||
} else {
|
||||
role = "USER"
|
||||
}
|
||||
cohereRequest.ChatHistory = append(cohereRequest.ChatHistory, ChatMessage{
|
||||
Role: role,
|
||||
Message: message.Content.(string),
|
||||
})
|
||||
}
|
||||
}
|
||||
return &cohereRequest
|
||||
}
|
||||
|
||||
func StreamResponseCohere2OpenAI(cohereResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
|
||||
var response *Response
|
||||
var responseText string
|
||||
var finishReason string
|
||||
|
||||
switch cohereResponse.EventType {
|
||||
case "stream-start":
|
||||
return nil, nil
|
||||
case "text-generation":
|
||||
responseText += cohereResponse.Text
|
||||
case "stream-end":
|
||||
usage := cohereResponse.Response.Meta.Tokens
|
||||
response = &Response{
|
||||
Meta: Meta{
|
||||
Tokens: Usage{
|
||||
InputTokens: usage.InputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
},
|
||||
},
|
||||
}
|
||||
finishReason = *cohereResponse.Response.FinishReason
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = responseText
|
||||
choice.Delta.Role = "assistant"
|
||||
if finishReason != "" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
var openaiResponse openai.ChatCompletionsStreamResponse
|
||||
openaiResponse.Object = "chat.completion.chunk"
|
||||
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
|
||||
return &openaiResponse, response
|
||||
}
|
||||
|
||||
func ResponseCohere2OpenAI(cohereResponse *Response) *openai.TextResponse {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: cohereResponse.Text,
|
||||
Name: nil,
|
||||
},
|
||||
FinishReason: stopReasonCohere2OpenAI(cohereResponse.FinishReason),
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", cohereResponse.ResponseID),
|
||||
Model: "model",
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
createdTime := helper.GetTimestamp()
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
common.SetEventStreamHeaders(c)
|
||||
var usage model.Usage
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
|
||||
var cohereResponse StreamResponse
|
||||
err := json.Unmarshal([]byte(data), &cohereResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
response, meta := StreamResponseCohere2OpenAI(&cohereResponse)
|
||||
if meta != nil {
|
||||
usage.PromptTokens += meta.Meta.Tokens.InputTokens
|
||||
usage.CompletionTokens += meta.Meta.Tokens.OutputTokens
|
||||
continue
|
||||
}
|
||||
if response == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
response.Id = fmt.Sprintf("chatcmpl-%d", createdTime)
|
||||
response.Model = c.GetString("original_model")
|
||||
response.Created = createdTime
|
||||
|
||||
err = render.ObjectData(c, response)
|
||||
if err != nil {
|
||||
logger.SysError(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
|
||||
render.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var cohereResponse Response
|
||||
err = json.Unmarshal(responseBody, &cohereResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if cohereResponse.ResponseID == "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: cohereResponse.Message,
|
||||
Type: cohereResponse.Message,
|
||||
Param: "",
|
||||
Code: resp.StatusCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := ResponseCohere2OpenAI(&cohereResponse)
|
||||
fullTextResponse.Model = modelName
|
||||
usage := model.Usage{
|
||||
PromptTokens: cohereResponse.Meta.Tokens.InputTokens,
|
||||
CompletionTokens: cohereResponse.Meta.Tokens.OutputTokens,
|
||||
TotalTokens: cohereResponse.Meta.Tokens.InputTokens + cohereResponse.Meta.Tokens.OutputTokens,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
147
relay/adaptor/cohere/model.go
Normal file
147
relay/adaptor/cohere/model.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package cohere
|
||||
|
||||
type Request struct {
|
||||
Message string `json:"message" required:"true"`
|
||||
Model string `json:"model,omitempty"` // 默认值为"command-r"
|
||||
Stream bool `json:"stream,omitempty"` // 默认值为false
|
||||
Preamble string `json:"preamble,omitempty"`
|
||||
ChatHistory []ChatMessage `json:"chat_history,omitempty"`
|
||||
ConversationID string `json:"conversation_id,omitempty"`
|
||||
PromptTruncation string `json:"prompt_truncation,omitempty"` // 默认值为"AUTO"
|
||||
Connectors []Connector `json:"connectors,omitempty"`
|
||||
Documents []Document `json:"documents,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"` // 默认值为0.3
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxInputTokens int `json:"max_input_tokens,omitempty"`
|
||||
K int `json:"k,omitempty"` // 默认值为0
|
||||
P *float64 `json:"p,omitempty"` // 默认值为0.75
|
||||
Seed int `json:"seed,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // 默认值为0.0
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolResults []ToolResult `json:"tool_results,omitempty"`
|
||||
}
|
||||
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role" required:"true"`
|
||||
Message string `json:"message" required:"true"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Name string `json:"name" required:"true"`
|
||||
Description string `json:"description" required:"true"`
|
||||
ParameterDefinitions map[string]ParameterSpec `json:"parameter_definitions"`
|
||||
}
|
||||
|
||||
type ParameterSpec struct {
|
||||
Description string `json:"description"`
|
||||
Type string `json:"type" required:"true"`
|
||||
Required bool `json:"required"`
|
||||
}
|
||||
|
||||
type ToolResult struct {
|
||||
Call ToolCall `json:"call"`
|
||||
Outputs []map[string]interface{} `json:"outputs"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
Name string `json:"name" required:"true"`
|
||||
Parameters map[string]interface{} `json:"parameters" required:"true"`
|
||||
}
|
||||
|
||||
type StreamResponse struct {
|
||||
IsFinished bool `json:"is_finished"`
|
||||
EventType string `json:"event_type"`
|
||||
GenerationID string `json:"generation_id,omitempty"`
|
||||
SearchQueries []*SearchQuery `json:"search_queries,omitempty"`
|
||||
SearchResults []*SearchResult `json:"search_results,omitempty"`
|
||||
Documents []*Document `json:"documents,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Citations []*Citation `json:"citations,omitempty"`
|
||||
Response *Response `json:"response,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type SearchQuery struct {
|
||||
Text string `json:"text"`
|
||||
GenerationID string `json:"generation_id"`
|
||||
}
|
||||
|
||||
type SearchResult struct {
|
||||
SearchQuery *SearchQuery `json:"search_query"`
|
||||
DocumentIDs []string `json:"document_ids"`
|
||||
Connector *Connector `json:"connector"`
|
||||
}
|
||||
|
||||
type Connector struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
type Document struct {
|
||||
ID string `json:"id"`
|
||||
Snippet string `json:"snippet"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type Citation struct {
|
||||
Start int `json:"start"`
|
||||
End int `json:"end"`
|
||||
Text string `json:"text"`
|
||||
DocumentIDs []string `json:"document_ids"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
ResponseID string `json:"response_id"`
|
||||
Text string `json:"text"`
|
||||
GenerationID string `json:"generation_id"`
|
||||
ChatHistory []*Message `json:"chat_history"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
Meta Meta `json:"meta"`
|
||||
Citations []*Citation `json:"citations"`
|
||||
Documents []*Document `json:"documents"`
|
||||
SearchResults []*SearchResult `json:"search_results"`
|
||||
SearchQueries []*SearchQuery `json:"search_queries"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type Version struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
type Units struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type ChatEntry struct {
|
||||
Role string `json:"role"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type Meta struct {
|
||||
APIVersion APIVersion `json:"api_version"`
|
||||
BilledUnits BilledUnits `json:"billed_units"`
|
||||
Tokens Usage `json:"tokens"`
|
||||
}
|
||||
|
||||
type APIVersion struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
type BilledUnits struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
59
relay/adaptor/common.go
Normal file
59
relay/adaptor/common.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package adaptor
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/client"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
)
|
||||
|
||||
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) {
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if meta.IsStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
fullRequestURL, err := a.GetRequestURL(meta)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get request url failed")
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(c.Request.Context(),
|
||||
c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "new request failed")
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", c.GetString(ctxkey.ContentType))
|
||||
|
||||
err = a.SetupRequestHeader(c, req, meta)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "setup request header failed")
|
||||
}
|
||||
resp, err := DoRequest(c, req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "do request failed")
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
||||
resp, err := client.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, errors.New("resp is nil")
|
||||
}
|
||||
_ = req.Body.Close()
|
||||
_ = c.Request.Body.Close()
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
76
relay/adaptor/coze/adaptor.go
Normal file
76
relay/adaptor/coze/adaptor.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package coze
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
meta *meta.Meta
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
a.meta = meta
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return fmt.Sprintf("%s/open_api/v2/chat", meta.BaseURL), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
request.User = a.meta.Config.UserID
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
var responseText *string
|
||||
if meta.IsStream {
|
||||
err, responseText = StreamHandler(c, resp)
|
||||
} else {
|
||||
err, responseText = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
if responseText != nil {
|
||||
usage = openai.ResponseText2Usage(*responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
} else {
|
||||
usage = &model.Usage{}
|
||||
}
|
||||
usage.PromptTokens = meta.PromptTokens
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "coze"
|
||||
}
|
||||
5
relay/adaptor/coze/constant/contenttype/define.go
Normal file
5
relay/adaptor/coze/constant/contenttype/define.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package contenttype
|
||||
|
||||
const (
|
||||
Text = "text"
|
||||
)
|
||||
7
relay/adaptor/coze/constant/event/define.go
Normal file
7
relay/adaptor/coze/constant/event/define.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package event
|
||||
|
||||
const (
|
||||
Message = "message"
|
||||
Done = "done"
|
||||
Error = "error"
|
||||
)
|
||||
6
relay/adaptor/coze/constant/messagetype/define.go
Normal file
6
relay/adaptor/coze/constant/messagetype/define.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package messagetype
|
||||
|
||||
const (
|
||||
Answer = "answer"
|
||||
FollowUp = "follow_up"
|
||||
)
|
||||
3
relay/adaptor/coze/constants.go
Normal file
3
relay/adaptor/coze/constants.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package coze
|
||||
|
||||
var ModelList = []string{}
|
||||
10
relay/adaptor/coze/helper.go
Normal file
10
relay/adaptor/coze/helper.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package coze
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/adaptor/coze/constant/event"
|
||||
|
||||
func event2StopReason(e *string) string {
|
||||
if e == nil || *e == event.Message {
|
||||
return ""
|
||||
}
|
||||
return "stop"
|
||||
}
|
||||
202
relay/adaptor/coze/main.go
Normal file
202
relay/adaptor/coze/main.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package coze
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/render"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/conv"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
// https://www.coze.com/open
|
||||
|
||||
func stopReasonCoze2OpenAI(reason *string) string {
|
||||
if reason == nil {
|
||||
return ""
|
||||
}
|
||||
switch *reason {
|
||||
case "end_turn":
|
||||
return "stop"
|
||||
case "stop_sequence":
|
||||
return "stop"
|
||||
case "max_tokens":
|
||||
return "length"
|
||||
default:
|
||||
return *reason
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
||||
cozeRequest := Request{
|
||||
Stream: textRequest.Stream,
|
||||
User: textRequest.User,
|
||||
BotId: strings.TrimPrefix(textRequest.Model, "bot-"),
|
||||
}
|
||||
for i, message := range textRequest.Messages {
|
||||
if i == len(textRequest.Messages)-1 {
|
||||
cozeRequest.Query = message.StringContent()
|
||||
continue
|
||||
}
|
||||
cozeMessage := Message{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
}
|
||||
cozeRequest.ChatHistory = append(cozeRequest.ChatHistory, cozeMessage)
|
||||
}
|
||||
return &cozeRequest
|
||||
}
|
||||
|
||||
func StreamResponseCoze2OpenAI(cozeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
|
||||
var response *Response
|
||||
var stopReason string
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
|
||||
if cozeResponse.Message != nil {
|
||||
if cozeResponse.Message.Type != messagetype.Answer {
|
||||
return nil, nil
|
||||
}
|
||||
choice.Delta.Content = cozeResponse.Message.Content
|
||||
}
|
||||
choice.Delta.Role = "assistant"
|
||||
finishReason := stopReasonCoze2OpenAI(&stopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
var openaiResponse openai.ChatCompletionsStreamResponse
|
||||
openaiResponse.Object = "chat.completion.chunk"
|
||||
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
|
||||
openaiResponse.Id = cozeResponse.ConversationId
|
||||
return &openaiResponse, response
|
||||
}
|
||||
|
||||
func ResponseCoze2OpenAI(cozeResponse *Response) *openai.TextResponse {
|
||||
var responseText string
|
||||
for _, message := range cozeResponse.Messages {
|
||||
if message.Type == messagetype.Answer {
|
||||
responseText = message.Content
|
||||
break
|
||||
}
|
||||
}
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: responseText,
|
||||
Name: nil,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", cozeResponse.ConversationId),
|
||||
Model: "coze-bot",
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *string) {
|
||||
var responseText string
|
||||
createdTime := helper.GetTimestamp()
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
common.SetEventStreamHeaders(c)
|
||||
var modelName string
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "data:")
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
|
||||
var cozeResponse StreamResponse
|
||||
err := json.Unmarshal([]byte(data), &cozeResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
response, _ := StreamResponseCoze2OpenAI(&cozeResponse)
|
||||
if response == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, choice := range response.Choices {
|
||||
responseText += conv.AsString(choice.Delta.Content)
|
||||
}
|
||||
response.Model = modelName
|
||||
response.Created = createdTime
|
||||
|
||||
err = render.ObjectData(c, response)
|
||||
if err != nil {
|
||||
logger.SysError(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
|
||||
render.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
return nil, &responseText
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *string) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var cozeResponse Response
|
||||
err = json.Unmarshal(responseBody, &cozeResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if cozeResponse.Code != 0 {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: cozeResponse.Msg,
|
||||
Code: cozeResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := ResponseCoze2OpenAI(&cozeResponse)
|
||||
fullTextResponse.Model = modelName
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
var responseText string
|
||||
if len(fullTextResponse.Choices) > 0 {
|
||||
responseText = fullTextResponse.Choices[0].Message.StringContent()
|
||||
}
|
||||
return nil, &responseText
|
||||
}
|
||||
38
relay/adaptor/coze/model.go
Normal file
38
relay/adaptor/coze/model.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package coze
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Type string `json:"type"`
|
||||
Content string `json:"content"`
|
||||
ContentType string `json:"content_type"`
|
||||
}
|
||||
|
||||
type ErrorInformation struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
ConversationId string `json:"conversation_id,omitempty"`
|
||||
BotId string `json:"bot_id"`
|
||||
User string `json:"user"`
|
||||
Query string `json:"query"`
|
||||
ChatHistory []Message `json:"chat_history,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
ConversationId string `json:"conversation_id,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Code int `json:"code,omitempty"`
|
||||
Msg string `json:"msg,omitempty"`
|
||||
}
|
||||
|
||||
type StreamResponse struct {
|
||||
Event string `json:"event,omitempty"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
IsFinish bool `json:"is_finish,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
ConversationId string `json:"conversation_id,omitempty"`
|
||||
ErrorInformation *ErrorInformation `json:"error_information,omitempty"`
|
||||
}
|
||||
73
relay/adaptor/deepl/adaptor.go
Normal file
73
relay/adaptor/deepl/adaptor.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package deepl
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
meta *meta.Meta
|
||||
promptText string
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
a.meta = meta
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return fmt.Sprintf("%s/v2/translate", meta.BaseURL), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "DeepL-Auth-Key "+meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
convertedRequest, text := ConvertRequest(*request)
|
||||
a.promptText = text
|
||||
return convertedRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err = StreamHandler(c, resp, meta.ActualModelName)
|
||||
} else {
|
||||
err = Handler(c, resp, meta.ActualModelName)
|
||||
}
|
||||
promptTokens := len(a.promptText)
|
||||
usage = &model.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
TotalTokens: promptTokens,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "deepl"
|
||||
}
|
||||
9
relay/adaptor/deepl/constants.go
Normal file
9
relay/adaptor/deepl/constants.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package deepl
|
||||
|
||||
// https://developers.deepl.com/docs/api-reference/glossaries
|
||||
|
||||
var ModelList = []string{
|
||||
"deepl-zh",
|
||||
"deepl-en",
|
||||
"deepl-ja",
|
||||
}
|
||||
11
relay/adaptor/deepl/helper.go
Normal file
11
relay/adaptor/deepl/helper.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package deepl
|
||||
|
||||
import "strings"
|
||||
|
||||
func parseLangFromModelName(modelName string) string {
|
||||
parts := strings.Split(modelName, "-")
|
||||
if len(parts) == 1 {
|
||||
return "ZH"
|
||||
}
|
||||
return parts[1]
|
||||
}
|
||||
137
relay/adaptor/deepl/main.go
Normal file
137
relay/adaptor/deepl/main.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package deepl
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/constant/finishreason"
|
||||
"github.com/songquanpeng/one-api/relay/constant/role"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// https://developers.deepl.com/docs/getting-started/your-first-api-request
|
||||
|
||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) (*Request, string) {
|
||||
var text string
|
||||
if len(textRequest.Messages) != 0 {
|
||||
text = textRequest.Messages[len(textRequest.Messages)-1].StringContent()
|
||||
}
|
||||
deeplRequest := Request{
|
||||
TargetLang: parseLangFromModelName(textRequest.Model),
|
||||
Text: []string{text},
|
||||
}
|
||||
return &deeplRequest, text
|
||||
}
|
||||
|
||||
func StreamResponseDeepL2OpenAI(deeplResponse *Response) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
if len(deeplResponse.Translations) != 0 {
|
||||
choice.Delta.Content = deeplResponse.Translations[0].Text
|
||||
}
|
||||
choice.Delta.Role = role.Assistant
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
openaiResponse := openai.ChatCompletionsStreamResponse{
|
||||
Object: constant.StreamObject,
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &openaiResponse
|
||||
}
|
||||
|
||||
func ResponseDeepL2OpenAI(deeplResponse *Response) *openai.TextResponse {
|
||||
var responseText string
|
||||
if len(deeplResponse.Translations) != 0 {
|
||||
responseText = deeplResponse.Translations[0].Text
|
||||
}
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: model.Message{
|
||||
Role: role.Assistant,
|
||||
Content: responseText,
|
||||
Name: nil,
|
||||
},
|
||||
FinishReason: finishreason.Stop,
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Object: constant.NonStreamObject,
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response, modelName string) *model.ErrorWithStatusCode {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
var deeplResponse Response
|
||||
err = json.Unmarshal(responseBody, &deeplResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
fullTextResponse := StreamResponseDeepL2OpenAI(&deeplResponse)
|
||||
fullTextResponse.Model = modelName
|
||||
fullTextResponse.Id = helper.GetResponseID(c)
|
||||
jsonData, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
if jsonData != nil {
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
|
||||
jsonData = nil
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
})
|
||||
_ = resp.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response, modelName string) *model.ErrorWithStatusCode {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
var deeplResponse Response
|
||||
err = json.Unmarshal(responseBody, &deeplResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if deeplResponse.Message != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: deeplResponse.Message,
|
||||
Code: "deepl_error",
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
}
|
||||
fullTextResponse := ResponseDeepL2OpenAI(&deeplResponse)
|
||||
fullTextResponse.Model = modelName
|
||||
fullTextResponse.Id = helper.GetResponseID(c)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil
|
||||
}
|
||||
16
relay/adaptor/deepl/model.go
Normal file
16
relay/adaptor/deepl/model.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package deepl
|
||||
|
||||
type Request struct {
|
||||
Text []string `json:"text"`
|
||||
TargetLang string `json:"target_lang"`
|
||||
}
|
||||
|
||||
type Translation struct {
|
||||
DetectedSourceLanguage string `json:"detected_source_language,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Translations []Translation `json:"translations,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
6
relay/adaptor/deepseek/constants.go
Normal file
6
relay/adaptor/deepseek/constants.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package deepseek
|
||||
|
||||
var ModelList = []string{
|
||||
"deepseek-chat",
|
||||
"deepseek-coder",
|
||||
}
|
||||
13
relay/adaptor/doubao/constants.go
Normal file
13
relay/adaptor/doubao/constants.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package doubao
|
||||
|
||||
// https://console.volcengine.com/ark/region:ark+cn-beijing/model
|
||||
|
||||
var ModelList = []string{
|
||||
"Doubao-pro-128k",
|
||||
"Doubao-pro-32k",
|
||||
"Doubao-pro-4k",
|
||||
"Doubao-lite-128k",
|
||||
"Doubao-lite-32k",
|
||||
"Doubao-lite-4k",
|
||||
"Doubao-embedding",
|
||||
}
|
||||
18
relay/adaptor/doubao/main.go
Normal file
18
relay/adaptor/doubao/main.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package doubao
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
func GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
switch meta.Mode {
|
||||
case relaymode.ChatCompletions:
|
||||
return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil
|
||||
case relaymode.Embeddings:
|
||||
return fmt.Sprintf("%s/api/v3/embeddings", meta.BaseURL), nil
|
||||
default:
|
||||
}
|
||||
return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode)
|
||||
}
|
||||
106
relay/adaptor/gemini/adaptor.go
Normal file
106
relay/adaptor/gemini/adaptor.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
var defaultVersion string
|
||||
switch meta.ActualModelName {
|
||||
case "gemini-2.0-flash-exp",
|
||||
"gemini-2.0-flash-thinking-exp":
|
||||
defaultVersion = "v1beta"
|
||||
default:
|
||||
defaultVersion = config.GeminiVersion
|
||||
}
|
||||
|
||||
version := helper.AssignOrDefault(meta.Config.APIVersion, defaultVersion)
|
||||
action := ""
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
action = "batchEmbedContents"
|
||||
default:
|
||||
action = "generateContent"
|
||||
}
|
||||
|
||||
if meta.IsStream {
|
||||
action = "streamGenerateContent?alt=sse"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
channelhelper.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("x-goog-api-key", meta.APIKey)
|
||||
req.URL.Query().Add("key", meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case relaymode.Embeddings:
|
||||
geminiEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
return geminiEmbeddingRequest, nil
|
||||
default:
|
||||
geminiRequest := ConvertRequest(*request)
|
||||
return geminiRequest, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = StreamHandler(c, resp)
|
||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
} else {
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
err, usage = EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "google gemini"
|
||||
}
|
||||
11
relay/adaptor/gemini/constants.go
Normal file
11
relay/adaptor/gemini/constants.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package gemini
|
||||
|
||||
// https://ai.google.dev/models/gemini
|
||||
|
||||
var ModelList = []string{
|
||||
"gemini-pro", "gemini-1.0-pro",
|
||||
"gemini-1.5-flash", "gemini-1.5-pro",
|
||||
"text-embedding-004", "aqa",
|
||||
"gemini-2.0-flash-exp",
|
||||
"gemini-2.0-flash-thinking-exp",
|
||||
}
|
||||
445
relay/adaptor/gemini/main.go
Normal file
445
relay/adaptor/gemini/main.go
Normal file
@@ -0,0 +1,445 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/image"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/common/render"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
|
||||
|
||||
const (
|
||||
VisionMaxImageNum = 16
|
||||
)
|
||||
|
||||
var mimeTypeMap = map[string]string{
|
||||
"json_object": "application/json",
|
||||
"text": "text/plain",
|
||||
}
|
||||
|
||||
var toolChoiceTypeMap = map[string]string{
|
||||
"none": "NONE",
|
||||
"auto": "AUTO",
|
||||
"required": "ANY",
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
geminiRequest := ChatRequest{
|
||||
Contents: make([]ChatContent, 0, len(textRequest.Messages)),
|
||||
SafetySettings: []ChatSafetySettings{
|
||||
{
|
||||
Category: "HARM_CATEGORY_HARASSMENT",
|
||||
Threshold: config.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_HATE_SPEECH",
|
||||
Threshold: config.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
Threshold: config.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
Threshold: config.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_CIVIC_INTEGRITY",
|
||||
Threshold: config.GeminiSafetySetting,
|
||||
},
|
||||
},
|
||||
GenerationConfig: ChatGenerationConfig{
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
MaxOutputTokens: textRequest.MaxTokens,
|
||||
},
|
||||
}
|
||||
if textRequest.ResponseFormat != nil {
|
||||
if mimeType, ok := mimeTypeMap[textRequest.ResponseFormat.Type]; ok {
|
||||
geminiRequest.GenerationConfig.ResponseMimeType = mimeType
|
||||
}
|
||||
if textRequest.ResponseFormat.JsonSchema != nil {
|
||||
geminiRequest.GenerationConfig.ResponseSchema = textRequest.ResponseFormat.JsonSchema.Schema
|
||||
geminiRequest.GenerationConfig.ResponseMimeType = mimeTypeMap["json_object"]
|
||||
}
|
||||
}
|
||||
if textRequest.Tools != nil {
|
||||
functions := make([]model.Function, 0, len(textRequest.Tools))
|
||||
for _, tool := range textRequest.Tools {
|
||||
functions = append(functions, tool.Function)
|
||||
}
|
||||
geminiRequest.Tools = []ChatTools{
|
||||
{
|
||||
FunctionDeclarations: functions,
|
||||
},
|
||||
}
|
||||
} else if textRequest.Functions != nil {
|
||||
geminiRequest.Tools = []ChatTools{
|
||||
{
|
||||
FunctionDeclarations: textRequest.Functions,
|
||||
},
|
||||
}
|
||||
}
|
||||
if textRequest.ToolChoice != nil {
|
||||
geminiRequest.ToolConfig = &ToolConfig{
|
||||
FunctionCallingConfig: FunctionCallingConfig{
|
||||
Mode: "auto",
|
||||
},
|
||||
}
|
||||
switch mode := textRequest.ToolChoice.(type) {
|
||||
case string:
|
||||
geminiRequest.ToolConfig.FunctionCallingConfig.Mode = toolChoiceTypeMap[mode]
|
||||
case map[string]interface{}:
|
||||
geminiRequest.ToolConfig.FunctionCallingConfig.Mode = "ANY"
|
||||
if fn, ok := mode["function"].(map[string]interface{}); ok {
|
||||
if name, ok := fn["name"].(string); ok {
|
||||
geminiRequest.ToolConfig.FunctionCallingConfig.AllowedFunctionNames = []string{name}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, message := range textRequest.Messages {
|
||||
content := ChatContent{
|
||||
Role: message.Role,
|
||||
Parts: []Part{
|
||||
{
|
||||
Text: message.StringContent(),
|
||||
},
|
||||
},
|
||||
}
|
||||
openaiContent := message.ParseContent()
|
||||
var parts []Part
|
||||
imageNum := 0
|
||||
for _, part := range openaiContent {
|
||||
if part.Type == model.ContentTypeText {
|
||||
parts = append(parts, Part{
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == model.ContentTypeImageURL {
|
||||
imageNum += 1
|
||||
if imageNum > VisionMaxImageNum {
|
||||
continue
|
||||
}
|
||||
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
|
||||
parts = append(parts, Part{
|
||||
InlineData: &InlineData{
|
||||
MimeType: mimeType,
|
||||
Data: data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
content.Parts = parts
|
||||
|
||||
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
||||
if content.Role == "assistant" {
|
||||
content.Role = "model"
|
||||
}
|
||||
// Converting system prompt to SystemInstructions
|
||||
if content.Role == "system" {
|
||||
geminiRequest.SystemInstruction = &content
|
||||
continue
|
||||
}
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||
}
|
||||
|
||||
return &geminiRequest
|
||||
}
|
||||
|
||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbeddingRequest {
|
||||
inputs := request.ParseInput()
|
||||
requests := make([]EmbeddingRequest, len(inputs))
|
||||
model := fmt.Sprintf("models/%s", request.Model)
|
||||
|
||||
for i, input := range inputs {
|
||||
requests[i] = EmbeddingRequest{
|
||||
Model: model,
|
||||
Content: ChatContent{
|
||||
Parts: []Part{
|
||||
{
|
||||
Text: input,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return &BatchEmbeddingRequest{
|
||||
Requests: requests,
|
||||
}
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Candidates []ChatCandidate `json:"candidates"`
|
||||
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
|
||||
}
|
||||
|
||||
func (g *ChatResponse) GetResponseText() string {
|
||||
if g == nil {
|
||||
return ""
|
||||
}
|
||||
var builder strings.Builder
|
||||
for _, candidate := range g.Candidates {
|
||||
for idx, part := range candidate.Content.Parts {
|
||||
if idx > 0 {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
builder.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
type ChatCandidate struct {
|
||||
Content ChatContent `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int64 `json:"index"`
|
||||
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
type ChatSafetyRating struct {
|
||||
Category string `json:"category"`
|
||||
Probability string `json:"probability"`
|
||||
}
|
||||
|
||||
type ChatPromptFeedback struct {
|
||||
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
func getToolCalls(candidate *ChatCandidate) []model.Tool {
|
||||
var toolCalls []model.Tool
|
||||
|
||||
item := candidate.Content.Parts[0]
|
||||
if item.FunctionCall == nil {
|
||||
return toolCalls
|
||||
}
|
||||
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
||||
if err != nil {
|
||||
logger.FatalLog("getToolCalls failed: " + err.Error())
|
||||
return toolCalls
|
||||
}
|
||||
toolCall := model.Tool{
|
||||
Id: fmt.Sprintf("call_%s", random.GetUUID()),
|
||||
Type: "function",
|
||||
Function: model.Function{
|
||||
Arguments: string(argsBytes),
|
||||
Name: item.FunctionCall.FunctionName,
|
||||
},
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
for i, candidate := range response.Candidates {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: i,
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
},
|
||||
FinishReason: constant.StopFinishReason,
|
||||
}
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
if candidate.Content.Parts[0].FunctionCall != nil {
|
||||
choice.Message.ToolCalls = getToolCalls(&candidate)
|
||||
} else {
|
||||
var builder strings.Builder
|
||||
for idx, part := range candidate.Content.Parts {
|
||||
if idx > 0 {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
builder.WriteString(part.Text)
|
||||
}
|
||||
choice.Message.Content = builder.String()
|
||||
}
|
||||
} else {
|
||||
choice.Message.Content = ""
|
||||
choice.FinishReason = candidate.FinishReason
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = geminiResponse.GetResponseText()
|
||||
//choice.FinishReason = &constant.StopFinishReason
|
||||
var response openai.ChatCompletionsStreamResponse
|
||||
response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID())
|
||||
response.Created = helper.GetTimestamp()
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = "gemini"
|
||||
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
|
||||
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),
|
||||
Model: "gemini-embedding",
|
||||
Usage: model.Usage{TotalTokens: 0},
|
||||
}
|
||||
for _, item := range response.Embeddings {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||
Object: `embedding`,
|
||||
Index: 0,
|
||||
Embedding: item.Values,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
common.SetEventStreamHeaders(c)
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
data = strings.TrimSpace(data)
|
||||
if !strings.HasPrefix(data, "data: ") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "data: ")
|
||||
data = strings.TrimSuffix(data, "\"")
|
||||
|
||||
var geminiResponse ChatResponse
|
||||
err := json.Unmarshal([]byte(data), &geminiResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
response := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
if response == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
responseText += response.Choices[0].Delta.StringContent()
|
||||
|
||||
err = render.ObjectData(c, response)
|
||||
if err != nil {
|
||||
logger.SysError(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
|
||||
render.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var geminiResponse ChatResponse
|
||||
err = json.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: "No candidates returned",
|
||||
Type: "server_error",
|
||||
Param: "",
|
||||
Code: 500,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||
fullTextResponse.Model = modelName
|
||||
completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), modelName)
|
||||
usage := model.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var geminiEmbeddingResponse EmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &geminiEmbeddingResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if geminiEmbeddingResponse.Error != nil {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: geminiEmbeddingResponse.Error.Message,
|
||||
Type: "gemini_error",
|
||||
Param: "",
|
||||
Code: geminiEmbeddingResponse.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := embeddingResponseGemini2OpenAI(&geminiEmbeddingResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
87
relay/adaptor/gemini/model.go
Normal file
87
relay/adaptor/gemini/model.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package gemini
|
||||
|
||||
type ChatRequest struct {
|
||||
Contents []ChatContent `json:"contents"`
|
||||
SystemInstruction *ChatContent `json:"system_instruction,omitempty"`
|
||||
SafetySettings []ChatSafetySettings `json:"safety_settings,omitempty"`
|
||||
GenerationConfig ChatGenerationConfig `json:"generation_config,omitempty"`
|
||||
Tools []ChatTools `json:"tools,omitempty"`
|
||||
ToolConfig *ToolConfig `json:"tool_config,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Content ChatContent `json:"content"`
|
||||
TaskType string `json:"taskType,omitempty"`
|
||||
Title string `json:"title,omitempty"`
|
||||
OutputDimensionality int `json:"outputDimensionality,omitempty"`
|
||||
}
|
||||
|
||||
type BatchEmbeddingRequest struct {
|
||||
Requests []EmbeddingRequest `json:"requests"`
|
||||
}
|
||||
|
||||
type EmbeddingData struct {
|
||||
Values []float64 `json:"values"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embeddings []EmbeddingData `json:"embeddings"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Code int `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
type InlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
FunctionName string `json:"name"`
|
||||
Arguments any `json:"args"`
|
||||
}
|
||||
|
||||
type Part struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *InlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
}
|
||||
|
||||
type ChatContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []Part `json:"parts"`
|
||||
}
|
||||
|
||||
type ChatSafetySettings struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
type ChatTools struct {
|
||||
FunctionDeclarations any `json:"function_declarations,omitempty"`
|
||||
}
|
||||
|
||||
type ChatGenerationConfig struct {
|
||||
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||
ResponseSchema any `json:"responseSchema,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionCallingConfig struct {
|
||||
Mode string `json:"mode,omitempty"`
|
||||
AllowedFunctionNames []string `json:"allowed_function_names,omitempty"`
|
||||
}
|
||||
|
||||
type ToolConfig struct {
|
||||
FunctionCallingConfig FunctionCallingConfig `json:"function_calling_config"`
|
||||
}
|
||||
27
relay/adaptor/groq/constants.go
Normal file
27
relay/adaptor/groq/constants.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package groq
|
||||
|
||||
// https://console.groq.com/docs/models
|
||||
|
||||
var ModelList = []string{
|
||||
"gemma-7b-it",
|
||||
"gemma2-9b-it",
|
||||
"llama-3.1-70b-versatile",
|
||||
"llama-3.1-8b-instant",
|
||||
"llama-3.2-11b-text-preview",
|
||||
"llama-3.2-11b-vision-preview",
|
||||
"llama-3.2-1b-preview",
|
||||
"llama-3.2-3b-preview",
|
||||
"llama-3.2-11b-vision-preview",
|
||||
"llama-3.2-90b-text-preview",
|
||||
"llama-3.2-90b-vision-preview",
|
||||
"llama-guard-3-8b",
|
||||
"llama3-70b-8192",
|
||||
"llama3-8b-8192",
|
||||
"llama3-groq-70b-8192-tool-use-preview",
|
||||
"llama3-groq-8b-8192-tool-use-preview",
|
||||
"llava-v1.5-7b-4096-preview",
|
||||
"mixtral-8x7b-32768",
|
||||
"distil-whisper-large-v3-en",
|
||||
"whisper-large-v3",
|
||||
"whisper-large-v3-turbo",
|
||||
}
|
||||
21
relay/adaptor/interface.go
Normal file
21
relay/adaptor/interface.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package adaptor
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor interface {
|
||||
Init(meta *meta.Meta)
|
||||
GetRequestURL(meta *meta.Meta) (string, error)
|
||||
SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error
|
||||
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
|
||||
ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error)
|
||||
DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
}
|
||||
9
relay/adaptor/lingyiwanwu/constants.go
Normal file
9
relay/adaptor/lingyiwanwu/constants.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package lingyiwanwu
|
||||
|
||||
// https://platform.lingyiwanwu.com/docs
|
||||
|
||||
var ModelList = []string{
|
||||
"yi-34b-chat-0205",
|
||||
"yi-34b-chat-200k",
|
||||
"yi-vl-plus",
|
||||
}
|
||||
11
relay/adaptor/minimax/constants.go
Normal file
11
relay/adaptor/minimax/constants.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package minimax
|
||||
|
||||
// https://www.minimaxi.com/document/guides/chat-model/V2?id=65e0736ab2845de20908e2dd
|
||||
|
||||
var ModelList = []string{
|
||||
"abab6.5-chat",
|
||||
"abab6.5s-chat",
|
||||
"abab6-chat",
|
||||
"abab5.5-chat",
|
||||
"abab5.5s-chat",
|
||||
}
|
||||
14
relay/adaptor/minimax/main.go
Normal file
14
relay/adaptor/minimax/main.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package minimax
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
func GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
if meta.Mode == relaymode.ChatCompletions {
|
||||
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil
|
||||
}
|
||||
return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode)
|
||||
}
|
||||
10
relay/adaptor/mistral/constants.go
Normal file
10
relay/adaptor/mistral/constants.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package mistral
|
||||
|
||||
var ModelList = []string{
|
||||
"open-mistral-7b",
|
||||
"open-mixtral-8x7b",
|
||||
"mistral-small-latest",
|
||||
"mistral-medium-latest",
|
||||
"mistral-large-latest",
|
||||
"mistral-embed",
|
||||
}
|
||||
7
relay/adaptor/moonshot/constants.go
Normal file
7
relay/adaptor/moonshot/constants.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package moonshot
|
||||
|
||||
var ModelList = []string{
|
||||
"moonshot-v1-8k",
|
||||
"moonshot-v1-32k",
|
||||
"moonshot-v1-128k",
|
||||
}
|
||||
19
relay/adaptor/novita/constants.go
Normal file
19
relay/adaptor/novita/constants.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package novita
|
||||
|
||||
// https://novita.ai/llm-api
|
||||
|
||||
var ModelList = []string{
|
||||
"meta-llama/llama-3-8b-instruct",
|
||||
"meta-llama/llama-3-70b-instruct",
|
||||
"nousresearch/hermes-2-pro-llama-3-8b",
|
||||
"nousresearch/nous-hermes-llama2-13b",
|
||||
"mistralai/mistral-7b-instruct",
|
||||
"cognitivecomputations/dolphin-mixtral-8x22b",
|
||||
"sao10k/l3-70b-euryale-v2.1",
|
||||
"sophosympatheia/midnight-rose-70b",
|
||||
"gryphe/mythomax-l2-13b",
|
||||
"Nous-Hermes-2-Mixtral-8x7B-DPO",
|
||||
"lzlv_70b",
|
||||
"teknium/openhermes-2.5-mistral-7b",
|
||||
"microsoft/wizardlm-2-8x22b",
|
||||
}
|
||||
15
relay/adaptor/novita/main.go
Normal file
15
relay/adaptor/novita/main.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package novita
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
func GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
if meta.Mode == relaymode.ChatCompletions {
|
||||
return fmt.Sprintf("%s/chat/completions", meta.BaseURL), nil
|
||||
}
|
||||
return "", fmt.Errorf("unsupported relay mode %d for novita", meta.Mode)
|
||||
}
|
||||
82
relay/adaptor/ollama/adaptor.go
Normal file
82
relay/adaptor/ollama/adaptor.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
// https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
|
||||
if meta.Mode == relaymode.Embeddings {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/embed", meta.BaseURL)
|
||||
}
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case relaymode.Embeddings:
|
||||
ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
return ollamaEmbeddingRequest, nil
|
||||
default:
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
err, usage = EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "ollama"
|
||||
}
|
||||
11
relay/adaptor/ollama/constants.go
Normal file
11
relay/adaptor/ollama/constants.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package ollama
|
||||
|
||||
var ModelList = []string{
|
||||
"codellama:7b-instruct",
|
||||
"llama2:7b",
|
||||
"llama2:latest",
|
||||
"llama3:latest",
|
||||
"phi3:latest",
|
||||
"qwen:0.5b-chat",
|
||||
"qwen:7b",
|
||||
}
|
||||
264
relay/adaptor/ollama/main.go
Normal file
264
relay/adaptor/ollama/main.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/image"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/common/render"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
ollamaRequest := ChatRequest{
|
||||
Model: request.Model,
|
||||
Options: &Options{
|
||||
Seed: int(request.Seed),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
FrequencyPenalty: request.FrequencyPenalty,
|
||||
PresencePenalty: request.PresencePenalty,
|
||||
NumPredict: request.MaxTokens,
|
||||
NumCtx: request.NumCtx,
|
||||
},
|
||||
Stream: request.Stream,
|
||||
}
|
||||
for _, message := range request.Messages {
|
||||
openaiContent := message.ParseContent()
|
||||
var imageUrls []string
|
||||
var contentText string
|
||||
for _, part := range openaiContent {
|
||||
switch part.Type {
|
||||
case model.ContentTypeText:
|
||||
contentText = part.Text
|
||||
case model.ContentTypeImageURL:
|
||||
_, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
|
||||
imageUrls = append(imageUrls, data)
|
||||
}
|
||||
}
|
||||
ollamaRequest.Messages = append(ollamaRequest.Messages, Message{
|
||||
Role: message.Role,
|
||||
Content: contentText,
|
||||
Images: imageUrls,
|
||||
})
|
||||
}
|
||||
return &ollamaRequest
|
||||
}
|
||||
|
||||
func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: model.Message{
|
||||
Role: response.Message.Role,
|
||||
Content: response.Message.Content,
|
||||
},
|
||||
}
|
||||
if response.Done {
|
||||
choice.FinishReason = "stop"
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Model: response.Model,
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
Usage: model.Usage{
|
||||
PromptTokens: response.PromptEvalCount,
|
||||
CompletionTokens: response.EvalCount,
|
||||
TotalTokens: response.PromptEvalCount + response.EvalCount,
|
||||
},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Role = ollamaResponse.Message.Role
|
||||
choice.Delta.Content = ollamaResponse.Message.Content
|
||||
if ollamaResponse.Done {
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
}
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: ollamaResponse.Model,
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var usage model.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "}\n"); i >= 0 {
|
||||
return i + 2, data[0 : i+1], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
|
||||
common.SetEventStreamHeaders(c)
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if strings.HasPrefix(data, "}") {
|
||||
data = strings.TrimPrefix(data, "}") + "}"
|
||||
}
|
||||
|
||||
var ollamaResponse ChatResponse
|
||||
err := json.Unmarshal([]byte(data), &ollamaResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
if ollamaResponse.EvalCount != 0 {
|
||||
usage.PromptTokens = ollamaResponse.PromptEvalCount
|
||||
usage.CompletionTokens = ollamaResponse.EvalCount
|
||||
usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount
|
||||
}
|
||||
|
||||
response := streamResponseOllama2OpenAI(&ollamaResponse)
|
||||
err = render.ObjectData(c, response)
|
||||
if err != nil {
|
||||
logger.SysError(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
|
||||
render.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
return &EmbeddingRequest{
|
||||
Model: request.Model,
|
||||
Input: request.ParseInput(),
|
||||
Options: &Options{
|
||||
Seed: int(request.Seed),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
FrequencyPenalty: request.FrequencyPenalty,
|
||||
PresencePenalty: request.PresencePenalty,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var ollamaResponse EmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&ollamaResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if ollamaResponse.Error != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: ollamaResponse.Error,
|
||||
Type: "ollama_error",
|
||||
Param: "",
|
||||
Code: "ollama_error",
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
fullTextResponse := embeddingResponseOllama2OpenAI(&ollamaResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
|
||||
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]openai.EmbeddingResponseItem, 0, 1),
|
||||
Model: response.Model,
|
||||
Usage: model.Usage{TotalTokens: 0},
|
||||
}
|
||||
|
||||
for i, embedding := range response.Embeddings {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||
Object: `embedding`,
|
||||
Index: i,
|
||||
Embedding: embedding,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
ctx := context.TODO()
|
||||
var ollamaResponse ChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
logger.Debugf(ctx, "ollama response: %s", string(responseBody))
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &ollamaResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if ollamaResponse.Error != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: ollamaResponse.Error,
|
||||
Type: "ollama_error",
|
||||
Param: "",
|
||||
Code: "ollama_error",
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseOllama2OpenAI(&ollamaResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
53
relay/adaptor/ollama/model.go
Normal file
53
relay/adaptor/ollama/model.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package ollama
|
||||
|
||||
type Options struct {
|
||||
Seed int `json:"seed,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Images []string `json:"images,omitempty"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
Options *Options `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
Message Message `json:"message,omitempty"`
|
||||
Response string `json:"response,omitempty"` // for stream response
|
||||
Done bool `json:"done,omitempty"`
|
||||
TotalDuration int `json:"total_duration,omitempty"`
|
||||
LoadDuration int `json:"load_duration,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int `json:"eval_duration,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input []string `json:"input"`
|
||||
// Truncate bool `json:"truncate,omitempty"`
|
||||
Options *Options `json:"options,omitempty"`
|
||||
// KeepAlive string `json:"keep_alive,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Embeddings [][]float64 `json:"embeddings"`
|
||||
}
|
||||
155
relay/adaptor/openai/adaptor.go
Normal file
155
relay/adaptor/openai/adaptor.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/doubao"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/novita"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
ChannelType int
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
a.ChannelType = meta.ChannelType
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
switch meta.ChannelType {
|
||||
case channeltype.Azure:
|
||||
if meta.Mode == relaymode.ImagesGenerations {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
||||
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
|
||||
fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.Config.APIVersion)
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||
requestURL := strings.Split(meta.RequestURLPath, "?")[0]
|
||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.Config.APIVersion)
|
||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||
model_ := meta.ActualModelName
|
||||
model_ = strings.Replace(model_, ".", "", -1)
|
||||
//https://github.com/songquanpeng/one-api/issues/1191
|
||||
// {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version}
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
return GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
|
||||
case channeltype.Minimax:
|
||||
return minimax.GetRequestURL(meta)
|
||||
case channeltype.Doubao:
|
||||
return doubao.GetRequestURL(meta)
|
||||
case channeltype.Novita:
|
||||
return novita.GetRequestURL(meta)
|
||||
default:
|
||||
return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
if meta.ChannelType == channeltype.Azure {
|
||||
req.Header.Set("api-key", meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
if meta.ChannelType == channeltype.OpenRouter {
|
||||
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||||
req.Header.Set("X-Title", "One API")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
if config.EnforceIncludeUsage && request.Stream {
|
||||
// always return usage in stream mode
|
||||
if request.StreamOptions == nil {
|
||||
request.StreamOptions = &model.StreamOptions{}
|
||||
}
|
||||
request.StreamOptions.IncludeUsage = true
|
||||
}
|
||||
|
||||
// o1/o1-mini/o1-preview do not support system prompt and max_tokens
|
||||
if strings.HasPrefix(request.Model, "o1") {
|
||||
request.MaxTokens = 0
|
||||
request.Messages = func(raw []model.Message) (filtered []model.Message) {
|
||||
for i := range raw {
|
||||
if raw[i].Role != "system" {
|
||||
filtered = append(filtered, raw[i])
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}(request.Messages)
|
||||
}
|
||||
|
||||
if request.Stream && strings.HasPrefix(request.Model, "gpt-4o-audio") {
|
||||
// TODO: Since it is not clear how to implement billing in stream mode,
|
||||
// it is temporarily not supported
|
||||
return nil, errors.New("stream mode is not supported for gpt-4o-audio")
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText, usage = StreamHandler(c, resp, meta.Mode)
|
||||
if usage == nil || usage.TotalTokens == 0 {
|
||||
usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
}
|
||||
if usage.TotalTokens != 0 && usage.PromptTokens == 0 { // some channels don't return prompt tokens & completion tokens
|
||||
usage.PromptTokens = meta.PromptTokens
|
||||
usage.CompletionTokens = usage.TotalTokens - meta.PromptTokens
|
||||
}
|
||||
} else {
|
||||
switch meta.Mode {
|
||||
case relaymode.ImagesGenerations:
|
||||
err, _ = ImageHandler(c, resp)
|
||||
case relaymode.ImagesEdits:
|
||||
err, _ = ImagesEditsHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
_, modelList := GetCompatibleChannelMeta(a.ChannelType)
|
||||
return modelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
channelName, _ := GetCompatibleChannelMeta(a.ChannelType)
|
||||
return channelName
|
||||
}
|
||||
74
relay/adaptor/openai/compatible.go
Normal file
74
relay/adaptor/openai/compatible.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/ai360"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/baichuan"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/deepseek"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/doubao"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/groq"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/lingyiwanwu"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/mistral"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/moonshot"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/novita"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/siliconflow"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/stepfun"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/togetherai"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/xai"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
)
|
||||
|
||||
var CompatibleChannels = []int{
|
||||
channeltype.Azure,
|
||||
channeltype.AI360,
|
||||
channeltype.Moonshot,
|
||||
channeltype.Baichuan,
|
||||
channeltype.Minimax,
|
||||
channeltype.Doubao,
|
||||
channeltype.Mistral,
|
||||
channeltype.Groq,
|
||||
channeltype.LingYiWanWu,
|
||||
channeltype.StepFun,
|
||||
channeltype.DeepSeek,
|
||||
channeltype.TogetherAI,
|
||||
channeltype.Novita,
|
||||
channeltype.SiliconFlow,
|
||||
channeltype.XAI,
|
||||
}
|
||||
|
||||
func GetCompatibleChannelMeta(channelType int) (string, []string) {
|
||||
switch channelType {
|
||||
case channeltype.Azure:
|
||||
return "azure", ModelList
|
||||
case channeltype.AI360:
|
||||
return "360", ai360.ModelList
|
||||
case channeltype.Moonshot:
|
||||
return "moonshot", moonshot.ModelList
|
||||
case channeltype.Baichuan:
|
||||
return "baichuan", baichuan.ModelList
|
||||
case channeltype.Minimax:
|
||||
return "minimax", minimax.ModelList
|
||||
case channeltype.Mistral:
|
||||
return "mistralai", mistral.ModelList
|
||||
case channeltype.Groq:
|
||||
return "groq", groq.ModelList
|
||||
case channeltype.LingYiWanWu:
|
||||
return "lingyiwanwu", lingyiwanwu.ModelList
|
||||
case channeltype.StepFun:
|
||||
return "stepfun", stepfun.ModelList
|
||||
case channeltype.DeepSeek:
|
||||
return "deepseek", deepseek.ModelList
|
||||
case channeltype.TogetherAI:
|
||||
return "together.ai", togetherai.ModelList
|
||||
case channeltype.Doubao:
|
||||
return "doubao", doubao.ModelList
|
||||
case channeltype.Novita:
|
||||
return "novita", novita.ModelList
|
||||
case channeltype.SiliconFlow:
|
||||
return "siliconflow", siliconflow.ModelList
|
||||
case channeltype.XAI:
|
||||
return "xai", xai.ModelList
|
||||
default:
|
||||
return "openai", ModelList
|
||||
}
|
||||
}
|
||||
25
relay/adaptor/openai/constants.go
Normal file
25
relay/adaptor/openai/constants.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package openai
|
||||
|
||||
var ModelList = []string{
|
||||
"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
|
||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
||||
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20", "chatgpt-4o-latest",
|
||||
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
||||
"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-12-17", "gpt-4o-audio-preview-2024-10-01",
|
||||
"gpt-4-vision-preview",
|
||||
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
|
||||
"text-moderation-latest", "text-moderation-stable",
|
||||
"text-davinci-edit-001",
|
||||
"davinci-002", "babbage-002",
|
||||
"dall-e-2", "dall-e-3",
|
||||
"whisper-1",
|
||||
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
|
||||
"o1", "o1-2024-12-17",
|
||||
"o1-preview", "o1-preview-2024-09-12",
|
||||
"o1-mini", "o1-mini-2024-09-12",
|
||||
}
|
||||
31
relay/adaptor/openai/helper.go
Normal file
31
relay/adaptor/openai/helper.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage {
|
||||
usage := &model.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
usage.CompletionTokens = CountTokenText(responseText, modelName)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return usage
|
||||
}
|
||||
|
||||
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
|
||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||
switch channelType {
|
||||
case channeltype.OpenAI:
|
||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
||||
case channeltype.Azure:
|
||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
||||
}
|
||||
}
|
||||
return fullRequestURL
|
||||
}
|
||||
62
relay/adaptor/openai/image.go
Normal file
62
relay/adaptor/openai/image.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
// ImagesEditsHandler just copy response body to client
|
||||
//
|
||||
// https://platform.openai.com/docs/api-reference/images/createEdit
|
||||
func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
|
||||
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||||
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var imageResponse ImageResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &imageResponse)
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
173
relay/adaptor/openai/main.go
Normal file
173
relay/adaptor/openai/main.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/conv"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/render"
|
||||
"github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
const (
|
||||
dataPrefix = "data: "
|
||||
done = "[DONE]"
|
||||
dataPrefixLength = len(dataPrefix)
|
||||
)
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
|
||||
responseText := ""
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
var usage *model.Usage
|
||||
|
||||
common.SetEventStreamHeaders(c)
|
||||
|
||||
doneRendered := false
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < dataPrefixLength { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(data[dataPrefixLength:], done) {
|
||||
render.StringData(c, data)
|
||||
doneRendered = true
|
||||
continue
|
||||
}
|
||||
switch relayMode {
|
||||
case relaymode.ChatCompletions:
|
||||
var streamResponse ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
render.StringData(c, data) // if error happened, pass the data to client
|
||||
continue // just ignore the error
|
||||
}
|
||||
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
|
||||
// but for empty choice and no usage, we should not pass it to client, this is for azure
|
||||
continue // just ignore empty choice
|
||||
}
|
||||
render.StringData(c, data)
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseText += conv.AsString(choice.Delta.Content)
|
||||
}
|
||||
if streamResponse.Usage != nil {
|
||||
usage = streamResponse.Usage
|
||||
}
|
||||
case relaymode.Completions:
|
||||
render.StringData(c, data)
|
||||
var streamResponse CompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseText += choice.Text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
|
||||
if !doneRendered {
|
||||
render.Done(c)
|
||||
}
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
|
||||
}
|
||||
|
||||
return nil, responseText, usage
|
||||
}
|
||||
|
||||
// Handler handles the non-stream response from OpenAI API
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var textResponse SlimTextResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if textResponse.Error.Type != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: textResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
logger.Debugf(c.Request.Context(), "handler response: %s", string(responseBody))
|
||||
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the HTTPClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if textResponse.Usage.TotalTokens == 0 || (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) {
|
||||
completionTokens := 0
|
||||
for _, choice := range textResponse.Choices {
|
||||
completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
|
||||
}
|
||||
textResponse.Usage = model.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
} else if textResponse.PromptTokensDetails.AudioTokens+textResponse.CompletionTokensDetails.AudioTokens > 0 {
|
||||
// Convert the more expensive audio tokens to uniformly priced text tokens.
|
||||
// Note that when there are no audio tokens in prompt and completion,
|
||||
// OpenAI will return empty PromptTokensDetails and CompletionTokensDetails, which can be misleading.
|
||||
textResponse.Usage.PromptTokens = textResponse.PromptTokensDetails.TextTokens +
|
||||
int(math.Ceil(
|
||||
float64(textResponse.PromptTokensDetails.AudioTokens)*
|
||||
ratio.GetAudioPromptRatio(modelName),
|
||||
))
|
||||
textResponse.Usage.CompletionTokens = textResponse.CompletionTokensDetails.TextTokens +
|
||||
int(math.Ceil(
|
||||
float64(textResponse.CompletionTokensDetails.AudioTokens)*
|
||||
ratio.GetAudioPromptRatio(modelName)*ratio.GetAudioCompletionRatio(modelName),
|
||||
))
|
||||
|
||||
textResponse.Usage.TotalTokens = textResponse.Usage.PromptTokens +
|
||||
textResponse.Usage.CompletionTokens
|
||||
}
|
||||
|
||||
return nil, &textResponse.Usage
|
||||
}
|
||||
167
relay/adaptor/openai/model.go
Normal file
167
relay/adaptor/openai/model.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"mime/multipart"
|
||||
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type TextContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type ImageContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
ImageURL *model.ImageURL `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []model.Message `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
}
|
||||
|
||||
type TextRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []model.Message `json:"messages"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
//Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
|
||||
type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
type WhisperJSONResponse struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type WhisperVerboseJSONResponse struct {
|
||||
Task string `json:"task,omitempty"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Duration float64 `json:"duration,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Segments []Segment `json:"segments,omitempty"`
|
||||
}
|
||||
|
||||
type Segment struct {
|
||||
Id int `json:"id"`
|
||||
Seek int `json:"seek"`
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
Text string `json:"text"`
|
||||
Tokens []int `json:"tokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
AvgLogprob float64 `json:"avg_logprob"`
|
||||
CompressionRatio float64 `json:"compression_ratio"`
|
||||
NoSpeechProb float64 `json:"no_speech_prob"`
|
||||
}
|
||||
|
||||
type TextToSpeechRequest struct {
|
||||
Model string `json:"model" binding:"required"`
|
||||
Input string `json:"input" binding:"required"`
|
||||
Voice string `json:"voice" binding:"required"`
|
||||
Speed float64 `json:"speed"`
|
||||
ResponseFormat string `json:"response_format"`
|
||||
}
|
||||
|
||||
type AudioTranscriptionRequest struct {
|
||||
File *multipart.FileHeader `form:"file" binding:"required"`
|
||||
Model string `form:"model" binding:"required"`
|
||||
Language string `form:"language"`
|
||||
Prompt string `form:"prompt"`
|
||||
ReponseFormat string `form:"response_format" binding:"oneof=json text srt verbose_json vtt"`
|
||||
Temperature float64 `form:"temperature"`
|
||||
TimestampGranularity []string `form:"timestamp_granularity"`
|
||||
}
|
||||
|
||||
type AudioTranslationRequest struct {
|
||||
File *multipart.FileHeader `form:"file" binding:"required"`
|
||||
Model string `form:"model" binding:"required"`
|
||||
Prompt string `form:"prompt"`
|
||||
ResponseFormat string `form:"response_format" binding:"oneof=json text srt verbose_json vtt"`
|
||||
Temperature float64 `form:"temperature"`
|
||||
}
|
||||
|
||||
type UsageOrResponseText struct {
|
||||
*model.Usage
|
||||
ResponseText string
|
||||
}
|
||||
|
||||
type SlimTextResponse struct {
|
||||
Choices []TextResponseChoice `json:"choices"`
|
||||
model.Usage `json:"usage"`
|
||||
Error model.Error `json:"error"`
|
||||
}
|
||||
|
||||
type TextResponseChoice struct {
|
||||
Index int `json:"index"`
|
||||
model.Message `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type TextResponse struct {
|
||||
Id string `json:"id"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Choices []TextResponseChoice `json:"choices"`
|
||||
model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type EmbeddingResponseItem struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []EmbeddingResponseItem `json:"data"`
|
||||
Model string `json:"model"`
|
||||
model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
Url string `json:"url,omitempty"`
|
||||
B64Json string `json:"b64_json,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageData `json:"data"`
|
||||
//model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta model.Message `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
Usage *model.Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type CompletionsStreamResponse struct {
|
||||
Choices []struct {
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
287
relay/adaptor/openai/token.go
Normal file
287
relay/adaptor/openai/token.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/image"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
// tokenEncoderMap won't grow after initialization
|
||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
||||
var defaultTokenEncoder *tiktoken.Tiktoken
|
||||
|
||||
func InitTokenEncoders() {
|
||||
logger.SysLog("initializing token encoders")
|
||||
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
||||
if err != nil {
|
||||
logger.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
|
||||
}
|
||||
defaultTokenEncoder = gpt35TokenEncoder
|
||||
gpt4oTokenEncoder, err := tiktoken.EncodingForModel("gpt-4o")
|
||||
if err != nil {
|
||||
logger.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
|
||||
}
|
||||
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
|
||||
if err != nil {
|
||||
logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
||||
}
|
||||
for model := range billingratio.ModelRatio {
|
||||
if strings.HasPrefix(model, "gpt-3.5") {
|
||||
tokenEncoderMap[model] = gpt35TokenEncoder
|
||||
} else if strings.HasPrefix(model, "gpt-4o") {
|
||||
tokenEncoderMap[model] = gpt4oTokenEncoder
|
||||
} else if strings.HasPrefix(model, "gpt-4") {
|
||||
tokenEncoderMap[model] = gpt4TokenEncoder
|
||||
} else {
|
||||
tokenEncoderMap[model] = nil
|
||||
}
|
||||
}
|
||||
logger.SysLog("token encoders initialized")
|
||||
}
|
||||
|
||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
tokenEncoder, ok := tokenEncoderMap[model]
|
||||
if ok && tokenEncoder != nil {
|
||||
return tokenEncoder
|
||||
}
|
||||
if ok {
|
||||
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
||||
if err != nil {
|
||||
logger.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
||||
tokenEncoder = defaultTokenEncoder
|
||||
}
|
||||
tokenEncoderMap[model] = tokenEncoder
|
||||
return tokenEncoder
|
||||
}
|
||||
return defaultTokenEncoder
|
||||
}
|
||||
|
||||
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||
if config.ApproximateTokenEnabled {
|
||||
return int(float64(len(text)) * 0.38)
|
||||
}
|
||||
return len(tokenEncoder.Encode(text, nil, nil))
|
||||
}
|
||||
|
||||
func CountTokenMessages(ctx context.Context,
|
||||
messages []model.Message, actualModel string) int {
|
||||
tokenEncoder := getTokenEncoder(actualModel)
|
||||
// Reference:
|
||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
// https://github.com/pkoukk/tiktoken-go/issues/6
|
||||
//
|
||||
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
var tokensPerMessage int
|
||||
var tokensPerName int
|
||||
if actualModel == "gpt-3.5-turbo-0301" {
|
||||
tokensPerMessage = 4
|
||||
tokensPerName = -1 // If there's a name, the role is omitted
|
||||
} else {
|
||||
tokensPerMessage = 3
|
||||
tokensPerName = 1
|
||||
}
|
||||
tokenNum := 0
|
||||
for _, message := range messages {
|
||||
tokenNum += tokensPerMessage
|
||||
contents := message.ParseContent()
|
||||
for _, content := range contents {
|
||||
switch content.Type {
|
||||
case model.ContentTypeText:
|
||||
tokenNum += getTokenNum(tokenEncoder, content.Text)
|
||||
case model.ContentTypeImageURL:
|
||||
imageTokens, err := countImageTokens(
|
||||
content.ImageURL.Url,
|
||||
content.ImageURL.Detail,
|
||||
actualModel)
|
||||
if err != nil {
|
||||
logger.SysError("error counting image tokens: " + err.Error())
|
||||
} else {
|
||||
tokenNum += imageTokens
|
||||
}
|
||||
case model.ContentTypeInputAudio:
|
||||
audioData, err := base64.StdEncoding.DecodeString(content.InputAudio.Data)
|
||||
if err != nil {
|
||||
logger.SysError("error decoding audio data: " + err.Error())
|
||||
}
|
||||
|
||||
tokens, err := helper.GetAudioTokens(ctx,
|
||||
bytes.NewReader(audioData),
|
||||
ratio.GetAudioPromptTokensPerSecond(actualModel))
|
||||
if err != nil {
|
||||
logger.SysError("error counting audio tokens: " + err.Error())
|
||||
} else {
|
||||
tokenNum += tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
if message.Name != nil {
|
||||
tokenNum += tokensPerName
|
||||
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||
}
|
||||
}
|
||||
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||
return tokenNum
|
||||
}
|
||||
|
||||
// func countVisonTokenMessages(messages []VisionMessage, model string) (int, error) {
|
||||
// tokenEncoder := getTokenEncoder(model)
|
||||
// // Reference:
|
||||
// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
// // https://github.com/pkoukk/tiktoken-go/issues/6
|
||||
// //
|
||||
// // Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
// var tokensPerMessage int
|
||||
// var tokensPerName int
|
||||
// if model == "gpt-3.5-turbo-0301" {
|
||||
// tokensPerMessage = 4
|
||||
// tokensPerName = -1 // If there's a name, the role is omitted
|
||||
// } else {
|
||||
// tokensPerMessage = 3
|
||||
// tokensPerName = 1
|
||||
// }
|
||||
// tokenNum := 0
|
||||
// for _, message := range messages {
|
||||
// tokenNum += tokensPerMessage
|
||||
// for _, cnt := range message.Content {
|
||||
// switch cnt.Type {
|
||||
// case OpenaiVisionMessageContentTypeText:
|
||||
// tokenNum += getTokenNum(tokenEncoder, cnt.Text)
|
||||
// case OpenaiVisionMessageContentTypeImageUrl:
|
||||
// imgblob, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(cnt.ImageUrl.URL, "data:image/jpeg;base64,"))
|
||||
// if err != nil {
|
||||
// return 0, errors.Wrap(err, "failed to decode base64 image")
|
||||
// }
|
||||
|
||||
// if imgtoken, err := CountVisionImageToken(imgblob, cnt.ImageUrl.Detail); err != nil {
|
||||
// return 0, errors.Wrap(err, "failed to count vision image token")
|
||||
// } else {
|
||||
// tokenNum += imgtoken
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
// if message.Name != nil {
|
||||
// tokenNum += tokensPerName
|
||||
// tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||
// }
|
||||
// }
|
||||
// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||
// return tokenNum, nil
|
||||
// }
|
||||
|
||||
const (
|
||||
lowDetailCost = 85
|
||||
highDetailCostPerTile = 170
|
||||
additionalCost = 85
|
||||
// gpt-4o-mini cost higher than other model
|
||||
gpt4oMiniLowDetailCost = 2833
|
||||
gpt4oMiniHighDetailCost = 5667
|
||||
gpt4oMiniAdditionalCost = 2833
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/guides/vision/calculating-costs
|
||||
// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
func countImageTokens(url string, detail string, model string) (_ int, err error) {
|
||||
var fetchSize = true
|
||||
var width, height int
|
||||
// Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding
|
||||
// detail == "auto" is undocumented on how it works, it just said the model will use the auto setting which will look at the image input size and decide if it should use the low or high setting.
|
||||
// According to the official guide, "low" disable the high-res model,
|
||||
// and only receive low-res 512px x 512px version of the image, indicating
|
||||
// that image is treated as low-res when size is smaller than 512px x 512px,
|
||||
// then we can assume that image size larger than 512px x 512px is treated
|
||||
// as high-res. Then we have the following logic:
|
||||
// if detail == "" || detail == "auto" {
|
||||
// width, height, err = image.GetImageSize(url)
|
||||
// if err != nil {
|
||||
// return 0, err
|
||||
// }
|
||||
// fetchSize = false
|
||||
// // not sure if this is correct
|
||||
// if width > 512 || height > 512 {
|
||||
// detail = "high"
|
||||
// } else {
|
||||
// detail = "low"
|
||||
// }
|
||||
// }
|
||||
|
||||
// However, in my test, it seems to be always the same as "high".
|
||||
// The following image, which is 125x50, is still treated as high-res, taken
|
||||
// 255 tokens in the response of non-stream chat completion api.
|
||||
// https://upload.wikimedia.org/wikipedia/commons/1/10/18_Infantry_Division_Messina.jpg
|
||||
if detail == "" || detail == "auto" {
|
||||
// assume by test, not sure if this is correct
|
||||
detail = "high"
|
||||
}
|
||||
switch detail {
|
||||
case "low":
|
||||
if strings.HasPrefix(model, "gpt-4o-mini") {
|
||||
return gpt4oMiniLowDetailCost, nil
|
||||
}
|
||||
return lowDetailCost, nil
|
||||
case "high":
|
||||
if fetchSize {
|
||||
width, height, err = image.GetImageSize(url)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if width > 2048 || height > 2048 { // max(width, height) > 2048
|
||||
ratio := float64(2048) / math.Max(float64(width), float64(height))
|
||||
width = int(float64(width) * ratio)
|
||||
height = int(float64(height) * ratio)
|
||||
}
|
||||
if width > 768 && height > 768 { // min(width, height) > 768
|
||||
ratio := float64(768) / math.Min(float64(width), float64(height))
|
||||
width = int(float64(width) * ratio)
|
||||
height = int(float64(height) * ratio)
|
||||
}
|
||||
numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512))
|
||||
if strings.HasPrefix(model, "gpt-4o-mini") {
|
||||
return numSquares*gpt4oMiniHighDetailCost + gpt4oMiniAdditionalCost, nil
|
||||
}
|
||||
result := numSquares*highDetailCostPerTile + additionalCost
|
||||
return result, nil
|
||||
default:
|
||||
return 0, errors.New("invalid detail option")
|
||||
}
|
||||
}
|
||||
|
||||
func CountTokenInput(input any, model string) int {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
return CountTokenText(v, model)
|
||||
case []string:
|
||||
text := ""
|
||||
for _, s := range v {
|
||||
text += s
|
||||
}
|
||||
return CountTokenText(text, model)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func CountTokenText(text string, model string) int {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
return getTokenNum(tokenEncoder, text)
|
||||
}
|
||||
|
||||
func CountToken(text string) int {
|
||||
return CountTokenInput(text, "gpt-3.5-turbo")
|
||||
}
|
||||
23
relay/adaptor/openai/util.go
Normal file
23
relay/adaptor/openai/util.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
|
||||
logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err))
|
||||
|
||||
Error := model.Error{
|
||||
Message: err.Error(),
|
||||
Type: "one_api_error",
|
||||
Code: code,
|
||||
}
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: Error,
|
||||
StatusCode: statusCode,
|
||||
}
|
||||
}
|
||||
67
relay/adaptor/palm/adaptor.go
Normal file
67
relay/adaptor/palm/adaptor.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package palm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("x-goog-api-key", meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = StreamHandler(c, resp)
|
||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
} else {
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "google palm"
|
||||
}
|
||||
5
relay/adaptor/palm/constants.go
Normal file
5
relay/adaptor/palm/constants.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package palm
|
||||
|
||||
var ModelList = []string{
|
||||
"PaLM-2",
|
||||
}
|
||||
40
relay/adaptor/palm/model.go
Normal file
40
relay/adaptor/palm/model.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package palm
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type ChatMessage struct {
|
||||
Author string `json:"author"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type Filter struct {
|
||||
Reason string `json:"reason"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type Prompt struct {
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Prompt Prompt `json:"prompt"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
TopP *float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Candidates []ChatMessage `json:"candidates"`
|
||||
Messages []model.Message `json:"messages"`
|
||||
Filters []Filter `json:"filters"`
|
||||
Error Error `json:"error"`
|
||||
}
|
||||
172
relay/adaptor/palm/palm.go
Normal file
172
relay/adaptor/palm/palm.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package palm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/render"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
||||
|
||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
palmRequest := ChatRequest{
|
||||
Prompt: Prompt{
|
||||
Messages: make([]ChatMessage, 0, len(textRequest.Messages)),
|
||||
},
|
||||
Temperature: textRequest.Temperature,
|
||||
CandidateCount: textRequest.N,
|
||||
TopP: textRequest.TopP,
|
||||
TopK: textRequest.MaxTokens,
|
||||
}
|
||||
for _, message := range textRequest.Messages {
|
||||
palmMessage := ChatMessage{
|
||||
Content: message.StringContent(),
|
||||
}
|
||||
if message.Role == "user" {
|
||||
palmMessage.Author = "0"
|
||||
} else {
|
||||
palmMessage.Author = "1"
|
||||
}
|
||||
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
|
||||
}
|
||||
return &palmRequest
|
||||
}
|
||||
|
||||
func responsePaLM2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
for i, candidate := range response.Candidates {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: i,
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: candidate.Content,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponsePaLM2OpenAI(palmResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
if len(palmResponse.Candidates) > 0 {
|
||||
choice.Delta.Content = palmResponse.Candidates[0].Content
|
||||
}
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
var response openai.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = "palm2"
|
||||
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID())
|
||||
createdTime := helper.GetTimestamp()
|
||||
|
||||
common.SetEventStreamHeaders(c)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
logger.SysError("error reading stream response: " + err.Error())
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
|
||||
var palmResponse ChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
|
||||
fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
|
||||
fullTextResponse.Id = responseId
|
||||
fullTextResponse.Created = createdTime
|
||||
if len(palmResponse.Candidates) > 0 {
|
||||
responseText = palmResponse.Candidates[0].Content
|
||||
}
|
||||
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
|
||||
err = render.ObjectData(c, string(jsonResponse))
|
||||
if err != nil {
|
||||
logger.SysError(err.Error())
|
||||
}
|
||||
|
||||
render.Done(c)
|
||||
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var palmResponse ChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: palmResponse.Error.Message,
|
||||
Type: palmResponse.Error.Status,
|
||||
Param: "",
|
||||
Code: palmResponse.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||
fullTextResponse.Model = modelName
|
||||
completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, modelName)
|
||||
usage := model.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
89
relay/adaptor/proxy/adaptor.go
Normal file
89
relay/adaptor/proxy/adaptor.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
var _ adaptor.Adaptor = new(Adaptor)
|
||||
|
||||
const channelName = "proxy"
|
||||
|
||||
type Adaptor struct{}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, errors.New("notimplement")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
for k, v := range resp.Header {
|
||||
for _, vv := range v {
|
||||
c.Writer.Header().Set(k, vv)
|
||||
}
|
||||
}
|
||||
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
if _, gerr := io.Copy(c.Writer, resp.Body); gerr != nil {
|
||||
return nil, &relaymodel.ErrorWithStatusCode{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: relaymodel.Error{
|
||||
Message: gerr.Error(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() (models []string) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return channelName
|
||||
}
|
||||
|
||||
// GetRequestURL remove static prefix, and return the real request url to the upstream service
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
prefix := fmt.Sprintf("/v1/oneapi/proxy/%d", meta.ChannelId)
|
||||
return meta.BaseURL + strings.TrimPrefix(meta.RequestURLPath, prefix), nil
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
for k, v := range c.Request.Header {
|
||||
req.Header.Set(k, v[0])
|
||||
}
|
||||
|
||||
// remove unnecessary headers
|
||||
req.Header.Del("Host")
|
||||
req.Header.Del("Content-Length")
|
||||
req.Header.Del("Accept-Encoding")
|
||||
req.Header.Del("Connection")
|
||||
|
||||
// set authorization header
|
||||
req.Header.Set("Authorization", meta.APIKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
return nil, errors.Errorf("not implement")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
180
relay/adaptor/replicate/adaptor.go
Normal file
180
relay/adaptor/replicate/adaptor.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
meta *meta.Meta
|
||||
}
|
||||
|
||||
// ConvertImageRequest implements adaptor.Adaptor.
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
return nil, errors.New("should call replicate.ConvertImageRequest instead")
|
||||
}
|
||||
|
||||
func ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
meta := meta.GetByContext(c)
|
||||
|
||||
if request.ResponseFormat != "b64_json" {
|
||||
return nil, errors.New("only support b64_json response format")
|
||||
}
|
||||
if request.N != 1 && request.N != 0 {
|
||||
return nil, errors.New("only support N=1")
|
||||
}
|
||||
|
||||
switch meta.Mode {
|
||||
case relaymode.ImagesGenerations:
|
||||
return convertImageCreateRequest(request)
|
||||
case relaymode.ImagesEdits:
|
||||
return convertImageRemixRequest(c)
|
||||
default:
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
}
|
||||
|
||||
func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
|
||||
return DrawImageRequest{
|
||||
Input: ImageInput{
|
||||
Steps: 25,
|
||||
Prompt: request.Prompt,
|
||||
Guidance: 3,
|
||||
Seed: int(time.Now().UnixNano()),
|
||||
SafetyTolerance: 5,
|
||||
NImages: 1, // replicate will always return 1 image
|
||||
Width: 1440,
|
||||
Height: 1440,
|
||||
AspectRatio: "1:1",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func convertImageRemixRequest(c *gin.Context) (any, error) {
|
||||
// recover request body
|
||||
requestBody, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get request body")
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
|
||||
rawReq := new(OpenaiImageEditRequest)
|
||||
if err := c.ShouldBind(rawReq); err != nil {
|
||||
return nil, errors.Wrap(err, "parse image edit form")
|
||||
}
|
||||
|
||||
return rawReq.toFluxRemixRequest()
|
||||
}
|
||||
|
||||
// ConvertRequest converts the request to the format that the target API expects.
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if !request.Stream {
|
||||
// TODO: support non-stream mode
|
||||
return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true")
|
||||
}
|
||||
|
||||
// Build the prompt from OpenAI messages
|
||||
var promptBuilder strings.Builder
|
||||
for _, message := range request.Messages {
|
||||
switch msgCnt := message.Content.(type) {
|
||||
case string:
|
||||
promptBuilder.WriteString(message.Role)
|
||||
promptBuilder.WriteString(": ")
|
||||
promptBuilder.WriteString(msgCnt)
|
||||
promptBuilder.WriteString("\n")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
replicateRequest := ReplicateChatRequest{
|
||||
Input: ChatInput{
|
||||
Prompt: promptBuilder.String(),
|
||||
MaxTokens: request.MaxTokens,
|
||||
Temperature: 1.0,
|
||||
TopP: 1.0,
|
||||
PresencePenalty: 0.0,
|
||||
FrequencyPenalty: 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
// Map optional fields
|
||||
if request.Temperature != nil {
|
||||
replicateRequest.Input.Temperature = *request.Temperature
|
||||
}
|
||||
if request.TopP != nil {
|
||||
replicateRequest.Input.TopP = *request.TopP
|
||||
}
|
||||
if request.PresencePenalty != nil {
|
||||
replicateRequest.Input.PresencePenalty = *request.PresencePenalty
|
||||
}
|
||||
if request.FrequencyPenalty != nil {
|
||||
replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty
|
||||
}
|
||||
if request.MaxTokens > 0 {
|
||||
replicateRequest.Input.MaxTokens = request.MaxTokens
|
||||
} else if request.MaxTokens == 0 {
|
||||
replicateRequest.Input.MaxTokens = 500
|
||||
}
|
||||
|
||||
return replicateRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
a.meta = meta
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
if !slices.Contains(ModelList, meta.OriginModelName) {
|
||||
return "", errors.Errorf("model %s not supported", meta.OriginModelName)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
logger.Info(c, "send request to replicate")
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
switch meta.Mode {
|
||||
case relaymode.ImagesGenerations,
|
||||
relaymode.ImagesEdits:
|
||||
err, usage = ImageHandler(c, resp)
|
||||
case relaymode.ChatCompletions:
|
||||
err, usage = ChatHandler(c, resp)
|
||||
default:
|
||||
err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "replicate"
|
||||
}
|
||||
191
relay/adaptor/replicate/chat.go
Normal file
191
relay/adaptor/replicate/chat.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/render"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func ChatHandler(c *gin.Context, resp *http.Response) (
|
||||
srvErr *model.ErrorWithStatusCode, usage *model.Usage) {
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
payload, _ := io.ReadAll(resp.Body)
|
||||
return openai.ErrorWrapper(
|
||||
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
|
||||
"bad_status_code", http.StatusInternalServerError),
|
||||
nil
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
respData := new(ChatResponse)
|
||||
if err = json.Unmarshal(respBody, respData); err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
for {
|
||||
err = func() error {
|
||||
// get task
|
||||
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||
http.MethodGet, respData.URLs.Get, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "new request")
|
||||
}
|
||||
|
||||
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||
taskResp, err := http.DefaultClient.Do(taskReq)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get task")
|
||||
}
|
||||
defer taskResp.Body.Close()
|
||||
|
||||
if taskResp.StatusCode != http.StatusOK {
|
||||
payload, _ := io.ReadAll(taskResp.Body)
|
||||
return errors.Errorf("bad status code [%d]%s",
|
||||
taskResp.StatusCode, string(payload))
|
||||
}
|
||||
|
||||
taskBody, err := io.ReadAll(taskResp.Body)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "read task response")
|
||||
}
|
||||
|
||||
taskData := new(ChatResponse)
|
||||
if err = json.Unmarshal(taskBody, taskData); err != nil {
|
||||
return errors.Wrap(err, "decode task response")
|
||||
}
|
||||
|
||||
switch taskData.Status {
|
||||
case "succeeded":
|
||||
case "failed", "canceled":
|
||||
return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error)
|
||||
default:
|
||||
time.Sleep(time.Second * 3)
|
||||
return errNextLoop
|
||||
}
|
||||
|
||||
if taskData.URLs.Stream == "" {
|
||||
return errors.New("stream url is empty")
|
||||
}
|
||||
|
||||
// request stream url
|
||||
responseText, err := chatStreamHandler(c, taskData.URLs.Stream)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "chat stream handler")
|
||||
}
|
||||
|
||||
ctxMeta := meta.GetByContext(c)
|
||||
usage = openai.ResponseText2Usage(responseText,
|
||||
ctxMeta.ActualModelName, ctxMeta.PromptTokens)
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
if errors.Is(err, errNextLoop) {
|
||||
continue
|
||||
}
|
||||
|
||||
return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
const (
|
||||
eventPrefix = "event: "
|
||||
dataPrefix = "data: "
|
||||
done = "[DONE]"
|
||||
)
|
||||
|
||||
func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) {
|
||||
// request stream endpoint
|
||||
streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "new request to stream")
|
||||
}
|
||||
|
||||
streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||
streamReq.Header.Set("Accept", "text/event-stream")
|
||||
streamReq.Header.Set("Cache-Control", "no-store")
|
||||
|
||||
resp, err := http.DefaultClient.Do(streamReq)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "do request to stream")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
payload, _ := io.ReadAll(resp.Body)
|
||||
return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
common.SetEventStreamHeaders(c)
|
||||
doneRendered := false
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle comments starting with ':'
|
||||
if strings.HasPrefix(line, ":") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse SSE fields
|
||||
if strings.HasPrefix(line, eventPrefix) {
|
||||
event := strings.TrimSpace(line[len(eventPrefix):])
|
||||
var data string
|
||||
// Read the following lines to get data and id
|
||||
for scanner.Scan() {
|
||||
nextLine := scanner.Text()
|
||||
if nextLine == "" {
|
||||
break
|
||||
}
|
||||
if strings.HasPrefix(nextLine, dataPrefix) {
|
||||
data = nextLine[len(dataPrefix):]
|
||||
} else if strings.HasPrefix(nextLine, "id:") {
|
||||
// id = strings.TrimSpace(nextLine[len("id:"):])
|
||||
}
|
||||
}
|
||||
|
||||
if event == "output" {
|
||||
render.StringData(c, data)
|
||||
responseText += data
|
||||
} else if event == "done" {
|
||||
render.Done(c)
|
||||
doneRendered = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", errors.Wrap(err, "scan stream")
|
||||
}
|
||||
|
||||
if !doneRendered {
|
||||
render.Done(c)
|
||||
}
|
||||
|
||||
return responseText, nil
|
||||
}
|
||||
58
relay/adaptor/replicate/constant.go
Normal file
58
relay/adaptor/replicate/constant.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package replicate
|
||||
|
||||
// ModelList is a list of models that can be used with Replicate.
|
||||
//
|
||||
// https://replicate.com/pricing
|
||||
var ModelList = []string{
|
||||
// -------------------------------------
|
||||
// image model
|
||||
// -------------------------------------
|
||||
"black-forest-labs/flux-1.1-pro",
|
||||
"black-forest-labs/flux-1.1-pro-ultra",
|
||||
"black-forest-labs/flux-canny-dev",
|
||||
"black-forest-labs/flux-canny-pro",
|
||||
"black-forest-labs/flux-depth-dev",
|
||||
"black-forest-labs/flux-depth-pro",
|
||||
"black-forest-labs/flux-dev",
|
||||
"black-forest-labs/flux-dev-lora",
|
||||
"black-forest-labs/flux-fill-dev",
|
||||
"black-forest-labs/flux-fill-pro",
|
||||
"black-forest-labs/flux-pro",
|
||||
"black-forest-labs/flux-redux-dev",
|
||||
"black-forest-labs/flux-redux-schnell",
|
||||
"black-forest-labs/flux-schnell",
|
||||
"black-forest-labs/flux-schnell-lora",
|
||||
"ideogram-ai/ideogram-v2",
|
||||
"ideogram-ai/ideogram-v2-turbo",
|
||||
"recraft-ai/recraft-v3",
|
||||
"recraft-ai/recraft-v3-svg",
|
||||
"stability-ai/stable-diffusion-3",
|
||||
"stability-ai/stable-diffusion-3.5-large",
|
||||
"stability-ai/stable-diffusion-3.5-large-turbo",
|
||||
"stability-ai/stable-diffusion-3.5-medium",
|
||||
// -------------------------------------
|
||||
// language model
|
||||
// -------------------------------------
|
||||
"ibm-granite/granite-20b-code-instruct-8k",
|
||||
"ibm-granite/granite-3.0-2b-instruct",
|
||||
"ibm-granite/granite-3.0-8b-instruct",
|
||||
"ibm-granite/granite-8b-code-instruct-128k",
|
||||
"meta/llama-2-13b",
|
||||
"meta/llama-2-13b-chat",
|
||||
"meta/llama-2-70b",
|
||||
"meta/llama-2-70b-chat",
|
||||
"meta/llama-2-7b",
|
||||
"meta/llama-2-7b-chat",
|
||||
"meta/meta-llama-3.1-405b-instruct",
|
||||
"meta/meta-llama-3-70b",
|
||||
"meta/meta-llama-3-70b-instruct",
|
||||
"meta/meta-llama-3-8b",
|
||||
"meta/meta-llama-3-8b-instruct",
|
||||
"mistralai/mistral-7b-instruct-v0.2",
|
||||
"mistralai/mistral-7b-v0.1",
|
||||
"mistralai/mixtral-8x7b-instruct-v0.1",
|
||||
// -------------------------------------
|
||||
// video model
|
||||
// -------------------------------------
|
||||
// "minimax/video-01", // TODO: implement the adaptor
|
||||
}
|
||||
207
relay/adaptor/replicate/image.go
Normal file
207
relay/adaptor/replicate/image.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/png"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"golang.org/x/image/webp"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var errNextLoop = errors.New("next_loop")
|
||||
|
||||
// ImageHandler handles the response from the image creation or remix request
|
||||
func ImageHandler(c *gin.Context, resp *http.Response) (
|
||||
*model.ErrorWithStatusCode, *model.Usage) {
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
payload, _ := io.ReadAll(resp.Body)
|
||||
return openai.ErrorWrapper(
|
||||
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
|
||||
"bad_status_code", http.StatusInternalServerError),
|
||||
nil
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
respData := new(ImageResponse)
|
||||
if err = json.Unmarshal(respBody, respData); err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
for {
|
||||
err = func() error {
|
||||
// get task
|
||||
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||
http.MethodGet, respData.URLs.Get, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "new request")
|
||||
}
|
||||
|
||||
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||
taskResp, err := http.DefaultClient.Do(taskReq)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get task")
|
||||
}
|
||||
defer taskResp.Body.Close()
|
||||
|
||||
if taskResp.StatusCode != http.StatusOK {
|
||||
payload, _ := io.ReadAll(taskResp.Body)
|
||||
return errors.Errorf("bad status code [%d]%s",
|
||||
taskResp.StatusCode, string(payload))
|
||||
}
|
||||
|
||||
taskBody, err := io.ReadAll(taskResp.Body)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "read task response")
|
||||
}
|
||||
|
||||
taskData := new(ImageResponse)
|
||||
if err = json.Unmarshal(taskBody, taskData); err != nil {
|
||||
return errors.Wrap(err, "decode task response")
|
||||
}
|
||||
|
||||
switch taskData.Status {
|
||||
case "succeeded":
|
||||
case "failed", "canceled":
|
||||
return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error)
|
||||
default:
|
||||
time.Sleep(time.Second * 3)
|
||||
return errNextLoop
|
||||
}
|
||||
|
||||
output, err := taskData.GetOutput()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get output")
|
||||
}
|
||||
if len(output) == 0 {
|
||||
return errors.New("response output is empty")
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
var pool errgroup.Group
|
||||
respBody := &openai.ImageResponse{
|
||||
Created: taskData.CompletedAt.Unix(),
|
||||
Data: []openai.ImageData{},
|
||||
}
|
||||
|
||||
for _, imgOut := range output {
|
||||
imgOut := imgOut
|
||||
pool.Go(func() error {
|
||||
// download image
|
||||
downloadReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||
http.MethodGet, imgOut, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "new request")
|
||||
}
|
||||
|
||||
imgResp, err := http.DefaultClient.Do(downloadReq)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "download image")
|
||||
}
|
||||
defer imgResp.Body.Close()
|
||||
|
||||
if imgResp.StatusCode != http.StatusOK {
|
||||
payload, _ := io.ReadAll(imgResp.Body)
|
||||
return errors.Errorf("bad status code [%d]%s",
|
||||
imgResp.StatusCode, string(payload))
|
||||
}
|
||||
|
||||
imgData, err := io.ReadAll(imgResp.Body)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "read image")
|
||||
}
|
||||
|
||||
imgData, err = ConvertImageToPNG(imgData)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "convert image")
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
respBody.Data = append(respBody.Data, openai.ImageData{
|
||||
B64Json: fmt.Sprintf("data:image/png;base64,%s",
|
||||
base64.StdEncoding.EncodeToString(imgData)),
|
||||
})
|
||||
mu.Unlock()
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := pool.Wait(); err != nil {
|
||||
if len(respBody.Data) == 0 {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err))
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, respBody)
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
if errors.Is(err, errNextLoop) {
|
||||
continue
|
||||
}
|
||||
|
||||
return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// ConvertImageToPNG converts a WebP image to PNG format
|
||||
func ConvertImageToPNG(webpData []byte) ([]byte, error) {
|
||||
// bypass if it's already a PNG image
|
||||
if bytes.HasPrefix(webpData, []byte("\x89PNG")) {
|
||||
return webpData, nil
|
||||
}
|
||||
|
||||
// check if is jpeg, convert to png
|
||||
if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) {
|
||||
img, _, err := image.Decode(bytes.NewReader(webpData))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "decode jpeg")
|
||||
}
|
||||
|
||||
var pngBuffer bytes.Buffer
|
||||
if err := png.Encode(&pngBuffer, img); err != nil {
|
||||
return nil, errors.Wrap(err, "encode png")
|
||||
}
|
||||
|
||||
return pngBuffer.Bytes(), nil
|
||||
}
|
||||
|
||||
// Decode the WebP image
|
||||
img, err := webp.Decode(bytes.NewReader(webpData))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "decode webp")
|
||||
}
|
||||
|
||||
// Encode the image as PNG
|
||||
var pngBuffer bytes.Buffer
|
||||
if err := png.Encode(&pngBuffer, img); err != nil {
|
||||
return nil, errors.Wrap(err, "encode png")
|
||||
}
|
||||
|
||||
return pngBuffer.Bytes(), nil
|
||||
}
|
||||
277
relay/adaptor/replicate/model.go
Normal file
277
relay/adaptor/replicate/model.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"image"
|
||||
"image/png"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type OpenaiImageEditRequest struct {
|
||||
Image *multipart.FileHeader `json:"image" form:"image" binding:"required"`
|
||||
Prompt string `json:"prompt" form:"prompt" binding:"required"`
|
||||
Mask *multipart.FileHeader `json:"mask" form:"mask" binding:"required"`
|
||||
Model string `json:"model" form:"model" binding:"required"`
|
||||
N int `json:"n" form:"n" binding:"min=0,max=10"`
|
||||
Size string `json:"size" form:"size"`
|
||||
ResponseFormat string `json:"response_format" form:"response_format"`
|
||||
}
|
||||
|
||||
// toFluxRemixRequest convert OpenAI's image edit request to Flux's remix request.
|
||||
//
|
||||
// Note that the mask formats of OpenAI and Flux are different:
|
||||
// OpenAI's mask sets the parts to be modified as transparent (0, 0, 0, 0),
|
||||
// while Flux sets the parts to be modified as black (255, 255, 255, 255),
|
||||
// so we need to convert the format here.
|
||||
//
|
||||
// Both OpenAI's Image and Mask are browser-native ImageData,
|
||||
// which need to be converted to base64 dataURI format.
|
||||
func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusReplicateRequest, error) {
|
||||
if r.ResponseFormat != "b64_json" {
|
||||
return nil, errors.New("response_format must be b64_json for replicate models")
|
||||
}
|
||||
|
||||
fluxReq := &InpaintingImageByFlusReplicateRequest{
|
||||
Input: FluxInpaintingInput{
|
||||
Prompt: r.Prompt,
|
||||
Seed: int(time.Now().UnixNano()),
|
||||
Steps: 30,
|
||||
Guidance: 3,
|
||||
SafetyTolerance: 5,
|
||||
PromptUnsampling: false,
|
||||
OutputFormat: "png",
|
||||
},
|
||||
}
|
||||
|
||||
imgFile, err := r.Image.Open()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "open image file")
|
||||
}
|
||||
defer imgFile.Close()
|
||||
imgData, err := io.ReadAll(imgFile)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "read image file")
|
||||
}
|
||||
|
||||
maskFile, err := r.Mask.Open()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "open mask file")
|
||||
}
|
||||
defer maskFile.Close()
|
||||
|
||||
// Convert image to base64
|
||||
imageBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgData)
|
||||
fluxReq.Input.Image = imageBase64
|
||||
|
||||
// Convert mask data to RGBA
|
||||
maskPNG, err := png.Decode(maskFile)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "decode mask file")
|
||||
}
|
||||
|
||||
// convert mask to RGBA
|
||||
var maskRGBA *image.RGBA
|
||||
switch converted := maskPNG.(type) {
|
||||
case *image.RGBA:
|
||||
maskRGBA = converted
|
||||
default:
|
||||
// Convert to RGBA
|
||||
bounds := maskPNG.Bounds()
|
||||
maskRGBA = image.NewRGBA(bounds)
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||
maskRGBA.Set(x, y, maskPNG.At(x, y))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
maskData := maskRGBA.Pix
|
||||
invertedMask := make([]byte, len(maskData))
|
||||
for i := 0; i+4 <= len(maskData); i += 4 {
|
||||
// If pixel is transparent (alpha = 0), make it black (255)
|
||||
if maskData[i+3] == 0 {
|
||||
invertedMask[i] = 255 // R
|
||||
invertedMask[i+1] = 255 // G
|
||||
invertedMask[i+2] = 255 // B
|
||||
invertedMask[i+3] = 255 // A
|
||||
} else {
|
||||
// Copy original pixel
|
||||
copy(invertedMask[i:i+4], maskData[i:i+4])
|
||||
}
|
||||
}
|
||||
|
||||
// Convert inverted mask to base64 encoded png image
|
||||
invertedMaskRGBA := &image.RGBA{
|
||||
Pix: invertedMask,
|
||||
Stride: maskRGBA.Stride,
|
||||
Rect: maskRGBA.Rect,
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = png.Encode(&buf, invertedMaskRGBA)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "encode inverted mask to png")
|
||||
}
|
||||
|
||||
invertedMaskBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(buf.Bytes())
|
||||
fluxReq.Input.Mask = invertedMaskBase64
|
||||
|
||||
return fluxReq, nil
|
||||
}
|
||||
|
||||
// DrawImageRequest draw image by fluxpro
|
||||
//
|
||||
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
|
||||
type DrawImageRequest struct {
|
||||
Input ImageInput `json:"input"`
|
||||
}
|
||||
|
||||
// ImageInput is input of DrawImageByFluxProRequest
|
||||
//
|
||||
// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema
|
||||
type ImageInput struct {
|
||||
Steps int `json:"steps" binding:"required,min=1"`
|
||||
Prompt string `json:"prompt" binding:"required,min=5"`
|
||||
ImagePrompt string `json:"image_prompt"`
|
||||
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
|
||||
Interval int `json:"interval" binding:"required,min=1,max=4"`
|
||||
AspectRatio string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"`
|
||||
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
|
||||
Seed int `json:"seed"`
|
||||
NImages int `json:"n_images" binding:"required,min=1,max=8"`
|
||||
Width int `json:"width" binding:"required,min=256,max=1440"`
|
||||
Height int `json:"height" binding:"required,min=256,max=1440"`
|
||||
}
|
||||
|
||||
// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro
|
||||
//
|
||||
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
|
||||
type InpaintingImageByFlusReplicateRequest struct {
|
||||
Input FluxInpaintingInput `json:"input"`
|
||||
}
|
||||
|
||||
// FluxInpaintingInput is input of DrawImageByFluxProRequest
|
||||
//
|
||||
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
|
||||
type FluxInpaintingInput struct {
|
||||
Mask string `json:"mask" binding:"required"`
|
||||
Image string `json:"image" binding:"required"`
|
||||
Seed int `json:"seed"`
|
||||
Steps int `json:"steps" binding:"required,min=1"`
|
||||
Prompt string `json:"prompt" binding:"required,min=5"`
|
||||
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
|
||||
OutputFormat string `json:"output_format"`
|
||||
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
|
||||
PromptUnsampling bool `json:"prompt_unsampling"`
|
||||
}
|
||||
|
||||
// ImageResponse is response of DrawImageByFluxProRequest
|
||||
//
|
||||
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
|
||||
type ImageResponse struct {
|
||||
CompletedAt time.Time `json:"completed_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
DataRemoved bool `json:"data_removed"`
|
||||
Error string `json:"error"`
|
||||
ID string `json:"id"`
|
||||
Input DrawImageRequest `json:"input"`
|
||||
Logs string `json:"logs"`
|
||||
Metrics FluxMetrics `json:"metrics"`
|
||||
// Output could be `string` or `[]string`
|
||||
Output any `json:"output"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
Status string `json:"status"`
|
||||
URLs FluxURLs `json:"urls"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
func (r *ImageResponse) GetOutput() ([]string, error) {
|
||||
switch v := r.Output.(type) {
|
||||
case string:
|
||||
return []string{v}, nil
|
||||
case []string:
|
||||
return v, nil
|
||||
case nil:
|
||||
return nil, nil
|
||||
case []interface{}:
|
||||
// convert []interface{} to []string
|
||||
ret := make([]string, len(v))
|
||||
for idx, vv := range v {
|
||||
if vvv, ok := vv.(string); ok {
|
||||
ret[idx] = vvv
|
||||
} else {
|
||||
return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv)
|
||||
}
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
default:
|
||||
return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output)
|
||||
}
|
||||
}
|
||||
|
||||
// FluxMetrics is metrics of ImageResponse
|
||||
type FluxMetrics struct {
|
||||
ImageCount int `json:"image_count"`
|
||||
PredictTime float64 `json:"predict_time"`
|
||||
TotalTime float64 `json:"total_time"`
|
||||
}
|
||||
|
||||
// FluxURLs is urls of ImageResponse
|
||||
type FluxURLs struct {
|
||||
Get string `json:"get"`
|
||||
Cancel string `json:"cancel"`
|
||||
}
|
||||
|
||||
type ReplicateChatRequest struct {
|
||||
Input ChatInput `json:"input" form:"input" binding:"required"`
|
||||
}
|
||||
|
||||
// ChatInput is input of ChatByReplicateRequest
|
||||
//
|
||||
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema
|
||||
type ChatInput struct {
|
||||
TopK int `json:"top_k"`
|
||||
TopP float64 `json:"top_p"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
MinTokens int `json:"min_tokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
SystemPrompt string `json:"system_prompt"`
|
||||
StopSequences string `json:"stop_sequences"`
|
||||
PromptTemplate string `json:"prompt_template"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
}
|
||||
|
||||
// ChatResponse is response of ChatByReplicateRequest
|
||||
//
|
||||
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json
|
||||
type ChatResponse struct {
|
||||
CompletedAt time.Time `json:"completed_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
DataRemoved bool `json:"data_removed"`
|
||||
Error string `json:"error"`
|
||||
ID string `json:"id"`
|
||||
Input ChatInput `json:"input"`
|
||||
Logs string `json:"logs"`
|
||||
Metrics FluxMetrics `json:"metrics"`
|
||||
// Output could be `string` or `[]string`
|
||||
Output []string `json:"output"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
Status string `json:"status"`
|
||||
URLs ChatResponseUrl `json:"urls"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// ChatResponseUrl is task urls of ChatResponse
|
||||
type ChatResponseUrl struct {
|
||||
Stream string `json:"stream"`
|
||||
Get string `json:"get"`
|
||||
Cancel string `json:"cancel"`
|
||||
}
|
||||
106
relay/adaptor/replicate/model_test.go
Normal file
106
relay/adaptor/replicate/model_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"image/draw"
|
||||
"image/png"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type nopCloser struct {
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func (n nopCloser) Close() error { return nil }
|
||||
|
||||
// Custom FileHeader to override Open method
|
||||
type customFileHeader struct {
|
||||
*multipart.FileHeader
|
||||
openFunc func() (multipart.File, error)
|
||||
}
|
||||
|
||||
func (c *customFileHeader) Open() (multipart.File, error) {
|
||||
return c.openFunc()
|
||||
}
|
||||
|
||||
func TestOpenaiImageEditRequest_toFluxRemixRequest(t *testing.T) {
|
||||
// Create a simple image for testing
|
||||
img := image.NewRGBA(image.Rect(0, 0, 10, 10))
|
||||
draw.Draw(img, img.Bounds(), &image.Uniform{C: image.Black}, image.Point{}, draw.Src)
|
||||
var imgBuf bytes.Buffer
|
||||
err := png.Encode(&imgBuf, img)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a simple mask for testing
|
||||
mask := image.NewRGBA(image.Rect(0, 0, 10, 10))
|
||||
draw.Draw(mask, mask.Bounds(), &image.Uniform{C: image.Black}, image.Point{}, draw.Src)
|
||||
var maskBuf bytes.Buffer
|
||||
err = png.Encode(&maskBuf, mask)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a multipart.FileHeader from the image and mask bytes
|
||||
imgFileHeader, err := createFileHeader("image", "test.png", imgBuf.Bytes())
|
||||
require.NoError(t, err)
|
||||
maskFileHeader, err := createFileHeader("mask", "test.png", maskBuf.Bytes())
|
||||
require.NoError(t, err)
|
||||
|
||||
req := &OpenaiImageEditRequest{
|
||||
Image: imgFileHeader,
|
||||
Mask: maskFileHeader,
|
||||
Prompt: "Test prompt",
|
||||
Model: "test-model",
|
||||
ResponseFormat: "b64_json",
|
||||
}
|
||||
|
||||
fluxReq, err := req.toFluxRemixRequest()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, fluxReq)
|
||||
require.Equal(t, req.Prompt, fluxReq.Input.Prompt)
|
||||
require.NotEmpty(t, fluxReq.Input.Image)
|
||||
require.NotEmpty(t, fluxReq.Input.Mask)
|
||||
}
|
||||
|
||||
// createFileHeader creates a multipart.FileHeader from file bytes
|
||||
func createFileHeader(fieldname, filename string, fileBytes []byte) (*multipart.FileHeader, error) {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
// Create a form file field
|
||||
part, err := writer.CreateFormFile(fieldname, filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write the file bytes to the form file field
|
||||
_, err = part.Write(fileBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Close the writer to finalize the form
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the multipart form
|
||||
req := &http.Request{
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(body),
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
err = req.ParseMultipartForm(int64(body.Len()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Retrieve the file header from the parsed form
|
||||
fileHeader := req.MultipartForm.File[fieldname][0]
|
||||
return fileHeader, nil
|
||||
}
|
||||
36
relay/adaptor/siliconflow/constants.go
Normal file
36
relay/adaptor/siliconflow/constants.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package siliconflow
|
||||
|
||||
// https://docs.siliconflow.cn/docs/getting-started
|
||||
|
||||
var ModelList = []string{
|
||||
"deepseek-ai/deepseek-llm-67b-chat",
|
||||
"Qwen/Qwen1.5-14B-Chat",
|
||||
"Qwen/Qwen1.5-7B-Chat",
|
||||
"Qwen/Qwen1.5-110B-Chat",
|
||||
"Qwen/Qwen1.5-32B-Chat",
|
||||
"01-ai/Yi-1.5-6B-Chat",
|
||||
"01-ai/Yi-1.5-9B-Chat-16K",
|
||||
"01-ai/Yi-1.5-34B-Chat-16K",
|
||||
"THUDM/chatglm3-6b",
|
||||
"deepseek-ai/DeepSeek-V2-Chat",
|
||||
"THUDM/glm-4-9b-chat",
|
||||
"Qwen/Qwen2-72B-Instruct",
|
||||
"Qwen/Qwen2-7B-Instruct",
|
||||
"Qwen/Qwen2-57B-A14B-Instruct",
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Instruct",
|
||||
"Qwen/Qwen2-1.5B-Instruct",
|
||||
"internlm/internlm2_5-7b-chat",
|
||||
"BAAI/bge-large-en-v1.5",
|
||||
"BAAI/bge-large-zh-v1.5",
|
||||
"Pro/Qwen/Qwen2-7B-Instruct",
|
||||
"Pro/Qwen/Qwen2-1.5B-Instruct",
|
||||
"Pro/Qwen/Qwen1.5-7B-Chat",
|
||||
"Pro/THUDM/glm-4-9b-chat",
|
||||
"Pro/THUDM/chatglm3-6b",
|
||||
"Pro/01-ai/Yi-1.5-9B-Chat-16K",
|
||||
"Pro/01-ai/Yi-1.5-6B-Chat",
|
||||
"Pro/google/gemma-2-9b-it",
|
||||
"Pro/internlm/internlm2_5-7b-chat",
|
||||
"Pro/meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"Pro/mistralai/Mistral-7B-Instruct-v0.2",
|
||||
}
|
||||
13
relay/adaptor/stepfun/constants.go
Normal file
13
relay/adaptor/stepfun/constants.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package stepfun
|
||||
|
||||
var ModelList = []string{
|
||||
"step-1-8k",
|
||||
"step-1-32k",
|
||||
"step-1-128k",
|
||||
"step-1-256k",
|
||||
"step-1-flash",
|
||||
"step-2-16k",
|
||||
"step-1v-8k",
|
||||
"step-1v-32k",
|
||||
"step-1x-medium",
|
||||
}
|
||||
90
relay/adaptor/tencent/adaptor.go
Normal file
90
relay/adaptor/tencent/adaptor.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package tencent
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
// https://cloud.tencent.com/document/api/1729/101837
|
||||
|
||||
type Adaptor struct {
|
||||
Sign string
|
||||
Action string
|
||||
Version string
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
a.Action = "ChatCompletions"
|
||||
a.Version = "2023-09-01"
|
||||
a.Timestamp = helper.GetTimestamp()
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
return meta.BaseURL + "/", nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", a.Sign)
|
||||
req.Header.Set("X-TC-Action", a.Action)
|
||||
req.Header.Set("X-TC-Version", a.Version)
|
||||
req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
_, secretId, secretKey, err := ParseConfig(apiKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tencentRequest := ConvertRequest(*request)
|
||||
// we have to calculate the sign here
|
||||
a.Sign = GetSign(*tencentRequest, a, secretId, secretKey)
|
||||
return tencentRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = StreamHandler(c, resp)
|
||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
} else {
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "tencent"
|
||||
}
|
||||
9
relay/adaptor/tencent/constants.go
Normal file
9
relay/adaptor/tencent/constants.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package tencent
|
||||
|
||||
var ModelList = []string{
|
||||
"hunyuan-lite",
|
||||
"hunyuan-standard",
|
||||
"hunyuan-standard-256K",
|
||||
"hunyuan-pro",
|
||||
"hunyuan-vision",
|
||||
}
|
||||
245
relay/adaptor/tencent/main.go
Normal file
245
relay/adaptor/tencent/main.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package tencent
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/conv"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/common/render"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
messages := make([]*Message, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
messages = append(messages, &Message{
|
||||
Content: message.StringContent(),
|
||||
Role: message.Role,
|
||||
})
|
||||
}
|
||||
return &ChatRequest{
|
||||
Model: &request.Model,
|
||||
Stream: &request.Stream,
|
||||
Messages: messages,
|
||||
TopP: request.TopP,
|
||||
Temperature: request.Temperature,
|
||||
}
|
||||
}
|
||||
|
||||
func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Usage: model.Usage{
|
||||
PromptTokens: response.Usage.PromptTokens,
|
||||
CompletionTokens: response.Usage.CompletionTokens,
|
||||
TotalTokens: response.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
if len(response.Choices) > 0 {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Choices[0].Messages.Content,
|
||||
},
|
||||
FinishReason: response.Choices[0].FinishReason,
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "tencent-hunyuan",
|
||||
}
|
||||
if len(TencentResponse.Choices) > 0 {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
||||
if TencentResponse.Choices[0].FinishReason == "stop" {
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
}
|
||||
response.Choices = append(response.Choices, choice)
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||
var responseText string
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
common.SetEventStreamHeaders(c)
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "data:")
|
||||
|
||||
var tencentResponse ChatResponse
|
||||
err := json.Unmarshal([]byte(data), &tencentResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
response := streamResponseTencent2OpenAI(&tencentResponse)
|
||||
if len(response.Choices) != 0 {
|
||||
responseText += conv.AsString(response.Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
err = render.ObjectData(c, response)
|
||||
if err != nil {
|
||||
logger.SysError(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
|
||||
render.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var TencentResponse ChatResponse
|
||||
var responseP ChatResponseP
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &responseP)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
TencentResponse = responseP.Response
|
||||
if TencentResponse.Error.Code != 0 {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: TencentResponse.Error.Message,
|
||||
Code: TencentResponse.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
||||
fullTextResponse.Model = "hunyuan"
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) {
|
||||
parts := strings.Split(config, "|")
|
||||
if len(parts) != 3 {
|
||||
err = errors.New("invalid tencent config")
|
||||
return
|
||||
}
|
||||
appId, err = strconv.ParseInt(parts[0], 10, 64)
|
||||
secretId = parts[1]
|
||||
secretKey = parts[2]
|
||||
return
|
||||
}
|
||||
|
||||
func sha256hex(s string) string {
|
||||
b := sha256.Sum256([]byte(s))
|
||||
return hex.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
func hmacSha256(s, key string) string {
|
||||
hashed := hmac.New(sha256.New, []byte(key))
|
||||
hashed.Write([]byte(s))
|
||||
return string(hashed.Sum(nil))
|
||||
}
|
||||
|
||||
func GetSign(req ChatRequest, adaptor *Adaptor, secId, secKey string) string {
|
||||
// build canonical request string
|
||||
host := "hunyuan.tencentcloudapi.com"
|
||||
httpRequestMethod := "POST"
|
||||
canonicalURI := "/"
|
||||
canonicalQueryString := ""
|
||||
canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n",
|
||||
"application/json", host, strings.ToLower(adaptor.Action))
|
||||
signedHeaders := "content-type;host;x-tc-action"
|
||||
payload, _ := json.Marshal(req)
|
||||
hashedRequestPayload := sha256hex(string(payload))
|
||||
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
|
||||
httpRequestMethod,
|
||||
canonicalURI,
|
||||
canonicalQueryString,
|
||||
canonicalHeaders,
|
||||
signedHeaders,
|
||||
hashedRequestPayload)
|
||||
// build string to sign
|
||||
algorithm := "TC3-HMAC-SHA256"
|
||||
requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10)
|
||||
timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64)
|
||||
t := time.Unix(timestamp, 0).UTC()
|
||||
// must be the format 2006-01-02, ref to package time for more info
|
||||
date := t.Format("2006-01-02")
|
||||
credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan")
|
||||
hashedCanonicalRequest := sha256hex(canonicalRequest)
|
||||
string2sign := fmt.Sprintf("%s\n%s\n%s\n%s",
|
||||
algorithm,
|
||||
requestTimestamp,
|
||||
credentialScope,
|
||||
hashedCanonicalRequest)
|
||||
|
||||
// sign string
|
||||
secretDate := hmacSha256(date, "TC3"+secKey)
|
||||
secretService := hmacSha256("hunyuan", secretDate)
|
||||
secretKey := hmacSha256("tc3_request", secretService)
|
||||
signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey)))
|
||||
|
||||
// build authorization
|
||||
authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||
algorithm,
|
||||
secId,
|
||||
credentialScope,
|
||||
signedHeaders,
|
||||
signature)
|
||||
return authorization
|
||||
}
|
||||
75
relay/adaptor/tencent/model.go
Normal file
75
relay/adaptor/tencent/model.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package tencent
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"Role"`
|
||||
Content string `json:"Content"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
// 模型名称,可选值包括 hunyuan-lite、hunyuan-standard、hunyuan-standard-256K、hunyuan-pro。
|
||||
// 各模型介绍请阅读 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 中的说明。
|
||||
//
|
||||
// 注意:
|
||||
// 不同的模型计费不同,请根据 [购买指南](https://cloud.tencent.com/document/product/1729/97731) 按需调用。
|
||||
Model *string `json:"Model"`
|
||||
// 聊天上下文信息。
|
||||
// 说明:
|
||||
// 1. 长度最多为 40,按对话时间从旧到新在数组中排列。
|
||||
// 2. Message.Role 可选值:system、user、assistant。
|
||||
// 其中,system 角色可选,如存在则必须位于列表的最开始。user 和 assistant 需交替出现(一问一答),以 user 提问开始和结束,且 Content 不能为空。Role 的顺序示例:[system(可选) user assistant user assistant user ...]。
|
||||
// 3. Messages 中 Content 总长度不能超过模型输入长度上限(可参考 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 文档),超过则会截断最前面的内容,只保留尾部内容。
|
||||
Messages []*Message `json:"Messages"`
|
||||
// 流式调用开关。
|
||||
// 说明:
|
||||
// 1. 未传值时默认为非流式调用(false)。
|
||||
// 2. 流式调用时以 SSE 协议增量返回结果(返回值取 Choices[n].Delta 中的值,需要拼接增量数据才能获得完整结果)。
|
||||
// 3. 非流式调用时:
|
||||
// 调用方式与普通 HTTP 请求无异。
|
||||
// 接口响应耗时较长,**如需更低时延建议设置为 true**。
|
||||
// 只返回一次最终结果(返回值取 Choices[n].Message 中的值)。
|
||||
//
|
||||
// 注意:
|
||||
// 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。
|
||||
Stream *bool `json:"Stream"`
|
||||
// 说明:
|
||||
// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
|
||||
// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
|
||||
// 3. 非必要不建议使用,不合理的取值会影响效果。
|
||||
TopP *float64 `json:"TopP"`
|
||||
// 说明:
|
||||
// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
|
||||
// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
|
||||
// 3. 非必要不建议使用,不合理的取值会影响效果。
|
||||
Temperature *float64 `json:"Temperature"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Code int `json:"Code"`
|
||||
Message string `json:"Message"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"PromptTokens"`
|
||||
CompletionTokens int `json:"CompletionTokens"`
|
||||
TotalTokens int `json:"TotalTokens"`
|
||||
}
|
||||
|
||||
type ResponseChoices struct {
|
||||
FinishReason string `json:"FinishReason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||
Messages Message `json:"Message,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
Delta Message `json:"Delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Choices []ResponseChoices `json:"Choices,omitempty"` // 结果
|
||||
Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串
|
||||
Id string `json:"Id,omitempty"` // 会话 id
|
||||
Usage Usage `json:"Usage,omitempty"` // token 数量
|
||||
Error Error `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"Note,omitempty"` // 注释
|
||||
ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
}
|
||||
|
||||
type ChatResponseP struct {
|
||||
Response ChatResponse `json:"Response,omitempty"`
|
||||
}
|
||||
10
relay/adaptor/togetherai/constants.go
Normal file
10
relay/adaptor/togetherai/constants.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package togetherai
|
||||
|
||||
// https://docs.together.ai/docs/inference-models
|
||||
|
||||
var ModelList = []string{
|
||||
"meta-llama/Llama-3-70b-chat-hf",
|
||||
"deepseek-ai/deepseek-coder-33b-instruct",
|
||||
"mistralai/Mixtral-8x22B-Instruct-v0.1",
|
||||
"Qwen/Qwen1.5-72B-Chat",
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user