diff --git a/common/config/config.go b/common/config/config.go index 86fd62b9..98d10a50 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "strconv" + "strings" "sync" "time" @@ -65,9 +66,9 @@ var EmailDomainWhitelist = []string{ "foxmail.com", } -var DebugEnabled = os.Getenv("DEBUG") == "true" -var DebugSQLEnabled = os.Getenv("DEBUG_SQL") == "true" -var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" +var DebugEnabled = strings.ToLower(os.Getenv("DEBUG")) == "true" +var DebugSQLEnabled = strings.ToLower(os.Getenv("DEBUG_SQL")) == "true" +var MemoryCacheEnabled = strings.ToLower(os.Getenv("MEMORY_CACHE_ENABLED")) == "true" var LogConsumeEnabled = true diff --git a/common/constants.go b/common/constants.go index 87221b61..e4466a57 100644 --- a/common/constants.go +++ b/common/constants.go @@ -4,3 +4,11 @@ import "time" var StartTime = time.Now().Unix() // unit: second var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change + +var ( + // CtxKeyChannel is the key to store the channel in the context + CtxKeyChannel string = "channel_docu" + CtxKeyRequestModel string = "request_model" + CtxKeyRawRequest string = "raw_request" + CtxKeyConvertedRequest string = "converted_request" +) diff --git a/go.mod b/go.mod index 39f9e295..eae5b749 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,9 @@ go 1.21 require ( github.com/Laisky/errors/v2 v2.0.1 github.com/Laisky/go-utils/v4 v4.9.1 + github.com/aws/aws-sdk-go-v2 v1.26.1 + github.com/aws/aws-sdk-go-v2/config v1.27.11 + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 github.com/gin-contrib/cors v1.7.0 github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/sessions v0.0.5 @@ -33,6 +36,18 @@ require ( github.com/Laisky/go-chaining v0.0.0-20180507092046-43dcdc5a21be // indirect github.com/Laisky/graphql v1.0.6 // indirect github.com/Laisky/zap v1.27.0 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.11 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 // indirect + github.com/aws/smithy-go v1.20.2 // indirect github.com/bytedance/sonic v1.11.2 // indirect github.com/cespare/xxhash v1.1.0 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect diff --git a/go.sum b/go.sum index bd2decbc..70890f9d 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,36 @@ github.com/Laisky/zap v1.27.0 h1:NEtFniRXOKUkEf9//FUiBkSIdgbB4V1H3khA/jmPZx4= github.com/Laisky/zap v1.27.0/go.mod h1:HABqM5YDQlPq8w+Pmp9h/x9F6Vy+3oHBLP+2+pBoaJw= github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA= +github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg= +github.com/aws/aws-sdk-go-v2/config v1.27.11 h1:f47rANd2LQEYHda2ddSCKYId18/8BhSRM4BULGmfgNA= +github.com/aws/aws-sdk-go-v2/config v1.27.11/go.mod h1:SMsV78RIOYdve1vf36z8LmnszlRWkwMQtomCAI0/mIE= +github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs= +github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 h1:FVJ0r5XTHSmIHJV6KuDmdYhEpvlHpiSd38RQWhut5J4= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1/go.mod h1:zusuAeqezXzAB24LGuzuekqMAEgWkVYukBec3kr3jUg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1xUsUr3I8cHps0G+XM3WWU16lP6yG8qu1GAZAs= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 h1:ogRAwT1/gxJBcSWDMZlgyFUM962F51A5CRhDLbxLdmo= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7/go.mod h1:YCsIZhXfRPLFFCl5xxY+1T9RKzOKjCut+28JSX2DnAk= +github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 h1:vN8hEbpRnL7+Hopy9dzmRle1xmDc7o8tmY0klsr175w= +github.com/aws/aws-sdk-go-v2/service/sso v1.20.5/go.mod h1:qGzynb/msuZIE8I75DVRCUXw3o3ZyBmUvMwQ2t/BrGM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 h1:Jux+gDDyi1Lruk+KHF91tK2KCuY61kzoCpvtvJJBtOE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4/go.mod h1:mUYPBhaF2lGiukDEjJX2BLRRKTmoUSitGDUgM4tRxak= +github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 h1:cwIxeBttqPN3qkaAjcEcsh8NYr8n2HZPkcKgPAi1phU= +github.com/aws/aws-sdk-go-v2/service/sts v1.28.6/go.mod h1:FZf1/nKNEkHdGGJP/cI2MoIMquumuRK6ol3QQJNDxmw= +github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= +github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/brianvoe/gofakeit/v6 v6.23.2 h1:lVde18uhad5wII/f5RMVFLtdQNE0HaGFuBUXmYKk8i8= github.com/brianvoe/gofakeit/v6 v6.23.2/go.mod h1:Ow6qC71xtwm79anlwKRlWZW6zVq9D2XHE4QSSMP/rU8= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= diff --git a/middleware/distributor.go b/middleware/distributor.go index 2d34d9f0..7abb0be2 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" @@ -73,6 +74,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode } logger.Info(c.Request.Context(), fmt.Sprintf("set channel %s ratio to %f", channel.Name, minimalRatio)) c.Set("channel_ratio", minimalRatio) + c.Set(common.CtxKeyChannel, channel) c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) diff --git a/relay/adaptor.go b/relay/adaptor.go index ef549b5b..588cb4c2 100644 --- a/relay/adaptor.go +++ b/relay/adaptor.go @@ -4,6 +4,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/aiproxy" "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/adaptor/aws" "github.com/songquanpeng/one-api/relay/adaptor/gemini" "github.com/songquanpeng/one-api/relay/adaptor/ollama" "github.com/songquanpeng/one-api/relay/adaptor/openai" @@ -19,6 +20,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor { // return &ali.Adaptor{} case apitype.Anthropic: return &anthropic.Adaptor{} + case apitype.AwsClaude: + return &aws.Adaptor{} // case apitype.Baidu: // return &baidu.Adaptor{} case apitype.Gemini: diff --git a/relay/adaptor/anthropic/main.go b/relay/adaptor/anthropic/main.go index aec327fe..79e55437 100644 --- a/relay/adaptor/anthropic/main.go +++ b/relay/adaptor/anthropic/main.go @@ -92,7 +92,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { } // https://docs.anthropic.com/claude/reference/messages-streaming -func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { +func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { var response *Response var responseText string var stopReason string @@ -130,7 +130,7 @@ func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo return &openaiResponse, response } -func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { +func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { var responseText string if len(claudeResponse.Content) > 0 { responseText = claudeResponse.Content[0].Text @@ -201,7 +201,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC logger.SysError("error unmarshalling stream response: " + err.Error()) return true } - response, meta := streamResponseClaude2OpenAI(&claudeResponse) + response, meta := StreamResponseClaude2OpenAI(&claudeResponse) if meta != nil { usage.PromptTokens += meta.Usage.InputTokens usage.CompletionTokens += meta.Usage.OutputTokens @@ -256,7 +256,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st StatusCode: resp.StatusCode, }, nil } - fullTextResponse := responseClaude2OpenAI(&claudeResponse) + fullTextResponse := ResponseClaude2OpenAI(&claudeResponse) fullTextResponse.Model = modelName usage := model.Usage{ PromptTokens: claudeResponse.Usage.InputTokens, diff --git a/relay/adaptor/aws/adapter.go b/relay/adaptor/aws/adapter.go new file mode 100644 index 00000000..a09a40ce --- /dev/null +++ b/relay/adaptor/aws/adapter.go @@ -0,0 +1,78 @@ +package aws + +import ( + "io" + "net/http" + + "github.com/Laisky/errors/v2" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +var _ adaptor.Adaptor = new(Adaptor) + +type Adaptor struct { +} + +func (a *Adaptor) Init(meta *meta.Meta) { + +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return "", nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + claudeReq := anthropic.ConvertRequest(*request) + c.Set(common.CtxKeyRequestModel, request.Model) + c.Set(common.CtxKeyRawRequest, request) + c.Set(common.CtxKeyConvertedRequest, claudeReq) + return claudeReq, nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } + return +} + +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html +var ModelList = []string{ + "claude-3-haiku-20240307", + "claude-3-sonnet-20240229", + "claude-3-opus-20240229", +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "aws" +} diff --git a/relay/adaptor/aws/main.go b/relay/adaptor/aws/main.go new file mode 100644 index 00000000..0c0643ed --- /dev/null +++ b/relay/adaptor/aws/main.go @@ -0,0 +1,233 @@ +// Package aws provides the AWS adaptor for the relay service. +package aws + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + "github.com/gin-gonic/gin" + "github.com/jinzhu/copier" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +func newAwsClient(channel *model.Channel) (*bedrockruntime.Client, error) { + ks := strings.Split(channel.Key, "\n") + if len(ks) != 2 { + return nil, errors.New("invalid key") + } + ak, sk := ks[0], ks[1] + + client := bedrockruntime.New(bedrockruntime.Options{ + Region: *channel.BaseURL, + Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")), + }) + + return client, nil +} + +func wrapErr(err error) *relaymodel.ErrorWithStatusCode { + return &relaymodel.ErrorWithStatusCode{ + StatusCode: http.StatusInternalServerError, + Error: relaymodel.Error{ + Message: fmt.Sprintf("%+v", err), + }, + } +} + +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html +func awsModelID(requestModel string) (string, error) { + switch requestModel { + case "claude-instant-1.2": + return "anthropic.claude-instant-v1", nil + case "claude-2.0": + return "anthropic.claude-v2", nil + case "claude-2.1": + return "anthropic.claude-v2:1", nil + case "claude-3-sonnet-20240229": + return "anthropic.claude-3-sonnet-20240229-v1:0", nil + case "claude-3-opus-20240229": + return "anthropic.claude-3-opus-20240229-v1:0", nil + case "claude-3-haiku-20240307": + return "anthropic.claude-3-haiku-20240307-v1:0", nil + default: + return "", errors.Errorf("unknown model: %s", requestModel) + } +} + +func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { + var channel *model.Channel + if channeli, ok := c.Get(common.CtxKeyChannel); !ok { + return wrapErr(errors.New("channel not found")), nil + } else { + channel = channeli.(*model.Channel) + } + + awsCli, err := newAwsClient(channel) + if err != nil { + return wrapErr(errors.Wrap(err, "newAwsClient")), nil + } + + awsModelId, err := awsModelID(channel.Models) + if err != nil { + return wrapErr(errors.Wrap(err, "awsModelID")), nil + } + + awsReq := &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + claudeReqi, ok := c.Get(common.CtxKeyConvertedRequest) + if !ok { + return wrapErr(errors.New("request not found")), nil + } + claudeReq := claudeReqi.(*anthropic.Request) + awsClaudeReq := &Request{ + AnthropicVersion: "bedrock-2023-05-31", + } + if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { + return wrapErr(errors.Wrap(err, "copy request")), nil + } + + awsReq.Body, err = json.Marshal(awsClaudeReq) + if err != nil { + return wrapErr(errors.Wrap(err, "marshal request")), nil + } + + awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) + if err != nil { + return wrapErr(errors.Wrap(err, "InvokeModel")), nil + } + + claudeResponse := new(anthropic.Response) + err = json.Unmarshal(awsResp.Body, claudeResponse) + if err != nil { + return wrapErr(errors.Wrap(err, "unmarshal response")), nil + } + + openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse) + openaiResp.Model = modelName + usage := relaymodel.Usage{ + PromptTokens: claudeResponse.Usage.InputTokens, + CompletionTokens: claudeResponse.Usage.OutputTokens, + TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens, + } + openaiResp.Usage = usage + + c.JSON(http.StatusOK, openaiResp) + return nil, &usage +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { + createdTime := helper.GetTimestamp() + + var channel *model.Channel + if channeli, ok := c.Get(common.CtxKeyChannel); !ok { + return wrapErr(errors.New("channel not found")), nil + } else { + channel = channeli.(*model.Channel) + } + + awsCli, err := newAwsClient(channel) + if err != nil { + return wrapErr(errors.Wrap(err, "newAwsClient")), nil + } + + awsModelId, err := awsModelID(channel.Models) + if err != nil { + return wrapErr(errors.Wrap(err, "awsModelID")), nil + } + + awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + claudeReqi, ok := c.Get(common.CtxKeyConvertedRequest) + if !ok { + return wrapErr(errors.New("request not found")), nil + } + claudeReq := claudeReqi.(*anthropic.Request) + + awsClaudeReq := &Request{ + AnthropicVersion: "bedrock-2023-05-31", + } + if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { + return wrapErr(errors.Wrap(err, "copy request")), nil + } + awsReq.Body, err = json.Marshal(awsClaudeReq) + if err != nil { + return wrapErr(errors.Wrap(err, "marshal request")), nil + } + + awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) + if err != nil { + return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil + } + stream := awsResp.GetStream() + defer stream.Close() + + var usage relaymodel.Usage + var id string + c.Stream(func(w io.Writer) bool { + event, ok := <-stream.Events() + if !ok { + return false + } + + switch v := event.(type) { + case *types.ResponseStreamMemberChunk: + claudeResp := new(anthropic.StreamResponse) + err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return false + } + + response, meta := anthropic.StreamResponseClaude2OpenAI(claudeResp) + if meta != nil { + usage.PromptTokens += meta.Usage.InputTokens + usage.CompletionTokens += meta.Usage.OutputTokens + id = fmt.Sprintf("chatcmpl-%s", meta.Id) + return true + } + if response == nil { + return true + } + response.Id = id + response.Model = c.GetString("original_model") + response.Created = createdTime + jsonStr, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case *types.UnknownUnionMember: + fmt.Println("unknown tag:", v.Tag) + return false + default: + fmt.Println("union is nil or unknown type") + return false + } + }) + + return nil, &usage +} diff --git a/relay/adaptor/aws/model.go b/relay/adaptor/aws/model.go new file mode 100644 index 00000000..bcbfb584 --- /dev/null +++ b/relay/adaptor/aws/model.go @@ -0,0 +1,17 @@ +package aws + +import "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + +// Request is the request to AWS Claude +// +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +type Request struct { + // AnthropicVersion should be "bedrock-2023-05-31" + AnthropicVersion string `json:"anthropic_version"` + Messages []anthropic.Message `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` +} diff --git a/relay/apitype/define.go b/relay/apitype/define.go index 82d32a50..3760ba00 100644 --- a/relay/apitype/define.go +++ b/relay/apitype/define.go @@ -12,6 +12,7 @@ const ( Tencent Gemini Ollama + AwsClaude Dummy // this one is only for count, do not add any channel after this ) diff --git a/relay/channeltype/define.go b/relay/channeltype/define.go index 80027a80..faa0d443 100644 --- a/relay/channeltype/define.go +++ b/relay/channeltype/define.go @@ -34,6 +34,7 @@ const ( Ollama LingYiWanWu StepFun + AwsClaude Dummy ) diff --git a/relay/channeltype/helper.go b/relay/channeltype/helper.go index 01c2918c..89e40142 100644 --- a/relay/channeltype/helper.go +++ b/relay/channeltype/helper.go @@ -25,6 +25,9 @@ func ToAPIType(channelType int) int { apiType = apitype.Gemini case Ollama: apiType = apitype.Ollama + case AwsClaude: + apiType = apitype.AwsClaude } + return apiType } diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go index eec59116..9ac29f30 100644 --- a/relay/channeltype/url.go +++ b/relay/channeltype/url.go @@ -34,6 +34,7 @@ var ChannelBaseURLs = []string{ "http://localhost:11434", // 30 "https://api.lingyiwanwu.com", // 31 "https://api.stepfun.com", // 32 + "", // 33 } func init() { diff --git a/relay/controller/image.go b/relay/controller/image.go index ea3e32a0..4079e450 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -109,9 +109,10 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } defer func(ctx context.Context) { - if resp.StatusCode != http.StatusOK { + if resp != nil && resp.StatusCode != http.StatusOK { return } + err := model.PostConsumeTokenQuota(meta.TokenId, quota) if err != nil { logger.SysError("error consuming token remain quota: " + err.Error()) diff --git a/relay/controller/text.go b/relay/controller/text.go index beda2822..d3ae6644 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -94,14 +94,16 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json") - if errorHappened { - billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) - logger.Error(ctx, fmt.Sprintf("relay text [%d] <- %q %q", - resp.StatusCode, resp.Request.URL.String(), string(requestBodyBytes))) - return RelayErrorHandler(resp) + if resp != nil { + errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json") + if errorHappened { + billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) + logger.Error(ctx, fmt.Sprintf("relay text [%d] <- %q %q", + resp.StatusCode, resp.Request.URL.String(), string(requestBodyBytes))) + return RelayErrorHandler(resp) + } + meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") } - meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") // do response usage, respErr := adaptor.DoResponse(c, resp, meta) diff --git a/web/default/package.json b/web/default/package.json index 438f020c..ba45011f 100644 --- a/web/default/package.json +++ b/web/default/package.json @@ -18,7 +18,7 @@ }, "scripts": { "start": "react-scripts start", - "build": "react-scripts build && mv -f build ../build/default", + "build": "react-scripts build && rm -rf ../build/default && mv -f build ../build/default", "test": "react-scripts test", "eject": "react-scripts eject" }, diff --git a/web/default/src/constants/channel.constants.js b/web/default/src/constants/channel.constants.js index 7535b666..82fc7d44 100644 --- a/web/default/src/constants/channel.constants.js +++ b/web/default/src/constants/channel.constants.js @@ -1,6 +1,7 @@ export const CHANNEL_OPTIONS = [ { key: 1, text: 'OpenAI', value: 1, color: 'green' }, { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, + { key: 33, text: 'AWS Claude', value: 33, color: 'black' }, { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, @@ -31,4 +32,4 @@ export const CHANNEL_OPTIONS = [ { key: 9, text: '代理:AI.LS', value: 9, color: 'yellow' }, { key: 12, text: '代理:API2GPT', value: 12, color: 'blue' }, { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' } -]; \ No newline at end of file +]; diff --git a/web/default/src/pages/Channel/EditChannel.js b/web/default/src/pages/Channel/EditChannel.js index 203cd714..2880ac98 100644 --- a/web/default/src/pages/Channel/EditChannel.js +++ b/web/default/src/pages/Channel/EditChannel.js @@ -144,6 +144,13 @@ const EditChannel = () => { }, []); const submit = async () => { + // some provider as AWS need both AK and SK rather than a single key, + // so we need to combine them into a single key to achieve the best compatibility. + if (inputs.ak && inputs.sk) { + console.log(`combine ak ${inputs.ak} and sk ${inputs.sk}`, inputs.ak, inputs.sk); + inputs.key = `${inputs.ak}\n${inputs.sk}`; + } + if (!isEdit && (inputs.name === '' || inputs.key === '')) { showInfo('请填写渠道名称和渠道密钥!'); return; @@ -392,7 +399,40 @@ const EditChannel = () => { /> { - batch ? + inputs.type === 33 && ( + + + + + + ) + } + { + inputs.type !== 33 && (batch ? { value={inputs.key} autoComplete='new-password' /> - + ) } { - !isEdit && ( + inputs.type !== 33 && !isEdit && ( { ) } { - inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && ( + inputs.type !== 3 && inputs.type !== 33 && inputs.type !== 8 && inputs.type !== 22 && (