diff --git a/common/model-ratio.go b/common/model-ratio.go index b418aa7d..74871880 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -82,9 +82,12 @@ var ModelRatio = map[string]float64{ "bge-large-en": 0.002 * RMB, "bge-large-8k": 0.002 * RMB, // https://ai.google.dev/pricing - "PaLM-2": 1, - "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens - "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "PaLM-2": 1, + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-1.0-pro-vision-001": 1, + "gemini-1.0-pro-001": 1, + "gemini-1.5-pro": 1, // https://open.bigmodel.cn/pricing "glm-4": 0.1 * RMB, "glm-4v": 0.1 * RMB, @@ -253,6 +256,9 @@ func GetCompletionRatio(name string) float64 { if strings.HasPrefix(name, "mistral-") { return 3 } + if strings.HasPrefix(name, "gemini-") { + return 3 + } switch name { case "llama2-70b-4096": return 0.8 / 0.7 diff --git a/controller/token.go b/controller/token.go index ff1333ee..9b52b053 100644 --- a/controller/token.go +++ b/controller/token.go @@ -144,6 +144,7 @@ func AddToken(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", + "data": cleanToken, }) return } diff --git a/relay/channel/gemini/constants.go b/relay/channel/gemini/constants.go index e8d3a155..32e7c240 100644 --- a/relay/channel/gemini/constants.go +++ b/relay/channel/gemini/constants.go @@ -3,6 +3,6 @@ package gemini // https://ai.google.dev/models/gemini var ModelList = []string{ - "gemini-pro", "gemini-1.0-pro-001", + "gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro", "gemini-pro-vision", "gemini-1.0-pro-vision-001", } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 06c66101..e2ae7d2b 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -3,13 +3,14 @@ package ollama import ( "errors" "fmt" + "io" + "net/http" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" - "io" - "net/http" ) type Adaptor struct { @@ -22,6 +23,9 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { // https://github.com/ollama/ollama/blob/main/docs/api.md fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) + if meta.Mode == constant.RelayModeEmbeddings { + fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL) + } return fullRequestURL, nil } @@ -37,7 +41,8 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } switch relayMode { case constant.RelayModeEmbeddings: - return nil, errors.New("not supported") + ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request) + return ollamaEmbeddingRequest, nil default: return ConvertRequest(*request), nil } @@ -51,7 +56,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel if meta.IsStream { err, usage = StreamHandler(c, resp) } else { - err, usage = Handler(c, resp) + switch meta.Mode { + case constant.RelayModeEmbeddings: + err, usage = EmbeddingHandler(c, resp) + default: + err, usage = Handler(c, resp) + } } return } diff --git a/relay/channel/ollama/main.go b/relay/channel/ollama/main.go index 7ec646a3..821a335b 100644 --- a/relay/channel/ollama/main.go +++ b/relay/channel/ollama/main.go @@ -5,6 +5,10 @@ import ( "context" "encoding/json" "fmt" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" @@ -12,9 +16,6 @@ import ( "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" ) func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { @@ -139,6 +140,64 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC return nil, &usage } +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ + Model: request.Model, + Prompt: strings.Join(request.ParseInput(), " "), + } +} + +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var ollamaResponse EmbeddingResponse + err := json.NewDecoder(resp.Body).Decode(&ollamaResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + if ollamaResponse.Error != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: ollamaResponse.Error, + Type: "ollama_error", + Param: "", + Code: "ollama_error", + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := embeddingResponseOllama2OpenAI(&ollamaResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, 1), + Model: "text-embedding-v1", + Usage: model.Usage{TotalTokens: 0}, + } + + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: `embedding`, + Index: 0, + Embedding: response.Embedding, + }) + return &openAIEmbeddingResponse +} + func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { ctx := context.TODO() var ollamaResponse ChatResponse diff --git a/relay/channel/ollama/model.go b/relay/channel/ollama/model.go index a8ef1ffc..8baf56a0 100644 --- a/relay/channel/ollama/model.go +++ b/relay/channel/ollama/model.go @@ -35,3 +35,13 @@ type ChatResponse struct { EvalDuration int `json:"eval_duration,omitempty"` Error string `json:"error,omitempty"` } + +type EmbeddingRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` +} + +type EmbeddingResponse struct { + Error string `json:"error,omitempty"` + Embedding []float64 `json:"embedding,omitempty"` +} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 5b7c639d..8fd4db2c 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -31,11 +31,8 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { task := strings.TrimPrefix(requestURL, "/v1/") model_ := meta.ActualModelName model_ = strings.Replace(model_, ".", "", -1) - // https://github.com/songquanpeng/one-api/issues/67 - model_ = strings.TrimSuffix(model_, "-0301") - model_ = strings.TrimSuffix(model_, "-0314") - model_ = strings.TrimSuffix(model_, "-0613") - + //https://github.com/songquanpeng/one-api/issues/1191 + // {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version} requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil case common.ChannelTypeMinimax: diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 2e97d62b..85599b1f 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -85,6 +85,24 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } + succeed := false + defer func() { + if succeed { + return + } + if preConsumedQuota > 0 { + // we need to roll back the pre-consumed quota + defer func(ctx context.Context) { + go func() { + // negative means add quota back for token & user + err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) + if err != nil { + logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) + } + }() + }(c.Request.Context()) + } + }() // map model name modelMapping := c.GetString("model_mapping") @@ -195,20 +213,9 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } if resp.StatusCode != http.StatusOK { - if preConsumedQuota > 0 { - // we need to roll back the pre-consumed quota - defer func(ctx context.Context) { - go func() { - // negative means add quota back for token & user - err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) - if err != nil { - logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) - } - }() - }(c.Request.Context()) - } return util.RelayErrorHandler(resp) } + succeed = true quotaDelta := quota - preConsumedQuota defer func(ctx context.Context) { go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) diff --git a/web/berry/src/views/Authentication/Auth/Register.js b/web/berry/src/views/Authentication/Auth/Register.js index 4489e560..8027649d 100644 --- a/web/berry/src/views/Authentication/Auth/Register.js +++ b/web/berry/src/views/Authentication/Auth/Register.js @@ -51,7 +51,7 @@ const Register = () => { - 已经有帐号了?点击登录 + 已经有帐号了?点击登录 diff --git a/web/berry/src/views/Authentication/AuthForms/AuthRegister.js b/web/berry/src/views/Authentication/AuthForms/AuthRegister.js index c286faad..8d588696 100644 --- a/web/berry/src/views/Authentication/AuthForms/AuthRegister.js +++ b/web/berry/src/views/Authentication/AuthForms/AuthRegister.js @@ -296,7 +296,7 @@ const RegisterForm = ({ ...others }) => {