merge upstream

Signed-off-by: wozulong <>
This commit is contained in:
wozulong 2024-03-27 18:01:43 +08:00
commit 923c2dee32
18 changed files with 172 additions and 43 deletions

View File

@ -4,7 +4,6 @@ on:
push: push:
tags: tags:
- '*' - '*'
- '!*-alpha*'
workflow_dispatch: workflow_dispatch:
inputs: inputs:
name: name:

View File

@ -4,7 +4,6 @@ on:
push: push:
tags: tags:
- '*' - '*'
- '!*-alpha*'
workflow_dispatch: workflow_dispatch:
inputs: inputs:
name: name:

View File

@ -55,6 +55,7 @@ var WeChatAuthEnabled = false
var TelegramOAuthEnabled = false var TelegramOAuthEnabled = false
var TurnstileCheckEnabled = false var TurnstileCheckEnabled = false
var RegisterEnabled = true var RegisterEnabled = true
var UserSelfDeletionEnabled = false
var EmailDomainRestrictionEnabled = false var EmailDomainRestrictionEnabled = false
var EmailDomainWhitelist = []string{ var EmailDomainWhitelist = []string{

View File

@ -71,8 +71,11 @@ var DefaultModelRatio = map[string]float64{
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1, "PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens
@ -212,5 +215,15 @@ func GetCompletionRatio(name string) float64 {
} else if strings.HasPrefix(name, "claude-3") { } else if strings.HasPrefix(name, "claude-3") {
return 5 return 5
} }
if strings.HasPrefix(name, "mistral-") {
return 3
}
if strings.HasPrefix(name, "gemini-") {
return 3
}
switch name {
case "llama2-70b-4096":
return 0.8 / 0.7
}
return 1 return 1
} }

View File

@ -18,9 +18,8 @@ func InitRedisClient() (err error) {
return nil return nil
} }
if os.Getenv("SYNC_FREQUENCY") == "" { if os.Getenv("SYNC_FREQUENCY") == "" {
RedisEnabled = false SysLog("SYNC_FREQUENCY not set, use default value 60")
SysLog("SYNC_FREQUENCY not set, Redis is disabled") SyncFrequency = 60
return nil
} }
SysLog("Redis is enabled") SysLog("Redis is enabled")
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))

View File

@ -108,6 +108,7 @@ func buildTestRequest() *dto.GeneralOpenAIRequest {
testRequest := &dto.GeneralOpenAIRequest{ testRequest := &dto.GeneralOpenAIRequest{
Model: "", // this will be set later Model: "", // this will be set later
MaxTokens: 1, MaxTokens: 1,
Stream: false,
} }
content, _ := json.Marshal("hi") content, _ := json.Marshal("hi")
testMessage := dto.Message{ testMessage := dto.Message{

View File

@ -558,6 +558,14 @@ func HardDeleteUser(c *gin.Context) {
} }
func DeleteSelf(c *gin.Context) { func DeleteSelf(c *gin.Context) {
if !common.UserSelfDeletionEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "当前设置不允许用户自我删除账号",
})
return
}
id := c.GetInt("id") id := c.GetInt("id")
user, _ := model.GetUserById(id, false) user, _ := model.GetUserById(id, false)

View File

@ -36,6 +36,7 @@ func InitOptionMap() {
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["UserSelfDeletionEnabled"] = strconv.FormatBool(common.UserSelfDeletionEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
@ -177,6 +178,8 @@ func updateOptionMap(key string, value string) (err error) {
common.TurnstileCheckEnabled = boolValue common.TurnstileCheckEnabled = boolValue
case "RegisterEnabled": case "RegisterEnabled":
common.RegisterEnabled = boolValue common.RegisterEnabled = boolValue
case "UserSelfDeletionEnabled":
common.UserSelfDeletionEnabled = boolValue
case "EmailDomainRestrictionEnabled": case "EmailDomainRestrictionEnabled":
common.EmailDomainRestrictionEnabled = boolValue common.EmailDomainRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled": case "AutomaticDisableChannelEnabled":

View File

@ -5,8 +5,8 @@ const (
) )
var ModelList = []string{ var ModelList = []string{
"gemini-pro", "gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
"gemini-pro-vision", "gemini-pro-vision", "gemini-1.0-pro-vision-001",
} }
var ChannelName = "google gemini" var ChannelName = "google gemini"

View File

@ -2,7 +2,6 @@ package ollama
import ( import (
"errors" "errors"
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
@ -10,6 +9,7 @@ import (
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
) )
@ -20,7 +20,12 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIReq
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/chat", info.BaseUrl), nil switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
return info.BaseUrl + "/api/embeddings", nil
default:
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
}
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
@ -32,7 +37,12 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
return requestOpenAI2Ollama(*request), nil switch relayMode {
case relayconstant.RelayModeEmbeddings:
return requestOpenAI2Embeddings(*request), nil
default:
return requestOpenAI2Ollama(*request), nil
}
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
@ -45,7 +55,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) if info.RelayMode == relayconstant.RelayModeEmbeddings {
err, usage, sensitiveResp = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
} else {
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
}
} }
return return
} }

View File

@ -3,16 +3,24 @@ package ollama
import "one-api/dto" import "one-api/dto"
type OllamaRequest struct { type OllamaRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"` Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Options *OllamaOptions `json:"options,omitempty"` Temperature float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
} }
type OllamaOptions struct { type OllamaEmbeddingRequest struct {
Temperature float64 `json:"temperature,omitempty"` Model string `json:"model,omitempty"`
Seed float64 `json:"seed,omitempty"` Prompt any `json:"prompt,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
} }
type OllamaEmbeddingResponse struct {
Embedding []float64 `json:"embedding,omitempty"`
}
//type OllamaOptions struct {
//}

View File

@ -1,6 +1,15 @@
package ollama package ollama
import "one-api/dto" import (
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/service"
)
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
messages := make([]dto.Message, 0, len(request.Messages)) messages := make([]dto.Message, 0, len(request.Messages))
@ -18,15 +27,82 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
Stop, _ = request.Stop.([]string) Stop, _ = request.Stop.([]string)
} }
return &OllamaRequest{ return &OllamaRequest{
Model: request.Model, Model: request.Model,
Messages: messages, Messages: messages,
Stream: request.Stream, Stream: request.Stream,
Options: &OllamaOptions{ Temperature: request.Temperature,
Temperature: request.Temperature, Seed: request.Seed,
Seed: request.Seed, Topp: request.TopP,
Topp: request.TopP, TopK: request.TopK,
TopK: request.TopK, Stop: Stop,
Stop: Stop,
},
} }
} }
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
return &OllamaEmbeddingRequest{
Model: request.Model,
Prompt: request.Input,
}
}
func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
var ollamaEmbeddingResponse OllamaEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
}
err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
}
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
data = append(data, dto.OpenAIEmbeddingResponseItem{
Embedding: ollamaEmbeddingResponse.Embedding,
Object: "embedding",
})
usage := &dto.Usage{
TotalTokens: promptTokens,
CompletionTokens: 0,
PromptTokens: promptTokens,
}
embeddingResponse := &dto.OpenAIEmbeddingResponse{
Object: "list",
Data: data,
Model: model,
Usage: *usage,
}
doResponseBody, err := json.Marshal(embeddingResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil, nil
}
resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
// Copy headers
for k, v := range resp.Header {
// 删除任何现有的相同头部,以防止重复添加头部
c.Writer.Header().Del(k)
for _, vv := range v {
c.Writer.Header().Add(k, vv)
}
}
// reset content length
c.Writer.Header().Del("Content-Length")
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
}
return nil, usage, nil
}

View File

@ -10,8 +10,8 @@ import (
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/ai360" "one-api/relay/channel/ai360"
"one-api/relay/channel/moonshot"
"one-api/relay/channel/lingyiwanwu" "one-api/relay/channel/lingyiwanwu"
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"strings" "strings"
@ -34,9 +34,6 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
model_ := info.UpstreamModelName model_ := info.UpstreamModelName
model_ = strings.Replace(model_, ".", "", -1) model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67 // https://github.com/songquanpeng/one-api/issues/67
model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613")
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil

View File

@ -162,7 +162,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
return service.RelayErrorHandler(resp) return service.OpenAIErrorWrapper(fmt.Errorf("bad response status code: %d", resp.StatusCode), "bad_response_status_code", resp.StatusCode)
} }
usage, openaiErr, sensitiveResp := adaptor.DoResponse(c, resp, relayInfo) usage, openaiErr, sensitiveResp := adaptor.DoResponse(c, resp, relayInfo)

View File

@ -29,7 +29,7 @@ func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int)
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode { func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
text := err.Error() text := err.Error()
// 定义一个正则表达式匹配URL // 定义一个正则表达式匹配URL
if strings.Contains(text, "Post") { if strings.Contains(text, "Post") || strings.Contains(text, "dial") {
common.SysLog(fmt.Sprintf("error: %s", text)) common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败" text = "请求上游地址失败"
} }

View File

@ -208,7 +208,7 @@ func CountTokenInput(input any, model string, check bool) (int, error, bool) {
} }
return CountTokenText(text, model, check) return CountTokenText(text, model, check)
} }
return 0, errors.New("unsupported input type"), false return CountTokenInput(fmt.Sprintf("%v", input), model, check)
} }
func CountAudioToken(text string, model string, check bool) (int, error, bool) { func CountAudioToken(text string, model string, check bool) (int, error, bool) {

View File

@ -45,6 +45,7 @@ const SystemSetting = () => {
TurnstileSiteKey: '', TurnstileSiteKey: '',
TurnstileSecretKey: '', TurnstileSecretKey: '',
RegisterEnabled: '', RegisterEnabled: '',
UserSelfDeletionEnabled: false,
EmailDomainRestrictionEnabled: '', EmailDomainRestrictionEnabled: '',
EmailDomainWhitelist: [], EmailDomainWhitelist: [],
// telegram login // telegram login
@ -104,6 +105,7 @@ const SystemSetting = () => {
case 'TurnstileCheckEnabled': case 'TurnstileCheckEnabled':
case 'EmailDomainRestrictionEnabled': case 'EmailDomainRestrictionEnabled':
case 'RegisterEnabled': case 'RegisterEnabled':
case 'UserSelfDeletionEnabled':
case 'PaymentEnabled': case 'PaymentEnabled':
value = inputs[key] === 'true' ? 'false' : 'true'; value = inputs[key] === 'true' ? 'false' : 'true';
break; break;
@ -536,6 +538,12 @@ const SystemSetting = () => {
name='TurnstileCheckEnabled' name='TurnstileCheckEnabled'
onChange={handleInputChange} onChange={handleInputChange}
/> />
<Form.Checkbox
checked={inputs.UserSelfDeletionEnabled === 'true'}
label='允许用户自行删除账户'
name='UserSelfDeletionEnabled'
onChange={handleInputChange}
/>
</Form.Group> </Form.Group>
<Divider /> <Divider />
<Header as='h3'> <Header as='h3'>

View File

@ -426,8 +426,11 @@ const TokensTable = () => {
if (await copy(text)) { if (await copy(text)) {
showSuccess('已复制到剪贴板!'); showSuccess('已复制到剪贴板!');
} else { } else {
// setSearchKeyword(text); Modal.error({
Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text }); title: '无法复制到剪贴板,请手动复制',
content: text,
size: 'large',
});
} }
}; };