diff --git a/common/image/image.go b/common/image/image.go index cbb656ad..a602936a 100644 --- a/common/image/image.go +++ b/common/image/image.go @@ -1,6 +1,8 @@ package image import ( + "bytes" + "encoding/base64" "image" _ "image/gif" _ "image/jpeg" @@ -8,11 +10,27 @@ import ( "net/http" "regexp" "strings" + "sync" _ "golang.org/x/image/webp" ) +func IsImageUrl(url string) (bool, error) { + resp, err := http.Head(url) + if err != nil { + return false, err + } + if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") { + return false, nil + } + return true, nil +} + func GetImageSizeFromUrl(url string) (width int, height int, err error) { + isImage, err := IsImageUrl(url) + if !isImage { + return + } resp, err := http.Get(url) if err != nil { return @@ -25,17 +43,51 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) { return img.Width, img.Height, nil } +func GetImageFromUrl(url string) (mimeType string, data string, err error) { + isImage, err := IsImageUrl(url) + if !isImage { + return + } + resp, err := http.Get(url) + if err != nil { + return + } + defer resp.Body.Close() + buffer := bytes.NewBuffer(nil) + _, err = buffer.ReadFrom(resp.Body) + if err != nil { + return + } + mimeType = resp.Header.Get("Content-Type") + data = base64.StdEncoding.EncodeToString(buffer.Bytes()) + return +} + var ( reg = regexp.MustCompile(`data:image/([^;]+);base64,`) ) +var readerPool = sync.Pool{ + New: func() interface{} { + return &bytes.Reader{} + }, +} + func GetImageSizeFromBase64(encoded string) (width int, height int, err error) { - encoded = strings.TrimPrefix(encoded, "data:image/png;base64,") - base64 := strings.NewReader(reg.ReplaceAllString(encoded, "")) - img, _, err := image.DecodeConfig(base64) + decoded, err := base64.StdEncoding.DecodeString(reg.ReplaceAllString(encoded, "")) if err != nil { - return + return 0, 0, err } + + reader := readerPool.Get().(*bytes.Reader) + defer readerPool.Put(reader) + reader.Reset(decoded) + + img, _, err := image.DecodeConfig(reader) + if err != nil { + return 0, 0, err + } + return img.Width, img.Height, nil } diff --git a/common/image/image_test.go b/common/image/image_test.go index 366eda6e..8e47b109 100644 --- a/common/image/image_test.go +++ b/common/image/image_test.go @@ -152,3 +152,20 @@ func TestGetImageSize(t *testing.T) { }) } } + +func TestGetImageSizeFromBase64(t *testing.T) { + for i, c := range cases { + t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + width, height, err := img.GetImageSizeFromBase64(encoded) + assert.NoError(t, err) + assert.Equal(t, c.width, width) + assert.Equal(t, c.height, height) + }) + } +} diff --git a/common/model-ratio.go b/common/model-ratio.go index d1c96d96..fa2adaa1 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -84,6 +84,7 @@ var ModelRatio = map[string]float64{ "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens @@ -115,6 +116,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error { } func GetModelRatio(name string) float64 { + if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { + name = strings.TrimSuffix(name, "-internet") + } ratio, ok := ModelRatio[name] if !ok { SysError("model ratio not found: " + name) diff --git a/controller/model.go b/controller/model.go index 9ae40f5c..6a759b63 100644 --- a/controller/model.go +++ b/controller/model.go @@ -432,6 +432,15 @@ func init() { Root: "gemini-pro", Parent: nil, }, + { + Id: "gemini-pro-vision", + Object: "model", + Created: 1677649963, + OwnedBy: "google", + Permission: permission, + Root: "gemini-pro-vision", + Parent: nil, + }, { Id: "chatglm_turbo", Object: "model", diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go index 523018de..ec55d4b6 100644 --- a/controller/relay-gemini.go +++ b/controller/relay-gemini.go @@ -7,11 +7,18 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/image" "strings" "github.com/gin-gonic/gin" ) +// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn + +const ( + GeminiVisionMaxImageNum = 16 +) + type GeminiChatRequest struct { Contents []GeminiChatContent `json:"contents"` SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` @@ -97,6 +104,30 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { }, }, } + openaiContent := message.ParseContent() + var parts []GeminiPart + imageNum := 0 + for _, part := range openaiContent { + if part.Type == ContentTypeText { + parts = append(parts, GeminiPart{ + Text: part.Text, + }) + } else if part.Type == ContentTypeImageURL { + imageNum += 1 + if imageNum > GeminiVisionMaxImageNum { + continue + } + mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url) + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: mimeType, + Data: data, + }, + }) + } + } + content.Parts = parts + // there's no assistant role in gemini and API shall vomit if Role is not user or model if content.Role == "assistant" { content.Role = "model" diff --git a/controller/relay-palm.go b/controller/relay-palm.go index 2bd0bcd8..0c1c8af6 100644 --- a/controller/relay-palm.go +++ b/controller/relay-palm.go @@ -187,6 +187,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st }, nil } fullTextResponse := responsePaLM2OpenAI(&palmResponse) + fullTextResponse.Model = model completionTokens := countTokenText(palmResponse.Candidates[0].Content, model) usage := Usage{ PromptTokens: promptTokens, diff --git a/controller/relay-text.go b/controller/relay-text.go index 29b55dae..38290dab 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -180,9 +180,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if baseURL != "" { fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) } - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - fullRequestURL += "?key=" + apiKey case APITypeGemini: requestBaseURL := "https://generativelanguage.googleapis.com" if baseURL != "" { @@ -200,21 +197,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") fullRequestURL += "?key=" + apiKey - // case APITypeZhipu: - // method := "invoke" - // if textRequest.Stream { - // method = "sse-invoke" - // } - // fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) - // case APITypeAli: - // fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" - // if relayMode == RelayModeEmbeddings { - // fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" - // } - // case APITypeTencent: - // fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions" - // case APITypeAIProxyLibrary: - // fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) + case APITypeZhipu: + method := "invoke" + if textRequest.Stream { + method = "sse-invoke" + } + fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) + case APITypeAli: + fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" + if relayMode == RelayModeEmbeddings { + fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" + } + case APITypeTencent: + fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions" + case APITypeAIProxyLibrary: + fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) } var promptTokens int var completionTokens int @@ -410,9 +407,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { // case APITypeTencent: // req.Header.Set("Authorization", apiKey) case APITypePaLM: - // do not set Authorization header + req.Header.Set("x-goog-api-key", apiKey) case APITypeGemini: - // do not set Authorization header + req.Header.Set("x-goog-api-key", apiKey) default: req.Header.Set("Authorization", "Bearer "+apiKey) } diff --git a/controller/relay.go b/controller/relay.go index 72d02237..0ba02edc 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -69,6 +69,22 @@ type ImageContent struct { ImageURL *ImageURL `json:"image_url,omitempty"` } +const ( + ContentTypeText = "text" + ContentTypeImageURL = "image_url" +) + +type OpenAIMessageContent struct { + Type string `json:"type,omitempty"` + Text string `json:"text"` + ImageURL *ImageURL `json:"image_url,omitempty"` +} + +func (m Message) IsStringContent() bool { + _, ok := m.Content.(string) + return ok +} + func (m Message) StringContent() string { content, ok := m.Content.(string) if ok { @@ -82,7 +98,7 @@ func (m Message) StringContent() string { if !ok { continue } - if contentMap["type"] == "text" { + if contentMap["type"] == ContentTypeText { if subStr, ok := contentMap["text"].(string); ok { contentStr += subStr } @@ -93,6 +109,47 @@ func (m Message) StringContent() string { return "" } +func (m Message) ParseContent() []OpenAIMessageContent { + var contentList []OpenAIMessageContent + content, ok := m.Content.(string) + if ok { + contentList = append(contentList, OpenAIMessageContent{ + Type: ContentTypeText, + Text: content, + }) + return contentList + } + anyList, ok := m.Content.([]any) + if ok { + for _, contentItem := range anyList { + contentMap, ok := contentItem.(map[string]any) + if !ok { + continue + } + switch contentMap["type"] { + case ContentTypeText: + if subStr, ok := contentMap["text"].(string); ok { + contentList = append(contentList, OpenAIMessageContent{ + Type: ContentTypeText, + Text: subStr, + }) + } + case ContentTypeImageURL: + if subObj, ok := contentMap["image_url"].(map[string]any); ok { + contentList = append(contentList, OpenAIMessageContent{ + Type: ContentTypeImageURL, + ImageURL: &ImageURL{ + Url: subObj["url"].(string), + }, + }) + } + } + } + return contentList + } + return nil +} + const ( RelayModeUnknown = iota RelayModeChatCompletions @@ -281,7 +338,7 @@ type OpenAITextResponseChoice struct { type OpenAITextResponse struct { Id string `json:"id"` - Model string `json:"model"` + Model string `json:"model,omitempty"` Object string `json:"object"` Created int64 `json:"created"` Choices []OpenAITextResponseChoice `json:"choices"` diff --git a/go.mod b/go.mod index cf225f0d..f96b5477 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/gin-gonic/gin v1.9.1 github.com/go-playground/validator/v10 v10.16.0 github.com/go-redis/redis/v8 v8.11.5 - github.com/google/uuid v1.4.0 + github.com/google/uuid v1.5.0 github.com/pkoukk/tiktoken-go v0.1.6 github.com/stretchr/testify v1.8.4 golang.org/x/crypto v0.17.0 diff --git a/go.sum b/go.sum index 821173c9..2474132d 100644 --- a/go.sum +++ b/go.sum @@ -65,8 +65,8 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= -github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU= +github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= diff --git a/middleware/recover.go b/middleware/recover.go index c3a3d748..8338a514 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -5,6 +5,7 @@ import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" + "runtime/debug" ) func RelayPanicRecover() gin.HandlerFunc { @@ -12,6 +13,7 @@ func RelayPanicRecover() gin.HandlerFunc { defer func() { if err := recover(); err != nil { common.SysError(fmt.Sprintf("panic detected: %v", err)) + common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), diff --git a/model/user.go b/model/user.go index a006fa68..f5778e6e 100644 --- a/model/user.go +++ b/model/user.go @@ -42,7 +42,11 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) { } func SearchUsers(keyword string) (users []*User, err error) { - err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error + if !common.UsingPostgreSQL { + err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error + } else { + err = DB.Omit("password").Where("username LIKE ? or email LIKE ? or display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").Find(&users).Error + } return users, err } diff --git a/pull_request_template.md b/pull_request_template.md index bbcd969c..a313004f 100644 --- a/pull_request_template.md +++ b/pull_request_template.md @@ -1,3 +1,9 @@ +[//]: # (请按照以下格式关联 issue) +[//]: # (请在提交 PR 前确认所提交的功能可用,附上截图即可,这将有助于项目维护者 review & merge 该 PR,谢谢) +[//]: # (项目维护者一般仅在周末处理 PR,因此如若未能及时回复希望能理解) +[//]: # (开发者交流群:910657413) +[//]: # (请在提交 PR 之前删除上面的注释) + close #issue_number 我已确认该 PR 已自测通过,相关截图如下: \ No newline at end of file diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 364da69d..0d4e114d 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -70,6 +70,13 @@ const EditChannel = () => { break; case 17: localModels = ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1']; + let withInternetVersion = []; + for (let i = 0; i < localModels.length; i++) { + if (localModels[i].startsWith('qwen-')) { + withInternetVersion.push(localModels[i] + '-internet'); + } + } + localModels = [...localModels, ...withInternetVersion]; break; case 16: localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']; @@ -84,7 +91,7 @@ const EditChannel = () => { localModels = ['hunyuan']; break; case 24: - localModels = ['gemini-pro']; + localModels = ['gemini-pro', 'gemini-pro-vision']; break; } setInputs((inputs) => ({ ...inputs, models: localModels }));