feat: support vision

- Added new dependencies: `github.com/fsnotify/fsnotify v1.4.9`, `github.com/go-playground/assert/v2 v2.2.0`, `github.com/nxadm/tail v1.4.8`, `github.com/onsi/ginkgo v1.16.5`, `github.com/onsi/gomega v1.18.1`, `golang.org/x/net v0.10.0`, `gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7`
- Updated dependencies: `github.com/gin-gonic/gin` from v1.9.0 to v2.0.0, `golang.org/x/net` from v0.17.0 to v0.10.0
- Removed dependencies: `github.com/golang-jwt/jwt v3.2.2+incompatible`, `github.com/gorilla/websocket v1.5.1`
- Updated Go version from `1.18` to `1.21`
- Made various modifications and refactoring in the code:
  - Added new struct `VisionMessage` with fields `Role`, `Content`, and `Name`
  - Added constants for certain types
  - Added methods and error handling to handle different message types
  - Modified existing struct and methods to accommodate changes
  - Removed unused imports
This commit is contained in:
Laisky.Cai
2023-11-17 01:59:11 +00:00
parent b2e46a33ac
commit b58bd7e3ab
13 changed files with 2051 additions and 1916 deletions

View File

@@ -76,7 +76,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeChatCompletions:
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
if textRequest.Messages == nil || textRequest.MessagesLen() == 0 {
return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeEmbeddings:
@@ -154,26 +154,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
}
case APITypeBaidu:
switch textRequest.Model {
case "ERNIE-Bot":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "ERNIE-Bot-4":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
var err error
if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
}
fullRequestURL += "?access_token=" + apiKey
// case APITypeBaidu:
// switch textRequest.Model {
// case "ERNIE-Bot":
// fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
// case "ERNIE-Bot-turbo":
// fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
// case "ERNIE-Bot-4":
// fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
// case "BLOOMZ-7B":
// fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
// case "Embedding-V1":
// fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
// }
// apiKey := c.Request.Header.Get("Authorization")
// apiKey = strings.TrimPrefix(apiKey, "Bearer ")
// var err error
// if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
// return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
// }
// fullRequestURL += "?access_token=" + apiKey
case APITypePaLM:
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
if baseURL != "" {
@@ -202,7 +202,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
var completionTokens int
switch relayMode {
case RelayModeChatCompletions:
promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
messages, err := textRequest.TextMessages()
if err != nil {
return errorWrapper(err, "parse_text_messages_failed", http.StatusBadRequest)
}
promptTokens = countTokenMessages(messages, textRequest.Model)
case RelayModeCompletions:
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
case RelayModeModerations:
@@ -257,67 +262,67 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeBaidu:
var jsonData []byte
var err error
switch relayMode {
case RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
jsonData, err = json.Marshal(baiduEmbeddingRequest)
default:
baiduRequest := requestOpenAI2Baidu(textRequest)
jsonData, err = json.Marshal(baiduRequest)
}
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
case APITypePaLM:
palmRequest := requestOpenAI2PaLM(textRequest)
jsonStr, err := json.Marshal(palmRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeZhipu:
zhipuRequest := requestOpenAI2Zhipu(textRequest)
jsonStr, err := json.Marshal(zhipuRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeAli:
var jsonStr []byte
var err error
switch relayMode {
case RelayModeEmbeddings:
aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
jsonStr, err = json.Marshal(aliEmbeddingRequest)
default:
aliRequest := requestOpenAI2Ali(textRequest)
jsonStr, err = json.Marshal(aliRequest)
}
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeTencent:
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
if err != nil {
return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
}
tencentRequest := requestOpenAI2Tencent(textRequest)
tencentRequest.AppId = appId
tencentRequest.SecretId = secretId
jsonStr, err := json.Marshal(tencentRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
sign := getTencentSign(*tencentRequest, secretKey)
c.Request.Header.Set("Authorization", sign)
requestBody = bytes.NewBuffer(jsonStr)
// case APITypeBaidu:
// var jsonData []byte
// var err error
// switch relayMode {
// case RelayModeEmbeddings:
// baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
// jsonData, err = json.Marshal(baiduEmbeddingRequest)
// default:
// baiduRequest := requestOpenAI2Baidu(textRequest)
// jsonData, err = json.Marshal(baiduRequest)
// }
// if err != nil {
// return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
// }
// requestBody = bytes.NewBuffer(jsonData)
// case APITypePaLM:
// palmRequest := requestOpenAI2PaLM(textRequest)
// jsonStr, err := json.Marshal(palmRequest)
// if err != nil {
// return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
// }
// requestBody = bytes.NewBuffer(jsonStr)
// case APITypeZhipu:
// zhipuRequest := requestOpenAI2Zhipu(textRequest)
// jsonStr, err := json.Marshal(zhipuRequest)
// if err != nil {
// return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
// }
// requestBody = bytes.NewBuffer(jsonStr)
// case APITypeAli:
// var jsonStr []byte
// var err error
// switch relayMode {
// case RelayModeEmbeddings:
// aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
// jsonStr, err = json.Marshal(aliEmbeddingRequest)
// default:
// aliRequest := requestOpenAI2Ali(textRequest)
// jsonStr, err = json.Marshal(aliRequest)
// }
// if err != nil {
// return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
// }
// requestBody = bytes.NewBuffer(jsonStr)
// case APITypeTencent:
// apiKey := c.Request.Header.Get("Authorization")
// apiKey = strings.TrimPrefix(apiKey, "Bearer ")
// appId, secretId, secretKey, err := parseTencentConfig(apiKey)
// if err != nil {
// return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
// }
// tencentRequest := requestOpenAI2Tencent(textRequest)
// tencentRequest.AppId = appId
// tencentRequest.SecretId = secretId
// jsonStr, err := json.Marshal(tencentRequest)
// if err != nil {
// return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
// }
// sign := getTencentSign(*tencentRequest, secretKey)
// c.Request.Header.Set("Authorization", sign)
// requestBody = bytes.NewBuffer(jsonStr)
case APITypeAIProxyLibrary:
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
@@ -357,16 +362,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
case APITypeZhipu:
token := getZhipuToken(apiKey)
req.Header.Set("Authorization", token)
case APITypeAli:
req.Header.Set("Authorization", "Bearer "+apiKey)
if textRequest.Stream {
req.Header.Set("X-DashScope-SSE", "enable")
}
case APITypeTencent:
req.Header.Set("Authorization", apiKey)
// case APITypeZhipu:
// token := getZhipuToken(apiKey)
// req.Header.Set("Authorization", token)
// case APITypeAli:
// req.Header.Set("Authorization", "Bearer "+apiKey)
// if textRequest.Stream {
// req.Header.Set("X-DashScope-SSE", "enable")
// }
// case APITypeTencent:
// req.Header.Set("Authorization", apiKey)
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
}
@@ -482,124 +487,124 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
return nil
}
case APITypeBaidu:
if isStream {
err, usage := baiduStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = baiduEmbeddingHandler(c, resp)
default:
err, usage = baiduHandler(c, resp)
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypePaLM:
if textRequest.Stream { // PaLM2 API does not support stream
err, responseText := palmStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeZhipu:
if isStream {
err, usage := zhipuStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
// zhipu's API does not return prompt tokens & completion tokens
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
} else {
err, usage := zhipuHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
// zhipu's API does not return prompt tokens & completion tokens
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
}
case APITypeAli:
if isStream {
err, usage := aliStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
default:
err, usage = aliHandler(c, resp)
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeXunfei:
auth := c.Request.Header.Get("Authorization")
auth = strings.TrimPrefix(auth, "Bearer ")
splits := strings.Split(auth, "|")
if len(splits) != 3 {
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
var err *OpenAIErrorWithStatusCode
var usage *Usage
if isStream {
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
} else {
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
// case APITypeBaidu:
// if isStream {
// err, usage := baiduStreamHandler(c, resp)
// if err != nil {
// return err
// }
// if usage != nil {
// textResponse.Usage = *usage
// }
// return nil
// } else {
// var err *OpenAIErrorWithStatusCode
// var usage *Usage
// switch relayMode {
// case RelayModeEmbeddings:
// err, usage = baiduEmbeddingHandler(c, resp)
// default:
// err, usage = baiduHandler(c, resp)
// }
// if err != nil {
// return err
// }
// if usage != nil {
// textResponse.Usage = *usage
// }
// return nil
// }
// case APITypePaLM:
// if textRequest.Stream { // PaLM2 API does not support stream
// err, responseText := palmStreamHandler(c, resp)
// if err != nil {
// return err
// }
// textResponse.Usage.PromptTokens = promptTokens
// textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
// return nil
// } else {
// err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
// if err != nil {
// return err
// }
// if usage != nil {
// textResponse.Usage = *usage
// }
// return nil
// }
// case APITypeZhipu:
// if isStream {
// err, usage := zhipuStreamHandler(c, resp)
// if err != nil {
// return err
// }
// if usage != nil {
// textResponse.Usage = *usage
// }
// // zhipu's API does not return prompt tokens & completion tokens
// textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
// return nil
// } else {
// err, usage := zhipuHandler(c, resp)
// if err != nil {
// return err
// }
// if usage != nil {
// textResponse.Usage = *usage
// }
// // zhipu's API does not return prompt tokens & completion tokens
// textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
// return nil
// }
// case APITypeAli:
// if isStream {
// err, usage := aliStreamHandler(c, resp)
// if err != nil {
// return err
// }
// if usage != nil {
// textResponse.Usage = *usage
// }
// return nil
// } else {
// var err *OpenAIErrorWithStatusCode
// var usage *Usage
// switch relayMode {
// case RelayModeEmbeddings:
// err, usage = aliEmbeddingHandler(c, resp)
// default:
// err, usage = aliHandler(c, resp)
// }
// if err != nil {
// return err
// }
// if usage != nil {
// textResponse.Usage = *usage
// }
// return nil
// }
// case APITypeXunfei:
// auth := c.Request.Header.Get("Authorization")
// auth = strings.TrimPrefix(auth, "Bearer ")
// splits := strings.Split(auth, "|")
// if len(splits) != 3 {
// return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
// }
// var err *OpenAIErrorWithStatusCode
// var usage *Usage
// if isStream {
// err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
// } else {
// err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
// }
// if err != nil {
// return err
// }
// if usage != nil {
// textResponse.Usage = *usage
// }
// return nil
case APITypeAIProxyLibrary:
if isStream {
err, usage := aiProxyLibraryStreamHandler(c, resp)
@@ -620,25 +625,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
return nil
}
case APITypeTencent:
if isStream {
err, responseText := tencentStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := tencentHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
// case APITypeTencent:
// if isStream {
// err, responseText := tencentStreamHandler(c, resp)
// if err != nil {
// return err
// }
// textResponse.Usage.PromptTokens = promptTokens
// textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
// return nil
// } else {
// err, usage := tencentHandler(c, resp)
// if err != nil {
// return err
// }
// if usage != nil {
// textResponse.Usage = *usage
// }
// return nil
// }
default:
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
}