Merge remote-tracking branch 'origin/upstream/main'

This commit is contained in:
Laisky.Cai
2024-05-29 06:12:37 +00:00
88 changed files with 790 additions and 579 deletions

View File

@@ -8,6 +8,7 @@ import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
@@ -17,7 +18,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/gin-gonic/gin"
)
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
@@ -133,6 +133,29 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
return &geminiRequest
}
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbeddingRequest {
inputs := request.ParseInput()
requests := make([]EmbeddingRequest, len(inputs))
model := fmt.Sprintf("models/%s", request.Model)
for i, input := range inputs {
requests[i] = EmbeddingRequest{
Model: model,
Content: ChatContent{
Parts: []Part{
{
Text: input,
},
},
},
}
}
return &BatchEmbeddingRequest{
Requests: requests,
}
}
type ChatResponse struct {
Candidates []ChatCandidate `json:"candidates"`
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
@@ -229,6 +252,23 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC
return &response
}
func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),
Model: "gemini-embedding",
Usage: model.Usage{TotalTokens: 0},
}
for _, item := range response.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: 0,
Embedding: item.Values,
})
}
return &openAIEmbeddingResponse
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
scanner := bufio.NewScanner(resp.Body)
@@ -336,3 +376,39 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var geminiEmbeddingResponse EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &geminiEmbeddingResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if geminiEmbeddingResponse.Error != nil {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: geminiEmbeddingResponse.Error.Message,
Type: "gemini_error",
Param: "",
Code: geminiEmbeddingResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseGemini2OpenAI(&geminiEmbeddingResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}