diff --git a/common/model-ratio.go b/common/model-ratio.go
index b418aa7d..74871880 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -82,9 +82,12 @@ var ModelRatio = map[string]float64{
"bge-large-en": 0.002 * RMB,
"bge-large-8k": 0.002 * RMB,
// https://ai.google.dev/pricing
- "PaLM-2": 1,
- "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
- "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
+ "PaLM-2": 1,
+ "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
+ "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
+ "gemini-1.0-pro-vision-001": 1,
+ "gemini-1.0-pro-001": 1,
+ "gemini-1.5-pro": 1,
// https://open.bigmodel.cn/pricing
"glm-4": 0.1 * RMB,
"glm-4v": 0.1 * RMB,
@@ -253,6 +256,9 @@ func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "mistral-") {
return 3
}
+ if strings.HasPrefix(name, "gemini-") {
+ return 3
+ }
switch name {
case "llama2-70b-4096":
return 0.8 / 0.7
diff --git a/controller/token.go b/controller/token.go
index ff1333ee..9b52b053 100644
--- a/controller/token.go
+++ b/controller/token.go
@@ -144,6 +144,7 @@ func AddToken(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
+ "data": cleanToken,
})
return
}
diff --git a/relay/channel/gemini/constants.go b/relay/channel/gemini/constants.go
index e8d3a155..32e7c240 100644
--- a/relay/channel/gemini/constants.go
+++ b/relay/channel/gemini/constants.go
@@ -3,6 +3,6 @@ package gemini
// https://ai.google.dev/models/gemini
var ModelList = []string{
- "gemini-pro", "gemini-1.0-pro-001",
+ "gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
"gemini-pro-vision", "gemini-1.0-pro-vision-001",
}
diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go
index 06c66101..e2ae7d2b 100644
--- a/relay/channel/ollama/adaptor.go
+++ b/relay/channel/ollama/adaptor.go
@@ -3,13 +3,14 @@ package ollama
import (
"errors"
"fmt"
+ "io"
+ "net/http"
+
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
- "io"
- "net/http"
)
type Adaptor struct {
@@ -22,6 +23,9 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
// https://github.com/ollama/ollama/blob/main/docs/api.md
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
+ if meta.Mode == constant.RelayModeEmbeddings {
+ fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL)
+ }
return fullRequestURL, nil
}
@@ -37,7 +41,8 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
}
switch relayMode {
case constant.RelayModeEmbeddings:
- return nil, errors.New("not supported")
+ ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request)
+ return ollamaEmbeddingRequest, nil
default:
return ConvertRequest(*request), nil
}
@@ -51,7 +56,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
- err, usage = Handler(c, resp)
+ switch meta.Mode {
+ case constant.RelayModeEmbeddings:
+ err, usage = EmbeddingHandler(c, resp)
+ default:
+ err, usage = Handler(c, resp)
+ }
}
return
}
diff --git a/relay/channel/ollama/main.go b/relay/channel/ollama/main.go
index 7ec646a3..821a335b 100644
--- a/relay/channel/ollama/main.go
+++ b/relay/channel/ollama/main.go
@@ -5,6 +5,10 @@ import (
"context"
"encoding/json"
"fmt"
+ "io"
+ "net/http"
+ "strings"
+
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
@@ -12,9 +16,6 @@ import (
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
- "io"
- "net/http"
- "strings"
)
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
@@ -139,6 +140,64 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
return nil, &usage
}
+func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
+ return &EmbeddingRequest{
+ Model: request.Model,
+ Prompt: strings.Join(request.ParseInput(), " "),
+ }
+}
+
+func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
+ var ollamaResponse EmbeddingResponse
+ err := json.NewDecoder(resp.Body).Decode(&ollamaResponse)
+ if err != nil {
+ return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ }
+
+ err = resp.Body.Close()
+ if err != nil {
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ }
+
+ if ollamaResponse.Error != "" {
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
+ Message: ollamaResponse.Error,
+ Type: "ollama_error",
+ Param: "",
+ Code: "ollama_error",
+ },
+ StatusCode: resp.StatusCode,
+ }, nil
+ }
+
+ fullTextResponse := embeddingResponseOllama2OpenAI(&ollamaResponse)
+ jsonResponse, err := json.Marshal(fullTextResponse)
+ if err != nil {
+ return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = c.Writer.Write(jsonResponse)
+ return nil, &fullTextResponse.Usage
+}
+
+func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
+ openAIEmbeddingResponse := openai.EmbeddingResponse{
+ Object: "list",
+ Data: make([]openai.EmbeddingResponseItem, 0, 1),
+ Model: "text-embedding-v1",
+ Usage: model.Usage{TotalTokens: 0},
+ }
+
+ openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
+ Object: `embedding`,
+ Index: 0,
+ Embedding: response.Embedding,
+ })
+ return &openAIEmbeddingResponse
+}
+
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
ctx := context.TODO()
var ollamaResponse ChatResponse
diff --git a/relay/channel/ollama/model.go b/relay/channel/ollama/model.go
index a8ef1ffc..8baf56a0 100644
--- a/relay/channel/ollama/model.go
+++ b/relay/channel/ollama/model.go
@@ -35,3 +35,13 @@ type ChatResponse struct {
EvalDuration int `json:"eval_duration,omitempty"`
Error string `json:"error,omitempty"`
}
+
+type EmbeddingRequest struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt"`
+}
+
+type EmbeddingResponse struct {
+ Error string `json:"error,omitempty"`
+ Embedding []float64 `json:"embedding,omitempty"`
+}
diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go
index 5b7c639d..8fd4db2c 100644
--- a/relay/channel/openai/adaptor.go
+++ b/relay/channel/openai/adaptor.go
@@ -31,11 +31,8 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := meta.ActualModelName
model_ = strings.Replace(model_, ".", "", -1)
- // https://github.com/songquanpeng/one-api/issues/67
- model_ = strings.TrimSuffix(model_, "-0301")
- model_ = strings.TrimSuffix(model_, "-0314")
- model_ = strings.TrimSuffix(model_, "-0613")
-
+ //https://github.com/songquanpeng/one-api/issues/1191
+ // {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version}
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
case common.ChannelTypeMinimax:
diff --git a/relay/controller/audio.go b/relay/controller/audio.go
index 2e97d62b..85599b1f 100644
--- a/relay/controller/audio.go
+++ b/relay/controller/audio.go
@@ -85,6 +85,24 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
+ succeed := false
+ defer func() {
+ if succeed {
+ return
+ }
+ if preConsumedQuota > 0 {
+ // we need to roll back the pre-consumed quota
+ defer func(ctx context.Context) {
+ go func() {
+ // negative means add quota back for token & user
+ err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
+ if err != nil {
+ logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
+ }
+ }()
+ }(c.Request.Context())
+ }
+ }()
// map model name
modelMapping := c.GetString("model_mapping")
@@ -195,20 +213,9 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
if resp.StatusCode != http.StatusOK {
- if preConsumedQuota > 0 {
- // we need to roll back the pre-consumed quota
- defer func(ctx context.Context) {
- go func() {
- // negative means add quota back for token & user
- err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
- if err != nil {
- logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
- }
- }()
- }(c.Request.Context())
- }
return util.RelayErrorHandler(resp)
}
+ succeed = true
quotaDelta := quota - preConsumedQuota
defer func(ctx context.Context) {
go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
diff --git a/web/berry/src/views/Authentication/Auth/Register.js b/web/berry/src/views/Authentication/Auth/Register.js
index 4489e560..8027649d 100644
--- a/web/berry/src/views/Authentication/Auth/Register.js
+++ b/web/berry/src/views/Authentication/Auth/Register.js
@@ -51,7 +51,7 @@ const Register = () => {
- 已经有帐号了?点击登录
+ 已经有帐号了?点击登录
diff --git a/web/berry/src/views/Authentication/AuthForms/AuthRegister.js b/web/berry/src/views/Authentication/AuthForms/AuthRegister.js
index c286faad..8d588696 100644
--- a/web/berry/src/views/Authentication/AuthForms/AuthRegister.js
+++ b/web/berry/src/views/Authentication/AuthForms/AuthRegister.js
@@ -296,7 +296,7 @@ const RegisterForm = ({ ...others }) => {