diff --git a/common/constants.go b/common/constants.go index 10bfca9..d72dc2c 100644 --- a/common/constants.go +++ b/common/constants.go @@ -233,6 +233,7 @@ const ( ChannelTypeMiniMax = 35 ChannelTypeScholarAI = 36 ChannelTypeDoubao = 37 + ChannelTypeVertexClaude = 38 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -276,4 +277,5 @@ var ChannelBaseURLs = []string{ "https://api.minimax.chat", //35 "https://api.scholarai.io", //36 "https://ark.cn-beijing.volces.com", //37 + "", //38 } diff --git a/go.mod b/go.mod index 2061262..3f7ece4 100644 --- a/go.mod +++ b/go.mod @@ -33,6 +33,7 @@ require ( ) require ( + cloud.google.com/go/compute/metadata v0.3.0 // indirect github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect @@ -75,6 +76,7 @@ require ( golang.org/x/arch v0.3.0 // indirect golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect golang.org/x/net v0.21.0 // indirect + golang.org/x/oauth2 v0.21.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect diff --git a/go.sum b/go.sum index 1cfb36f..a5a451c 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI= github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI= @@ -80,6 +82,7 @@ github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaW github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -203,6 +206,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= +golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/relay/channel/scholarai/relay-scholar.go b/relay/channel/scholarai/relay-scholarai.go similarity index 100% rename from relay/channel/scholarai/relay-scholar.go rename to relay/channel/scholarai/relay-scholarai.go diff --git a/relay/channel/vertex-claude/adaptor.go b/relay/channel/vertex-claude/adaptor.go new file mode 100644 index 0000000..720b279 --- /dev/null +++ b/relay/channel/vertex-claude/adaptor.go @@ -0,0 +1,83 @@ +package vertex_claude + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "strings" +) + +const ( + // LOCATION europe-west1 or us-east5 + LOCATION = "us-east5" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + parts := strings.SplitN(info.ApiKey, "|", 2) + if len(parts) != 2 { + return "", fmt.Errorf("invalid api key: %s", info.ApiKey) + } + projectId := strings.TrimSpace(parts[0]) + model, err := getRedirectModel(info.UpstreamModelName) + if err != nil { + return "", err + } + return fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", LOCATION, projectId, LOCATION, model), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + parts := strings.SplitN(info.ApiKey, "|", 2) + if len(parts) != 2 { + return fmt.Errorf("invalid api key: %s", info.ApiKey) + } + json := strings.TrimSpace(parts[1]) + accessToken, err := getAccessToken(json) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+accessToken) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, _ int, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return requestOpenAI2VertexClaude(*request) +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + err, usage = vertexClaudeStreamHandler(c, resp) + } else { + err, usage = vertexClaudeHandler(c, resp) + } + return +} + +func (a *Adaptor) GetModelList() (models []string) { + for n := range modelIdMap { + models = append(models, n) + } + return +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/vertex-claude/constants.go b/relay/channel/vertex-claude/constants.go new file mode 100644 index 0000000..45fcbad --- /dev/null +++ b/relay/channel/vertex-claude/constants.go @@ -0,0 +1,7 @@ +package vertex_claude + +var modelIdMap = map[string]string{ + "claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620", +} + +var ChannelName = "vertex-claude" diff --git a/relay/channel/vertex-claude/dto.go b/relay/channel/vertex-claude/dto.go new file mode 100644 index 0000000..2298c51 --- /dev/null +++ b/relay/channel/vertex-claude/dto.go @@ -0,0 +1,12 @@ +package vertex_claude + +import "one-api/relay/channel/claude" + +type VertexClaudeRequest struct { + // vertex-2023-10-16 + AnthropicVersion string `json:"anthropic_version"` + System string `json:"system,omitempty"` + Messages []claude.ClaudeMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` +} diff --git a/relay/channel/vertex-claude/relay-vertex-claude.go b/relay/channel/vertex-claude/relay-vertex-claude.go new file mode 100644 index 0000000..52653f0 --- /dev/null +++ b/relay/channel/vertex-claude/relay-vertex-claude.go @@ -0,0 +1,254 @@ +package vertex_claude + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaymodel "one-api/dto" + "one-api/relay/channel/claude" + "one-api/service" + "strings" + "sync" + "time" +) + +var accessTokenMap sync.Map + +func getAccessToken(json string) (string, error) { + data, ok := accessTokenMap.Load(json) + if ok { + token := data.(oauth2.Token) + if time.Now().Before(token.Expiry) { + return token.AccessToken, nil + } + } + creds, err := google.CredentialsFromJSON(context.Background(), []byte(json), "https://www.googleapis.com/auth/cloud-platform") + if err != nil { + return "", err + } + token, err := creds.TokenSource.Token() + if err != nil { + return "", err + } + accessTokenMap.Store(json, *token) + return token.AccessToken, nil +} + +func getRedirectModel(requestModel string) (string, error) { + if model, ok := modelIdMap[requestModel]; ok { + return model, nil + } + return "", errors.Errorf("model %s not found", requestModel) +} + +func requestOpenAI2VertexClaude(request dto.GeneralOpenAIRequest) (*VertexClaudeRequest, error) { + vertexClaudeRequest := VertexClaudeRequest{ + AnthropicVersion: "vertex-2023-10-16", + Stream: request.Stream, + } + if vertexClaudeRequest.MaxTokens == 0 { + vertexClaudeRequest.MaxTokens = 4096 + } + formatMessages := make([]dto.Message, 0) + var lastMessage *dto.Message + for i, message := range request.Messages { + if message.Role == "" { + request.Messages[i].Role = "user" + } + fmtMessage := dto.Message{ + Role: message.Role, + Content: message.Content, + } + if lastMessage != nil && lastMessage.Role == message.Role { + if lastMessage.IsStringContent() && message.IsStringContent() { + content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\"")) + fmtMessage.Content = content + // delete last message + formatMessages = formatMessages[:len(formatMessages)-1] + } + } + if fmtMessage.Content == nil { + content, _ := json.Marshal("...") + fmtMessage.Content = content + } + formatMessages = append(formatMessages, fmtMessage) + lastMessage = &request.Messages[i] + } + claudeMessages := make([]claude.ClaudeMessage, 0) + for _, message := range formatMessages { + if message.Role == "system" { + if message.IsStringContent() { + vertexClaudeRequest.System = message.StringContent() + } else { + contents := message.ParseContent() + content := "" + for _, ctx := range contents { + if ctx.Type == "text" { + content += ctx.Text + } + } + vertexClaudeRequest.System = content + } + } else { + claudeMessage := claude.ClaudeMessage{ + Role: message.Role, + } + if message.IsStringContent() { + claudeMessage.Content = message.StringContent() + } else { + claudeMediaMessages := make([]claude.ClaudeMediaMessage, 0) + for _, mediaMessage := range message.ParseContent() { + claudeMediaMessage := claude.ClaudeMediaMessage{ + Type: mediaMessage.Type, + } + if mediaMessage.Type == "text" { + claudeMediaMessage.Text = mediaMessage.Text + } else { + imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl) + claudeMediaMessage.Type = "image" + claudeMediaMessage.Source = &claude.ClaudeMessageSource{ + Type: "base64", + } + // 判断是否是url + if strings.HasPrefix(imageUrl.Url, "http") { + // 是url,获取图片的类型和base64编码的数据 + mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url) + claudeMediaMessage.Source.MediaType = mimeType + claudeMediaMessage.Source.Data = data + } else { + _, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url) + if err != nil { + return nil, err + } + claudeMediaMessage.Source.MediaType = "image/" + format + claudeMediaMessage.Source.Data = base64String + } + } + claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage) + } + claudeMessage.Content = claudeMediaMessages + } + claudeMessages = append(claudeMessages, claudeMessage) + } + } + vertexClaudeRequest.Messages = claudeMessages + return &vertexClaudeRequest, nil +} + +func vertexClaudeHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var claudeResponse claude.ClaudeResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &claudeResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + openaiResp := claude.ResponseClaude2OpenAI(claude.RequestModeMessage, &claudeResponse) + usage := relaymodel.Usage{ + PromptTokens: claudeResponse.Usage.InputTokens, + CompletionTokens: claudeResponse.Usage.OutputTokens, + TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens, + } + openaiResp.Usage = usage + c.JSON(http.StatusOK, openaiResp) + return nil, &usage +} + +func vertexClaudeStreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 { // ignore blank line or wrong format + continue + } + if data[:5] != "data:" { + continue + } + data = data[5:] + dataChan <- data + } + stopChan <- true + }() + var id string + var model string + createdTime := common.GetTimestamp() + var usage relaymodel.Usage + service.SetEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + claudeResp := new(claude.ClaudeResponse) + err := json.Unmarshal([]byte(data), &claudeResp) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + response, claudeUsage := claude.StreamResponseClaude2OpenAI(claude.RequestModeMessage, claudeResp) + + if claudeUsage != nil { + usage.PromptTokens += claudeUsage.InputTokens + usage.CompletionTokens += claudeUsage.OutputTokens + } + + if response == nil { + return true + } + + if response.Id != "" { + id = response.Id + } + if response.Model != "" { + model = response.Model + } + response.Created = createdTime + response.Id = id + response.Model = model + + jsonStr, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage +} diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 2355ccf..3e375d9 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -21,6 +21,7 @@ const ( APITypeAws APITypeCohere APITypeScholarAI + APITypeVertexClaude APITypeDummy // this one is only for count, do not add any channel after this ) @@ -59,6 +60,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeCohere case common.ChannelTypeScholarAI: apiType = APITypeScholarAI + case common.ChannelTypeVertexClaude: + apiType = APITypeVertexClaude } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 9188dd7..2634a5d 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -14,6 +14,7 @@ import ( "one-api/relay/channel/perplexity" "one-api/relay/channel/scholarai" "one-api/relay/channel/tencent" + vertex_claude "one-api/relay/channel/vertex-claude" "one-api/relay/channel/xunfei" "one-api/relay/channel/zhipu" "one-api/relay/channel/zhipu_4v" @@ -54,6 +55,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &cohere.Adaptor{} case constant.APITypeScholarAI: return &scholarai.Adaptor{} + case constant.APITypeVertexClaude: + return &vertex_claude.Adaptor{} } return nil } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 4b0f041..1fd1d43 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -112,6 +112,13 @@ export const CHANNEL_OPTIONS = [ color: 'blue', label: '豆包' }, + { + key: 38, + text: 'Vertex Claude', + value: 38, + color: 'blue', + label: 'Vertex Claude' + }, { key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' }, { key: 22,