Compare commits

..

18 Commits

Author SHA1 Message Date
JustSong
1c8922153d feat: support gemini-vision-pro 2023-12-24 18:54:32 +08:00
Laisky.Cai
f3c07e1451 fix: openai response should contains model (#841)
* fix: openai response should contains `model`

- Update model attributes in `claudeHandler` for `relay-claude.go`
- Implement model type for fullTextResponse in `relay-gemini.go`
- Add new `Model` field to `OpenAITextResponse` struct in `relay.go`

* chore: set model name response for models

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-24 16:58:31 +08:00
Bryan
40ceb29e54 fix: fix SearchUsers not working if using PostgreSQL (#778)
* fix SearchUsers

* refactor: using UsingPostgreSQL as condition

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-24 16:42:00 +08:00
dependabot[bot]
0699ecd0af chore(deps): bump golang.org/x/crypto from 0.14.0 to 0.17.0 (#840)
Bumps [golang.org/x/crypto](https://github.com/golang/crypto) from 0.14.0 to 0.17.0.
- [Commits](https://github.com/golang/crypto/compare/v0.14.0...v0.17.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-12-24 16:29:48 +08:00
moondie
ee9e746520 feat: update ali stream implementation & enable internet search (#856)
* Update relay-ali.go: 改进stream模式,添加联网搜索能力

通义千问支持stream的增量模式,不需要每次去掉上次的前缀;实测qwen-max联网模式效果不错,添加了联网模式。如果别的模型有问题可以改为单独给qwen-max开放

* 删除"stream参数"

刚发现原来阿里api没有这个参数,上次误加了。

* refactor: only enable search when specified

* fix: remove custom suffix when get model ratio

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-24 16:17:21 +08:00
Buer
a763681c2e fix: fix base64 image parse error (#858) 2023-12-24 15:35:56 +08:00
JustSong
b7fcb319da chore: check if SESSION_SECRET equals to random_string 2023-12-20 22:50:50 +08:00
JustSong
67c64e71c8 fix: fix max_tokens check 2023-12-20 21:45:33 +08:00
JustSong
97030e27f8 fix: fix gemini panic (close #833) 2023-12-17 23:30:45 +08:00
JustSong
461f5dab56 docs: update readme 2023-12-17 22:25:03 +08:00
JustSong
af378c59af docs: update readme 2023-12-17 22:19:16 +08:00
ShinChven ✨
bc6769826b feat: add condition to validate n value for non-Azure channels (#775)
- Add a condition to validate the n value only for non-Azure channels, ensuring it falls within the acceptable range.
- Fix Azure compatibility
2023-12-17 19:49:08 +08:00
Oliver Lee
0fe26cc4bd feat: update ali relay implementation (#830)
* 修改通译千问最新接口:1.删除history参数,改用官方推荐的messages参数 2.整理messages参数顺序,补充必要上下文信息 3.用autogen调试测试通过

* chore: update impl

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-17 19:43:23 +08:00
Calcium-Ion
7d6a169669 feat: able to set sqlite busy_timeout (#818)
* add sqlite busy_timeout=3000

* chore: update impl

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-17 19:17:00 +08:00
Ghostz
66f06e5d6f feat: reset image num to 1 when not given (#821)
* Update relay-image.go

* fix: reset image num to 1 when not given

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-17 18:54:08 +08:00
JustSong
6acb9537a9 fix: try to return a more meaningful error message (close #817) 2023-12-17 18:33:27 +08:00
JustSong
7069c49bdf fix: fix xunfei panic error (close #820) 2023-12-17 18:06:37 +08:00
JustSong
58dee76bf7 fix: fix Gemini stream problem 2023-12-17 16:16:18 +08:00
25 changed files with 411 additions and 125 deletions

View File

@@ -73,13 +73,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
+ [x] [360 智脑](https://ai.360.cn)
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
2. 支持配置镜像以及众多第三方代理服务
+ [x] [OpenAI-SB](https://openai-sb.com)
+ [x] [CloseAI](https://referer.shadowai.xyz/r/2412)
+ [x] [API2D](https://api2d.com/r/197971)
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`
+ [x] 自定义渠道:例如各种未收录的第三方代理服务
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
5. 支持**多机部署**[详见此处](#多机部署)。
@@ -371,6 +365,7 @@ graph LR
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
16. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。

View File

@@ -4,3 +4,4 @@ var UsingSQLite = false
var UsingPostgreSQL = false
var SQLitePath = "one-api.db"
var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000)

View File

@@ -1,6 +1,8 @@
package image
import (
"bytes"
"encoding/base64"
"image"
_ "image/gif"
_ "image/jpeg"
@@ -8,11 +10,27 @@ import (
"net/http"
"regexp"
"strings"
"sync"
_ "golang.org/x/image/webp"
)
func IsImageUrl(url string) (bool, error) {
resp, err := http.Head(url)
if err != nil {
return false, err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return false, nil
}
return true, nil
}
func GetImageSizeFromUrl(url string) (width int, height int, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
if err != nil {
return
@@ -25,17 +43,51 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) {
return img.Width, img.Height, nil
}
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
buffer := bytes.NewBuffer(nil)
_, err = buffer.ReadFrom(resp.Body)
if err != nil {
return
}
mimeType = resp.Header.Get("Content-Type")
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
return
}
var (
reg = regexp.MustCompile(`data:image/([^;]+);base64,`)
)
var readerPool = sync.Pool{
New: func() interface{} {
return &bytes.Reader{}
},
}
func GetImageSizeFromBase64(encoded string) (width int, height int, err error) {
encoded = strings.TrimPrefix(encoded, "data:image/png;base64,")
base64 := strings.NewReader(reg.ReplaceAllString(encoded, ""))
img, _, err := image.DecodeConfig(base64)
decoded, err := base64.StdEncoding.DecodeString(reg.ReplaceAllString(encoded, ""))
if err != nil {
return
return 0, 0, err
}
reader := readerPool.Get().(*bytes.Reader)
defer readerPool.Put(reader)
reader.Reset(decoded)
img, _, err := image.DecodeConfig(reader)
if err != nil {
return 0, 0, err
}
return img.Width, img.Height, nil
}

View File

@@ -152,3 +152,20 @@ func TestGetImageSize(t *testing.T) {
})
}
}
func TestGetImageSizeFromBase64(t *testing.T) {
for i, c := range cases {
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
width, height, err := img.GetImageSizeFromBase64(encoded)
assert.NoError(t, err)
assert.Equal(t, c.width, width)
assert.Equal(t, c.height, height)
})
}
}

View File

@@ -36,7 +36,11 @@ func init() {
}
if os.Getenv("SESSION_SECRET") != "" {
SessionSecret = os.Getenv("SESSION_SECRET")
if os.Getenv("SESSION_SECRET") == "random_string" {
SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
} else {
SessionSecret = os.Getenv("SESSION_SECRET")
}
}
if os.Getenv("SQLITE_PATH") != "" {
SQLitePath = os.Getenv("SQLITE_PATH")

View File

@@ -84,12 +84,15 @@ var ModelRatio = map[string]float64{
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens
"qwen-max": 1.4286, // ¥0.02 / 1k tokens
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
@@ -113,6 +116,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
}
func GetModelRatio(name string) float64 {
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet")
}
ratio, ok := ModelRatio[name]
if !ok {
SysError("model ratio not found: " + name)

View File

@@ -432,6 +432,15 @@ func init() {
Root: "gemini-pro",
Parent: nil,
},
{
Id: "gemini-pro-vision",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "gemini-pro-vision",
Parent: nil,
},
{
Id: "chatglm_turbo",
Object: "model",
@@ -486,6 +495,24 @@ func init() {
Root: "qwen-plus",
Parent: nil,
},
{
Id: "qwen-max",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-max",
Parent: nil,
},
{
Id: "qwen-max-longcontext",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-max-longcontext",
Parent: nil,
},
{
Id: "text-embedding-v1",
Object: "model",

View File

@@ -13,20 +13,21 @@ import (
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
type AliMessage struct {
User string `json:"user"`
Bot string `json:"bot"`
Content string `json:"content"`
Role string `json:"role"`
}
type AliInput struct {
Prompt string `json:"prompt"`
History []AliMessage `json:"history"`
//Prompt string `json:"prompt"`
Messages []AliMessage `json:"messages"`
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
}
type AliChatRequest struct {
@@ -81,41 +82,32 @@ type AliChatResponse struct {
AliError
}
const AliEnableSearchModelSuffix = "-internet"
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
prompt := ""
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, AliMessage{
User: message.StringContent(),
Bot: "Okay",
})
continue
} else {
if i == len(request.Messages)-1 {
prompt = message.StringContent()
break
}
messages = append(messages, AliMessage{
User: message.StringContent(),
Bot: request.Messages[i+1].StringContent(),
})
i++
}
messages = append(messages, AliMessage{
Content: message.StringContent(),
Role: strings.ToLower(message.Role),
})
}
enableSearch := false
aliModel := request.Model
if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix)
}
return &AliChatRequest{
Model: request.Model,
Model: aliModel,
Input: AliInput{
Prompt: prompt,
History: messages,
Messages: messages,
},
Parameters: AliParameters{
EnableSearch: enableSearch,
IncrementalOutput: request.Stream,
},
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
// TopP: request.TopP,
// TopK: 50,
// //Seed: 0,
// //EnableSearch: false,
//},
}
}
@@ -217,7 +209,7 @@ func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStre
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "ernie-bot",
Model: "qwen",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response
@@ -255,7 +247,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
stopChan <- true
}()
setEventStreamHeaders(c)
lastResponseText := ""
//lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@@ -271,8 +263,8 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
lastResponseText = aliResponse.Output.Text
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
//lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
@@ -318,6 +310,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode
}, nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
fullTextResponse.Model = "qwen"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -255,6 +255,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
}, nil
}
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
fullTextResponse.Model = "ernie-bot"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -204,6 +204,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
}, nil
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
fullTextResponse.Model = model
completionTokens := countTokenText(claudeResponse.Completion, model)
usage := Usage{
PromptTokens: promptTokens,

View File

@@ -1,15 +1,24 @@
package controller
import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/common/image"
"strings"
"github.com/gin-gonic/gin"
)
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
const (
GeminiVisionMaxImageNum = 16
)
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
@@ -95,6 +104,30 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
},
},
}
openaiContent := message.ParseContent()
var parts []GeminiPart
imageNum := 0
for _, part := range openaiContent {
if part.Type == ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
}
}
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"
@@ -112,7 +145,7 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
Role: "model",
Parts: []GeminiPart{
{
Text: "ok",
Text: "Okay",
},
},
})
@@ -128,6 +161,16 @@ type GeminiChatResponse struct {
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
}
func (g *GeminiChatResponse) GetResponseText() string {
if g == nil {
return ""
}
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
return g.Candidates[0].Content.Parts[0].Text
}
return ""
}
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason string `json:"finishReason"`
@@ -156,10 +199,13 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse
Index: i,
Message: Message{
Role: "assistant",
Content: candidate.Content.Parts[0].Text,
Content: "",
},
FinishReason: stopFinishReason,
}
if len(candidate.Content.Parts) > 0 {
choice.Message.Content = candidate.Content.Parts[0].Text
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
@@ -167,9 +213,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
choice.Delta.Content = geminiResponse.Candidates[0].Content.Parts[0].Text
}
choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &stopFinishReason
var response ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
@@ -180,50 +224,61 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCo
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
dataChan := make(chan string)
stopChan := make(chan bool)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
go func() {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.SysError("error reading stream response: " + err.Error())
stopChan <- true
return
for scanner.Scan() {
data := scanner.Text()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "\"text\": \"") {
continue
}
data = strings.TrimPrefix(data, "\"text\": \"")
data = strings.TrimSuffix(data, "\"")
dataChan <- data
}
err = resp.Body.Close()
if err != nil {
common.SysError("error closing stream response: " + err.Error())
stopChan <- true
return
}
var geminiResponse GeminiChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
stopChan <- true
return
}
fullTextResponse := streamResponseGeminiChat2OpenAI(&geminiResponse)
fullTextResponse.Id = responseId
fullTextResponse.Created = createdTime
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
responseText += geminiResponse.Candidates[0].Content.Parts[0].Text
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
stopChan <- true
return
}
dataChan <- string(jsonResponse)
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
c.Render(-1, common.CustomEvent{Data: "data: " + data})
// this is used to prevent annoying \ related format bug
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
type dummyStruct struct {
Content string `json:"content"`
}
var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content
response := ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "gemini-pro",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
@@ -263,7 +318,8 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
completionTokens := countTokenText(geminiResponse.Candidates[0].Content.Parts[0].Text, model)
fullTextResponse.Model = model
completionTokens := countTokenText(geminiResponse.GetResponseText(), model)
usage := Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,

View File

@@ -19,7 +19,6 @@ func isWithinRange(element string, value int) bool {
if _, ok := common.DalleGenerationImageAmounts[element]; !ok {
return false
}
min := common.DalleGenerationImageAmounts[element][0]
max := common.DalleGenerationImageAmounts[element][1]
@@ -42,6 +41,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if imageRequest.N == 0 {
imageRequest.N = 1
}
// Size validation
if imageRequest.Size != "" {
imageSize = imageRequest.Size
@@ -79,7 +82,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
// Number of generated images validation
if isWithinRange(imageModel, imageRequest.N) == false {
return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
// channel not azure
if channelType != common.ChannelTypeAzure {
return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
}
// map model name
@@ -102,7 +108,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
baseURL = c.GetString("base_url")
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations {
if channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
apiVersion := GetAPIVersion(c)
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview

View File

@@ -187,6 +187,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
fullTextResponse.Model = model
completionTokens := countTokenText(palmResponse.Candidates[0].Content, model)
usage := Usage{
PromptTokens: promptTokens,

View File

@@ -237,6 +237,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatus
}, nil
}
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
fullTextResponse.Model = "hunyuan"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -58,6 +58,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
return errorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest)
}
if relayMode == RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
@@ -177,9 +180,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeGemini:
requestBaseURL := "https://generativelanguage.googleapis.com"
if baseURL != "" {
@@ -190,14 +190,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
version = c.GetString("api_version")
}
action := "generateContent"
// actually gemini does not support stream, it's fake
//if textRequest.Stream {
// action = "streamGenerateContent"
//}
if textRequest.Stream {
action = "streamGenerateContent"
}
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeZhipu:
method := "invoke"
if textRequest.Stream {
@@ -394,9 +390,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
case APITypeTencent:
req.Header.Set("Authorization", apiKey)
case APITypePaLM:
// do not set Authorization header
req.Header.Set("x-goog-api-key", apiKey)
case APITypeGemini:
// do not set Authorization header
req.Header.Set("x-goog-api-key", apiKey)
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
}

View File

@@ -263,11 +263,52 @@ func setEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("X-Accel-Buffering", "no")
}
type GeneralErrorResponse struct {
Error OpenAIError `json:"error"`
Message string `json:"message"`
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
Header struct {
Message string `json:"message"`
} `json:"header"`
Response struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
} `json:"response"`
}
func (e GeneralErrorResponse) ToMessage() string {
if e.Error.Message != "" {
return e.Error.Message
}
if e.Message != "" {
return e.Message
}
if e.Msg != "" {
return e.Msg
}
if e.Err != "" {
return e.Err
}
if e.ErrorMsg != "" {
return e.ErrorMsg
}
if e.Header.Message != "" {
return e.Header.Message
}
if e.Response.Error.Message != "" {
return e.Response.Error.Message
}
return ""
}
func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) {
openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
OpenAIError: OpenAIError{
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
@@ -281,12 +322,20 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
if err != nil {
return
}
var textResponse TextResponse
err = json.Unmarshal(responseBody, &textResponse)
var errResponse GeneralErrorResponse
err = json.Unmarshal(responseBody, &errResponse)
if err != nil {
return
}
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
if errResponse.Error.Message != "" {
// OpenAI format error, so we override the default one
openAIErrorWithStatusCode.OpenAIError = errResponse.Error
} else {
openAIErrorWithStatusCode.OpenAIError.Message = errResponse.ToMessage()
}
if openAIErrorWithStatusCode.OpenAIError.Message == "" {
openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
}
return
}

View File

@@ -230,7 +230,13 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin
case stop = <-stopChan:
}
}
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
{
Content: "",
},
}
}
xunfeiResponse.Payload.Choices.Text[0].Content = content
response := responseXunfei2OpenAI(&xunfeiResponse)

View File

@@ -290,6 +290,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
}, nil
}
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
fullTextResponse.Model = "chatglm"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -31,6 +31,22 @@ type ImageContent struct {
ImageURL *ImageURL `json:"image_url,omitempty"`
}
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
)
type OpenAIMessageContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}
func (m Message) IsStringContent() bool {
_, ok := m.Content.(string)
return ok
}
func (m Message) StringContent() string {
content, ok := m.Content.(string)
if ok {
@@ -44,7 +60,7 @@ func (m Message) StringContent() string {
if !ok {
continue
}
if contentMap["type"] == "text" {
if contentMap["type"] == ContentTypeText {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
@@ -55,6 +71,47 @@ func (m Message) StringContent() string {
return ""
}
func (m Message) ParseContent() []OpenAIMessageContent {
var contentList []OpenAIMessageContent
content, ok := m.Content.(string)
if ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeText,
Text: content,
})
return contentList
}
anyList, ok := m.Content.([]any)
if ok {
for _, contentItem := range anyList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
switch contentMap["type"] {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeText,
Text: subStr,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeImageURL,
ImageURL: &ImageURL{
Url: subObj["url"].(string),
},
})
}
}
}
return contentList
}
return nil
}
const (
RelayModeUnknown = iota
RelayModeChatCompletions
@@ -206,6 +263,7 @@ type OpenAITextResponseChoice struct {
type OpenAITextResponse struct {
Id string `json:"id"`
Model string `json:"model,omitempty"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
@@ -236,7 +294,7 @@ type ChatCompletionsStreamResponseChoice struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type ChatCompletionsStreamResponse struct {

4
go.mod
View File

@@ -16,7 +16,7 @@ require (
github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.5
github.com/stretchr/testify v1.8.3
golang.org/x/crypto v0.14.0
golang.org/x/crypto v0.17.0
golang.org/x/image v0.14.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
@@ -58,7 +58,7 @@ require (
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

8
go.sum
View File

@@ -150,8 +150,8 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
@@ -164,8 +164,8 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=

View File

@@ -5,6 +5,7 @@ import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"runtime/debug"
)
func RelayPanicRecover() gin.HandlerFunc {
@@ -12,6 +13,7 @@ func RelayPanicRecover() gin.HandlerFunc {
defer func() {
if err := recover(); err != nil {
common.SysError(fmt.Sprintf("panic detected: %v", err))
common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err),

View File

@@ -1,6 +1,7 @@
package model
import (
"fmt"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
@@ -59,7 +60,8 @@ func chooseDB() (*gorm.DB, error) {
// Use SQLite
common.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout)
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}

View File

@@ -42,7 +42,11 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) {
}
func SearchUsers(keyword string) (users []*User, err error) {
err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
if !common.UsingPostgreSQL {
err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
} else {
err = DB.Omit("password").Where("username LIKE ? or email LIKE ? or display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
}
return users, err
}

View File

@@ -69,7 +69,14 @@ const EditChannel = () => {
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1'];
break;
case 17:
localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1'];
localModels = ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1'];
let withInternetVersion = [];
for (let i = 0; i < localModels.length; i++) {
if (localModels[i].startsWith('qwen-')) {
withInternetVersion.push(localModels[i] + '-internet');
}
}
localModels = [...localModels, ...withInternetVersion];
break;
case 16:
localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite'];
@@ -84,7 +91,7 @@ const EditChannel = () => {
localModels = ['hunyuan'];
break;
case 24:
localModels = ['gemini-pro'];
localModels = ['gemini-pro', 'gemini-pro-vision'];
break;
}
setInputs((inputs) => ({ ...inputs, models: localModels }));