Merge commit '5b349efff9906d6db9f31644008959250b2c30f9'

This commit is contained in:
Laisky.Cai 2024-03-27 22:09:47 +00:00
commit 36353a1d96
10 changed files with 120 additions and 30 deletions

View File

@ -85,6 +85,9 @@ var ModelRatio = map[string]float64{
"PaLM-2": 1, "PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro": 1,
// https://open.bigmodel.cn/pricing // https://open.bigmodel.cn/pricing
"glm-4": 0.1 * RMB, "glm-4": 0.1 * RMB,
"glm-4v": 0.1 * RMB, "glm-4v": 0.1 * RMB,
@ -253,6 +256,9 @@ func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "mistral-") { if strings.HasPrefix(name, "mistral-") {
return 3 return 3
} }
if strings.HasPrefix(name, "gemini-") {
return 3
}
switch name { switch name {
case "llama2-70b-4096": case "llama2-70b-4096":
return 0.8 / 0.7 return 0.8 / 0.7

View File

@ -144,6 +144,7 @@ func AddToken(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": cleanToken,
}) })
return return
} }

View File

@ -3,6 +3,6 @@ package gemini
// https://ai.google.dev/models/gemini // https://ai.google.dev/models/gemini
var ModelList = []string{ var ModelList = []string{
"gemini-pro", "gemini-1.0-pro-001", "gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
"gemini-pro-vision", "gemini-1.0-pro-vision-001", "gemini-pro-vision", "gemini-1.0-pro-vision-001",
} }

View File

@ -3,13 +3,14 @@ package ollama
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
) )
type Adaptor struct { type Adaptor struct {
@ -22,6 +23,9 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
// https://github.com/ollama/ollama/blob/main/docs/api.md // https://github.com/ollama/ollama/blob/main/docs/api.md
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
if meta.Mode == constant.RelayModeEmbeddings {
fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL)
}
return fullRequestURL, nil return fullRequestURL, nil
} }
@ -37,7 +41,8 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
} }
switch relayMode { switch relayMode {
case constant.RelayModeEmbeddings: case constant.RelayModeEmbeddings:
return nil, errors.New("not supported") ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request)
return ollamaEmbeddingRequest, nil
default: default:
return ConvertRequest(*request), nil return ConvertRequest(*request), nil
} }
@ -51,8 +56,13 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
if meta.IsStream { if meta.IsStream {
err, usage = StreamHandler(c, resp) err, usage = StreamHandler(c, resp)
} else { } else {
switch meta.Mode {
case constant.RelayModeEmbeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp) err, usage = Handler(c, resp)
} }
}
return return
} }

View File

@ -5,6 +5,10 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
@ -12,9 +16,6 @@ import (
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
) )
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
@ -139,6 +140,64 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
return nil, &usage return nil, &usage
} }
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Model: request.Model,
Prompt: strings.Join(request.ParseInput(), " "),
}
}
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var ollamaResponse EmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&ollamaResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if ollamaResponse.Error != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: ollamaResponse.Error,
Type: "ollama_error",
Param: "",
Code: "ollama_error",
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseOllama2OpenAI(&ollamaResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, 1),
Model: "text-embedding-v1",
Usage: model.Usage{TotalTokens: 0},
}
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: 0,
Embedding: response.Embedding,
})
return &openAIEmbeddingResponse
}
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
ctx := context.TODO() ctx := context.TODO()
var ollamaResponse ChatResponse var ollamaResponse ChatResponse

View File

@ -35,3 +35,13 @@ type ChatResponse struct {
EvalDuration int `json:"eval_duration,omitempty"` EvalDuration int `json:"eval_duration,omitempty"`
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
} }
type EmbeddingRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
}
type EmbeddingResponse struct {
Error string `json:"error,omitempty"`
Embedding []float64 `json:"embedding,omitempty"`
}

View File

@ -31,11 +31,8 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
task := strings.TrimPrefix(requestURL, "/v1/") task := strings.TrimPrefix(requestURL, "/v1/")
model_ := meta.ActualModelName model_ := meta.ActualModelName
model_ = strings.Replace(model_, ".", "", -1) model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67 //https://github.com/songquanpeng/one-api/issues/1191
model_ = strings.TrimSuffix(model_, "-0301") // {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version}
model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613")
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
case common.ChannelTypeMinimax: case common.ChannelTypeMinimax:

View File

@ -85,6 +85,24 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
} }
} }
succeed := false
defer func() {
if succeed {
return
}
if preConsumedQuota > 0 {
// we need to roll back the pre-consumed quota
defer func(ctx context.Context) {
go func() {
// negative means add quota back for token & user
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
}
}()
}(c.Request.Context())
}
}()
// map model name // map model name
modelMapping := c.GetString("model_mapping") modelMapping := c.GetString("model_mapping")
@ -195,20 +213,9 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
if preConsumedQuota > 0 {
// we need to roll back the pre-consumed quota
defer func(ctx context.Context) {
go func() {
// negative means add quota back for token & user
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
}
}()
}(c.Request.Context())
}
return util.RelayErrorHandler(resp) return util.RelayErrorHandler(resp)
} }
succeed = true
quotaDelta := quota - preConsumedQuota quotaDelta := quota - preConsumedQuota
defer func(ctx context.Context) { defer func(ctx context.Context) {
go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)

View File

@ -51,7 +51,7 @@ const Register = () => {
<Grid item xs={12}> <Grid item xs={12}>
<Grid item container direction="column" alignItems="center" xs={12}> <Grid item container direction="column" alignItems="center" xs={12}>
<Typography component={Link} to="/login" variant="subtitle1" sx={{ textDecoration: 'none' }}> <Typography component={Link} to="/login" variant="subtitle1" sx={{ textDecoration: 'none' }}>
已经有帐号了?点击登录 已经有帐号了点击登录
</Typography> </Typography>
</Grid> </Grid>
</Grid> </Grid>

View File

@ -296,7 +296,7 @@ const RegisterForm = ({ ...others }) => {
<Box sx={{ mt: 2 }}> <Box sx={{ mt: 2 }}>
<AnimateButton> <AnimateButton>
<Button disableElevation disabled={isSubmitting} fullWidth size="large" type="submit" variant="contained" color="primary"> <Button disableElevation disabled={isSubmitting} fullWidth size="large" type="submit" variant="contained" color="primary">
Sign up 注册
</Button> </Button>
</AnimateButton> </AnimateButton>
</Box> </Box>