diff --git a/.gitignore b/.gitignore index 1b2cf07..16cbc67 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ upload *.db build *.db-journal -logs \ No newline at end of file +logs +web/dist \ No newline at end of file diff --git a/common/email-outlook-auth.go b/common/email-outlook-auth.go new file mode 100644 index 0000000..723a10b --- /dev/null +++ b/common/email-outlook-auth.go @@ -0,0 +1,32 @@ +package common + +import ( + "errors" + "net/smtp" +) + +type outlookAuth struct { + username, password string +} + +func LoginAuth(username, password string) smtp.Auth { + return &outlookAuth{username, password} +} + +func (a *outlookAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) { + return "LOGIN", []byte{}, nil +} + +func (a *outlookAuth) Next(fromServer []byte, more bool) ([]byte, error) { + if more { + switch string(fromServer) { + case "Username:": + return []byte(a.username), nil + case "Password:": + return []byte(a.password), nil + default: + return nil, errors.New("unknown fromServer") + } + } + return nil, nil +} diff --git a/common/email.go b/common/email.go index 13345d8..62c9048 100644 --- a/common/email.go +++ b/common/email.go @@ -62,6 +62,9 @@ func SendEmail(subject string, receiver string, content string) error { if err != nil { return err } + } else if strings.HasSuffix(SMTPAccount, "outlook.com") { + auth = LoginAuth(SMTPAccount, SMTPToken) + err = smtp.SendMail(addr, auth, SMTPAccount, to, mail) } else { err = smtp.SendMail(addr, auth, SMTPAccount, to, mail) } diff --git a/common/model-ratio.go b/common/model-ratio.go index 67ae69a..242ccd7 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -3,6 +3,7 @@ package common import ( "encoding/json" "strings" + "sync" ) // from songquanpeng/one-api @@ -183,8 +184,14 @@ var defaultModelPrice = map[string]float64{ "swap_face": 0.05, } -var modelPrice map[string]float64 = nil -var modelRatio map[string]float64 = nil +var ( + modelPriceMap = make(map[string]float64) + modelPriceMapMutex = sync.RWMutex{} +) +var ( + modelRatioMap map[string]float64 = nil + modelRatioMapMutex = sync.RWMutex{} +) var CompletionRatio map[string]float64 = nil var defaultCompletionRatio = map[string]float64{ @@ -194,11 +201,18 @@ var defaultCompletionRatio = map[string]float64{ "gpt-4o-all": 2, } -func ModelPrice2JSONString() string { - if modelPrice == nil { - modelPrice = defaultModelPrice +func GetModelPriceMap() map[string]float64 { + modelPriceMapMutex.Lock() + defer modelPriceMapMutex.Unlock() + if modelPriceMap == nil { + modelPriceMap = defaultModelPrice } - jsonBytes, err := json.Marshal(modelPrice) + return modelPriceMap +} + +func ModelPrice2JSONString() string { + GetModelPriceMap() + jsonBytes, err := json.Marshal(modelPriceMap) if err != nil { SysError("error marshalling model price: " + err.Error()) } @@ -206,21 +220,21 @@ func ModelPrice2JSONString() string { } func UpdateModelPriceByJSONString(jsonStr string) error { - modelPrice = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &modelPrice) + modelPriceMapMutex.Lock() + defer modelPriceMapMutex.Unlock() + modelPriceMap = make(map[string]float64) + return json.Unmarshal([]byte(jsonStr), &modelPriceMap) } // GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false func GetModelPrice(name string, printErr bool) (float64, bool) { - if modelPrice == nil { - modelPrice = defaultModelPrice - } + GetModelPriceMap() if strings.HasPrefix(name, "gpt-4-gizmo") { name = "gpt-4-gizmo-*" } else if strings.HasPrefix(name, "g-") { name = "g-*" } - price, ok := modelPrice[name] + price, ok := modelPriceMap[name] if !ok { if printErr { SysError("model price not found: " + name) @@ -230,18 +244,18 @@ func GetModelPrice(name string, printErr bool) (float64, bool) { return price, true } -func GetModelPriceMap() map[string]float64 { - if modelPrice == nil { - modelPrice = defaultModelPrice +func GetModelRatioMap() map[string]float64 { + modelRatioMapMutex.Lock() + defer modelRatioMapMutex.Unlock() + if modelRatioMap == nil { + modelRatioMap = defaultModelRatio } - return modelPrice + return modelRatioMap } func ModelRatio2JSONString() string { - if modelRatio == nil { - modelRatio = defaultModelRatio - } - jsonBytes, err := json.Marshal(modelRatio) + GetModelRatioMap() + jsonBytes, err := json.Marshal(modelRatioMap) if err != nil { SysError("error marshalling model ratio: " + err.Error()) } @@ -249,20 +263,20 @@ func ModelRatio2JSONString() string { } func UpdateModelRatioByJSONString(jsonStr string) error { - modelRatio = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &modelRatio) + modelRatioMapMutex.Lock() + defer modelRatioMapMutex.Unlock() + modelRatioMap = make(map[string]float64) + return json.Unmarshal([]byte(jsonStr), &modelRatioMap) } func GetModelRatio(name string) float64 { - if modelRatio == nil { - modelRatio = defaultModelRatio - } + GetModelRatioMap() if strings.HasPrefix(name, "gpt-4-gizmo") { name = "gpt-4-gizmo-*" } else if strings.HasPrefix(name, "g-") { name = "g-*" } - ratio, ok := modelRatio[name] + ratio, ok := modelRatioMap[name] if !ok { SysError("model ratio not found: " + name) return 30 diff --git a/controller/relay.go b/controller/relay.go index bc951f7..0c79015 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -43,12 +43,13 @@ func Relay(c *gin.Context) { requestId := c.GetString(common.RequestIdKey) channelId := c.GetInt("channel_id") channelType := c.GetInt("channel_type") + channelName := c.GetString("channel_name") group := c.GetString("group") originalModel := c.GetString("original_model") openaiErr := relayHandler(c, relayMode) c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) if openaiErr != nil { - go processChannelError(c, channelId, channelType, openaiErr) + go processChannelError(c, channelId, channelType, channelName, openaiErr) } else { retryTimes = 0 } @@ -60,7 +61,7 @@ func Relay(c *gin.Context) { } channelId = channel.Id useChannel := c.GetStringSlice("use_channel") - useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) + useChannel = append(useChannel, fmt.Sprintf("%d", channel.Id)) c.Set("use_channel", useChannel) common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) middleware.SetupContextForSelectedChannel(c, channel, originalModel) @@ -69,7 +70,7 @@ func Relay(c *gin.Context) { c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) openaiErr = relayHandler(c, relayMode) if openaiErr != nil { - go processChannelError(c, channelId, channel.Type, openaiErr) + go processChannelError(c, channel.Id, channel.Type, channel.Name, openaiErr) } } useChannel := c.GetStringSlice("use_channel") @@ -128,11 +129,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt return true } -func processChannelError(c *gin.Context, channelId int, channelType int, err *dto.OpenAIErrorWithStatusCode) { +func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, err *dto.OpenAIErrorWithStatusCode) { autoBan := c.GetBool("auto_ban") common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message)) if service.ShouldDisableChannel(channelType, err) && autoBan { - channelName := c.GetString("channel_name") service.DisableChannel(channelId, channelName, err.Error.Message) } } diff --git a/middleware/utils.go b/middleware/utils.go index 43801c1..082f565 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -1,11 +1,13 @@ package middleware import ( + "fmt" "github.com/gin-gonic/gin" "one-api/common" ) func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) { + userId := c.GetInt("id") c.JSON(statusCode, gin.H{ "error": gin.H{ "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), @@ -13,7 +15,7 @@ func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) { }, }) c.Abort() - common.LogError(c.Request.Context(), message) + common.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message)) } func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) { diff --git a/model/channel.go b/model/channel.go index 8603033..7db3f07 100644 --- a/model/channel.go +++ b/model/channel.go @@ -100,8 +100,8 @@ func SearchChannels(keyword string, group string, model string) ([]*Channel, err var whereClause string var args []interface{} if group != "" { - whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " LIKE ? AND " + modelsCol + " LIKE ?" - args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+group+"%", "%"+model+"%") + whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " = ? AND " + modelsCol + " LIKE ?" + args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, group, "%"+model+"%") } else { whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?" args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%") diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index ce97755..f452d56 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -222,9 +222,11 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } } service.Done(c) - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil + if resp != nil { + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil + } } return nil, &usage } diff --git a/relay/channel/cloudflare/constant.go b/relay/channel/cloudflare/constant.go index a874685..0e2aec2 100644 --- a/relay/channel/cloudflare/constant.go +++ b/relay/channel/cloudflare/constant.go @@ -1,6 +1,7 @@ package cloudflare var ModelList = []string{ + "@cf/meta/llama-3.1-8b-instruct", "@cf/meta/llama-2-7b-chat-fp16", "@cf/meta/llama-2-7b-chat-int8", "@cf/mistral/mistral-7b-instruct-v0.1", diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index c674ba1..66ba839 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -53,7 +53,7 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n") } else if constant.DifyDebug && difyResponse.Event == "node_started" { choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n") - } else if difyResponse.Event == "message" { + } else if difyResponse.Event == "message" || difyResponse.Event == "agent_message" { choice.Delta.SetContentString(difyResponse.Answer) } response.Choices = append(response.Choices, choice) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 498e0a0..79dc44b 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -83,13 +83,28 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques if imageNum > GeminiVisionMaxImageNum { continue } - mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ - MimeType: mimeType, - Data: data, - }, - }) + // 判断是否是url + if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") { + // 是url,获取图片的类型和base64编码的数据 + mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: mimeType, + Data: data, + }, + }) + } else { + _, format, base64String, err := common.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url) + if err != nil { + continue + } + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: "image/" + format, + Data: base64String, + }, + }) + } } } content.Parts = parts diff --git a/relay/channel/ollama/dto.go b/relay/channel/ollama/dto.go index a6d6238..4f99a24 100644 --- a/relay/channel/ollama/dto.go +++ b/relay/channel/ollama/dto.go @@ -3,14 +3,18 @@ package ollama import "one-api/dto" type OllamaRequest struct { - Model string `json:"model,omitempty"` - Messages []dto.Message `json:"messages,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - Seed float64 `json:"seed,omitempty"` - Topp float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Stop any `json:"stop,omitempty"` + Model string `json:"model,omitempty"` + Messages []dto.Message `json:"messages,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Seed float64 `json:"seed,omitempty"` + Topp float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Stop any `json:"stop,omitempty"` + Tools []dto.ToolCall `json:"tools,omitempty"` + ResponseFormat *dto.ResponseFormat `json:"response_format,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` } type OllamaEmbeddingRequest struct { @@ -21,6 +25,3 @@ type OllamaEmbeddingRequest struct { type OllamaEmbeddingResponse struct { Embedding []float64 `json:"embedding,omitempty"` } - -//type OllamaOptions struct { -//} diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index f63fe57..6bf395a 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -28,14 +28,18 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { Stop, _ = request.Stop.([]string) } return &OllamaRequest{ - Model: request.Model, - Messages: messages, - Stream: request.Stream, - Temperature: request.Temperature, - Seed: request.Seed, - Topp: request.TopP, - TopK: request.TopK, - Stop: Stop, + Model: request.Model, + Messages: messages, + Stream: request.Stream, + Temperature: request.Temperature, + Seed: request.Seed, + Topp: request.TopP, + TopK: request.TopK, + Stop: Stop, + Tools: request.Tools, + ResponseFormat: request.ResponseFormat, + FrequencyPenalty: request.FrequencyPenalty, + PresencePenalty: request.PresencePenalty, } } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 424f183..87ad7d3 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -22,7 +22,7 @@ import ( func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { containStreamUsage := false - responseId := "" + var responseId string var createAt int64 = 0 var systemFingerprint string model := info.UpstreamModelName @@ -86,7 +86,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel var lastStreamResponse dto.ChatCompletionsStreamResponse err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse) if err == nil { - if lastStreamResponse.Usage != nil && service.ValidUsage(lastStreamResponse.Usage) { + responseId = lastStreamResponse.Id + createAt = lastStreamResponse.Created + systemFingerprint = lastStreamResponse.GetSystemFingerprint() + model = lastStreamResponse.Model + if service.ValidUsage(lastStreamResponse.Usage) { + containStreamUsage = true + usage = lastStreamResponse.Usage if !info.ShouldIncludeUsage { shouldSendLastResp = false } @@ -109,14 +115,9 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel var streamResponse dto.ChatCompletionsStreamResponse err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) if err == nil { - responseId = streamResponse.Id - createAt = streamResponse.Created - systemFingerprint = streamResponse.GetSystemFingerprint() - model = streamResponse.Model - if service.ValidUsage(streamResponse.Usage) { - usage = streamResponse.Usage - containStreamUsage = true - } + //if service.ValidUsage(streamResponse.Usage) { + // usage = streamResponse.Usage + //} for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString()) if choice.Delta.ToolCalls != nil { @@ -133,14 +134,10 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } } else { for _, streamResponse := range streamResponses { - responseId = streamResponse.Id - createAt = streamResponse.Created - systemFingerprint = streamResponse.GetSystemFingerprint() - model = streamResponse.Model - if service.ValidUsage(streamResponse.Usage) { - usage = streamResponse.Usage - containStreamUsage = true - } + //if service.ValidUsage(streamResponse.Usage) { + // usage = streamResponse.Usage + // containStreamUsage = true + //} for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString()) if choice.Delta.ToolCalls != nil { diff --git a/relay/relay-image.go b/relay/relay-image.go index f6a2641..83c7538 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -121,7 +121,8 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { } } - quota := int(modelPrice*groupRatio*common.QuotaPerUnit*sizeRatio*qualityRatio) * imageRequest.N + imageRatio := modelPrice * sizeRatio * qualityRatio * float64(imageRequest.N) + quota := int(imageRatio * groupRatio * common.QuotaPerUnit) if userQuota-quota < 0 { return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) @@ -180,7 +181,7 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { } logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality) - postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, modelPrice, true, logContent) + postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, imageRatio, true, logContent) return nil } diff --git a/web/src/components/LoginForm.js b/web/src/components/LoginForm.js index 5711f60..75d9f00 100644 --- a/web/src/components/LoginForm.js +++ b/web/src/components/LoginForm.js @@ -1,7 +1,14 @@ import React, { useContext, useEffect, useState } from 'react'; import { Link, useNavigate, useSearchParams } from 'react-router-dom'; import { UserContext } from '../context/User'; -import { API, getLogo, showError, showInfo, showSuccess } from '../helpers'; +import { + API, + getLogo, + showError, + showInfo, + showSuccess, + updateAPI, +} from '../helpers'; import { onGitHubOAuthClicked, onLinuxDoOAuthClicked } from './utils'; import Turnstile from 'react-turnstile'; import { @@ -102,6 +109,7 @@ const LoginForm = () => { if (success) { userDispatch({ type: 'login', payload: data }); setUserData(data); + updateAPI(); showSuccess('登录成功!'); if (username === 'root' && password === '123456') { Modal.error({ diff --git a/web/src/helpers/api.js b/web/src/helpers/api.js index 4c464db..19f902e 100644 --- a/web/src/helpers/api.js +++ b/web/src/helpers/api.js @@ -1,7 +1,7 @@ import { getUserIdFromLocalStorage, showError } from './utils'; import axios from 'axios'; -export const API = axios.create({ +export let API = axios.create({ baseURL: import.meta.env.VITE_REACT_APP_SERVER_URL ? import.meta.env.VITE_REACT_APP_SERVER_URL : '', @@ -10,6 +10,17 @@ export const API = axios.create({ }, }); +export function updateAPI() { + API = axios.create({ + baseURL: import.meta.env.VITE_REACT_APP_SERVER_URL + ? import.meta.env.VITE_REACT_APP_SERVER_URL + : '', + headers: { + 'New-API-User': getUserIdFromLocalStorage(), + }, + }); +} + API.interceptors.response.use( (response) => response, (error) => { diff --git a/web/src/pages/Token/index.js b/web/src/pages/Token/index.js index d85dacf..ca3c2a4 100644 --- a/web/src/pages/Token/index.js +++ b/web/src/pages/Token/index.js @@ -1,11 +1,14 @@ import React from 'react'; import TokensTable from '../../components/TokensTable'; -import { Layout } from '@douyinfe/semi-ui'; +import { Banner, Layout } from '@douyinfe/semi-ui'; const Token = () => ( <> -

我的令牌

+